1use std::{
2 fmt::Display,
3 ops::{Add, AddAssign},
4};
5
6use approx::{AbsDiffEq, RelativeEq, relative_eq};
7use itertools::Itertools;
8use ndarray::prelude::*;
9use rand::Rng;
10use rand_distr::{Distribution, weighted::WeightedIndex};
11use serde::{
12 Deserialize, Deserializer, Serialize, Serializer,
13 de::{MapAccess, Visitor},
14 ser::SerializeMap,
15};
16
17use crate::{
18 datasets::{CatSample, CatType},
19 impl_json_io,
20 models::{CPD, CatPhi, Labelled, Phi},
21 types::{EPSILON, Labels, Set, States},
22 utils::MI,
23};
24
25#[derive(Clone, Debug)]
27pub struct CatCPDS {
28 n_xz: Array2<f64>,
30 n: f64,
32}
33
34impl CatCPDS {
35 #[inline]
47 pub fn new(n_xz: Array2<f64>, n: f64) -> Self {
48 assert!(
50 n_xz.iter().all(|&x| x.is_finite() && x >= 0.),
51 "Counts must be finite and non-negative."
52 );
53 assert!(
54 n.is_finite() && n >= 0.,
55 "Sample size must be finite and non-negative."
56 );
57
58 Self { n_xz, n }
59 }
60
61 #[inline]
68 pub const fn sample_conditional_counts(&self) -> &Array2<f64> {
69 &self.n_xz
70 }
71
72 #[inline]
79 pub const fn sample_size(&self) -> f64 {
80 self.n
81 }
82}
83
84impl AddAssign for CatCPDS {
85 fn add_assign(&mut self, other: Self) {
86 self.n_xz += &other.n_xz;
88 self.n += other.n;
90 }
91}
92
93impl Add for CatCPDS {
94 type Output = Self;
95
96 #[inline]
97 fn add(mut self, rhs: Self) -> Self::Output {
98 self += rhs;
99 self
100 }
101}
102
103impl Serialize for CatCPDS {
104 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105 where
106 S: Serializer,
107 {
108 let mut map = serializer.serialize_map(Some(2))?;
110 let sample_conditional_counts: Vec<Vec<f64>> =
112 self.n_xz.rows().into_iter().map(|x| x.to_vec()).collect();
113 map.serialize_entry("sample_conditional_counts", &sample_conditional_counts)?;
115 map.serialize_entry("sample_size", &self.n)?;
117 map.end()
119 }
120}
121
122impl<'de> Deserialize<'de> for CatCPDS {
123 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
124 where
125 D: Deserializer<'de>,
126 {
127 #[derive(Deserialize)]
128 #[serde(field_identifier, rename_all = "snake_case")]
129 enum Field {
130 SampleConditionalCounts,
131 SampleSize,
132 }
133
134 struct CatCPDSVisitor;
135
136 impl<'de> Visitor<'de> for CatCPDSVisitor {
137 type Value = CatCPDS;
138
139 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
140 formatter.write_str("struct CatCPDS")
141 }
142
143 fn visit_map<V>(self, mut map: V) -> Result<CatCPDS, V::Error>
144 where
145 V: MapAccess<'de>,
146 {
147 use serde::de::Error as E;
148
149 let mut sample_conditional_counts = None;
151 let mut sample_size = None;
152
153 while let Some(key) = map.next_key()? {
155 match key {
156 Field::SampleConditionalCounts => {
157 if sample_conditional_counts.is_some() {
158 return Err(E::duplicate_field("sample_conditional_counts"));
159 }
160 sample_conditional_counts = Some(map.next_value()?);
161 }
162 Field::SampleSize => {
163 if sample_size.is_some() {
164 return Err(E::duplicate_field("sample_size"));
165 }
166 sample_size = Some(map.next_value()?);
167 }
168 }
169 }
170
171 let sample_conditional_counts = sample_conditional_counts
173 .ok_or_else(|| E::missing_field("sample_conditional_counts"))?;
174 let sample_size = sample_size.ok_or_else(|| E::missing_field("sample_size"))?;
175
176 let sample_conditional_counts = {
178 let counts: Vec<Vec<f64>> = sample_conditional_counts;
179 let shape = (counts.len(), counts[0].len());
180 Array::from_iter(counts.into_iter().flatten())
181 .into_shape_with_order(shape)
182 .map_err(|_| E::custom("Invalid sample conditional counts shape"))?
183 };
184
185 Ok(CatCPDS::new(sample_conditional_counts, sample_size))
186 }
187 }
188
189 const FIELDS: &[&str] = &["sample_conditional_counts", "sample_size"];
190
191 deserializer.deserialize_struct("CatCPDS", FIELDS, CatCPDSVisitor)
192 }
193}
194
195#[derive(Clone, Debug)]
197pub struct CatCPD {
198 labels: Labels,
200 states: States,
201 shape: Array1<usize>,
202 multi_index: MI,
203 conditioning_labels: Labels,
205 conditioning_states: States,
206 conditioning_shape: Array1<usize>,
207 conditioning_multi_index: MI,
208 parameters: Array2<f64>,
210 parameters_size: usize,
211 sample_statistics: Option<CatCPDS>,
213 sample_log_likelihood: Option<f64>,
214}
215
216impl CatCPD {
217 pub fn new(states: States, conditioning_states: States, parameters: Array2<f64>) -> Self {
237 let labels: Set<_> = states.keys().cloned().collect();
239 let conditioning_labels: Set<_> = conditioning_states.keys().cloned().collect();
241
242 assert!(
244 labels.is_disjoint(&conditioning_labels),
245 "Labels and conditioning labels must be disjoint."
246 );
247
248 let shape = Array::from_iter(states.values().map(Set::len));
250
251 assert!(
253 parameters.is_empty() || parameters.ncols() == shape.product(),
254 "Product of the number of states must match the number of columns: \n\
255 \t expected: parameters.ncols() == {} , \n\
256 \t found: parameters.ncols() == {} .",
257 shape.product(),
258 parameters.ncols(),
259 );
260
261 let conditioning_shape = Array::from_iter(conditioning_states.values().map(Set::len));
263
264 assert!(
266 parameters.is_empty() || parameters.nrows() == conditioning_shape.product(),
267 "Product of the number of conditioning states must match the number of rows: \n\
268 \t expected: parameters.nrows() == {} , \n\
269 \t found: parameters.nrows() == {} .",
270 conditioning_shape.product(),
271 parameters.nrows(),
272 );
273
274 parameters
276 .sum_axis(Axis(1))
277 .iter()
278 .enumerate()
279 .for_each(|(i, &x)| {
280 if !relative_eq!(x, 1.0, epsilon = EPSILON) {
281 panic!("Failed to sum probability to one: {}.", parameters.row(i));
282 }
283 });
284
285 let mut parameters = parameters;
287
288 let mut labels = labels;
290 let mut states = states;
291 let mut shape = shape;
292
293 if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
295 let mut sorted_states_idx: Vec<_> = states.values().multi_cartesian_product().collect();
297 let mut sorted_labels_idx: Vec<_> = (0..labels.len()).collect();
299 sorted_labels_idx.sort_by_key(|&i| &labels[i]);
301 sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
303 *sorted_states_idx = sorted_labels_idx
304 .iter()
305 .map(|&i| sorted_states_idx[i])
306 .collect();
307 });
308 let mut sorted_row_idx: Vec<_> = (0..parameters.ncols()).collect();
310 sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
312 states.sort_keys();
314 states.values_mut().for_each(Set::sort);
315 labels = states.keys().cloned().collect();
316 shape = states.values().map(|x| x.len()).collect();
317 let mut new_parameters = parameters.clone();
319 new_parameters
321 .columns_mut()
322 .into_iter()
323 .enumerate()
324 .for_each(|(i, mut new_parameters_col)| {
325 new_parameters_col.assign(¶meters.column(sorted_row_idx[i]));
327 });
328 parameters = new_parameters;
330 }
331
332 let labels = labels;
334 let states = states;
335 let shape = shape;
336
337 let mut conditioning_labels = conditioning_labels;
339 let mut conditioning_states = conditioning_states;
340 let mut conditioning_shape = conditioning_shape;
341
342 if !conditioning_states.keys().is_sorted()
344 || !conditioning_states.values().all(|x| x.iter().is_sorted())
345 {
346 let mut sorted_states_idx: Vec<_> = conditioning_states
348 .values()
349 .multi_cartesian_product()
350 .collect();
351 let mut sorted_labels_idx: Vec<_> = (0..conditioning_labels.len()).collect();
353 sorted_labels_idx.sort_by_key(|&i| &conditioning_labels[i]);
355 sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
357 *sorted_states_idx = sorted_labels_idx
358 .iter()
359 .map(|&i| sorted_states_idx[i])
360 .collect();
361 });
362 let mut sorted_row_idx: Vec<_> = (0..parameters.nrows()).collect();
364 sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
366 conditioning_states.sort_keys();
368 conditioning_states.values_mut().for_each(Set::sort);
369 conditioning_labels = conditioning_states.keys().cloned().collect();
370 conditioning_shape = conditioning_states.values().map(|x| x.len()).collect();
371 let mut new_parameters = parameters.clone();
373 new_parameters.rows_mut().into_iter().enumerate().for_each(
375 |(i, mut new_parameters_row)| {
376 new_parameters_row.assign(¶meters.row(sorted_row_idx[i]));
378 },
379 );
380 parameters = new_parameters;
382 }
383
384 let conditioning_labels = conditioning_labels;
386 let conditioning_states = conditioning_states;
387 let conditioning_shape = conditioning_shape;
388
389 let parameters = parameters;
391
392 debug_assert!(labels.iter().is_sorted(), "Labels must be sorted.");
394 debug_assert!(states.keys().is_sorted(), "Labels must be sorted.");
395 debug_assert!(
396 states.values().all(|x| x.iter().is_sorted()),
397 "States must be sorted."
398 );
399 debug_assert!(
400 conditioning_labels.iter().is_sorted(),
401 "Conditioning labels must be sorted."
402 );
403 debug_assert!(
404 conditioning_states.keys().is_sorted(),
405 "Conditioning labels must be sorted."
406 );
407 debug_assert!(
408 conditioning_states.values().all(|x| x.iter().is_sorted()),
409 "Conditioning states must be sorted."
410 );
411
412 let multi_index = MI::new(shape.clone());
414 let conditioning_multi_index = MI::new(conditioning_shape.clone());
416 let parameters_size = parameters.ncols().saturating_sub(1) * parameters.nrows();
418
419 Self {
420 labels,
421 states,
422 shape,
423 multi_index,
424 conditioning_labels,
425 conditioning_states,
426 conditioning_shape,
427 conditioning_multi_index,
428 parameters,
429 parameters_size,
430 sample_statistics: None,
431 sample_log_likelihood: None,
432 }
433 }
434
435 #[inline]
442 pub const fn states(&self) -> &States {
443 &self.states
444 }
445
446 #[inline]
453 pub const fn shape(&self) -> &Array1<usize> {
454 &self.shape
455 }
456
457 #[inline]
464 pub const fn multi_index(&self) -> &MI {
465 &self.multi_index
466 }
467
468 #[inline]
475 pub const fn conditioning_states(&self) -> &States {
476 &self.conditioning_states
477 }
478
479 #[inline]
486 pub const fn conditioning_shape(&self) -> &Array1<usize> {
487 &self.conditioning_shape
488 }
489
490 #[inline]
497 pub const fn conditioning_multi_index(&self) -> &MI {
498 &self.conditioning_multi_index
499 }
500
501 pub fn marginalize(&self, x: &Set<usize>, z: &Set<usize>) -> Self {
513 if x.is_empty() && z.is_empty() {
515 return self.clone();
516 }
517 let labels_x = self.labels();
519 let labels_z = self.conditioning_labels();
520 let not_x = (0..labels_x.len()).filter(|i| !x.contains(i)).collect();
522 let not_z = (0..labels_z.len()).filter(|i| !z.contains(i)).collect();
523 let phi = self.clone().into_phi();
525 let x = phi.indices_from(x, labels_x);
527 let z = phi.indices_from(z, labels_z);
528 let phi = phi.marginalize(&(&x | &z));
530 let not_x = phi.indices_from(¬_x, labels_x);
532 let not_z = phi.indices_from(¬_z, labels_z);
533 phi.into_cpd(¬_x, ¬_z)
535 }
536
537 pub fn with_optionals(
554 state: States,
555 conditioning_states: States,
556 parameters: Array2<f64>,
557 sample_statistics: Option<CatCPDS>,
558 sample_log_likelihood: Option<f64>,
559 ) -> Self {
560 if let Some(sample_statistics) = &sample_statistics {
561 let sample_conditional_counts = &sample_statistics.n_xz;
563 assert!(
565 sample_conditional_counts.shape() == parameters.shape(),
566 "Sample conditional counts must have the same shape as parameters: \n\
567 \t expected: sample_conditional_counts.shape() == {:?} , \n\
568 \t found: sample_conditional_counts.shape() == {:?} .",
569 parameters.shape(),
570 sample_conditional_counts.shape(),
571 );
572 }
573 if let Some(sample_log_likelihood) = &sample_log_likelihood {
575 assert!(
576 sample_log_likelihood.is_finite() && *sample_log_likelihood <= 0.,
577 "Sample log-likelihood must be finite and non-positive: \n\
578 \t expected: sample_ll <= 0 , \n\
579 \t found: sample_ll == {sample_log_likelihood} ."
580 );
581 }
582
583 let mut cpd = Self::new(state, conditioning_states, parameters);
585
586 cpd.sample_statistics = sample_statistics;
590 cpd.sample_log_likelihood = sample_log_likelihood;
591
592 cpd
593 }
594
595 #[inline]
607 pub fn from_phi(phi: CatPhi, x: &Set<usize>, z: &Set<usize>) -> Self {
608 phi.into_cpd(x, z)
609 }
610
611 #[inline]
622 pub fn into_phi(self) -> CatPhi {
623 CatPhi::from_cpd(self)
624 }
625}
626
627impl Labelled for CatCPD {
628 #[inline]
629 fn labels(&self) -> &Labels {
630 &self.labels
631 }
632}
633
634impl PartialEq for CatCPD {
635 fn eq(&self, other: &Self) -> bool {
636 self.labels.eq(&other.labels)
638 && self.states.eq(&other.states)
639 && self.shape.eq(&other.shape)
640 && self.conditioning_labels.eq(&other.conditioning_labels)
641 && self.conditioning_states.eq(&other.conditioning_states)
642 && self.conditioning_shape.eq(&other.conditioning_shape)
643 && self.multi_index.eq(&other.multi_index)
644 && self.parameters.eq(&other.parameters)
645 }
646}
647
648impl AbsDiffEq for CatCPD {
649 type Epsilon = f64;
650
651 fn default_epsilon() -> Self::Epsilon {
652 Self::Epsilon::default_epsilon()
653 }
654
655 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
656 self.labels.eq(&other.labels)
658 && self.states.eq(&other.states)
659 && self.shape.eq(&other.shape)
660 && self.conditioning_labels.eq(&other.conditioning_labels)
661 && self.conditioning_states.eq(&other.conditioning_states)
662 && self.conditioning_shape.eq(&other.conditioning_shape)
663 && self.multi_index.eq(&other.multi_index)
664 && self.parameters.abs_diff_eq(&other.parameters, epsilon)
665 }
666}
667
668impl RelativeEq for CatCPD {
669 fn default_max_relative() -> Self::Epsilon {
670 Self::Epsilon::default_max_relative()
671 }
672
673 fn relative_eq(
674 &self,
675 other: &Self,
676 epsilon: Self::Epsilon,
677 max_relative: Self::Epsilon,
678 ) -> bool {
679 self.labels.eq(&other.labels)
681 && self.states.eq(&other.states)
682 && self.shape.eq(&other.shape)
683 && self.conditioning_labels.eq(&other.conditioning_labels)
684 && self.conditioning_states.eq(&other.conditioning_states)
685 && self.conditioning_shape.eq(&other.conditioning_shape)
686 && self.multi_index.eq(&other.multi_index)
687 && self
688 .parameters
689 .relative_eq(&other.parameters, epsilon, max_relative)
690 }
691}
692
693impl CPD for CatCPD {
694 type Support = CatSample;
695 type Parameters = Array2<f64>;
696 type Statistics = CatCPDS;
697
698 #[inline]
699 fn conditioning_labels(&self) -> &Labels {
700 &self.conditioning_labels
701 }
702
703 #[inline]
704 fn parameters(&self) -> &Self::Parameters {
705 &self.parameters
706 }
707
708 #[inline]
709 fn parameters_size(&self) -> usize {
710 self.parameters_size
711 }
712
713 #[inline]
714 fn sample_statistics(&self) -> Option<&Self::Statistics> {
715 self.sample_statistics.as_ref()
716 }
717
718 #[inline]
719 fn sample_log_likelihood(&self) -> Option<f64> {
720 self.sample_log_likelihood
721 }
722
723 fn pf(&self, x: &Self::Support, z: &Self::Support) -> f64 {
724 let n = self.labels.len();
726 let m = self.conditioning_labels.len();
728
729 assert_eq!(
731 x.len(),
732 n,
733 "Vector X must match number of variables: \n\
734 \t expected: |X| == {} , \n\
735 \t found: |X| == {} .",
736 n,
737 x.len(),
738 );
739 assert_eq!(
741 z.len(),
742 m,
743 "Vector Z must match number of conditioning variables: \n\
744 \t expected: |Z| == {} , \n\
745 \t found: |Z| == {} .",
746 m,
747 z.len(),
748 );
749
750 if n == 0 {
752 return 1.;
753 }
754
755 let x = match n {
757 1 => x[0] as usize,
759 _ => {
761 let x = x.iter().map(|&x| x as usize);
763 self.multi_index.ravel(x)
765 }
766 };
767
768 let z = match m {
770 0 => 0,
772 1 => z[0] as usize,
774 _ => {
776 let z = z.iter().map(|&z| z as usize);
778 self.conditioning_multi_index.ravel(z)
780 }
781 };
782
783 self.parameters[[z, x]]
785 }
786
787 fn sample<R: Rng>(&self, rng: &mut R, z: &Self::Support) -> Self::Support {
788 let n = self.labels.len();
790 let m = self.conditioning_labels.len();
792
793 assert_eq!(
795 z.len(),
796 m,
797 "Vector Z must match number of conditioning variables: \n\
798 \t expected: |Z| == {} , \n\
799 \t found: |Z| == {} .",
800 m,
801 z.len(),
802 );
803
804 if n == 0 {
806 return array![];
807 }
808
809 let z = match m {
811 0 => 0,
813 1 => z[0] as usize,
815 _ => {
817 let z = z.iter().map(|&z| z as usize);
819 self.conditioning_multi_index.ravel(z)
821 }
822 };
823
824 let p = self.parameters.row(z);
826 let s = WeightedIndex::new(&p).unwrap();
828 let x = s.sample(rng);
830
831 match n {
833 1 => array![x as CatType],
835 _ => {
837 let x = self.multi_index.unravel(x);
839 let x = x.iter().map(|&x| x as CatType);
841 x.collect()
843 }
844 }
845 }
846}
847
848impl Display for CatCPD {
849 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
850 assert_eq!(self.labels().len(), 1);
852
853 let n = std::iter::once(&self.labels()[0])
855 .chain(&self.states()[0])
856 .chain(self.conditioning_labels())
857 .chain(self.conditioning_states().values().flatten())
858 .map(|x| x.len())
859 .max()
860 .unwrap_or(0)
861 .max(8);
862 let z = self.conditioning_shape().len();
864 let s = self.shape()[0];
866
867 let hline = "-".repeat((n + 3) * (z + s) + 1);
869 writeln!(f, "{hline}")?;
870
871 let header = std::iter::repeat_n("", z) .chain([self.labels()[0].as_str()]) .chain(std::iter::repeat_n("", s.saturating_sub(1))) .map(|x| format!("{x:n$}")) .join(" | ");
877 writeln!(f, "| {header} |")?;
878
879 let separator = std::iter::repeat_n("-".repeat(n), z + s).join(" | ");
881 writeln!(f, "| {separator} |")?;
882
883 let header = self
885 .conditioning_labels()
886 .iter()
887 .chain(&self.states()[0]) .map(|x| format!("{x:n$}")) .join(" | ");
890 writeln!(f, "| {header} |")?;
891 writeln!(f, "| {separator} |")?;
892
893 for (states, values) in self
895 .conditioning_states()
896 .values()
897 .multi_cartesian_product()
898 .zip(self.parameters().rows())
899 {
900 let states = states.iter().map(|x| format!("{x:n$}"));
902 let values = values.iter().map(|x| format!("{x:n$.6}"));
904 let states_values = states.chain(values).join(" | ");
906 writeln!(f, "| {states_values} |")?;
907 }
908
909 writeln!(f, "{hline}")?;
911
912 Ok(())
913 }
914}
915
916impl Serialize for CatCPD {
917 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
918 where
919 S: Serializer,
920 {
921 let mut size = 4;
923 size += self.sample_statistics.is_some() as usize;
925 size += self.sample_log_likelihood.is_some() as usize;
926 let mut map = serializer.serialize_map(Some(size))?;
928
929 map.serialize_entry("states", &self.states)?;
931 map.serialize_entry("conditioning_states", &self.conditioning_states)?;
933
934 let parameters: Vec<Vec<f64>> = self
936 .parameters
937 .rows()
938 .into_iter()
939 .map(|x| x.to_vec())
940 .collect();
941 map.serialize_entry("parameters", ¶meters)?;
943
944 if let Some(sample_statistics) = &self.sample_statistics {
946 map.serialize_entry("sample_statistics", sample_statistics)?;
947 }
948 if let Some(sample_log_likelihood) = &self.sample_log_likelihood {
950 map.serialize_entry("sample_log_likelihood", sample_log_likelihood)?;
951 }
952
953 map.serialize_entry("type", "catcpd")?;
955
956 map.end()
958 }
959}
960
961impl<'de> Deserialize<'de> for CatCPD {
962 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
963 where
964 D: Deserializer<'de>,
965 {
966 #[derive(Deserialize)]
967 #[serde(field_identifier, rename_all = "snake_case")]
968 enum Field {
969 States,
970 ConditioningStates,
971 Parameters,
972 SampleStatistics,
973 SampleLogLikelihood,
974 Type,
975 }
976
977 struct CatCPDVisitor;
978
979 impl<'de> Visitor<'de> for CatCPDVisitor {
980 type Value = CatCPD;
981
982 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
983 formatter.write_str("struct CatCPD")
984 }
985
986 fn visit_map<V>(self, mut map: V) -> Result<CatCPD, V::Error>
987 where
988 V: MapAccess<'de>,
989 {
990 use serde::de::Error as E;
991
992 let mut states = None;
994 let mut conditioning_states = None;
995 let mut parameters = None;
996 let mut sample_statistics = None;
997 let mut sample_log_likelihood = None;
998 let mut type_ = None;
999
1000 while let Some(key) = map.next_key()? {
1002 match key {
1003 Field::States => {
1004 if states.is_some() {
1005 return Err(E::duplicate_field("states"));
1006 }
1007 states = Some(map.next_value()?);
1008 }
1009 Field::ConditioningStates => {
1010 if conditioning_states.is_some() {
1011 return Err(E::duplicate_field("conditioning_states"));
1012 }
1013 conditioning_states = Some(map.next_value()?);
1014 }
1015 Field::Parameters => {
1016 if parameters.is_some() {
1017 return Err(E::duplicate_field("parameters"));
1018 }
1019 parameters = Some(map.next_value()?);
1020 }
1021 Field::SampleStatistics => {
1022 if sample_statistics.is_some() {
1023 return Err(E::duplicate_field("sample_statistics"));
1024 }
1025 sample_statistics = Some(map.next_value()?);
1026 }
1027 Field::SampleLogLikelihood => {
1028 if sample_log_likelihood.is_some() {
1029 return Err(E::duplicate_field("sample_log_likelihood"));
1030 }
1031 sample_log_likelihood = Some(map.next_value()?);
1032 }
1033 Field::Type => {
1034 if type_.is_some() {
1035 return Err(E::duplicate_field("type"));
1036 }
1037 type_ = Some(map.next_value()?);
1038 }
1039 }
1040 }
1041
1042 let states = states.ok_or_else(|| E::missing_field("states"))?;
1044 let conditioning_states =
1045 conditioning_states.ok_or_else(|| E::missing_field("conditioning_states"))?;
1046 let parameters = parameters.ok_or_else(|| E::missing_field("parameters"))?;
1047
1048 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
1050 assert_eq!(type_, "catcpd", "Invalid type for CatCPD.");
1051
1052 let parameters: Vec<Vec<f64>> = parameters;
1054 let shape = (parameters.len(), parameters[0].len());
1055 let parameters = Array::from_iter(parameters.into_iter().flatten())
1056 .into_shape_with_order(shape)
1057 .map_err(|_| E::custom("Invalid parameters shape"))?;
1058
1059 Ok(CatCPD::with_optionals(
1060 states,
1061 conditioning_states,
1062 parameters,
1063 sample_statistics,
1064 sample_log_likelihood,
1065 ))
1066 }
1067 }
1068
1069 const FIELDS: &[&str] = &[
1070 "states",
1071 "conditioning_states",
1072 "parameters",
1073 "sample_statistics",
1074 "sample_log_likelihood",
1075 "type",
1076 ];
1077
1078 deserializer.deserialize_struct("CatCPD", FIELDS, CatCPDVisitor)
1079 }
1080}
1081
1082impl_json_io!(CatCPD);