causal_hub/models/bayesian_network/gaussian/
sufficient_statistics.rs1use std::ops::{Add, AddAssign};
2
3use ndarray::prelude::*;
4use serde::{
5 Deserialize, Deserializer, Serialize, Serializer,
6 de::{MapAccess, Visitor},
7 ser::SerializeMap,
8};
9
10#[derive(Clone, Debug)]
12pub struct GaussCPDS {
13 mu_x: Array1<f64>,
15 mu_z: Array1<f64>,
17 m_xx: Array2<f64>,
19 m_xz: Array2<f64>,
21 m_zz: Array2<f64>,
23 n: f64,
25}
26
27impl GaussCPDS {
28 #[inline]
55 pub fn new(
56 mu_x: Array1<f64>,
57 mu_z: Array1<f64>,
58 m_xx: Array2<f64>,
59 m_xz: Array2<f64>,
60 m_zz: Array2<f64>,
61 n: f64,
62 ) -> Self {
63 assert_eq!(
65 mu_x.len(),
66 m_xx.nrows(),
67 "Response mean vector length must match response covariance matrix size."
68 );
69 assert_eq!(
70 mu_z.len(),
71 m_zz.nrows(),
72 "Design mean vector length must match design covariance matrix size."
73 );
74 assert!(
75 m_xx.is_square(),
76 "Response covariance matrix must be square."
77 );
78 assert_eq!(
79 m_xz.nrows(),
80 m_xx.nrows(),
81 "Cross-covariance matrix must have the same \n\
82 number of rows as the response covariance matrix."
83 );
84 assert_eq!(
85 m_xz.ncols(),
86 m_zz.nrows(),
87 "Cross-covariance matrix must have the same \n\
88 number of columns as the design covariance matrix."
89 );
90 assert!(m_zz.is_square(), "Design covariance matrix must be square.");
91 assert!(
93 mu_x.iter().all(|&x| x.is_finite()),
94 "Response mean vector must have finite values."
95 );
96 assert!(
97 mu_z.iter().all(|&x| x.is_finite()),
98 "Design mean vector must have finite values."
99 );
100 assert!(
101 m_xx.iter().all(|&x| x.is_finite()),
102 "Response covariance matrix must have finite values."
103 );
104 assert!(
105 m_xz.iter().all(|&x| x.is_finite()),
106 "Cross-covariance matrix must have finite values."
107 );
108 assert!(
109 m_zz.iter().all(|&x| x.is_finite()),
110 "Design covariance matrix must have finite values."
111 );
112 assert!(
113 n.is_finite() && n >= 0.0,
114 "Sample size must be non-negative."
115 );
116
117 Self {
118 mu_x,
119 mu_z,
120 m_xx,
121 m_xz,
122 m_zz,
123 n,
124 }
125 }
126
127 #[inline]
134 pub fn sample_response_mean(&self) -> &Array1<f64> {
135 &self.mu_x
136 }
137
138 #[inline]
145 pub fn sample_design_mean(&self) -> &Array1<f64> {
146 &self.mu_z
147 }
148
149 #[inline]
156 pub fn sample_response_covariance(&self) -> Array2<f64> {
157 let col_mu_x = self.mu_x.view().insert_axis(Axis(1));
159 let row_mu_x = self.mu_x.view().insert_axis(Axis(0));
160 &self.m_xx - self.n * &col_mu_x.dot(&row_mu_x)
162 }
163
164 #[inline]
171 pub fn sample_cross_covariance(&self) -> Array2<f64> {
172 let col_mu_x = self.mu_x.view().insert_axis(Axis(1));
174 let row_mu_z = self.mu_z.view().insert_axis(Axis(0));
175 &self.m_xz - self.n * &col_mu_x.dot(&row_mu_z)
177 }
178
179 #[inline]
186 pub fn sample_design_covariance(&self) -> Array2<f64> {
187 let col_mu_z = self.mu_z.view().insert_axis(Axis(1));
189 let row_mu_z = self.mu_z.view().insert_axis(Axis(0));
190 &self.m_zz - self.n * &col_mu_z.dot(&row_mu_z)
192 }
193
194 #[inline]
201 pub fn sample_size(&self) -> f64 {
202 self.n
203 }
204}
205
206impl AddAssign for GaussCPDS {
207 fn add_assign(&mut self, other: Self) {
208 let n = self.n + other.n;
210 self.mu_x = (self.n * &self.mu_x + other.n * &other.mu_x) / n;
212 self.mu_z = (self.n * &self.mu_z + other.n * &other.mu_z) / n;
214 self.m_xx = (self.n * &self.m_xx + other.n * &other.m_xx) / n;
216 self.m_xz = (self.n * &self.m_xz + other.n * &other.m_xz) / n;
218 self.m_zz = (self.n * &self.m_zz + other.n * &other.m_zz) / n;
220 self.n = n;
222 }
223}
224
225impl Add for GaussCPDS {
226 type Output = Self;
227
228 #[inline]
229 fn add(mut self, rhs: Self) -> Self::Output {
230 self += rhs;
231 self
232 }
233}
234
235impl Serialize for GaussCPDS {
236 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
237 where
238 S: Serializer,
239 {
240 let mut map = serializer.serialize_map(Some(6))?;
242
243 let sample_response_mean = self.mu_x.to_vec();
245 map.serialize_entry("sample_response_mean", &sample_response_mean)?;
247
248 let sample_design_mean = self.mu_z.to_vec();
250 map.serialize_entry("sample_design_mean", &sample_design_mean)?;
252
253 let sample_response_covariance: Vec<_> =
255 self.m_xx.rows().into_iter().map(|x| x.to_vec()).collect();
256 map.serialize_entry("sample_response_covariance", &sample_response_covariance)?;
258
259 let sample_cross_covariance: Vec<_> =
261 self.m_xz.rows().into_iter().map(|x| x.to_vec()).collect();
262 map.serialize_entry("sample_cross_covariance", &sample_cross_covariance)?;
264
265 let sample_design_covariance: Vec<_> =
267 self.m_zz.rows().into_iter().map(|x| x.to_vec()).collect();
268 map.serialize_entry("sample_design_covariance", &sample_design_covariance)?;
270
271 map.serialize_entry("sample_size", &self.n)?;
273
274 map.end()
276 }
277}
278
279impl<'de> Deserialize<'de> for GaussCPDS {
280 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
281 where
282 D: Deserializer<'de>,
283 {
284 #[derive(Deserialize)]
285 #[serde(field_identifier, rename_all = "snake_case")]
286 #[allow(clippy::enum_variant_names)]
287 enum Field {
288 SampleResponseMean,
289 SampleDesignMean,
290 SampleResponseCovariance,
291 SampleCrossCovariance,
292 SampleDesignCovariance,
293 SampleSize,
294 }
295
296 struct GaussCPDSVisitor;
297
298 impl<'de> Visitor<'de> for GaussCPDSVisitor {
299 type Value = GaussCPDS;
300
301 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
302 formatter.write_str("struct GaussCPDS")
303 }
304
305 fn visit_map<V>(self, mut map: V) -> Result<GaussCPDS, V::Error>
306 where
307 V: MapAccess<'de>,
308 {
309 use serde::de::Error as E;
310
311 let mut sample_response_mean = None;
313 let mut sample_design_mean = None;
314 let mut sample_response_covariance = None;
315 let mut sample_cross_covariance = None;
316 let mut sample_design_covariance = None;
317 let mut sample_size = None;
318
319 while let Some(key) = map.next_key()? {
320 match key {
321 Field::SampleResponseMean => {
322 if sample_response_mean.is_some() {
323 return Err(E::duplicate_field("sample_response_mean"));
324 }
325 sample_response_mean = Some(map.next_value()?);
326 }
327 Field::SampleDesignMean => {
328 if sample_design_mean.is_some() {
329 return Err(E::duplicate_field("sample_design_mean"));
330 }
331 sample_design_mean = Some(map.next_value()?);
332 }
333 Field::SampleResponseCovariance => {
334 if sample_response_covariance.is_some() {
335 return Err(E::duplicate_field("sample_response_covariance"));
336 }
337 sample_response_covariance = Some(map.next_value()?);
338 }
339 Field::SampleCrossCovariance => {
340 if sample_cross_covariance.is_some() {
341 return Err(E::duplicate_field("sample_cross_covariance"));
342 }
343 sample_cross_covariance = Some(map.next_value()?);
344 }
345 Field::SampleDesignCovariance => {
346 if sample_design_covariance.is_some() {
347 return Err(E::duplicate_field("sample_design_covariance"));
348 }
349 sample_design_covariance = Some(map.next_value()?);
350 }
351 Field::SampleSize => {
352 if sample_size.is_some() {
353 return Err(E::duplicate_field("sample_size"));
354 }
355 sample_size = Some(map.next_value()?);
356 }
357 }
358 }
359
360 let sample_response_mean =
362 sample_response_mean.ok_or_else(|| E::missing_field("sample_response_mean"))?;
363 let sample_design_mean =
364 sample_design_mean.ok_or_else(|| E::missing_field("sample_design_mean"))?;
365 let sample_response_covariance = sample_response_covariance
366 .ok_or_else(|| E::missing_field("sample_response_covariance"))?;
367 let sample_cross_covariance = sample_cross_covariance
368 .ok_or_else(|| E::missing_field("sample_cross_covariance"))?;
369 let sample_design_covariance = sample_design_covariance
370 .ok_or_else(|| E::missing_field("sample_design_covariance"))?;
371 let sample_size = sample_size.ok_or_else(|| E::missing_field("sample_size"))?;
372
373 let sample_response_mean = Array1::from_vec(sample_response_mean);
375 let sample_design_mean = Array1::from_vec(sample_design_mean);
377 let sample_response_covariance = {
379 let values: Vec<Vec<f64>> = sample_response_covariance;
380 let shape = (values.len(), values.first().map_or(0, |v| v.len()));
381 Array::from_iter(values.into_iter().flatten())
382 .into_shape_with_order(shape)
383 .map_err(|_| E::custom("Invalid sample response covariance shape"))?
384 };
385 let sample_cross_covariance = {
387 let values: Vec<Vec<f64>> = sample_cross_covariance;
388 let shape = (values.len(), values.first().map_or(0, |v| v.len()));
389 Array::from_iter(values.into_iter().flatten())
390 .into_shape_with_order(shape)
391 .map_err(|_| E::custom("Invalid sample cross covariance shape"))?
392 };
393 let sample_design_covariance = {
395 let values: Vec<Vec<f64>> = sample_design_covariance;
396 let shape = (values.len(), values.first().map_or(0, |v| v.len()));
397 Array::from_iter(values.into_iter().flatten())
398 .into_shape_with_order(shape)
399 .map_err(|_| E::custom("Invalid sample design covariance shape"))?
400 };
401
402 Ok(GaussCPDS::new(
403 sample_response_mean,
404 sample_design_mean,
405 sample_response_covariance,
406 sample_cross_covariance,
407 sample_design_covariance,
408 sample_size,
409 ))
410 }
411 }
412
413 const FIELDS: &[&str] = &[
414 "sample_response_mean",
415 "sample_design_mean",
416 "sample_response_covariance",
417 "sample_cross_covariance",
418 "sample_design_covariance",
419 "sample_size",
420 ];
421
422 deserializer.deserialize_struct("GaussCPDS", FIELDS, GaussCPDSVisitor)
423 }
424}