1use std::ops::{Add, AddAssign};
2
3use approx::{AbsDiffEq, RelativeEq, relative_eq};
4use itertools::Itertools;
5use ndarray::prelude::*;
6use serde::{
7 Deserialize, Deserializer, Serialize, Serializer,
8 de::{MapAccess, Visitor},
9 ser::SerializeMap,
10};
11
12use crate::{
13 datasets::CatSample,
14 impl_json_io,
15 models::{CIM, Labelled},
16 types::{EPSILON, Labels, Set, States},
17 utils::MI,
18};
19
20#[derive(Clone, Debug)]
22pub struct CatCIMS {
23 n_xz: Array3<f64>,
25 t_xz: Array2<f64>,
27 n: f64,
29}
30
31impl CatCIMS {
32 #[inline]
45 pub fn new(n_xz: Array3<f64>, t_xz: Array2<f64>, n: f64) -> Self {
46 assert_eq!(
48 n_xz.shape()[1],
49 n_xz.shape()[2],
50 "The second and third dimensions of the conditional counts must be equal."
51 );
52 assert_eq!(
53 n_xz.shape()[0],
54 t_xz.shape()[0],
55 "The first dimension of the conditional counts must match \n
56 the first dimension of the conditional times."
57 );
58 assert_eq!(
59 n_xz.shape()[1],
60 t_xz.shape()[1],
61 "The second dimension of the conditional counts must match \n
62 the second dimension of the conditional times."
63 );
64 assert!(
65 n_xz.iter().all(|&x| x.is_finite() && x >= 0.),
66 "Conditional counts must be finite and non-negative."
67 );
68 assert!(
69 t_xz.iter().all(|&x| x.is_finite() && x >= 0.),
70 "Conditional times must be finite and non-negative."
71 );
72 assert!(
73 n.is_finite() && n >= 0.,
74 "Sample size must be finite and non-negative."
75 );
76
77 Self { n_xz, t_xz, n }
78 }
79
80 #[inline]
87 pub const fn sample_conditional_counts(&self) -> &Array3<f64> {
88 &self.n_xz
89 }
90
91 #[inline]
98 pub const fn sample_conditional_times(&self) -> &Array2<f64> {
99 &self.t_xz
100 }
101
102 #[inline]
109 pub const fn sample_size(&self) -> f64 {
110 self.n
111 }
112}
113
114impl AddAssign for CatCIMS {
115 fn add_assign(&mut self, other: Self) {
116 self.n_xz += &other.n_xz;
118 self.t_xz += &other.t_xz;
119 self.n += other.n;
120 }
121}
122
123impl Add for CatCIMS {
124 type Output = Self;
125
126 fn add(mut self, other: Self) -> Self::Output {
127 self += other;
128 self
129 }
130}
131
132impl Serialize for CatCIMS {
133 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134 where
135 S: Serializer,
136 {
137 let mut map = serializer.serialize_map(Some(3))?;
139
140 let sample_conditional_counts: Vec<Vec<Vec<f64>>> = self
142 .n_xz
143 .outer_iter()
144 .map(|sample_conditional_counts| {
145 sample_conditional_counts
146 .rows()
147 .into_iter()
148 .map(|x| x.to_vec())
149 .collect()
150 })
151 .collect();
152
153 map.serialize_entry("sample_conditional_counts", &sample_conditional_counts)?;
155
156 let sample_conditional_times: Vec<Vec<f64>> =
158 self.t_xz.rows().into_iter().map(|x| x.to_vec()).collect();
159
160 map.serialize_entry("sample_conditional_times", &sample_conditional_times)?;
162
163 map.serialize_entry("sample_size", &self.n)?;
165
166 map.end()
168 }
169}
170
171impl<'de> Deserialize<'de> for CatCIMS {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 #[derive(Deserialize)]
177 #[serde(field_identifier, rename_all = "snake_case")]
178 #[allow(clippy::enum_variant_names)]
179 enum Field {
180 SampleConditionalCounts,
181 SampleConditionalTimes,
182 SampleSize,
183 }
184
185 struct CatCIMSVisitor;
186
187 impl<'de> Visitor<'de> for CatCIMSVisitor {
188 type Value = CatCIMS;
189
190 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
191 formatter.write_str("struct CatCIMS")
192 }
193
194 fn visit_map<V>(self, mut map: V) -> Result<CatCIMS, V::Error>
195 where
196 V: MapAccess<'de>,
197 {
198 use serde::de::Error as E;
199
200 let mut sample_conditional_counts = None;
202 let mut sample_conditional_times = None;
203 let mut sample_size = None;
204
205 while let Some(key) = map.next_key()? {
207 match key {
208 Field::SampleConditionalCounts => {
209 if sample_conditional_counts.is_some() {
210 return Err(E::duplicate_field("sample_conditional_counts"));
211 }
212 sample_conditional_counts = Some(map.next_value()?);
213 }
214 Field::SampleConditionalTimes => {
215 if sample_conditional_times.is_some() {
216 return Err(E::duplicate_field("sample_conditional_times"));
217 }
218 sample_conditional_times = Some(map.next_value()?);
219 }
220 Field::SampleSize => {
221 if sample_size.is_some() {
222 return Err(E::duplicate_field("sample_size"));
223 }
224 sample_size = Some(map.next_value()?);
225 }
226 }
227 }
228
229 let sample_conditional_counts = sample_conditional_counts
231 .ok_or_else(|| E::missing_field("sample_conditional_counts"))?;
232 let sample_conditional_times = sample_conditional_times
233 .ok_or_else(|| E::missing_field("sample_conditional_times"))?;
234 let sample_size = sample_size.ok_or_else(|| E::missing_field("sample_size"))?;
235
236 let sample_conditional_counts = {
238 let counts: Vec<Vec<Vec<f64>>> = sample_conditional_counts;
239 let shape = (counts.len(), counts[0].len(), counts[0][0].len());
240 let counts = counts.into_iter().flatten().flatten();
241 Array::from_iter(counts)
242 .into_shape_with_order(shape)
243 .map_err(|_| E::custom("Invalid sample conditional counts shape"))?
244 };
245
246 let sample_conditional_times = {
248 let times: Vec<Vec<f64>> = sample_conditional_times;
249 let shape = (times.len(), times[0].len());
250 let times = times.into_iter().flatten();
251 Array::from_iter(times)
252 .into_shape_with_order(shape)
253 .map_err(|_| E::custom("Invalid sample conditional times shape"))?
254 };
255
256 Ok(CatCIMS::new(
257 sample_conditional_counts,
258 sample_conditional_times,
259 sample_size,
260 ))
261 }
262 }
263
264 const FIELDS: &[&str] = &[
265 "sample_conditional_counts",
266 "sample_conditional_times",
267 "sample_size",
268 ];
269
270 deserializer.deserialize_struct("CatCIMS", FIELDS, CatCIMSVisitor)
271 }
272}
273
274#[derive(Clone, Debug)]
276pub struct CatCIM {
277 labels: Labels,
279 states: States,
280 shape: Array1<usize>,
281 multi_index: MI,
282 conditioning_labels: Labels,
284 conditioning_states: States,
285 conditioning_shape: Array1<usize>,
286 conditioning_multi_index: MI,
287 parameters: Array3<f64>,
289 parameters_size: usize,
290 sample_statistics: Option<CatCIMS>,
292 sample_log_likelihood: Option<f64>,
293}
294
295impl CatCIM {
296 pub fn new(states: States, conditioning_states: States, parameters: Array3<f64>) -> Self {
315 let labels: Set<_> = states.keys().cloned().collect();
317 let conditioning_labels: Set<_> = conditioning_states.keys().cloned().collect();
319
320 assert!(
322 labels.is_disjoint(&conditioning_labels),
323 "Labels and conditioning labels must be disjoint."
324 );
325
326 let shape = Array::from_iter(states.values().map(Set::len));
328
329 assert!(
331 parameters.is_empty() || parameters.shape()[1] == shape.product(),
332 "Product of the number of states must match the number of columns: \n\
333 \t expected: parameters.shape[1] == {} , \n\
334 \t found: parameters.shape[1] == {} .",
335 shape.product(),
336 parameters.shape()[1],
337 );
338
339 assert!(
341 parameters.is_empty() || parameters.shape()[2] == shape.product(),
342 "Product of the number of states must match the number of columns: \n\
343 \t expected: parameters.shape[2] == {} , \n\
344 \t found: parameters.shape[2] == {} .",
345 shape.product(),
346 parameters.shape()[2],
347 );
348
349 let conditioning_shape = Array::from_iter(conditioning_states.values().map(Set::len));
351
352 assert!(
354 parameters.is_empty() || parameters.shape()[0] == conditioning_shape.product(),
355 "Product of the number of conditioning states must match the number of rows: \n\
356 \t expected: parameters.shape[0] == {} , \n\
357 \t found: parameters.shape[0] == {} .",
358 conditioning_shape.product(),
359 parameters.shape()[0],
360 );
361
362 parameters.outer_iter().for_each(|q| {
364 assert!(q.is_square(), "Q must be square.");
366 assert!(
368 q.iter().all(|&x| x.is_finite()),
369 "Q must have finite values."
370 );
371 assert!(
373 q.diag().iter().all(|&x| x <= 0.),
374 "Q diagonal must be non-positive."
375 );
376 assert!(
378 q.indexed_iter().all(|((i, j), &x)| i == j || x >= 0.),
379 "Q off-diagonal must be non-negative."
380 );
381 assert!(
383 q.rows()
384 .into_iter()
385 .all(|x| relative_eq!(x.sum(), 0., epsilon = EPSILON)),
386 "Q rows must sum to zero."
387 );
388 });
389
390 let mut parameters = parameters;
392
393 let mut labels = labels;
395 let mut states = states;
396 let mut shape = shape;
397
398 if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
400 let mut sorted_states_idx: Vec<_> = states.values().multi_cartesian_product().collect();
402 let mut sorted_labels_idx: Vec<_> = (0..labels.len()).collect();
404 sorted_labels_idx.sort_by_key(|&i| &labels[i]);
406 sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
408 *sorted_states_idx = sorted_labels_idx
409 .iter()
410 .map(|&i| sorted_states_idx[i])
411 .collect();
412 });
413 let mut sorted_row_idx: Vec<_> = (0..parameters.shape()[1]).collect();
415 sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
417 states.sort_keys();
419 states.values_mut().for_each(Set::sort);
420 labels = states.keys().cloned().collect();
421 shape = states.values().map(Set::len).collect();
422 let mut new_parameters = parameters.clone();
424 new_parameters.axis_iter_mut(Axis(1)).enumerate().for_each(
426 |(i, mut new_parameters_axis)| {
427 new_parameters_axis.assign(¶meters.index_axis(Axis(1), sorted_row_idx[i]));
429 },
430 );
431 parameters = new_parameters;
433 let mut new_parameters = parameters.clone();
435 new_parameters.axis_iter_mut(Axis(2)).enumerate().for_each(
437 |(i, mut new_parameters_axis)| {
438 new_parameters_axis.assign(¶meters.index_axis(Axis(2), sorted_row_idx[i]));
440 },
441 );
442 parameters = new_parameters;
444 }
445
446 let labels = labels;
448 let states = states;
449 let shape = shape;
450
451 let mut conditioning_labels = conditioning_labels;
453 let mut conditioning_states = conditioning_states;
454 let mut conditioning_shape = conditioning_shape;
455
456 if !conditioning_states.keys().is_sorted()
458 || !conditioning_states.values().all(|x| x.iter().is_sorted())
459 {
460 let mut sorted_states_idx: Vec<_> = conditioning_states
462 .values()
463 .multi_cartesian_product()
464 .collect();
465 let mut sorted_labels_idx: Vec<_> = (0..conditioning_labels.len()).collect();
467 sorted_labels_idx.sort_by_key(|&i| &conditioning_labels[i]);
469 sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
471 *sorted_states_idx = sorted_labels_idx
472 .iter()
473 .map(|&i| sorted_states_idx[i])
474 .collect();
475 });
476 let mut sorted_row_idx: Vec<_> = (0..parameters.shape()[0]).collect();
478 sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
480 conditioning_states.sort_keys();
482 conditioning_states.values_mut().for_each(Set::sort);
483 conditioning_labels = conditioning_states.keys().cloned().collect();
484 conditioning_shape = conditioning_states.values().map(Set::len).collect();
485 let mut new_parameters = parameters.clone();
487 new_parameters.axis_iter_mut(Axis(0)).enumerate().for_each(
489 |(i, mut new_parameters_axis)| {
490 new_parameters_axis.assign(¶meters.index_axis(Axis(0), sorted_row_idx[i]));
492 },
493 );
494 parameters = new_parameters;
496 }
497
498 let conditioning_labels = conditioning_labels;
500 let conditioning_states = conditioning_states;
501 let conditioning_shape = conditioning_shape;
502
503 let parameters = parameters;
505
506 debug_assert!(labels.iter().is_sorted(), "Labels must be sorted.");
508 debug_assert!(states.keys().is_sorted(), "Labels must be sorted.");
509 debug_assert!(
510 states.values().all(|x| x.iter().is_sorted()),
511 "States must be sorted."
512 );
513 debug_assert!(
514 conditioning_labels.iter().is_sorted(),
515 "Conditioning labels must be sorted."
516 );
517 debug_assert!(
518 conditioning_states.keys().is_sorted(),
519 "Conditioning labels must be sorted."
520 );
521 debug_assert!(
522 conditioning_states.values().all(|x| x.iter().is_sorted()),
523 "Conditioning states must be sorted."
524 );
525
526 let multi_index = MI::new(shape.clone());
528 let conditioning_multi_index = MI::new(conditioning_shape.clone());
530
531 let s = parameters.shape();
533 let parameters_size = s[0] * s[1] * s[2].saturating_sub(1);
535
536 Self {
537 labels,
538 states,
539 shape,
540 multi_index,
541 conditioning_labels,
542 conditioning_states,
543 conditioning_shape,
544 conditioning_multi_index,
545 parameters,
546 parameters_size,
547 sample_statistics: None,
548 sample_log_likelihood: None,
549 }
550 }
551
552 #[inline]
559 pub const fn states(&self) -> &States {
560 &self.states
561 }
562
563 #[inline]
570 pub const fn shape(&self) -> &Array1<usize> {
571 &self.shape
572 }
573
574 #[inline]
581 pub const fn multi_index(&self) -> &MI {
582 &self.multi_index
583 }
584
585 #[inline]
592 pub const fn conditioning_states(&self) -> &States {
593 &self.conditioning_states
594 }
595
596 #[inline]
603 pub const fn conditioning_shape(&self) -> &Array1<usize> {
604 &self.conditioning_shape
605 }
606
607 #[inline]
614 pub const fn conditioning_multi_index(&self) -> &MI {
615 &self.conditioning_multi_index
616 }
617
618 pub fn with_optionals(
636 states: States,
637 conditioning_states: States,
638 parameters: Array3<f64>,
639 sample_statistics: Option<CatCIMS>,
640 sample_log_likelihood: Option<f64>,
641 ) -> Self {
642 if let Some(sample_statistics) = &sample_statistics {
644 let sample_conditional_counts = &sample_statistics.n_xz;
646 assert!(
648 sample_conditional_counts.shape() == parameters.shape(),
649 "Sample conditional counts must have the same shape as parameters: \n\
650 \t expected: sample_conditional_counts.shape() == {:?} , \n\
651 \t found: sample_conditional_counts.shape() == {:?} .",
652 parameters.shape(),
653 sample_conditional_counts.shape(),
654 );
655 }
656 if let Some(sample_log_likelihood) = &sample_log_likelihood {
658 assert!(
659 sample_log_likelihood.is_finite(),
660 "Sample log-likelihood must be finite: \n\
661 \t expected: sample_ll is finite, \n\
662 \t found: sample_ll is {sample_log_likelihood} ."
663 )
664 }
665
666 let mut cim = Self::new(states, conditioning_states, parameters);
668
669 cim.sample_statistics = sample_statistics;
673 cim.sample_log_likelihood = sample_log_likelihood;
674
675 cim
676 }
677}
678
679impl PartialEq for CatCIM {
680 fn eq(&self, other: &Self) -> bool {
681 self.labels.eq(&other.labels)
683 && self.states.eq(&other.states)
684 && self.shape.eq(&other.shape)
685 && self.conditioning_labels.eq(&other.conditioning_labels)
686 && self.conditioning_states.eq(&other.conditioning_states)
687 && self.conditioning_shape.eq(&other.conditioning_shape)
688 && self.multi_index.eq(&other.multi_index)
689 && self.parameters.eq(&other.parameters)
690 }
691}
692
693impl AbsDiffEq for CatCIM {
694 type Epsilon = f64;
695
696 fn default_epsilon() -> Self::Epsilon {
697 Self::Epsilon::default_epsilon()
698 }
699
700 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
701 self.labels.eq(&other.labels)
703 && self.states.eq(&other.states)
704 && self.shape.eq(&other.shape)
705 && self.conditioning_labels.eq(&other.conditioning_labels)
706 && self.conditioning_states.eq(&other.conditioning_states)
707 && self.conditioning_shape.eq(&other.conditioning_shape)
708 && self.multi_index.eq(&other.multi_index)
709 && self.parameters.abs_diff_eq(&other.parameters, epsilon)
710 }
711}
712
713impl RelativeEq for CatCIM {
714 fn default_max_relative() -> Self::Epsilon {
715 Self::Epsilon::default_max_relative()
716 }
717
718 fn relative_eq(
719 &self,
720 other: &Self,
721 epsilon: Self::Epsilon,
722 max_relative: Self::Epsilon,
723 ) -> bool {
724 self.labels.eq(&other.labels)
726 && self.states.eq(&other.states)
727 && self.shape.eq(&other.shape)
728 && self.conditioning_labels.eq(&other.conditioning_labels)
729 && self.conditioning_states.eq(&other.conditioning_states)
730 && self.conditioning_shape.eq(&other.conditioning_shape)
731 && self.multi_index.eq(&other.multi_index)
732 && self
733 .parameters
734 .relative_eq(&other.parameters, epsilon, max_relative)
735 }
736}
737
738impl Labelled for CatCIM {
739 #[inline]
740 fn labels(&self) -> &Labels {
741 &self.labels
742 }
743}
744
745impl CIM for CatCIM {
746 type Support = CatSample;
747 type Parameters = Array3<f64>;
748 type Statistics = CatCIMS;
749
750 #[inline]
751 fn conditioning_labels(&self) -> &Labels {
752 &self.conditioning_labels
753 }
754
755 #[inline]
756 fn parameters(&self) -> &Self::Parameters {
757 &self.parameters
758 }
759
760 #[inline]
761 fn parameters_size(&self) -> usize {
762 self.parameters_size
763 }
764
765 #[inline]
766 fn sample_statistics(&self) -> Option<&Self::Statistics> {
767 self.sample_statistics.as_ref()
768 }
769
770 #[inline]
771 fn sample_log_likelihood(&self) -> Option<f64> {
772 self.sample_log_likelihood
773 }
774}
775
776impl Serialize for CatCIM {
777 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
778 where
779 S: Serializer,
780 {
781 let mut size = 4;
783 size += self.sample_statistics.is_some() as usize;
784 size += self.sample_log_likelihood.is_some() as usize;
785
786 let mut map = serializer.serialize_map(Some(size))?;
788
789 map.serialize_entry("states", &self.states)?;
791 map.serialize_entry("conditioning_states", &self.conditioning_states)?;
793
794 let parameters: Vec<Vec<Vec<f64>>> = self
796 .parameters
797 .outer_iter()
798 .map(|parameters| parameters.rows().into_iter().map(|x| x.to_vec()).collect())
799 .collect();
800
801 map.serialize_entry("parameters", ¶meters)?;
803
804 if let Some(sample_statistics) = &self.sample_statistics {
806 map.serialize_entry("sample_statistics", &sample_statistics)?;
807 }
808 if let Some(sample_log_likelihood) = self.sample_log_likelihood {
810 map.serialize_entry("sample_log_likelihood", &sample_log_likelihood)?;
811 }
812
813 map.serialize_entry("type", "catcim")?;
815
816 map.end()
818 }
819}
820
821impl<'de> Deserialize<'de> for CatCIM {
822 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
823 where
824 D: Deserializer<'de>,
825 {
826 #[derive(Deserialize)]
827 #[serde(field_identifier, rename_all = "snake_case")]
828 enum Field {
829 States,
830 ConditioningStates,
831 Parameters,
832 SampleStatistics,
833 SampleLogLikelihood,
834 Type,
835 }
836
837 struct CatCIMVisitor;
838
839 impl<'de> Visitor<'de> for CatCIMVisitor {
840 type Value = CatCIM;
841
842 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
843 formatter.write_str("struct CatCIM")
844 }
845
846 fn visit_map<V>(self, mut map: V) -> Result<CatCIM, V::Error>
847 where
848 V: MapAccess<'de>,
849 {
850 use serde::de::Error as E;
851
852 let mut states = None;
854 let mut conditioning_states = None;
855 let mut parameters = None;
856 let mut sample_statistics = None;
857 let mut sample_log_likelihood = None;
858 let mut type_ = None;
859
860 while let Some(key) = map.next_key()? {
862 match key {
863 Field::States => {
864 if states.is_some() {
865 return Err(E::duplicate_field("states"));
866 }
867 states = Some(map.next_value()?);
868 }
869 Field::ConditioningStates => {
870 if conditioning_states.is_some() {
871 return Err(E::duplicate_field("conditioning_states"));
872 }
873 conditioning_states = Some(map.next_value()?);
874 }
875 Field::Parameters => {
876 if parameters.is_some() {
877 return Err(E::duplicate_field("parameters"));
878 }
879 parameters = Some(map.next_value()?);
880 }
881 Field::SampleStatistics => {
882 if sample_statistics.is_some() {
883 return Err(E::duplicate_field("sample_statistics"));
884 }
885 sample_statistics = Some(map.next_value()?);
886 }
887 Field::SampleLogLikelihood => {
888 if sample_log_likelihood.is_some() {
889 return Err(E::duplicate_field("sample_log_likelihood"));
890 }
891 sample_log_likelihood = Some(map.next_value()?);
892 }
893 Field::Type => {
894 if type_.is_some() {
895 return Err(E::duplicate_field("type"));
896 }
897 type_ = Some(map.next_value()?);
898 }
899 }
900 }
901
902 let states = states.ok_or_else(|| E::missing_field("states"))?;
904 let conditioning_states =
905 conditioning_states.ok_or_else(|| E::missing_field("conditioning_states"))?;
906 let parameters = parameters.ok_or_else(|| E::missing_field("parameters"))?;
907
908 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
910 assert_eq!(type_, "catcim", "Invalid type for CatCIM.");
911
912 let parameters: Vec<Vec<Vec<f64>>> = parameters;
914 let shape = (
915 parameters.len(),
916 parameters[0].len(),
917 parameters[0][0].len(),
918 );
919 let parameters = parameters.into_iter().flatten().flatten();
920 let parameters = Array::from_iter(parameters)
921 .into_shape_with_order(shape)
922 .map_err(|_| E::custom("Invalid parameters shape"))?;
923
924 Ok(CatCIM::with_optionals(
925 states,
926 conditioning_states,
927 parameters,
928 sample_statistics,
929 sample_log_likelihood,
930 ))
931 }
932 }
933
934 const FIELDS: &[&str] = &[
935 "states",
936 "conditioning_states",
937 "parameters",
938 "sample_statistics",
939 "sample_log_likelihood",
940 "type",
941 ];
942
943 deserializer.deserialize_struct("CatCIM", FIELDS, CatCIMVisitor)
944 }
945}
946
947impl_json_io!(CatCIM);