causal_hub/models/bayesian_network/categorical/
parameters.rs

1use std::{
2    fmt::Display,
3    ops::{Add, AddAssign},
4};
5
6use approx::{AbsDiffEq, RelativeEq, relative_eq};
7use itertools::Itertools;
8use ndarray::prelude::*;
9use rand::Rng;
10use rand_distr::{Distribution, weighted::WeightedIndex};
11use serde::{
12    Deserialize, Deserializer, Serialize, Serializer,
13    de::{MapAccess, Visitor},
14    ser::SerializeMap,
15};
16
17use crate::{
18    datasets::{CatSample, CatType},
19    impl_json_io,
20    models::{CPD, CatPhi, Labelled, Phi},
21    types::{EPSILON, Labels, Set, States},
22    utils::MI,
23};
24
25/// Sample (sufficient) statistics for the categorical CPD.
26#[derive(Clone, Debug)]
27pub struct CatCPDS {
28    /// Conditional counts |Z| x |X|.
29    n_xz: Array2<f64>,
30    /// Sample size.
31    n: f64,
32}
33
34impl CatCPDS {
35    /// Creates a new sample (sufficient) statistics for the categorical CPD.
36    ///
37    /// # Arguments
38    ///
39    /// * `n_xz` - The conditional counts |Z| x |X|.
40    /// * `n` - The sample size.
41    ///
42    /// # Returns
43    ///
44    /// A new sample (sufficient) statistics instance.
45    ///
46    #[inline]
47    pub fn new(n_xz: Array2<f64>, n: f64) -> Self {
48        // Assert the counts are finite and non-negative.
49        assert!(
50            n_xz.iter().all(|&x| x.is_finite() && x >= 0.),
51            "Counts must be finite and non-negative."
52        );
53        assert!(
54            n.is_finite() && n >= 0.,
55            "Sample size must be finite and non-negative."
56        );
57
58        Self { n_xz, n }
59    }
60
61    /// Returns the sample conditional counts |Z| x |X|.
62    ///
63    /// # Returns
64    ///
65    /// The sample conditional counts.
66    ///
67    #[inline]
68    pub const fn sample_conditional_counts(&self) -> &Array2<f64> {
69        &self.n_xz
70    }
71
72    /// Returns the sample size.
73    ///
74    /// # Returns
75    ///
76    /// The sample size.
77    ///
78    #[inline]
79    pub const fn sample_size(&self) -> f64 {
80        self.n
81    }
82}
83
84impl AddAssign for CatCPDS {
85    fn add_assign(&mut self, other: Self) {
86        // Add the counts.
87        self.n_xz += &other.n_xz;
88        // Add the sample sizes.
89        self.n += other.n;
90    }
91}
92
93impl Add for CatCPDS {
94    type Output = Self;
95
96    #[inline]
97    fn add(mut self, rhs: Self) -> Self::Output {
98        self += rhs;
99        self
100    }
101}
102
103impl Serialize for CatCPDS {
104    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105    where
106        S: Serializer,
107    {
108        // Allocate the map.
109        let mut map = serializer.serialize_map(Some(2))?;
110        // Convert the sample conditional counts to a flat format.
111        let sample_conditional_counts: Vec<Vec<f64>> =
112            self.n_xz.rows().into_iter().map(|x| x.to_vec()).collect();
113        // Serialize sample conditional counts.
114        map.serialize_entry("sample_conditional_counts", &sample_conditional_counts)?;
115        // Serialize sample size.
116        map.serialize_entry("sample_size", &self.n)?;
117        // End the map.
118        map.end()
119    }
120}
121
122impl<'de> Deserialize<'de> for CatCPDS {
123    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
124    where
125        D: Deserializer<'de>,
126    {
127        #[derive(Deserialize)]
128        #[serde(field_identifier, rename_all = "snake_case")]
129        enum Field {
130            SampleConditionalCounts,
131            SampleSize,
132        }
133
134        struct CatCPDSVisitor;
135
136        impl<'de> Visitor<'de> for CatCPDSVisitor {
137            type Value = CatCPDS;
138
139            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
140                formatter.write_str("struct CatCPDS")
141            }
142
143            fn visit_map<V>(self, mut map: V) -> Result<CatCPDS, V::Error>
144            where
145                V: MapAccess<'de>,
146            {
147                use serde::de::Error as E;
148
149                // Allocate the fields.
150                let mut sample_conditional_counts = None;
151                let mut sample_size = None;
152
153                // Parse the map.
154                while let Some(key) = map.next_key()? {
155                    match key {
156                        Field::SampleConditionalCounts => {
157                            if sample_conditional_counts.is_some() {
158                                return Err(E::duplicate_field("sample_conditional_counts"));
159                            }
160                            sample_conditional_counts = Some(map.next_value()?);
161                        }
162                        Field::SampleSize => {
163                            if sample_size.is_some() {
164                                return Err(E::duplicate_field("sample_size"));
165                            }
166                            sample_size = Some(map.next_value()?);
167                        }
168                    }
169                }
170
171                // Extract the fields.
172                let sample_conditional_counts = sample_conditional_counts
173                    .ok_or_else(|| E::missing_field("sample_conditional_counts"))?;
174                let sample_size = sample_size.ok_or_else(|| E::missing_field("sample_size"))?;
175
176                // Convert sample conditional counts to ndarray.
177                let sample_conditional_counts = {
178                    let counts: Vec<Vec<f64>> = sample_conditional_counts;
179                    let shape = (counts.len(), counts[0].len());
180                    Array::from_iter(counts.into_iter().flatten())
181                        .into_shape_with_order(shape)
182                        .map_err(|_| E::custom("Invalid sample conditional counts shape"))?
183                };
184
185                Ok(CatCPDS::new(sample_conditional_counts, sample_size))
186            }
187        }
188
189        const FIELDS: &[&str] = &["sample_conditional_counts", "sample_size"];
190
191        deserializer.deserialize_struct("CatCPDS", FIELDS, CatCPDSVisitor)
192    }
193}
194
195/// A categorical CPD.
196#[derive(Clone, Debug)]
197pub struct CatCPD {
198    // Labels of the conditioned variable.
199    labels: Labels,
200    states: States,
201    shape: Array1<usize>,
202    multi_index: MI,
203    // Labels of the conditioning variables.
204    conditioning_labels: Labels,
205    conditioning_states: States,
206    conditioning_shape: Array1<usize>,
207    conditioning_multi_index: MI,
208    // Parameters.
209    parameters: Array2<f64>,
210    parameters_size: usize,
211    // Sample (sufficient) statistics, if any.
212    sample_statistics: Option<CatCPDS>,
213    sample_log_likelihood: Option<f64>,
214}
215
216impl CatCPD {
217    /// Creates a new categorical conditional probability distribution.
218    ///
219    /// # Arguments
220    ///
221    /// * `states` - The variable label and states.
222    /// * `conditioning_states` - The conditioning variables labels and states.
223    /// * `parameters` - The probabilities of the states.
224    ///
225    /// # Panics
226    ///
227    /// * If the labels and conditioning labels are not disjoint.
228    /// * If the product of the shape of the of states does not match the number of columns.
229    /// * If the product of the shape of the of conditioning states does not match the number of rows.
230    /// * If the parameters do not sum to one by row, unless empty.
231    ///
232    /// # Returns
233    ///
234    /// A new `CatCPD` instance.
235    ///
236    pub fn new(states: States, conditioning_states: States, parameters: Array2<f64>) -> Self {
237        // Get the labels of the variables.
238        let labels: Set<_> = states.keys().cloned().collect();
239        // Get the labels of the variables.
240        let conditioning_labels: Set<_> = conditioning_states.keys().cloned().collect();
241
242        // Assert labels and conditioning labels are disjoint.
243        assert!(
244            labels.is_disjoint(&conditioning_labels),
245            "Labels and conditioning labels must be disjoint."
246        );
247
248        // Get the states shape.
249        let shape = Array::from_iter(states.values().map(Set::len));
250
251        // Check that the product of the shape matches the number of columns.
252        assert!(
253            parameters.is_empty() || parameters.ncols() == shape.product(),
254            "Product of the number of states must match the number of columns: \n\
255            \t expected:    parameters.ncols() == {} , \n\
256            \t found:       parameters.ncols() == {} .",
257            shape.product(),
258            parameters.ncols(),
259        );
260
261        // Get the shape of the set of states.
262        let conditioning_shape = Array::from_iter(conditioning_states.values().map(Set::len));
263
264        // Check that the product of the conditioning shape matches the number of rows.
265        assert!(
266            parameters.is_empty() || parameters.nrows() == conditioning_shape.product(),
267            "Product of the number of conditioning states must match the number of rows: \n\
268            \t expected:    parameters.nrows() == {} , \n\
269            \t found:       parameters.nrows() == {} .",
270            conditioning_shape.product(),
271            parameters.nrows(),
272        );
273
274        // Check parameters validity.
275        parameters
276            .sum_axis(Axis(1))
277            .iter()
278            .enumerate()
279            .for_each(|(i, &x)| {
280                if !relative_eq!(x, 1.0, epsilon = EPSILON) {
281                    panic!("Failed to sum probability to one: {}.", parameters.row(i));
282                }
283            });
284
285        // Make parameters mutable.
286        let mut parameters = parameters;
287
288        // Make states mutable.
289        let mut labels = labels;
290        let mut states = states;
291        let mut shape = shape;
292
293        // Check if states are sorted.
294        if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
295            // Compute the current states order.
296            let mut sorted_states_idx: Vec<_> = states.values().multi_cartesian_product().collect();
297            // Sort the labels.
298            let mut sorted_labels_idx: Vec<_> = (0..labels.len()).collect();
299            // Sort the labels.
300            sorted_labels_idx.sort_by_key(|&i| &labels[i]);
301            // Sort the states by the labels.
302            sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
303                *sorted_states_idx = sorted_labels_idx
304                    .iter()
305                    .map(|&i| sorted_states_idx[i])
306                    .collect();
307            });
308            // Initialize the sorted row indices.
309            let mut sorted_row_idx: Vec<_> = (0..parameters.ncols()).collect();
310            // Sort the row indices.
311            sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
312            // Sort the labels.
313            states.sort_keys();
314            states.values_mut().for_each(Set::sort);
315            labels = states.keys().cloned().collect();
316            shape = states.values().map(|x| x.len()).collect();
317            // Allocate new parameters.
318            let mut new_parameters = parameters.clone();
319            // Sort the values by multi indices.
320            new_parameters
321                .columns_mut()
322                .into_iter()
323                .enumerate()
324                .for_each(|(i, mut new_parameters_col)| {
325                    // Assign the sorted values to the new values array.
326                    new_parameters_col.assign(&parameters.column(sorted_row_idx[i]));
327                });
328            // Update the values with the new sorted values.
329            parameters = new_parameters;
330        }
331
332        // Make states immutable.
333        let labels = labels;
334        let states = states;
335        let shape = shape;
336
337        // Make conditioning states mutable.
338        let mut conditioning_labels = conditioning_labels;
339        let mut conditioning_states = conditioning_states;
340        let mut conditioning_shape = conditioning_shape;
341
342        // Check if conditioning states are sorted.
343        if !conditioning_states.keys().is_sorted()
344            || !conditioning_states.values().all(|x| x.iter().is_sorted())
345        {
346            // Compute the current states order.
347            let mut sorted_states_idx: Vec<_> = conditioning_states
348                .values()
349                .multi_cartesian_product()
350                .collect();
351            // Sort the conditioning labels.
352            let mut sorted_labels_idx: Vec<_> = (0..conditioning_labels.len()).collect();
353            // Sort the conditioning labels.
354            sorted_labels_idx.sort_by_key(|&i| &conditioning_labels[i]);
355            // Sort the conditioning states by the labels.
356            sorted_states_idx.iter_mut().for_each(|sorted_states_idx| {
357                *sorted_states_idx = sorted_labels_idx
358                    .iter()
359                    .map(|&i| sorted_states_idx[i])
360                    .collect();
361            });
362            // Initialize the sorted row indices.
363            let mut sorted_row_idx: Vec<_> = (0..parameters.nrows()).collect();
364            // Sort the row indices.
365            sorted_row_idx.sort_by_key(|&i| &sorted_states_idx[i]);
366            // Sort the labels.
367            conditioning_states.sort_keys();
368            conditioning_states.values_mut().for_each(Set::sort);
369            conditioning_labels = conditioning_states.keys().cloned().collect();
370            conditioning_shape = conditioning_states.values().map(|x| x.len()).collect();
371            // Allocate new parameters.
372            let mut new_parameters = parameters.clone();
373            // Sort the values by multi indices.
374            new_parameters.rows_mut().into_iter().enumerate().for_each(
375                |(i, mut new_parameters_row)| {
376                    // Assign the sorted values to the new values array.
377                    new_parameters_row.assign(&parameters.row(sorted_row_idx[i]));
378                },
379            );
380            // Update the values with the new sorted values.
381            parameters = new_parameters;
382        }
383
384        // Make conditioning states immutable.
385        let conditioning_labels = conditioning_labels;
386        let conditioning_states = conditioning_states;
387        let conditioning_shape = conditioning_shape;
388
389        // Make parameters immutable.
390        let parameters = parameters;
391
392        // Debug assert to check the sorting of the labels.
393        debug_assert!(labels.iter().is_sorted(), "Labels must be sorted.");
394        debug_assert!(states.keys().is_sorted(), "Labels must be sorted.");
395        debug_assert!(
396            states.values().all(|x| x.iter().is_sorted()),
397            "States must be sorted."
398        );
399        debug_assert!(
400            conditioning_labels.iter().is_sorted(),
401            "Conditioning labels must be sorted."
402        );
403        debug_assert!(
404            conditioning_states.keys().is_sorted(),
405            "Conditioning labels must be sorted."
406        );
407        debug_assert!(
408            conditioning_states.values().all(|x| x.iter().is_sorted()),
409            "Conditioning states must be sorted."
410        );
411
412        // Compute the multi index.
413        let multi_index = MI::new(shape.clone());
414        // Compute the conditioning multi index.
415        let conditioning_multi_index = MI::new(conditioning_shape.clone());
416        // Compute the parameters size.
417        let parameters_size = parameters.ncols().saturating_sub(1) * parameters.nrows();
418
419        Self {
420            labels,
421            states,
422            shape,
423            multi_index,
424            conditioning_labels,
425            conditioning_states,
426            conditioning_shape,
427            conditioning_multi_index,
428            parameters,
429            parameters_size,
430            sample_statistics: None,
431            sample_log_likelihood: None,
432        }
433    }
434
435    /// Returns the states of the conditioned variable.
436    ///
437    /// # Returns
438    ///
439    /// The states of the conditioned variable.
440    ///
441    #[inline]
442    pub const fn states(&self) -> &States {
443        &self.states
444    }
445
446    /// Returns the shape of the conditioned variable.
447    ///
448    /// # Returns
449    ///
450    /// The shape of the conditioned variable.
451    ///
452    #[inline]
453    pub const fn shape(&self) -> &Array1<usize> {
454        &self.shape
455    }
456
457    /// Returns the ravel multi index of the conditioning variables.
458    ///
459    /// # Returns
460    ///
461    /// The ravel multi index of the conditioning variables.
462    ///
463    #[inline]
464    pub const fn multi_index(&self) -> &MI {
465        &self.multi_index
466    }
467
468    /// Returns the states of the conditioning variables.
469    ///
470    /// # Returns
471    ///
472    /// The states of the conditioning variables.
473    ///
474    #[inline]
475    pub const fn conditioning_states(&self) -> &States {
476        &self.conditioning_states
477    }
478
479    /// Returns the shape of the conditioning variables.
480    ///
481    /// # Returns
482    ///
483    /// The shape of the conditioning variables.
484    ///
485    #[inline]
486    pub const fn conditioning_shape(&self) -> &Array1<usize> {
487        &self.conditioning_shape
488    }
489
490    /// Returns the ravel multi index of the conditioning variables.
491    ///
492    /// # Returns
493    ///
494    /// The ravel multi index of the conditioning variables.
495    ///
496    #[inline]
497    pub const fn conditioning_multi_index(&self) -> &MI {
498        &self.conditioning_multi_index
499    }
500
501    /// Marginalizes the over the variables `X` and conditioning variables `Z`.
502    ///
503    /// # Arguments
504    ///
505    /// * `x` - The variables to marginalize over.
506    /// * `z` - The conditioning variables to marginalize over.
507    ///
508    /// # Returns
509    ///
510    /// A new instance with the marginalized variables.
511    ///
512    pub fn marginalize(&self, x: &Set<usize>, z: &Set<usize>) -> Self {
513        // Base case: if no variables to marginalize, return self clone.
514        if x.is_empty() && z.is_empty() {
515            return self.clone();
516        }
517        // Get labels.
518        let labels_x = self.labels();
519        let labels_z = self.conditioning_labels();
520        // Get indices to preserve.
521        let not_x = (0..labels_x.len()).filter(|i| !x.contains(i)).collect();
522        let not_z = (0..labels_z.len()).filter(|i| !z.contains(i)).collect();
523        // Convert to potential.
524        let phi = self.clone().into_phi();
525        // Map CPD indices to potential indices.
526        let x = phi.indices_from(x, labels_x);
527        let z = phi.indices_from(z, labels_z);
528        // Marginalize the potential.
529        let phi = phi.marginalize(&(&x | &z));
530        // Map CPD indices to potential indices.
531        let not_x = phi.indices_from(&not_x, labels_x);
532        let not_z = phi.indices_from(&not_z, labels_z);
533        // Convert back to CPD.
534        phi.into_cpd(&not_x, &not_z)
535    }
536
537    /// Creates a new categorical conditional probability distribution with optional fields.
538    ///
539    /// # Arguments
540    ///
541    /// * `states` - The variables states.
542    /// * `parameters` - The probabilities of the states.
543    /// * `statistics` - The sufficient statistics used to fit the distribution, if any.
544    ///
545    /// # Panics
546    ///
547    /// See `new` method for panics.
548    ///
549    /// # Returns
550    ///
551    /// A new `CatCPD` instance.
552    ///
553    pub fn with_optionals(
554        state: States,
555        conditioning_states: States,
556        parameters: Array2<f64>,
557        sample_statistics: Option<CatCPDS>,
558        sample_log_likelihood: Option<f64>,
559    ) -> Self {
560        if let Some(sample_statistics) = &sample_statistics {
561            // Get the sample conditional counts.
562            let sample_conditional_counts = &sample_statistics.n_xz;
563            // Assert the sample conditional counts have the same shape as parameters.
564            assert!(
565                sample_conditional_counts.shape() == parameters.shape(),
566                "Sample conditional counts must have the same shape as parameters: \n\
567                \t expected:    sample_conditional_counts.shape() == {:?} , \n\
568                \t found:       sample_conditional_counts.shape() == {:?} .",
569                parameters.shape(),
570                sample_conditional_counts.shape(),
571            );
572        }
573        // Assert the sample log-likelihood is finite and non-positive.
574        if let Some(sample_log_likelihood) = &sample_log_likelihood {
575            assert!(
576                sample_log_likelihood.is_finite() && *sample_log_likelihood <= 0.,
577                "Sample log-likelihood must be finite and non-positive: \n\
578                \t expected: sample_ll <= 0 , \n\
579                \t found:    sample_ll == {sample_log_likelihood} ."
580            );
581        }
582
583        // Construct the categorical CPD.
584        let mut cpd = Self::new(state, conditioning_states, parameters);
585
586        // FIXME: Check labels alignment with optional fields.
587
588        // Set the optionals.
589        cpd.sample_statistics = sample_statistics;
590        cpd.sample_log_likelihood = sample_log_likelihood;
591
592        cpd
593    }
594
595    /// Converts a potential \phi(X \cup Z) to a CPD P(X | Z).
596    ///
597    /// # Arguments
598    ///
599    /// * `x` - The set of variables.
600    /// * `z` - The set of conditioning variables.
601    ///
602    /// # Returns
603    ///
604    /// The corresponding CPD.
605    ///
606    #[inline]
607    pub fn from_phi(phi: CatPhi, x: &Set<usize>, z: &Set<usize>) -> Self {
608        phi.into_cpd(x, z)
609    }
610
611    /// Converts a CPD P(X | Z) to a potential \phi(X \cup Z).
612    ///
613    /// # Arguments
614    ///
615    /// * `cpd` - The CPD to convert.
616    ///
617    /// # Returns
618    ///
619    /// The corresponding potential.
620    ///
621    #[inline]
622    pub fn into_phi(self) -> CatPhi {
623        CatPhi::from_cpd(self)
624    }
625}
626
627impl Labelled for CatCPD {
628    #[inline]
629    fn labels(&self) -> &Labels {
630        &self.labels
631    }
632}
633
634impl PartialEq for CatCPD {
635    fn eq(&self, other: &Self) -> bool {
636        // Check for equality, excluding the sample values.
637        self.labels.eq(&other.labels)
638            && self.states.eq(&other.states)
639            && self.shape.eq(&other.shape)
640            && self.conditioning_labels.eq(&other.conditioning_labels)
641            && self.conditioning_states.eq(&other.conditioning_states)
642            && self.conditioning_shape.eq(&other.conditioning_shape)
643            && self.multi_index.eq(&other.multi_index)
644            && self.parameters.eq(&other.parameters)
645    }
646}
647
648impl AbsDiffEq for CatCPD {
649    type Epsilon = f64;
650
651    fn default_epsilon() -> Self::Epsilon {
652        Self::Epsilon::default_epsilon()
653    }
654
655    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
656        // Check for equality, excluding the sample values.
657        self.labels.eq(&other.labels)
658            && self.states.eq(&other.states)
659            && self.shape.eq(&other.shape)
660            && self.conditioning_labels.eq(&other.conditioning_labels)
661            && self.conditioning_states.eq(&other.conditioning_states)
662            && self.conditioning_shape.eq(&other.conditioning_shape)
663            && self.multi_index.eq(&other.multi_index)
664            && self.parameters.abs_diff_eq(&other.parameters, epsilon)
665    }
666}
667
668impl RelativeEq for CatCPD {
669    fn default_max_relative() -> Self::Epsilon {
670        Self::Epsilon::default_max_relative()
671    }
672
673    fn relative_eq(
674        &self,
675        other: &Self,
676        epsilon: Self::Epsilon,
677        max_relative: Self::Epsilon,
678    ) -> bool {
679        // Check for equality, excluding the sample values.
680        self.labels.eq(&other.labels)
681            && self.states.eq(&other.states)
682            && self.shape.eq(&other.shape)
683            && self.conditioning_labels.eq(&other.conditioning_labels)
684            && self.conditioning_states.eq(&other.conditioning_states)
685            && self.conditioning_shape.eq(&other.conditioning_shape)
686            && self.multi_index.eq(&other.multi_index)
687            && self
688                .parameters
689                .relative_eq(&other.parameters, epsilon, max_relative)
690    }
691}
692
693impl CPD for CatCPD {
694    type Support = CatSample;
695    type Parameters = Array2<f64>;
696    type Statistics = CatCPDS;
697
698    #[inline]
699    fn conditioning_labels(&self) -> &Labels {
700        &self.conditioning_labels
701    }
702
703    #[inline]
704    fn parameters(&self) -> &Self::Parameters {
705        &self.parameters
706    }
707
708    #[inline]
709    fn parameters_size(&self) -> usize {
710        self.parameters_size
711    }
712
713    #[inline]
714    fn sample_statistics(&self) -> Option<&Self::Statistics> {
715        self.sample_statistics.as_ref()
716    }
717
718    #[inline]
719    fn sample_log_likelihood(&self) -> Option<f64> {
720        self.sample_log_likelihood
721    }
722
723    fn pf(&self, x: &Self::Support, z: &Self::Support) -> f64 {
724        // Get number of variables.
725        let n = self.labels.len();
726        // Get number of conditioning variables.
727        let m = self.conditioning_labels.len();
728
729        // Assert X matches number of variables.
730        assert_eq!(
731            x.len(),
732            n,
733            "Vector X must match number of variables: \n\
734            \t expected:    |X| == {} , \n\
735            \t found:       |X| == {} .",
736            n,
737            x.len(),
738        );
739        // Assert Z matches number of conditioning variables.
740        assert_eq!(
741            z.len(),
742            m,
743            "Vector Z must match number of conditioning variables: \n\
744            \t expected:    |Z| == {} , \n\
745            \t found:       |Z| == {} .",
746            m,
747            z.len(),
748        );
749
750        // No variables.
751        if n == 0 {
752            return 1.;
753        }
754
755        // Convert states to indices.
756        let x = match n {
757            // ... one variable.
758            1 => x[0] as usize,
759            // ... multiple variables.
760            _ => {
761                // Convert states to indices.
762                let x = x.iter().map(|&x| x as usize);
763                // Ravel the variables.
764                self.multi_index.ravel(x)
765            }
766        };
767
768        // Convert conditioning states to indices.
769        let z = match m {
770            // ... no conditioning variables.
771            0 => 0,
772            // ... one conditioning variable.
773            1 => z[0] as usize,
774            // ... multiple conditioning variables.
775            _ => {
776                // Convert conditioning states to indices.
777                let z = z.iter().map(|&z| z as usize);
778                // Ravel the conditioning variables.
779                self.conditioning_multi_index.ravel(z)
780            }
781        };
782
783        // Get the probability.
784        self.parameters[[z, x]]
785    }
786
787    fn sample<R: Rng>(&self, rng: &mut R, z: &Self::Support) -> Self::Support {
788        // Get number of variables.
789        let n = self.labels.len();
790        // Get number of conditioning variables.
791        let m = self.conditioning_labels.len();
792
793        // Assert Z matches number of conditioning variables.
794        assert_eq!(
795            z.len(),
796            m,
797            "Vector Z must match number of conditioning variables: \n\
798            \t expected:    |Z| == {} , \n\
799            \t found:       |Z| == {} .",
800            m,
801            z.len(),
802        );
803
804        // No variables.
805        if n == 0 {
806            return array![];
807        }
808
809        // Convert conditioning states to indices.
810        let z = match m {
811            // ... no conditioning variables.
812            0 => 0,
813            // ... one conditioning variable.
814            1 => z[0] as usize,
815            // ... multiple conditioning variables.
816            _ => {
817                // Convert conditioning states to indices.
818                let z = z.iter().map(|&z| z as usize);
819                // Ravel the conditioning variables.
820                self.conditioning_multi_index.ravel(z)
821            }
822        };
823
824        // Get the distribution of the vertex.
825        let p = self.parameters.row(z);
826        // Construct the sampler.
827        let s = WeightedIndex::new(&p).unwrap();
828        // Sample from the distribution.
829        let x = s.sample(rng);
830
831        // Convert indices to states.
832        match n {
833            // ... one variable.
834            1 => array![x as CatType],
835            // ... multiple variables.
836            _ => {
837                // Unravel the sample.
838                let x = self.multi_index.unravel(x);
839                // Convert indices to states.
840                let x = x.iter().map(|&x| x as CatType);
841                // Return the sample.
842                x.collect()
843            }
844        }
845    }
846}
847
848impl Display for CatCPD {
849    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
850        // FIXME: This assumes `x` has a single element.
851        assert_eq!(self.labels().len(), 1);
852
853        // Determine the maximum width for formatting based on the labels and states.
854        let n = std::iter::once(&self.labels()[0])
855            .chain(&self.states()[0])
856            .chain(self.conditioning_labels())
857            .chain(self.conditioning_states().values().flatten())
858            .map(|x| x.len())
859            .max()
860            .unwrap_or(0)
861            .max(8);
862        // Get the number of variables to condition on.
863        let z = self.conditioning_shape().len();
864        // Get the number of states for the first variable.
865        let s = self.shape()[0];
866
867        // Create a horizontal line for table formatting.
868        let hline = "-".repeat((n + 3) * (z + s) + 1);
869        writeln!(f, "{hline}")?;
870
871        // Create the header row for the table.
872        let header = std::iter::repeat_n("", z) // Empty columns for the conditioning variables.
873            .chain([self.labels()[0].as_str()]) // Label for the first variable.
874            .chain(std::iter::repeat_n("", s.saturating_sub(1))) // Empty columns for remaining states.
875            .map(|x| format!("{x:n$}")) // Format each column with fixed width.
876            .join(" | ");
877        writeln!(f, "| {header} |")?;
878
879        // Create a separator row for the table.
880        let separator = std::iter::repeat_n("-".repeat(n), z + s).join(" | ");
881        writeln!(f, "| {separator} |")?;
882
883        // Create the second header row with labels and states.
884        let header = self
885            .conditioning_labels()
886            .iter()
887            .chain(&self.states()[0]) // Include states of the first variable.
888            .map(|x| format!("{x:n$}")) // Format each column with fixed width.
889            .join(" | ");
890        writeln!(f, "| {header} |")?;
891        writeln!(f, "| {separator} |")?;
892
893        // Iterate over the Cartesian product of states and parameter rows.
894        for (states, values) in self
895            .conditioning_states()
896            .values()
897            .multi_cartesian_product()
898            .zip(self.parameters().rows())
899        {
900            // Format the states for the current row.
901            let states = states.iter().map(|x| format!("{x:n$}"));
902            // Format the parameter values for the current row.
903            let values = values.iter().map(|x| format!("{x:n$.6}"));
904            // Join the states and values for the current row.
905            let states_values = states.chain(values).join(" | ");
906            writeln!(f, "| {states_values} |")?;
907        }
908
909        // Write the closing horizontal line for the table.
910        writeln!(f, "{hline}")?;
911
912        Ok(())
913    }
914}
915
916impl Serialize for CatCPD {
917    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
918    where
919        S: Serializer,
920    {
921        // Count the elements to serialize.
922        let mut size = 4;
923        // Add optional fields, if any.
924        size += self.sample_statistics.is_some() as usize;
925        size += self.sample_log_likelihood.is_some() as usize;
926        // Allocate the map.
927        let mut map = serializer.serialize_map(Some(size))?;
928
929        // Serialize states.
930        map.serialize_entry("states", &self.states)?;
931        // Serialize conditioning states.
932        map.serialize_entry("conditioning_states", &self.conditioning_states)?;
933
934        // Convert parameters to a flat format.
935        let parameters: Vec<Vec<f64>> = self
936            .parameters
937            .rows()
938            .into_iter()
939            .map(|x| x.to_vec())
940            .collect();
941        // Serialize parameters.
942        map.serialize_entry("parameters", &parameters)?;
943
944        // Serialize the sufficient statistics, if any.
945        if let Some(sample_statistics) = &self.sample_statistics {
946            map.serialize_entry("sample_statistics", sample_statistics)?;
947        }
948        // Serialize the sample log-likelihood, if any.
949        if let Some(sample_log_likelihood) = &self.sample_log_likelihood {
950            map.serialize_entry("sample_log_likelihood", sample_log_likelihood)?;
951        }
952
953        // Serialize type.
954        map.serialize_entry("type", "catcpd")?;
955
956        // Finalize the map serialization.
957        map.end()
958    }
959}
960
961impl<'de> Deserialize<'de> for CatCPD {
962    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
963    where
964        D: Deserializer<'de>,
965    {
966        #[derive(Deserialize)]
967        #[serde(field_identifier, rename_all = "snake_case")]
968        enum Field {
969            States,
970            ConditioningStates,
971            Parameters,
972            SampleStatistics,
973            SampleLogLikelihood,
974            Type,
975        }
976
977        struct CatCPDVisitor;
978
979        impl<'de> Visitor<'de> for CatCPDVisitor {
980            type Value = CatCPD;
981
982            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
983                formatter.write_str("struct CatCPD")
984            }
985
986            fn visit_map<V>(self, mut map: V) -> Result<CatCPD, V::Error>
987            where
988                V: MapAccess<'de>,
989            {
990                use serde::de::Error as E;
991
992                // Allocate fields
993                let mut states = None;
994                let mut conditioning_states = None;
995                let mut parameters = None;
996                let mut sample_statistics = None;
997                let mut sample_log_likelihood = None;
998                let mut type_ = None;
999
1000                // Parse the map.
1001                while let Some(key) = map.next_key()? {
1002                    match key {
1003                        Field::States => {
1004                            if states.is_some() {
1005                                return Err(E::duplicate_field("states"));
1006                            }
1007                            states = Some(map.next_value()?);
1008                        }
1009                        Field::ConditioningStates => {
1010                            if conditioning_states.is_some() {
1011                                return Err(E::duplicate_field("conditioning_states"));
1012                            }
1013                            conditioning_states = Some(map.next_value()?);
1014                        }
1015                        Field::Parameters => {
1016                            if parameters.is_some() {
1017                                return Err(E::duplicate_field("parameters"));
1018                            }
1019                            parameters = Some(map.next_value()?);
1020                        }
1021                        Field::SampleStatistics => {
1022                            if sample_statistics.is_some() {
1023                                return Err(E::duplicate_field("sample_statistics"));
1024                            }
1025                            sample_statistics = Some(map.next_value()?);
1026                        }
1027                        Field::SampleLogLikelihood => {
1028                            if sample_log_likelihood.is_some() {
1029                                return Err(E::duplicate_field("sample_log_likelihood"));
1030                            }
1031                            sample_log_likelihood = Some(map.next_value()?);
1032                        }
1033                        Field::Type => {
1034                            if type_.is_some() {
1035                                return Err(E::duplicate_field("type"));
1036                            }
1037                            type_ = Some(map.next_value()?);
1038                        }
1039                    }
1040                }
1041
1042                // Check required fields.
1043                let states = states.ok_or_else(|| E::missing_field("states"))?;
1044                let conditioning_states =
1045                    conditioning_states.ok_or_else(|| E::missing_field("conditioning_states"))?;
1046                let parameters = parameters.ok_or_else(|| E::missing_field("parameters"))?;
1047
1048                // Assert type is correct.
1049                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
1050                assert_eq!(type_, "catcpd", "Invalid type for CatCPD.");
1051
1052                // Convert parameters to ndarray.
1053                let parameters: Vec<Vec<f64>> = parameters;
1054                let shape = (parameters.len(), parameters[0].len());
1055                let parameters = Array::from_iter(parameters.into_iter().flatten())
1056                    .into_shape_with_order(shape)
1057                    .map_err(|_| E::custom("Invalid parameters shape"))?;
1058
1059                Ok(CatCPD::with_optionals(
1060                    states,
1061                    conditioning_states,
1062                    parameters,
1063                    sample_statistics,
1064                    sample_log_likelihood,
1065                ))
1066            }
1067        }
1068
1069        const FIELDS: &[&str] = &[
1070            "states",
1071            "conditioning_states",
1072            "parameters",
1073            "sample_statistics",
1074            "sample_log_likelihood",
1075            "type",
1076        ];
1077
1078        deserializer.deserialize_struct("CatCPD", FIELDS, CatCPDVisitor)
1079    }
1080}
1081
1082// Implement `JsonIO` for `CatCPD`.
1083impl_json_io!(CatCPD);