causal_hub/models/bayesian_network/gaussian/
parameters.rs

1use approx::{AbsDiffEq, RelativeEq};
2use ndarray::prelude::*;
3use ndarray_linalg::{CholeskyInto, Determinant, UPLO};
4use rand::Rng;
5use rand_distr::{Distribution, StandardNormal};
6use serde::{
7    Deserialize, Deserializer, Serialize, Serializer,
8    de::{MapAccess, Visitor},
9    ser::SerializeMap,
10};
11
12use crate::{
13    datasets::GaussSample,
14    impl_json_io,
15    models::{CPD, GaussCPDS, GaussPhi, Labelled, Phi},
16    types::{EPSILON, LN_2_PI, Labels, Set},
17    utils::PseudoInverse,
18};
19
20/// Parameters of a Gaussian CPD.
21#[derive(Clone, Debug)]
22pub struct GaussCPDP {
23    /// Coefficient matrix |X| x |Z|.
24    a: Array2<f64>,
25    /// Intercept vector |X|.
26    b: Array1<f64>,
27    /// Covariance matrix |X| x |X|.
28    s: Array2<f64>,
29}
30
31impl GaussCPDP {
32    /// Creates a new `GaussCPDP` instance.
33    ///
34    /// # Arguments
35    ///
36    /// * `a` - Coefficient matrix |X| x |Z|.
37    /// * `b` - Intercept vector |X|.
38    /// * `s` - Covariance matrix |X| x |X|.
39    ///
40    /// # Panics
41    ///
42    /// * Panics if the number of rows of `a` does not match the size of `b`.
43    /// * Panics if the number of rows of `a` does not match the size of `s`.
44    /// * Panics if `s` is not square and symmetric.
45    /// * Panics if any of the values in `a`, `b`, or `s` are not finite.
46    ///
47    /// # Returns
48    ///
49    /// A new `GaussCPDP` instance.
50    ///
51    pub fn new(a: Array2<f64>, b: Array1<f64>, s: Array2<f64>) -> Self {
52        // Assert the dimensions are correct.
53        assert_eq!(
54            a.nrows(),
55            b.len(),
56            "Coefficient matrix rows must match intercept vector size."
57        );
58        assert_eq!(
59            a.nrows(),
60            s.nrows(),
61            "Coefficient matrix rows must match covariance matrix size."
62        );
63        assert!(s.is_square(), "Covariance matrix must be square.");
64        // Assert values are finite.
65        assert!(
66            a.iter().all(|&x| x.is_finite()),
67            "Coefficient matrix must have finite values."
68        );
69        assert!(
70            b.iter().all(|&x| x.is_finite()),
71            "Intercept vector must have finite values."
72        );
73        assert!(
74            s.iter().all(|&x| x.is_finite()),
75            "Covariance matrix must have finite values."
76        );
77        assert_eq!(s, s.t(), "Covariance matrix must be symmetric.");
78
79        Self { a, b, s }
80    }
81
82    /// Returns the coefficient matrix |X| x |Z|.
83    ///
84    /// # Returns
85    ///
86    /// A reference to the coefficient matrix.
87    ///
88    #[inline]
89    pub const fn coefficients(&self) -> &Array2<f64> {
90        &self.a
91    }
92
93    /// Returns the intercept vector |X|.
94    ///
95    /// # Returns
96    ///
97    /// A reference to the intercept vector.
98    ///
99    #[inline]
100    pub const fn intercept(&self) -> &Array1<f64> {
101        &self.b
102    }
103
104    /// Returns the covariance matrix |X| x |X|.
105    ///
106    /// # Returns
107    ///
108    /// A reference to the covariance matrix.
109    ///
110    #[inline]
111    pub const fn covariance(&self) -> &Array2<f64> {
112        &self.s
113    }
114}
115
116impl PartialEq for GaussCPDP {
117    fn eq(&self, other: &Self) -> bool {
118        self.a.eq(&other.a) && self.b.eq(&other.b) && self.s.eq(&other.s)
119    }
120}
121
122impl AbsDiffEq for GaussCPDP {
123    type Epsilon = f64;
124
125    fn default_epsilon() -> Self::Epsilon {
126        Self::Epsilon::default_epsilon()
127    }
128
129    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
130        self.a.abs_diff_eq(&other.a, epsilon)
131            && self.b.abs_diff_eq(&other.b, epsilon)
132            && self.s.abs_diff_eq(&other.s, epsilon)
133    }
134}
135
136impl RelativeEq for GaussCPDP {
137    fn default_max_relative() -> Self::Epsilon {
138        Self::Epsilon::default_max_relative()
139    }
140
141    fn relative_eq(
142        &self,
143        other: &Self,
144        epsilon: Self::Epsilon,
145        max_relative: Self::Epsilon,
146    ) -> bool {
147        self.a.relative_eq(&other.a, epsilon, max_relative)
148            && self.b.relative_eq(&other.b, epsilon, max_relative)
149            && self.s.relative_eq(&other.s, epsilon, max_relative)
150    }
151}
152
153impl Serialize for GaussCPDP {
154    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
155    where
156        S: Serializer,
157    {
158        // Allocate the map.
159        let mut map = serializer.serialize_map(Some(3))?;
160
161        // Convert the coefficient matrix to a flat format.
162        let coefficients: Vec<_> = self.a.rows().into_iter().map(|x| x.to_vec()).collect();
163        // Serialize coefficients.
164        map.serialize_entry("coefficients", &coefficients)?;
165
166        // Convert the intercept vector to a flat format.
167        let intercept = self.b.to_vec();
168        // Serialize intercept.
169        map.serialize_entry("intercept", &intercept)?;
170
171        // Convert the covariance matrix to a flat format.
172        let covariance: Vec<_> = self.s.rows().into_iter().map(|x| x.to_vec()).collect();
173        // Serialize covariance.
174        map.serialize_entry("covariance", &covariance)?;
175
176        // End the map.
177        map.end()
178    }
179}
180
181impl<'de> Deserialize<'de> for GaussCPDP {
182    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
183    where
184        D: Deserializer<'de>,
185    {
186        #[derive(Deserialize)]
187        #[serde(field_identifier, rename_all = "snake_case")]
188        enum Field {
189            Coefficients,
190            Intercept,
191            Covariance,
192        }
193
194        struct GaussCPDPVisitor;
195
196        impl<'de> Visitor<'de> for GaussCPDPVisitor {
197            type Value = GaussCPDP;
198
199            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
200                formatter.write_str("struct GaussCPDP")
201            }
202
203            fn visit_map<V>(self, mut map: V) -> Result<GaussCPDP, V::Error>
204            where
205                V: MapAccess<'de>,
206            {
207                use serde::de::Error as E;
208
209                // Allocate the fields.
210                let mut coefficients = None;
211                let mut intercept = None;
212                let mut covariance = None;
213
214                // Parse the map.
215                while let Some(key) = map.next_key()? {
216                    match key {
217                        Field::Coefficients => {
218                            if coefficients.is_some() {
219                                return Err(E::duplicate_field("coefficients"));
220                            }
221                            coefficients = Some(map.next_value()?);
222                        }
223                        Field::Intercept => {
224                            if intercept.is_some() {
225                                return Err(E::duplicate_field("intercept"));
226                            }
227                            intercept = Some(map.next_value()?);
228                        }
229                        Field::Covariance => {
230                            if covariance.is_some() {
231                                return Err(E::duplicate_field("covariance"));
232                            }
233                            covariance = Some(map.next_value()?);
234                        }
235                    }
236                }
237
238                // Extract the fields.
239                let coefficients = coefficients.ok_or_else(|| E::missing_field("coefficients"))?;
240                let intercept = intercept.ok_or_else(|| E::missing_field("intercept"))?;
241                let covariance = covariance.ok_or_else(|| E::missing_field("covariance"))?;
242
243                // Convert coefficients to array.
244                let coefficients = {
245                    let values: Vec<Vec<f64>> = coefficients;
246                    let shape = (values.len(), values[0].len());
247                    Array::from_iter(values.into_iter().flatten())
248                        .into_shape_with_order(shape)
249                        .map_err(|_| E::custom("Invalid coefficients shape"))?
250                };
251                // Convert intercept to array.
252                let intercept = Array1::from_vec(intercept);
253                // Convert covariance to array.
254                let covariance = {
255                    let values: Vec<Vec<f64>> = covariance;
256                    let shape = (values.len(), values[0].len());
257                    Array::from_iter(values.into_iter().flatten())
258                        .into_shape_with_order(shape)
259                        .map_err(|_| E::custom("Invalid covariance shape"))?
260                };
261
262                Ok(GaussCPDP::new(coefficients, intercept, covariance))
263            }
264        }
265
266        const FIELDS: &[&str] = &["coefficients", "intercept", "covariance"];
267
268        deserializer.deserialize_struct("GaussCPDP", FIELDS, GaussCPDPVisitor)
269    }
270}
271
272/// A Gaussian CPD.
273#[derive(Clone, Debug)]
274pub struct GaussCPD {
275    // Labels of the variables.
276    labels: Labels,
277    // Labels of the conditioning variables.
278    conditioning_labels: Labels,
279    // Parameters.
280    parameters: GaussCPDP,
281    // Sample (sufficient) statistics, if any.
282    sample_statistics: Option<GaussCPDS>,
283    // Sample log-likelihood, if any.
284    sample_log_likelihood: Option<f64>,
285}
286
287impl GaussCPD {
288    /// Creates a new Gaussian CPD instance.
289    ///
290    /// # Arguments
291    ///
292    /// * `labels` - Labels of the variables.
293    /// * `conditioning_labels` - Labels of the conditioning variables.
294    /// * `parameters` - Parameters of the CPD.
295    ///
296    /// # Panics
297    ///
298    /// * Panics if `labels` and `conditioning_labels` are not disjoint.
299    /// * Panics if the dimensions of `parameters` do not match the lengths of `labels` and `conditioning_labels`.
300    ///
301    /// # Returns
302    ///
303    /// A new Gaussian CPD instance.
304    ///
305    pub fn new(
306        mut labels: Labels,
307        mut conditioning_labels: Labels,
308        mut parameters: GaussCPDP,
309    ) -> Self {
310        // Assert labels and conditioning labels are disjoint.
311        assert!(
312            labels.is_disjoint(&conditioning_labels),
313            "Labels and conditioning labels must be disjoint."
314        );
315        // Assert parameters dimensions match labels and conditioning labels lengths.
316        assert_eq!(
317            parameters.a.nrows(),
318            labels.len(),
319            "Coefficient matrix rows must match labels length."
320        );
321        assert_eq!(
322            parameters.a.ncols(),
323            conditioning_labels.len(),
324            "Coefficient matrix columns must match conditioning labels length."
325        );
326        assert_eq!(
327            parameters.b.len(),
328            labels.len(),
329            "Intercept vector size must match labels length."
330        );
331        assert_eq!(
332            parameters.s.nrows(),
333            labels.len(),
334            "Covariance matrix rows must match labels length."
335        );
336        assert_eq!(
337            parameters.s.ncols(),
338            labels.len(),
339            "Covariance matrix columns must match labels length."
340        );
341
342        // Check if labels are sorted.
343        if !labels.is_sorted() {
344            // Allocate indices to sort labels.
345            let mut indices: Vec<usize> = (0..labels.len()).collect();
346            // Sort the indices by labels.
347            indices.sort_by_key(|&i| &labels[i]);
348            // Sort the labels.
349            labels.sort();
350            // Reorder the parameters.
351            let mut new_a = parameters.a.clone();
352            let mut new_b = parameters.b.clone();
353            let mut new_s = parameters.s.clone();
354            // Reorder rows of A.
355            for (i, &j) in indices.iter().enumerate() {
356                new_a.row_mut(i).assign(&parameters.a.row(j));
357            }
358            // Reorder b.
359            for (i, &j) in indices.iter().enumerate() {
360                new_b[i] = parameters.b[j];
361            }
362            // Reorder rows of S.
363            for (i, &j) in indices.iter().enumerate() {
364                new_s.row_mut(i).assign(&parameters.s.row(j));
365            }
366            // Allocate a temporary copy of S to reorder columns.
367            let _s = new_s.clone();
368            // Reorder columns of S.
369            for (i, &j) in indices.iter().enumerate() {
370                new_s.column_mut(i).assign(&_s.column(j));
371            }
372            // Update parameters.
373            parameters.a = new_a;
374            parameters.b = new_b;
375            parameters.s = new_s;
376        }
377
378        // Check if conditioning labels are sorted.
379        if !conditioning_labels.is_sorted() {
380            // Allocate indices to sort conditioning labels.
381            let mut indices: Vec<usize> = (0..conditioning_labels.len()).collect();
382            // Sort the indices by conditioning labels.
383            indices.sort_by_key(|&i| &conditioning_labels[i]);
384            // Sort the conditioning labels.
385            conditioning_labels.sort();
386            // Reorder the parameters.
387            let mut new_a = parameters.a.clone();
388            // Reorder columns of A.
389            for (i, &j) in indices.iter().enumerate() {
390                new_a.column_mut(i).assign(&parameters.a.column(j));
391            }
392            // Update parameters.
393            parameters.a = new_a;
394        }
395
396        Self {
397            labels,
398            conditioning_labels,
399            parameters,
400            sample_statistics: None,
401            sample_log_likelihood: None,
402        }
403    }
404
405    /// Marginalizes the over the variables `X` and conditioning variables `Z`.
406    ///
407    /// # Arguments
408    ///
409    /// * `x` - The variables to marginalize over.
410    /// * `z` - The conditioning variables to marginalize over.
411    ///
412    /// # Returns
413    ///
414    /// A new instance with the marginalized variables.
415    ///
416    pub fn marginalize(&self, x: &Set<usize>, z: &Set<usize>) -> Self {
417        // Base case: if no variables to marginalize, return self clone.
418        if x.is_empty() && z.is_empty() {
419            return self.clone();
420        }
421        // Get labels.
422        let labels_x = self.labels();
423        let labels_z = self.conditioning_labels();
424        // Get indices to preserve.
425        let not_x = (0..labels_x.len()).filter(|i| !x.contains(i)).collect();
426        let not_z = (0..labels_z.len()).filter(|i| !z.contains(i)).collect();
427        // Convert to potential.
428        let phi = self.clone().into_phi();
429        // Map CPD indices to potential indices.
430        let x = phi.indices_from(x, labels_x);
431        let z = phi.indices_from(z, labels_z);
432        // Marginalize the potential.
433        let phi = phi.marginalize(&(&x | &z));
434        // Map CPD indices to potential indices.
435        let not_x = phi.indices_from(&not_x, labels_x);
436        let not_z = phi.indices_from(&not_z, labels_z);
437        // Convert back to CPD.
438        phi.into_cpd(&not_x, &not_z)
439    }
440
441    /// Creates a new Gaussian CPD instance.
442    ///
443    /// # Arguments
444    ///
445    /// * `labels` - Labels of the variables.
446    /// * `conditioning_labels` - Labels of the conditioning variables.
447    /// * `parameters` - Parameters of the CPD.
448    /// * `sample_statistics` - Sample (sufficient) statistics, if any.
449    /// * `sample_log_likelihood` - Sample log-likelihood, if any.
450    ///
451    /// # Returns
452    ///
453    /// A new Gaussian CPD instance.
454    ///
455    pub fn with_optionals(
456        labels: Labels,
457        conditioning_labels: Labels,
458        parameters: GaussCPDP,
459        sample_statistics: Option<GaussCPDS>,
460        sample_log_likelihood: Option<f64>,
461    ) -> Self {
462        // FIXME: Check inputs.
463
464        // Create the CPD.
465        let mut cpd = Self::new(labels, conditioning_labels, parameters);
466
467        // FIXME: Check labels alignment with optional fields.
468
469        // Set the optional fields.
470        cpd.sample_statistics = sample_statistics;
471        cpd.sample_log_likelihood = sample_log_likelihood;
472
473        cpd
474    }
475
476    /// Converts a potential \phi(X \cup Z) to a CPD P(X | Z).
477    ///
478    /// # Arguments
479    ///
480    /// * `x` - The set of variables.
481    /// * `z` - The set of conditioning variables.
482    ///
483    /// # Returns
484    ///
485    /// The corresponding CPD.
486    ///
487    #[inline]
488    pub fn from_phi(phi: GaussPhi, x: &Set<usize>, z: &Set<usize>) -> Self {
489        phi.into_cpd(x, z)
490    }
491
492    /// Converts a CPD P(X | Z) to a potential \phi(X \cup Z).
493    ///
494    /// # Arguments
495    ///
496    /// * `cpd` - The CPD to convert.
497    ///
498    /// # Returns
499    ///
500    /// The corresponding potential.
501    ///
502    #[inline]
503    pub fn into_phi(self) -> GaussPhi {
504        GaussPhi::from_cpd(self)
505    }
506}
507
508impl Labelled for GaussCPD {
509    #[inline]
510    fn labels(&self) -> &Labels {
511        &self.labels
512    }
513}
514
515impl PartialEq for GaussCPD {
516    fn eq(&self, other: &Self) -> bool {
517        self.labels.eq(&other.labels)
518            && self.conditioning_labels.eq(&other.conditioning_labels)
519            && self.parameters.eq(&other.parameters)
520    }
521}
522
523impl AbsDiffEq for GaussCPD {
524    type Epsilon = f64;
525
526    fn default_epsilon() -> Self::Epsilon {
527        Self::Epsilon::default_epsilon()
528    }
529
530    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
531        self.labels.eq(&other.labels)
532            && self.conditioning_labels.eq(&other.conditioning_labels)
533            && self.parameters.abs_diff_eq(&other.parameters, epsilon)
534    }
535}
536
537impl RelativeEq for GaussCPD {
538    fn default_max_relative() -> Self::Epsilon {
539        Self::Epsilon::default_max_relative()
540    }
541
542    fn relative_eq(
543        &self,
544        other: &Self,
545        epsilon: Self::Epsilon,
546        max_relative: Self::Epsilon,
547    ) -> bool {
548        self.labels.eq(&other.labels)
549            && self.conditioning_labels.eq(&other.conditioning_labels)
550            && self
551                .parameters
552                .relative_eq(&other.parameters, epsilon, max_relative)
553    }
554}
555
556impl CPD for GaussCPD {
557    type Support = GaussSample;
558    type Parameters = GaussCPDP;
559    type Statistics = GaussCPDS;
560
561    #[inline]
562    fn conditioning_labels(&self) -> &Labels {
563        &self.conditioning_labels
564    }
565
566    #[inline]
567    fn parameters(&self) -> &Self::Parameters {
568        &self.parameters
569    }
570
571    #[inline]
572    fn parameters_size(&self) -> usize {
573        let s = {
574            // Covariance matrix is symmetric.
575            let s = self.parameters.s.nrows();
576            s * (s + 1) / 2
577        };
578
579        self.parameters.a.len() + self.parameters.b.len() + s
580    }
581
582    #[inline]
583    fn sample_statistics(&self) -> Option<&Self::Statistics> {
584        self.sample_statistics.as_ref()
585    }
586
587    #[inline]
588    fn sample_log_likelihood(&self) -> Option<f64> {
589        self.sample_log_likelihood
590    }
591
592    fn pf(&self, x: &Self::Support, z: &Self::Support) -> f64 {
593        // Get number of variables.
594        let n = self.labels.len();
595        // Get number of conditioning variables.
596        let m = self.conditioning_labels.len();
597
598        // Assert X matches number of variables.
599        assert_eq!(
600            x.len(),
601            n,
602            "Vector X must match number of variables: \n\
603            \t expected:    |X| == {} , \n\
604            \t found:       |X| == {} .",
605            n,
606            x.len(),
607        );
608        // Assert Z matches number of conditioning variables.
609        assert_eq!(
610            z.len(),
611            m,
612            "Vector Z must match number of conditioning variables: \n\
613            \t expected:    |Z| == {} , \n\
614            \t found:       |Z| == {} .",
615            m,
616            z.len(),
617        );
618
619        // Get parameters.
620        let (a, b, s) = (
621            self.parameters.coefficients(),
622            self.parameters.intercept(),
623            self.parameters.covariance(),
624        );
625
626        // No variables.
627        if n == 0 {
628            return 1.;
629        }
630
631        // One variable ...
632        if n == 1 {
633            // Compute the mean.
634            let mu = match m {
635                // ... no conditioning variables.
636                0 => b[0], // Get the mean.
637                // ... one conditioning variable.
638                1 => f64::mul_add(a[[0, 0]], z[0], b[0]), // Compute the mean.
639                // ... multiple conditioning variables.
640                _ => (a.dot(z) + b)[0], // Compute mean vector.
641            };
642            // Compute deviation from mean.
643            let x_mu = x[0] - mu;
644            // Get the variance.
645            let k = s[[0, 0]];
646            // Compute log probability density function.
647            let ln_pf = -0.5 * (LN_2_PI + f64::ln(k) + f64::powi(x_mu, 2) / k);
648            // Return probability density function.
649            return f64::exp(ln_pf);
650        }
651
652        // Multiple variables, multiple conditioning variables.
653
654        // Compute mean vector.
655        let mu = a.dot(z) + b;
656        // Compute deviation from mean.
657        let x_mu = x - mu;
658        // Compute precision matrix.
659        let k = s.pinv();
660        // Compute log probability density function.
661        let n_ln_2_pi = s.nrows() as f64 * LN_2_PI;
662        let (_, ln_det) = s.sln_det().expect("Failed to compute the determinant.");
663        let ln_pf = -0.5 * (n_ln_2_pi + ln_det + x_mu.dot(&k).dot(&x_mu));
664        // Return probability density function.
665        f64::exp(ln_pf)
666    }
667
668    fn sample<R: Rng>(&self, rng: &mut R, z: &Self::Support) -> Self::Support {
669        // Get number of variables.
670        let n = self.labels.len();
671        // Get number of conditioning variables.
672        let m = self.conditioning_labels.len();
673
674        // Assert Z matches number of conditioning variables.
675        assert_eq!(
676            z.len(),
677            m,
678            "Vector Z must match number of conditioning variables: \n\
679            \t expected:    |Z| == {} , \n\
680            \t found:       |Z| == {} .",
681            m,
682            z.len(),
683        );
684
685        // Get parameters.
686        let (a, b, s) = (
687            self.parameters.coefficients(),
688            self.parameters.intercept(),
689            self.parameters.covariance(),
690        );
691
692        // No variables.
693        if n == 0 {
694            return array![];
695        }
696
697        // One variable ...
698        if n == 1 {
699            // Compute the mean.
700            let mu = match m {
701                // ... no conditioning variables.
702                0 => b[0], // Get the mean.
703                // ... one conditioning variable.
704                1 => f64::mul_add(a[[0, 0]], z[0], b[0]), // Compute the mean.
705                // ... multiple conditioning variables.
706                _ => (a.dot(z) + b)[0], // Compute mean vector.
707            };
708            // Sample from standard normal.
709            let e: f64 = StandardNormal.sample(rng);
710            // Compute the sample.
711            let x = f64::mul_add(s[[0, 0]].sqrt(), e, mu);
712            // Return the sample.
713            return array![x];
714        }
715
716        // Multiple variables, multiple conditioning variables.
717
718        // Compute the mean.
719        let mu = a.dot(z) + b;
720        // Compute the Cholesky decomposition of the covariance matrix.
721        let l = (s + EPSILON * Array::eye(s.nrows()))
722            .cholesky_into(UPLO::Lower)
723            .expect("Failed to compute Cholesky decomposition.");
724        // Sample from standard normal.
725        let e = StandardNormal
726            .sample_iter(rng)
727            .take(s.nrows())
728            .collect::<Array1<_>>();
729        // Compute the sample.
730        l.dot(&e) + mu
731    }
732}
733
734impl Serialize for GaussCPD {
735    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
736    where
737        S: Serializer,
738    {
739        // Count the elements to serialize.
740        let mut size = 4;
741        // Add optional fields, if any.
742        size += self.sample_statistics.is_some() as usize;
743        size += self.sample_log_likelihood.is_some() as usize;
744        // Allocate the map.
745        let mut map = serializer.serialize_map(Some(size))?;
746
747        // Serialize labels.
748        map.serialize_entry("labels", &self.labels)?;
749        // Serialize conditioning labels.
750        map.serialize_entry("conditioning_labels", &self.conditioning_labels)?;
751        // Serialize parameters.
752        map.serialize_entry("parameters", &self.parameters)?;
753
754        // Serialize sample statistics, if any.
755        if let Some(sample_statistics) = &self.sample_statistics {
756            map.serialize_entry("sample_statistics", sample_statistics)?;
757        }
758
759        // Serialize sample log-likelihood, if any.
760        if let Some(sample_log_likelihood) = &self.sample_log_likelihood {
761            map.serialize_entry("sample_log_likelihood", sample_log_likelihood)?;
762        }
763
764        // Serialize type.
765        map.serialize_entry("type", "gausscpd")?;
766
767        // End the map.
768        map.end()
769    }
770}
771
772impl<'de> Deserialize<'de> for GaussCPD {
773    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
774    where
775        D: Deserializer<'de>,
776    {
777        #[derive(Deserialize)]
778        #[serde(field_identifier, rename_all = "snake_case")]
779        enum Field {
780            Labels,
781            ConditioningLabels,
782            Parameters,
783            SampleStatistics,
784            SampleLogLikelihood,
785            Type,
786        }
787
788        struct GaussCPDVisitor;
789
790        impl<'de> Visitor<'de> for GaussCPDVisitor {
791            type Value = GaussCPD;
792
793            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
794                formatter.write_str("struct GaussCPD")
795            }
796
797            fn visit_map<V>(self, mut map: V) -> Result<GaussCPD, V::Error>
798            where
799                V: MapAccess<'de>,
800            {
801                use serde::de::Error as E;
802
803                // Allocate the fields.
804                let mut labels = None;
805                let mut conditioning_labels = None;
806                let mut parameters = None;
807                let mut sample_statistics = None;
808                let mut sample_log_likelihood = None;
809                let mut type_ = None;
810
811                // Parse the map.
812                while let Some(key) = map.next_key()? {
813                    match key {
814                        Field::Labels => {
815                            if labels.is_some() {
816                                return Err(E::duplicate_field("labels"));
817                            }
818                            labels = Some(map.next_value()?);
819                        }
820                        Field::ConditioningLabels => {
821                            if conditioning_labels.is_some() {
822                                return Err(E::duplicate_field("conditioning_labels"));
823                            }
824                            conditioning_labels = Some(map.next_value()?);
825                        }
826                        Field::Parameters => {
827                            if parameters.is_some() {
828                                return Err(E::duplicate_field("parameters"));
829                            }
830                            parameters = Some(map.next_value()?);
831                        }
832                        Field::SampleStatistics => {
833                            if sample_statistics.is_some() {
834                                return Err(E::duplicate_field("sample_statistics"));
835                            }
836                            sample_statistics = Some(map.next_value()?);
837                        }
838                        Field::SampleLogLikelihood => {
839                            if sample_log_likelihood.is_some() {
840                                return Err(E::duplicate_field("sample_log_likelihood"));
841                            }
842                            sample_log_likelihood = Some(map.next_value()?);
843                        }
844                        Field::Type => {
845                            if type_.is_some() {
846                                return Err(E::duplicate_field("type"));
847                            }
848                            type_ = Some(map.next_value()?);
849                        }
850                    }
851                }
852
853                // Extract the fields.
854                let labels = labels.ok_or_else(|| E::missing_field("labels"))?;
855                let conditioning_labels =
856                    conditioning_labels.ok_or_else(|| E::missing_field("conditioning_labels"))?;
857                let parameters = parameters.ok_or_else(|| E::missing_field("parameters"))?;
858
859                // Assert type is correct.
860                let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
861                assert_eq!(type_, "gausscpd", "Invalid type for GaussCPD.");
862
863                Ok(GaussCPD::with_optionals(
864                    labels,
865                    conditioning_labels,
866                    parameters,
867                    sample_statistics,
868                    sample_log_likelihood,
869                ))
870            }
871        }
872
873        const FIELDS: &[&str] = &[
874            "labels",
875            "conditioning_labels",
876            "parameters",
877            "sample_statistics",
878            "sample_log_likelihood",
879            "type",
880        ];
881
882        deserializer.deserialize_struct("GaussCPD", FIELDS, GaussCPDVisitor)
883    }
884}
885
886// Implement `JsonIO` for `GaussCPD`.
887impl_json_io!(GaussCPD);