causal_hub/models/continuous_time_bayesian_network/categorical/
parameters.rs

1use std::ops::{Add, AddAssign};
2
3use approx::{AbsDiffEq, RelativeEq, relative_eq};
4use itertools::Itertools;
5use ndarray::prelude::*;
6use serde::{
7    Deserialize, Deserializer, Serialize, Serializer,
8    de::{MapAccess, Visitor},
9    ser::SerializeMap,
10};
11
12use crate::{
13    datasets::CatSample,
14    impl_json_io,
15    models::{CIM, Labelled},
16    types::{EPSILON, Labels, Set, States},
17    utils::MI,
18};
19
20/// Sample (sufficient) statistics for a categorical CIM.
21#[derive(Clone, Debug)]
22pub struct CatCIMS {
23    /// Conditional counts |Z| x |X| x |X|.
24    n_xz: Array3<f64>,
25    /// Conditional times |Z| x |X|.
26    t_xz: Array2<f64>,
27    /// Sample size.
28    n: f64,
29}
30
31impl CatCIMS {
32    /// Creates a new sample (sufficient) statistics for the categorical CIM.
33    ///
34    /// # Arguments
35    ///
36    /// * `n_xz` - Conditional counts |Z| x |X| x |X|.
37    /// * `t_xz` - Conditional times |Z| x |X|.
38    /// * `n` - Sample size.
39    ///
40    /// # Returns
41    ///
42    /// A new sample (sufficient) statistics for the categorical CIM.
43    ///
44    #[inline]
45    pub fn new(n_xz: Array3<f64>, t_xz: Array2<f64>, n: f64) -> Self {
46        // Assert the dimensions are correct.
47        assert_eq!(
48            n_xz.shape()[1],
49            n_xz.shape()[2],
50            "The second and third dimensions of the conditional counts must be equal."
51        );
52        assert_eq!(
53            n_xz.shape()[0],
54            t_xz.shape()[0],
55            "The first dimension of the conditional counts must match \n
56            the first dimension of the conditional times."
57        );
58        assert_eq!(
59            n_xz.shape()[1],
60            t_xz.shape()[1],
61            "The second dimension of the conditional counts must match \n
62            the second dimension of the conditional times."
63        );
64        assert!(
65            n_xz.iter().all(|&x| x.is_finite() && x >= 0.),
66            "Conditional counts must be finite and non-negative."
67        );
68        assert!(
69            t_xz.iter().all(|&x| x.is_finite() && x >= 0.),
70            "Conditional times must be finite and non-negative."
71        );
72        assert!(
73            n.is_finite() && n >= 0.,
74            "Sample size must be finite and non-negative."
75        );
76
77        Self { n_xz, t_xz, n }
78    }
79
80    /// Returns the sample conditional counts |Z| x |X| x |X|.
81    ///
82    /// # Returns
83    ///
84    /// The sample conditional counts |Z| x |X| x |X|.
85    ///
86    #[inline]
87    pub const fn sample_conditional_counts(&self) -> &Array3<f64> {
88        &self.n_xz
89    }
90
91    /// Returns the sample conditional times |Z| x |X|.
92    ///
93    /// # Returns
94    ///
95    /// The sample conditional times |Z| x |X|.
96    ///
97    #[inline]
98    pub const fn sample_conditional_times(&self) -> &Array2<f64> {
99        &self.t_xz
100    }
101
102    /// Returns the sample size.
103    ///
104    /// # Returns
105    ///
106    /// The sample size.
107    ///
108    #[inline]
109    pub const fn sample_size(&self) -> f64 {
110        self.n
111    }
112}
113
114impl AddAssign for CatCIMS {
115    fn add_assign(&mut self, other: Self) {
116        // Add the counts and times.
117        self.n_xz += &other.n_xz;
118        self.t_xz += &other.t_xz;
119        self.n += other.n;
120    }
121}
122
123impl Add for CatCIMS {
124    type Output = Self;
125
126    fn add(mut self, other: Self) -> Self::Output {
127        self += other;
128        self
129    }
130}
131
132impl Serialize for CatCIMS {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: Serializer,
136    {
137        // Allocate the map.
138        let mut map = serializer.serialize_map(Some(3))?;
139
140        // Convert the sample conditional counts to a flat format.
141        let sample_conditional_counts: Vec<Vec<Vec<f64>>> = self
142            .n_xz
143            .outer_iter()
144            .map(|sample_conditional_counts| {
145                sample_conditional_counts
146                    .rows()
147                    .into_iter()
148                    .map(|x| x.to_vec())
149                    .collect()
150            })
151            .collect();
152
153        // Serialize sample conditional counts.
154        map.serialize_entry("sample_conditional_counts", &sample_conditional_counts)?;
155
156        // Convert the sample conditional times to a flat format.
157        let sample_conditional_times: Vec<Vec<f64>> =
158            self.t_xz.rows().into_iter().map(|x| x.to_vec()).collect();
159
160        // Serialize sample conditional times.
161        map.serialize_entry("sample_conditional_times", &sample_conditional_times)?;
162
163        // Serialize sample size.
164        map.serialize_entry("sample_size", &self.n)?;
165
166        // Finalize the map serialization.
167        map.end()
168    }
169}
170
171impl<'de> Deserialize<'de> for CatCIMS {
172    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173    where
174        D: Deserializer<'de>,
175    {
176        #[derive(Deserialize)]
177        #[serde(field_identifier, rename_all = "snake_case")]
178        #[allow(clippy::enum_variant_names)]
179        enum Field {
180            SampleConditionalCounts,
181            SampleConditionalTimes,
182            SampleSize,
183        }
184
185        struct CatCIMSVisitor;
186
187        impl<'de> Visitor<'de> for CatCIMSVisitor {
188            type Value = CatCIMS;
189
190            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
191                formatter.write_str("struct CatCIMS")
192            }
193
194            fn visit_map<V>(self, mut map: V) -> Result<CatCIMS, V::Error>
195            where
196                V: MapAccess<'de>,
197            {
198                use serde::de::Error as E;
199
200                // Allocate fields
201                let mut sample_conditional_counts = None;
202                let mut sample_conditional_times = None;
203                let mut sample_size = None;
204
205                // Parse the map.
206                while let Some(key) = map.next_key()? {
207                    match key {
208                        Field::SampleConditionalCounts => {
209                            if sample_conditional_counts.is_some() {
210                                return Err(E::duplicate_field("sample_conditional_counts"));
211                            }
212                            sample_conditional_counts = Some(map.next_value()?);
213                        }
214                        Field::SampleConditionalTimes => {
215                            if sample_conditional_times.is_some() {
216                                return Err(E::duplicate_field("sample_conditional_times"));
217                            }
218                            sample_conditional_times = Some(map.next_value()?);
219                        }
220                        Field::SampleSize => {
221                            if sample_size.is_some() {
222                                return Err(E::duplicate_field("sample_size"));
223                            }
224                            sample_size = Some(map.next_value()?);
225                        }
226                    }
227                }
228
229                // Check all fields are present.
230                let sample_conditional_counts = sample_conditional_counts
231                    .ok_or_else(|| E::missing_field("sample_conditional_counts"))?;
232                let sample_conditional_times = sample_conditional_times
233                    .ok_or_else(|| E::missing_field("sample_conditional_times"))?;
234                let sample_size = sample_size.ok_or_else(|| E::missing_field("sample_size"))?;
235
236                // Convert sample conditional counts to ndarray.
237                let sample_conditional_counts = {
238                    let counts: Vec<Vec<Vec<f64>>> = sample_conditional_counts;
239                    let shape = (counts.len(), counts[0].len(), counts[0][0].len());
240                    let counts = counts.into_iter().flatten().flatten();
241                    Array::from_iter(counts)
242                        .into_shape_with_order(shape)
243                        .map_err(|_| E::custom("Invalid sample conditional counts shape"))?
244                };
245
246                // Convert sample conditional times to ndarray.
247                let sample_conditional_times = {
248                    let times: Vec<Vec<f64>> = sample_conditional_times;
249                    let shape = (times.len(), times[0].len());
250                    let times = times.into_iter().flatten();
251                    Array::from_iter(times)
252                        .into_shape_with_order(shape)
253                        .map_err(|_| E::custom("Invalid sample conditional times shape"))?
254                };
255
256                Ok(CatCIMS::new(
257                    sample_conditional_counts,
258                    sample_conditional_times,
259                    sample_size,
260                ))
261            }
262        }
263
264        const FIELDS: &[&str] = &[
265            "sample_conditional_counts",
266            "sample_conditional_times",
267            "sample_size",
268        ];
269
270        deserializer.deserialize_struct("CatCIMS", FIELDS, CatCIMSVisitor)
271    }
272}
273
274/// A struct representing a categorical conditional intensity matrix.
275#[derive(Clone, Debug)]
276pub struct CatCIM {
277    // Labels of the conditioned variable.
278    labels: Labels,
279    states: States,
280    shape: Array1<usize>,
281    multi_index: MI,
282    // Labels of the conditioning variables.
283    conditioning_labels: Labels,
284    conditioning_states: States,
285    conditioning_shape: Array1<usize>,
286    conditioning_multi_index: MI,
287    // Parameters.
288    parameters: Array3<f64>,
289    parameters_size: usize,
290    // Sample (sufficient) statistics, if any.
291    sample_statistics: Option<CatCIMS>,
292    sample_log_likelihood: Option<f64>,
293}
294
295impl CatCIM {
296    /// Creates a new categorical conditional intensity matrix.
297    ///
298    /// # Arguments
299    ///
300    /// * `states` - The variables states.
301    /// * `parameters` - The intensity matrices of the states.
302    ///
303    /// # Panics
304    ///
305    /// * If the labels and conditioning labels are not disjoint.
306    /// * If the product of the shape of the states does not match the length of the second and third axis.
307    /// * If the product of the shape of the conditioning states does not match the length of the first axis.
308    /// * If the parameters are not valid intensity matrices, unless empty.
309    ///
310    /// # Returns
311    ///
312    /// A new `CatCIM` instance.
313    ///
314    pub fn new(states: States, conditioning_states: States, parameters: Array3<f64>) -> Self {
315        // Get the labels of the variables.
316        let labels: Set<_> = states.keys().cloned().collect();
317        // Get the labels of the variables.
318        let conditioning_labels: Set<_> = conditioning_states.keys().cloned().collect();
319
320        // Assert labels and conditioning labels are disjoint.
321        assert!(
322            labels.is_disjoint(&conditioning_labels),
323            "Labels and conditioning labels must be disjoint."
324        );
325
326        // Get the states shape.
327        let shape = Array::from_iter(states.values().map(Set::len));
328
329        // Check that the product of the shape matches the number of columns.
330        assert!(
331            parameters.is_empty() || parameters.shape()[1] == shape.product(),
332            "Product of the number of states must match the number of columns: \n\
333            \t expected:    parameters.shape[1] == {} , \n\
334            \t found:       parameters.shape[1] == {} .",
335            shape.product(),
336            parameters.shape()[1],
337        );
338
339        // Check that the product of the shape matches the number of columns.
340        assert!(
341            parameters.is_empty() || parameters.shape()[2] == shape.product(),
342            "Product of the number of states must match the number of columns: \n\
343            \t expected:    parameters.shape[2] == {} , \n\
344            \t found:       parameters.shape[2] == {} .",
345            shape.product(),
346            parameters.shape()[2],
347        );
348
349        // Get the shape of the set of states.
350        let conditioning_shape = Array::from_iter(conditioning_states.values().map(Set::len));
351
352        // Check that the product of the conditioning shape matches the number of rows.
353        assert!(
354            parameters.is_empty() || parameters.shape()[0] == conditioning_shape.product(),
355            "Product of the number of conditioning states must match the number of rows: \n\
356            \t expected:    parameters.shape[0] == {} , \n\
357            \t found:       parameters.shape[0] == {} .",
358            conditioning_shape.product(),
359            parameters.shape()[0],
360        );
361
362        // Check parameters validity.
363        parameters.outer_iter().for_each(|q| {
364            // Assert Q is square.
365            assert!(q.is_square(), "Q must be square.");
366            // Assert Q has finite values.
367            assert!(
368                q.iter().all(|&x| x.is_finite()),
369                "Q must have finite values."
370            );
371            // Assert Q has non-positive diagonal.
372            assert!(
373                q.diag().iter().all(|&x| x <= 0.),
374                "Q diagonal must be non-positive."
375            );
376            // Assert Q has non-negative off-diagonal.
377            assert!(
378                q.indexed_iter().all(|((i, j), &x)| i == j || x >= 0.),
379                "Q off-diagonal must be non-negative."
380            );
381            // Assert Q rows sum to zero.
382            assert!(
383                q.rows()
384                    .into_iter()
385                    .all(|x| relative_eq!(x.sum(), 0., epsilon = EPSILON)),
386                "Q rows must sum to zero."
387            );
388        });
389
390        // Make parameters mutable.
391        let mut parameters = parameters;
392
393        // Make states mutable.
394        let mut labels = labels;
395        let mut states = states;
396        let mut shape = shape;
397
398        // Check if states are sorted.
399        if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
400            // Compute the current states order.
401            let mut sorted_states_idx: Vec<_> = states.values().multi_cartesian_product().collect();
402            // Sort the labels.
403            let mut sorted_labels_idx: Vec<_> = (0..labels.len()).collect();
404            // Sort the labels.
405            sorted_labels_idx.sort_by_key(|&i| &labels[i]);
406            // Sort the states by the labels.
407            sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
408                *sorted_states_idx = sorted_labels_idx
409                    .iter()
410                    .map(|&i| sorted_states_idx[i])
411                    .collect();
412            });
413            // Initialize the sorted row indices.
414            let mut sorted_row_idx: Vec<_> = (0..parameters.shape()[1]).collect();
415            // Sort the row indices.
416            sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
417            // Sort the labels.
418            states.sort_keys();
419            states.values_mut().for_each(Set::sort);
420            labels = states.keys().cloned().collect();
421            shape = states.values().map(Set::len).collect();
422            // Allocate new parameters, for axis 1.
423            let mut new_parameters = parameters.clone();
424            // Sort the values by multi indices.
425            new_parameters.axis_iter_mut(Axis(1)).enumerate().for_each(
426                |(i, mut new_parameters_axis)| {
427                    // Assign the sorted values to the new values array.
428                    new_parameters_axis.assign(&parameters.index_axis(Axis(1), sorted_row_idx[i]));
429                },
430            );
431            // Update the values with the new sorted values.
432            parameters = new_parameters;
433            // Allocate new parameters, for axis 2.
434            let mut new_parameters = parameters.clone();
435            // Sort the values by multi indices.
436            new_parameters.axis_iter_mut(Axis(2)).enumerate().for_each(
437                |(i, mut new_parameters_axis)| {
438                    // Assign the sorted values to the new values array.
439                    new_parameters_axis.assign(&parameters.index_axis(Axis(2), sorted_row_idx[i]));
440                },
441            );
442            // Update the values with the new sorted values.
443            parameters = new_parameters;
444        }
445
446        // Make states immutable.
447        let labels = labels;
448        let states = states;
449        let shape = shape;
450
451        // Make conditioning states mutable.
452        let mut conditioning_labels = conditioning_labels;
453        let mut conditioning_states = conditioning_states;
454        let mut conditioning_shape = conditioning_shape;
455
456        // Check if conditioning states are sorted.
457        if !conditioning_states.keys().is_sorted()
458            || !conditioning_states.values().all(|x| x.iter().is_sorted())
459        {
460            // Compute the current states order.
461            let mut sorted_states_idx: Vec<_> = conditioning_states
462                .values()
463                .multi_cartesian_product()
464                .collect();
465            // Sort the conditioning labels.
466            let mut sorted_labels_idx: Vec<_> = (0..conditioning_labels.len()).collect();
467            // Sort the conditioning labels.
468            sorted_labels_idx.sort_by_key(|&i| &conditioning_labels[i]);
469            // Sort the conditioning states by the labels.
470            sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
471                *sorted_states_idx = sorted_labels_idx
472                    .iter()
473                    .map(|&i| sorted_states_idx[i])
474                    .collect();
475            });
476            // Initialize the sorted row indices.
477            let mut sorted_row_idx: Vec<_> = (0..parameters.shape()[0]).collect();
478            // Sort the row indices.
479            sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
480            // Sort the labels.
481            conditioning_states.sort_keys();
482            conditioning_states.values_mut().for_each(Set::sort);
483            conditioning_labels = conditioning_states.keys().cloned().collect();
484            conditioning_shape = conditioning_states.values().map(Set::len).collect();
485            // Allocate new parameters.
486            let mut new_parameters = parameters.clone();
487            // Sort the values by multi indices.
488            new_parameters.axis_iter_mut(Axis(0)).enumerate().for_each(
489                |(i, mut new_parameters_axis)| {
490                    // Assign the sorted values to the new values array.
491                    new_parameters_axis.assign(&parameters.index_axis(Axis(0), sorted_row_idx[i]));
492                },
493            );
494            // Update the values with the new sorted values.
495            parameters = new_parameters;
496        }
497
498        // Make conditioning states immutable.
499        let conditioning_labels = conditioning_labels;
500        let conditioning_states = conditioning_states;
501        let conditioning_shape = conditioning_shape;
502
503        // Make parameters immutable.
504        let parameters = parameters;
505
506        // Debug assert to check the sorting of the labels.
507        debug_assert!(labels.iter().is_sorted(), "Labels must be sorted.");
508        debug_assert!(states.keys().is_sorted(), "Labels must be sorted.");
509        debug_assert!(
510            states.values().all(|x| x.iter().is_sorted()),
511            "States must be sorted."
512        );
513        debug_assert!(
514            conditioning_labels.iter().is_sorted(),
515            "Conditioning labels must be sorted."
516        );
517        debug_assert!(
518            conditioning_states.keys().is_sorted(),
519            "Conditioning labels must be sorted."
520        );
521        debug_assert!(
522            conditioning_states.values().all(|x| x.iter().is_sorted()),
523            "Conditioning states must be sorted."
524        );
525
526        // Compute the multi index.
527        let multi_index = MI::new(shape.clone());
528        // Compute the conditioning multi index.
529        let conditioning_multi_index = MI::new(conditioning_shape.clone());
530
531        // Get the shape of the parameters.
532        let s = parameters.shape();
533        // Compute the parameters size.
534        let parameters_size = s[0] * s[1] * s[2].saturating_sub(1);
535
536        Self {
537            labels,
538            states,
539            shape,
540            multi_index,
541            conditioning_labels,
542            conditioning_states,
543            conditioning_shape,
544            conditioning_multi_index,
545            parameters,
546            parameters_size,
547            sample_statistics: None,
548            sample_log_likelihood: None,
549        }
550    }
551
552    /// Returns the states of the conditioned variable.
553    ///
554    /// # Returns
555    ///
556    /// The states of the conditioned variable.
557    ///
558    #[inline]
559    pub const fn states(&self) -> &States {
560        &self.states
561    }
562
563    /// Returns the shape of the conditioned variable.
564    ///
565    /// # Returns
566    ///
567    /// The shape of the conditioned variable.
568    ///
569    #[inline]
570    pub const fn shape(&self) -> &Array1<usize> {
571        &self.shape
572    }
573
574    /// Returns the ravel multi index of the conditioning variables.
575    ///
576    /// # Returns
577    ///
578    /// The ravel multi index of the conditioning variables.
579    ///
580    #[inline]
581    pub const fn multi_index(&self) -> &MI {
582        &self.multi_index
583    }
584
585    /// Returns the states of the conditioning variables.
586    ///
587    /// # Returns
588    ///
589    /// The states of the conditioning variables.
590    ///
591    #[inline]
592    pub const fn conditioning_states(&self) -> &States {
593        &self.conditioning_states
594    }
595
596    /// Returns the shape of the conditioning variables.
597    ///
598    /// # Returns
599    ///
600    /// The shape of the conditioning variables.
601    ///
602    #[inline]
603    pub const fn conditioning_shape(&self) -> &Array1<usize> {
604        &self.conditioning_shape
605    }
606
607    /// Returns the ravel multi index of the conditioning variables.
608    ///
609    /// # Returns
610    ///
611    /// The ravel multi index of the conditioning variables.
612    ///
613    #[inline]
614    pub const fn conditioning_multi_index(&self) -> &MI {
615        &self.conditioning_multi_index
616    }
617
618    /// Creates a new categorical conditional intensity matrix.
619    ///
620    /// # Arguments
621    ///
622    /// * `states` - The variables states.
623    /// * `parameters` - The intensity matrices of the states.
624    /// * `sample_statistics` - The sample statistics used to fit the distribution, if any.
625    /// * `sample_log_likelihood` - The sample log-likelihood given the distribution, if any.
626    ///
627    /// # Panics
628    ///
629    /// See `new` method for panics.
630    ///
631    /// # Returns
632    ///
633    /// A new `CatCIM` instance.
634    ///
635    pub fn with_optionals(
636        states: States,
637        conditioning_states: States,
638        parameters: Array3<f64>,
639        sample_statistics: Option<CatCIMS>,
640        sample_log_likelihood: Option<f64>,
641    ) -> Self {
642        // Assert the sample conditional counts are finite and non-negative, with same shape as parameters.
643        if let Some(sample_statistics) = &sample_statistics {
644            // Get the sample conditional counts.
645            let sample_conditional_counts = &sample_statistics.n_xz;
646            // Assert the sample conditional counts have the same shape as parameters.
647            assert!(
648                sample_conditional_counts.shape() == parameters.shape(),
649                "Sample conditional counts must have the same shape as parameters: \n\
650                \t expected:    sample_conditional_counts.shape() == {:?} , \n\
651                \t found:       sample_conditional_counts.shape() == {:?} .",
652                parameters.shape(),
653                sample_conditional_counts.shape(),
654            );
655        }
656        // Assert the sample log-likelihood is finite.
657        if let Some(sample_log_likelihood) = &sample_log_likelihood {
658            assert!(
659                sample_log_likelihood.is_finite(),
660                "Sample log-likelihood must be finite: \n\
661                \t expected: sample_ll is finite, \n\
662                \t found:    sample_ll is {sample_log_likelihood} ."
663            )
664        }
665
666        // Construct the CIM.
667        let mut cim = Self::new(states, conditioning_states, parameters);
668
669        // FIXME: Check labels alignment with optional fields.
670
671        // Set the sample statistics and log-likelihood.
672        cim.sample_statistics = sample_statistics;
673        cim.sample_log_likelihood = sample_log_likelihood;
674
675        cim
676    }
677}
678
679impl PartialEq for CatCIM {
680    fn eq(&self, other: &Self) -> bool {
681        // Check for equality, excluding the sample values.
682        self.labels.eq(&other.labels)
683            && self.states.eq(&other.states)
684            && self.shape.eq(&other.shape)
685            && self.conditioning_labels.eq(&other.conditioning_labels)
686            && self.conditioning_states.eq(&other.conditioning_states)
687            && self.conditioning_shape.eq(&other.conditioning_shape)
688            && self.multi_index.eq(&other.multi_index)
689            && self.parameters.eq(&other.parameters)
690    }
691}
692
693impl AbsDiffEq for CatCIM {
694    type Epsilon = f64;
695
696    fn default_epsilon() -> Self::Epsilon {
697        Self::Epsilon::default_epsilon()
698    }
699
700    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
701        // Check for equality, excluding the sample values.
702        self.labels.eq(&other.labels)
703            && self.states.eq(&other.states)
704            && self.shape.eq(&other.shape)
705            && self.conditioning_labels.eq(&other.conditioning_labels)
706            && self.conditioning_states.eq(&other.conditioning_states)
707            && self.conditioning_shape.eq(&other.conditioning_shape)
708            && self.multi_index.eq(&other.multi_index)
709            && self.parameters.abs_diff_eq(&other.parameters, epsilon)
710    }
711}
712
713impl RelativeEq for CatCIM {
714    fn default_max_relative() -> Self::Epsilon {
715        Self::Epsilon::default_max_relative()
716    }
717
718    fn relative_eq(
719        &self,
720        other: &Self,
721        epsilon: Self::Epsilon,
722        max_relative: Self::Epsilon,
723    ) -> bool {
724        // Check for equality, excluding the sample values.
725        self.labels.eq(&other.labels)
726            && self.states.eq(&other.states)
727            && self.shape.eq(&other.shape)
728            && self.conditioning_labels.eq(&other.conditioning_labels)
729            && self.conditioning_states.eq(&other.conditioning_states)
730            && self.conditioning_shape.eq(&other.conditioning_shape)
731            && self.multi_index.eq(&other.multi_index)
732            && self
733                .parameters
734                .relative_eq(&other.parameters, epsilon, max_relative)
735    }
736}
737
738impl Labelled for CatCIM {
739    #[inline]
740    fn labels(&self) -> &Labels {
741        &self.labels
742    }
743}
744
745impl CIM for CatCIM {
746    type Support = CatSample;
747    type Parameters = Array3<f64>;
748    type Statistics = CatCIMS;
749
750    #[inline]
751    fn conditioning_labels(&self) -> &Labels {
752        &self.conditioning_labels
753    }
754
755    #[inline]
756    fn parameters(&self) -> &Self::Parameters {
757        &self.parameters
758    }
759
760    #[inline]
761    fn parameters_size(&self) -> usize {
762        self.parameters_size
763    }
764
765    #[inline]
766    fn sample_statistics(&self) -> Option<&Self::Statistics> {
767        self.sample_statistics.as_ref()
768    }
769
770    #[inline]
771    fn sample_log_likelihood(&self) -> Option<f64> {
772        self.sample_log_likelihood
773    }
774}
775
776impl Serialize for CatCIM {
777    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
778    where
779        S: Serializer,
780    {
781        // Count the elements to serialize.
782        let mut size = 4;
783        size += self.sample_statistics.is_some() as usize;
784        size += self.sample_log_likelihood.is_some() as usize;
785
786        // Allocate the map.
787        let mut map = serializer.serialize_map(Some(size))?;
788
789        // Serialize states.
790        map.serialize_entry("states", &self.states)?;
791        // Serialize conditioning states.
792        map.serialize_entry("conditioning_states", &self.conditioning_states)?;
793
794        // Convert parameters to a flat format.
795        let parameters: Vec<Vec<Vec<f64>>> = self
796            .parameters
797            .outer_iter()
798            .map(|parameters| parameters.rows().into_iter().map(|x| x.to_vec()).collect())
799            .collect();
800
801        // Serialize parameters.
802        map.serialize_entry("parameters", &parameters)?;
803
804        // Serialize sample statistics, if any.
805        if let Some(sample_statistics) = &self.sample_statistics {
806            map.serialize_entry("sample_statistics", &sample_statistics)?;
807        }
808        // Serialize sample log likelihood, if any.
809        if let Some(sample_log_likelihood) = self.sample_log_likelihood {
810            map.serialize_entry("sample_log_likelihood", &sample_log_likelihood)?;
811        }
812
813        // Serialize type.
814        map.serialize_entry("type", "catcim")?;
815
816        // Finalize the map serialization.
817        map.end()
818    }
819}
820
821impl<'de> Deserialize<'de> for CatCIM {
822    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
823    where
824        D: Deserializer<'de>,
825    {
826        #[derive(Deserialize)]
827        #[serde(field_identifier, rename_all = "snake_case")]
828        enum Field {
829            States,
830            ConditioningStates,
831            Parameters,
832            SampleStatistics,
833            SampleLogLikelihood,
834            Type,
835        }
836
837        struct CatCIMVisitor;
838
839        impl<'de> Visitor<'de> for CatCIMVisitor {
840            type Value = CatCIM;
841
842            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
843                formatter.write_str("struct CatCIM")
844            }
845
846            fn visit_map<V>(self, mut map: V) -> Result<CatCIM, V::Error>
847            where
848                V: MapAccess<'de>,
849            {
850                use serde::de::Error as E;
851
852                // Allocate fields
853                let mut states = None;
854                let mut conditioning_states = None;
855                let mut parameters = None;
856                let mut sample_statistics = None;
857                let mut sample_log_likelihood = None;
858                let mut type_ = None;
859
860                // Parse the map.
861                while let Some(key) = map.next_key()? {
862                    match key {
863                        Field::States => {
864                            if states.is_some() {
865                                return Err(E::duplicate_field("states"));
866                            }
867                            states = Some(map.next_value()?);
868                        }
869                        Field::ConditioningStates => {
870                            if conditioning_states.is_some() {
871                                return Err(E::duplicate_field("conditioning_states"));
872                            }
873                            conditioning_states = Some(map.next_value()?);
874                        }
875                        Field::Parameters => {
876                            if parameters.is_some() {
877                                return Err(E::duplicate_field("parameters"));
878                            }
879                            parameters = Some(map.next_value()?);
880                        }
881                        Field::SampleStatistics => {
882                            if sample_statistics.is_some() {
883                                return Err(E::duplicate_field("sample_statistics"));
884                            }
885                            sample_statistics = Some(map.next_value()?);
886                        }
887                        Field::SampleLogLikelihood => {
888                            if sample_log_likelihood.is_some() {
889                                return Err(E::duplicate_field("sample_log_likelihood"));
890                            }
891                            sample_log_likelihood = Some(map.next_value()?);
892                        }
893                        Field::Type => {
894                            if type_.is_some() {
895                                return Err(E::duplicate_field("type"));
896                            }
897                            type_ = Some(map.next_value()?);
898                        }
899                    }
900                }
901
902                // Check required fields.
903                let states = states.ok_or_else(|| E::missing_field("states"))?;
904                let conditioning_states =
905                    conditioning_states.ok_or_else(|| E::missing_field("conditioning_states"))?;
906                let parameters = parameters.ok_or_else(|| E::missing_field("parameters"))?;
907
908                // Assert type is correct.
909                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
910                assert_eq!(type_, "catcim", "Invalid type for CatCIM.");
911
912                // Convert parameters to ndarray.
913                let parameters: Vec<Vec<Vec<f64>>> = parameters;
914                let shape = (
915                    parameters.len(),
916                    parameters[0].len(),
917                    parameters[0][0].len(),
918                );
919                let parameters = parameters.into_iter().flatten().flatten();
920                let parameters = Array::from_iter(parameters)
921                    .into_shape_with_order(shape)
922                    .map_err(|_| E::custom("Invalid parameters shape"))?;
923
924                Ok(CatCIM::with_optionals(
925                    states,
926                    conditioning_states,
927                    parameters,
928                    sample_statistics,
929                    sample_log_likelihood,
930                ))
931            }
932        }
933
934        const FIELDS: &[&str] = &[
935            "states",
936            "conditioning_states",
937            "parameters",
938            "sample_statistics",
939            "sample_log_likelihood",
940            "type",
941        ];
942
943        deserializer.deserialize_struct("CatCIM", FIELDS, CatCIMVisitor)
944    }
945}
946
947// Implement `JsonIO` for `CatCIM`.
948impl_json_io!(CatCIM);