causal_hub/models/bayesian_network/gaussian/
sufficient_statistics.rs

1use std::ops::{Add, AddAssign};
2
3use ndarray::prelude::*;
4use serde::{
5    Deserialize, Deserializer, Serialize, Serializer,
6    de::{MapAccess, Visitor},
7    ser::SerializeMap,
8};
9
10/// Sample (sufficient) statistics for a Gaussian CPD.
11#[derive(Clone, Debug)]
12pub struct GaussCPDS {
13    /// Response mean vector |X|.
14    mu_x: Array1<f64>,
15    /// Design mean vector |Z|.
16    mu_z: Array1<f64>,
17    /// Response covariance (uncentered) matrix |X| x |X|.
18    m_xx: Array2<f64>,
19    /// Cross-covariance (uncentered) matrix |X| x |Z|.
20    m_xz: Array2<f64>,
21    /// Design covariance (uncentered) matrix |Z| x |Z|.
22    m_zz: Array2<f64>,
23    /// Sample size.
24    n: f64,
25}
26
27impl GaussCPDS {
28    /// Creates a new `GaussCPDS` instance.
29    ///
30    /// # Arguments
31    ///
32    /// * `mu_x` - Response mean vector |X|.
33    /// * `mu_z` - Design mean vector |Z|.
34    /// * `m_xx` - Response covariance (uncentered) matrix |X| x |X|.
35    /// * `m_xz` - Cross-covariance (uncentered) matrix |X| x |Z|.
36    /// * `m_zz` - Design covariance (uncentered) matrix |Z| x |Z|.
37    /// * `n` - Sample size.
38    ///
39    /// # Panics
40    ///
41    /// * Panics if `mu_x` length does not match `m_xx` size.
42    /// * Panics if `mu_z` length does not match `m_zz` size.
43    /// * Panics if `m_xx` is not square.
44    /// * Panics if the number of rows of `m_xz` does not match the size of `m_xx`.
45    /// * Panics if the number of columns of `m_xz` does not match the size of `m_zz`.
46    /// * Panics if `m_zz` is not square.
47    /// * Panics if any of the values in `mu_x`, `mu_z`, `m_xx`, `m_xz`, or `m_zz` are not finite.
48    /// * Panics if `n` is not finite or is negative.
49    ///
50    /// # Returns
51    ///
52    /// A new `GaussCPDS` instance.
53    ///
54    #[inline]
55    pub fn new(
56        mu_x: Array1<f64>,
57        mu_z: Array1<f64>,
58        m_xx: Array2<f64>,
59        m_xz: Array2<f64>,
60        m_zz: Array2<f64>,
61        n: f64,
62    ) -> Self {
63        // Assert the dimensions are correct.
64        assert_eq!(
65            mu_x.len(),
66            m_xx.nrows(),
67            "Response mean vector length must match response covariance matrix size."
68        );
69        assert_eq!(
70            mu_z.len(),
71            m_zz.nrows(),
72            "Design mean vector length must match design covariance matrix size."
73        );
74        assert!(
75            m_xx.is_square(),
76            "Response covariance matrix must be square."
77        );
78        assert_eq!(
79            m_xz.nrows(),
80            m_xx.nrows(),
81            "Cross-covariance matrix must have the same \n\
82            number of rows as the response covariance matrix."
83        );
84        assert_eq!(
85            m_xz.ncols(),
86            m_zz.nrows(),
87            "Cross-covariance matrix must have the same \n\
88            number of columns as the design covariance matrix."
89        );
90        assert!(m_zz.is_square(), "Design covariance matrix must be square.");
91        // Assert values are finite.
92        assert!(
93            mu_x.iter().all(|&x| x.is_finite()),
94            "Response mean vector must have finite values."
95        );
96        assert!(
97            mu_z.iter().all(|&x| x.is_finite()),
98            "Design mean vector must have finite values."
99        );
100        assert!(
101            m_xx.iter().all(|&x| x.is_finite()),
102            "Response covariance matrix must have finite values."
103        );
104        assert!(
105            m_xz.iter().all(|&x| x.is_finite()),
106            "Cross-covariance matrix must have finite values."
107        );
108        assert!(
109            m_zz.iter().all(|&x| x.is_finite()),
110            "Design covariance matrix must have finite values."
111        );
112        assert!(
113            n.is_finite() && n >= 0.0,
114            "Sample size must be non-negative."
115        );
116
117        Self {
118            mu_x,
119            mu_z,
120            m_xx,
121            m_xz,
122            m_zz,
123            n,
124        }
125    }
126
127    /// Returns the response mean vector |X|.
128    ///
129    /// # Returns
130    ///
131    /// A reference to the response mean vector.
132    ///
133    #[inline]
134    pub fn sample_response_mean(&self) -> &Array1<f64> {
135        &self.mu_x
136    }
137
138    /// Returns the design mean vector |Z|.
139    ///
140    /// # Returns
141    ///
142    /// A reference to the design mean vector.
143    ///
144    #[inline]
145    pub fn sample_design_mean(&self) -> &Array1<f64> {
146        &self.mu_z
147    }
148
149    /// Returns the response covariance matrix |X| x |X|.
150    ///
151    /// # Returns
152    ///
153    /// A reference to the response covariance matrix.
154    ///
155    #[inline]
156    pub fn sample_response_covariance(&self) -> Array2<f64> {
157        // Compute the centering factor.
158        let col_mu_x = self.mu_x.view().insert_axis(Axis(1));
159        let row_mu_x = self.mu_x.view().insert_axis(Axis(0));
160        // Apply centering.
161        &self.m_xx - self.n * &col_mu_x.dot(&row_mu_x)
162    }
163
164    /// Returns the cross-covariance matrix |X| x (|Z| + 1).
165    ///
166    /// # Returns
167    ///
168    /// A reference to the cross-covariance matrix.
169    ///
170    #[inline]
171    pub fn sample_cross_covariance(&self) -> Array2<f64> {
172        // Compute the centering factor.
173        let col_mu_x = self.mu_x.view().insert_axis(Axis(1));
174        let row_mu_z = self.mu_z.view().insert_axis(Axis(0));
175        // Apply centering.
176        &self.m_xz - self.n * &col_mu_x.dot(&row_mu_z)
177    }
178
179    /// Returns the design covariance matrix (|Z| + 1) x (|Z| + 1).
180    ///
181    /// # Returns
182    ///
183    /// A reference to the design covariance matrix.
184    ///
185    #[inline]
186    pub fn sample_design_covariance(&self) -> Array2<f64> {
187        // Compute the centering factor.
188        let col_mu_z = self.mu_z.view().insert_axis(Axis(1));
189        let row_mu_z = self.mu_z.view().insert_axis(Axis(0));
190        // Apply centering.
191        &self.m_zz - self.n * &col_mu_z.dot(&row_mu_z)
192    }
193
194    /// Returns the sample size.
195    ///
196    /// # Returns
197    ///
198    /// The sample size.
199    ///
200    #[inline]
201    pub fn sample_size(&self) -> f64 {
202        self.n
203    }
204}
205
206impl AddAssign for GaussCPDS {
207    fn add_assign(&mut self, other: Self) {
208        // Compute the total sample sizes.
209        let n = self.n + other.n;
210        // Update the response mean vector.
211        self.mu_x = (self.n * &self.mu_x + other.n * &other.mu_x) / n;
212        // Update the design mean vector.
213        self.mu_z = (self.n * &self.mu_z + other.n * &other.mu_z) / n;
214        // Update the response covariance matrix.
215        self.m_xx = (self.n * &self.m_xx + other.n * &other.m_xx) / n;
216        // Update the cross-covariance matrix.
217        self.m_xz = (self.n * &self.m_xz + other.n * &other.m_xz) / n;
218        // Update the design covariance matrix.
219        self.m_zz = (self.n * &self.m_zz + other.n * &other.m_zz) / n;
220        // Update the sample size.
221        self.n = n;
222    }
223}
224
225impl Add for GaussCPDS {
226    type Output = Self;
227
228    #[inline]
229    fn add(mut self, rhs: Self) -> Self::Output {
230        self += rhs;
231        self
232    }
233}
234
235impl Serialize for GaussCPDS {
236    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
237    where
238        S: Serializer,
239    {
240        // Allocate the map.
241        let mut map = serializer.serialize_map(Some(6))?;
242
243        // Convert the sample response mean to a flat format.
244        let sample_response_mean = self.mu_x.to_vec();
245        // Serialize sample response mean.
246        map.serialize_entry("sample_response_mean", &sample_response_mean)?;
247
248        // Convert the sample design mean to a flat format.
249        let sample_design_mean = self.mu_z.to_vec();
250        // Serialize sample design mean.
251        map.serialize_entry("sample_design_mean", &sample_design_mean)?;
252
253        // Convert the sample response covariance to a flat format.
254        let sample_response_covariance: Vec<_> =
255            self.m_xx.rows().into_iter().map(|x| x.to_vec()).collect();
256        // Serialize sample response covariance.
257        map.serialize_entry("sample_response_covariance", &sample_response_covariance)?;
258
259        // Convert the sample cross covariance to a flat format.
260        let sample_cross_covariance: Vec<_> =
261            self.m_xz.rows().into_iter().map(|x| x.to_vec()).collect();
262        // Serialize sample cross covariance.
263        map.serialize_entry("sample_cross_covariance", &sample_cross_covariance)?;
264
265        // Convert the sample design covariance to a flat format.
266        let sample_design_covariance: Vec<_> =
267            self.m_zz.rows().into_iter().map(|x| x.to_vec()).collect();
268        // Serialize sample design covariance.
269        map.serialize_entry("sample_design_covariance", &sample_design_covariance)?;
270
271        // Serialize sample size.
272        map.serialize_entry("sample_size", &self.n)?;
273
274        // End the map.
275        map.end()
276    }
277}
278
279impl<'de> Deserialize<'de> for GaussCPDS {
280    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
281    where
282        D: Deserializer<'de>,
283    {
284        #[derive(Deserialize)]
285        #[serde(field_identifier, rename_all = "snake_case")]
286        #[allow(clippy::enum_variant_names)]
287        enum Field {
288            SampleResponseMean,
289            SampleDesignMean,
290            SampleResponseCovariance,
291            SampleCrossCovariance,
292            SampleDesignCovariance,
293            SampleSize,
294        }
295
296        struct GaussCPDSVisitor;
297
298        impl<'de> Visitor<'de> for GaussCPDSVisitor {
299            type Value = GaussCPDS;
300
301            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
302                formatter.write_str("struct GaussCPDS")
303            }
304
305            fn visit_map<V>(self, mut map: V) -> Result<GaussCPDS, V::Error>
306            where
307                V: MapAccess<'de>,
308            {
309                use serde::de::Error as E;
310
311                // Allocate the fields.
312                let mut sample_response_mean = None;
313                let mut sample_design_mean = None;
314                let mut sample_response_covariance = None;
315                let mut sample_cross_covariance = None;
316                let mut sample_design_covariance = None;
317                let mut sample_size = None;
318
319                while let Some(key) = map.next_key()? {
320                    match key {
321                        Field::SampleResponseMean => {
322                            if sample_response_mean.is_some() {
323                                return Err(E::duplicate_field("sample_response_mean"));
324                            }
325                            sample_response_mean = Some(map.next_value()?);
326                        }
327                        Field::SampleDesignMean => {
328                            if sample_design_mean.is_some() {
329                                return Err(E::duplicate_field("sample_design_mean"));
330                            }
331                            sample_design_mean = Some(map.next_value()?);
332                        }
333                        Field::SampleResponseCovariance => {
334                            if sample_response_covariance.is_some() {
335                                return Err(E::duplicate_field("sample_response_covariance"));
336                            }
337                            sample_response_covariance = Some(map.next_value()?);
338                        }
339                        Field::SampleCrossCovariance => {
340                            if sample_cross_covariance.is_some() {
341                                return Err(E::duplicate_field("sample_cross_covariance"));
342                            }
343                            sample_cross_covariance = Some(map.next_value()?);
344                        }
345                        Field::SampleDesignCovariance => {
346                            if sample_design_covariance.is_some() {
347                                return Err(E::duplicate_field("sample_design_covariance"));
348                            }
349                            sample_design_covariance = Some(map.next_value()?);
350                        }
351                        Field::SampleSize => {
352                            if sample_size.is_some() {
353                                return Err(E::duplicate_field("sample_size"));
354                            }
355                            sample_size = Some(map.next_value()?);
356                        }
357                    }
358                }
359
360                // Extract the fields.
361                let sample_response_mean =
362                    sample_response_mean.ok_or_else(|| E::missing_field("sample_response_mean"))?;
363                let sample_design_mean =
364                    sample_design_mean.ok_or_else(|| E::missing_field("sample_design_mean"))?;
365                let sample_response_covariance = sample_response_covariance
366                    .ok_or_else(|| E::missing_field("sample_response_covariance"))?;
367                let sample_cross_covariance = sample_cross_covariance
368                    .ok_or_else(|| E::missing_field("sample_cross_covariance"))?;
369                let sample_design_covariance = sample_design_covariance
370                    .ok_or_else(|| E::missing_field("sample_design_covariance"))?;
371                let sample_size = sample_size.ok_or_else(|| E::missing_field("sample_size"))?;
372
373                // Convert sample response mean to array.
374                let sample_response_mean = Array1::from_vec(sample_response_mean);
375                // Convert sample design mean to array.
376                let sample_design_mean = Array1::from_vec(sample_design_mean);
377                // Convert sample response covariance to array.
378                let sample_response_covariance = {
379                    let values: Vec<Vec<f64>> = sample_response_covariance;
380                    let shape = (values.len(), values.first().map_or(0, |v| v.len()));
381                    Array::from_iter(values.into_iter().flatten())
382                        .into_shape_with_order(shape)
383                        .map_err(|_| E::custom("Invalid sample response covariance shape"))?
384                };
385                // Convert sample cross covariance to array.
386                let sample_cross_covariance = {
387                    let values: Vec<Vec<f64>> = sample_cross_covariance;
388                    let shape = (values.len(), values.first().map_or(0, |v| v.len()));
389                    Array::from_iter(values.into_iter().flatten())
390                        .into_shape_with_order(shape)
391                        .map_err(|_| E::custom("Invalid sample cross covariance shape"))?
392                };
393                // Convert sample design covariance to array.
394                let sample_design_covariance = {
395                    let values: Vec<Vec<f64>> = sample_design_covariance;
396                    let shape = (values.len(), values.first().map_or(0, |v| v.len()));
397                    Array::from_iter(values.into_iter().flatten())
398                        .into_shape_with_order(shape)
399                        .map_err(|_| E::custom("Invalid sample design covariance shape"))?
400                };
401
402                Ok(GaussCPDS::new(
403                    sample_response_mean,
404                    sample_design_mean,
405                    sample_response_covariance,
406                    sample_cross_covariance,
407                    sample_design_covariance,
408                    sample_size,
409                ))
410            }
411        }
412
413        const FIELDS: &[&str] = &[
414            "sample_response_mean",
415            "sample_design_mean",
416            "sample_response_covariance",
417            "sample_cross_covariance",
418            "sample_design_covariance",
419            "sample_size",
420        ];
421
422        deserializer.deserialize_struct("GaussCPDS", FIELDS, GaussCPDSVisitor)
423    }
424}