1use approx::{AbsDiffEq, RelativeEq};
2use ndarray::prelude::*;
3use ndarray_linalg::{CholeskyInto, Determinant, UPLO};
4use rand::Rng;
5use rand_distr::{Distribution, StandardNormal};
6use serde::{
7 Deserialize, Deserializer, Serialize, Serializer,
8 de::{MapAccess, Visitor},
9 ser::SerializeMap,
10};
11
12use crate::{
13 datasets::GaussSample,
14 impl_json_io,
15 models::{CPD, GaussCPDS, GaussPhi, Labelled, Phi},
16 types::{EPSILON, LN_2_PI, Labels, Set},
17 utils::PseudoInverse,
18};
19
20#[derive(Clone, Debug)]
22pub struct GaussCPDP {
23 a: Array2<f64>,
25 b: Array1<f64>,
27 s: Array2<f64>,
29}
30
31impl GaussCPDP {
32 pub fn new(a: Array2<f64>, b: Array1<f64>, s: Array2<f64>) -> Self {
52 assert_eq!(
54 a.nrows(),
55 b.len(),
56 "Coefficient matrix rows must match intercept vector size."
57 );
58 assert_eq!(
59 a.nrows(),
60 s.nrows(),
61 "Coefficient matrix rows must match covariance matrix size."
62 );
63 assert!(s.is_square(), "Covariance matrix must be square.");
64 assert!(
66 a.iter().all(|&x| x.is_finite()),
67 "Coefficient matrix must have finite values."
68 );
69 assert!(
70 b.iter().all(|&x| x.is_finite()),
71 "Intercept vector must have finite values."
72 );
73 assert!(
74 s.iter().all(|&x| x.is_finite()),
75 "Covariance matrix must have finite values."
76 );
77 assert_eq!(s, s.t(), "Covariance matrix must be symmetric.");
78
79 Self { a, b, s }
80 }
81
82 #[inline]
89 pub const fn coefficients(&self) -> &Array2<f64> {
90 &self.a
91 }
92
93 #[inline]
100 pub const fn intercept(&self) -> &Array1<f64> {
101 &self.b
102 }
103
104 #[inline]
111 pub const fn covariance(&self) -> &Array2<f64> {
112 &self.s
113 }
114}
115
116impl PartialEq for GaussCPDP {
117 fn eq(&self, other: &Self) -> bool {
118 self.a.eq(&other.a) && self.b.eq(&other.b) && self.s.eq(&other.s)
119 }
120}
121
122impl AbsDiffEq for GaussCPDP {
123 type Epsilon = f64;
124
125 fn default_epsilon() -> Self::Epsilon {
126 Self::Epsilon::default_epsilon()
127 }
128
129 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
130 self.a.abs_diff_eq(&other.a, epsilon)
131 && self.b.abs_diff_eq(&other.b, epsilon)
132 && self.s.abs_diff_eq(&other.s, epsilon)
133 }
134}
135
136impl RelativeEq for GaussCPDP {
137 fn default_max_relative() -> Self::Epsilon {
138 Self::Epsilon::default_max_relative()
139 }
140
141 fn relative_eq(
142 &self,
143 other: &Self,
144 epsilon: Self::Epsilon,
145 max_relative: Self::Epsilon,
146 ) -> bool {
147 self.a.relative_eq(&other.a, epsilon, max_relative)
148 && self.b.relative_eq(&other.b, epsilon, max_relative)
149 && self.s.relative_eq(&other.s, epsilon, max_relative)
150 }
151}
152
153impl Serialize for GaussCPDP {
154 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
155 where
156 S: Serializer,
157 {
158 let mut map = serializer.serialize_map(Some(3))?;
160
161 let coefficients: Vec<_> = self.a.rows().into_iter().map(|x| x.to_vec()).collect();
163 map.serialize_entry("coefficients", &coefficients)?;
165
166 let intercept = self.b.to_vec();
168 map.serialize_entry("intercept", &intercept)?;
170
171 let covariance: Vec<_> = self.s.rows().into_iter().map(|x| x.to_vec()).collect();
173 map.serialize_entry("covariance", &covariance)?;
175
176 map.end()
178 }
179}
180
181impl<'de> Deserialize<'de> for GaussCPDP {
182 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
183 where
184 D: Deserializer<'de>,
185 {
186 #[derive(Deserialize)]
187 #[serde(field_identifier, rename_all = "snake_case")]
188 enum Field {
189 Coefficients,
190 Intercept,
191 Covariance,
192 }
193
194 struct GaussCPDPVisitor;
195
196 impl<'de> Visitor<'de> for GaussCPDPVisitor {
197 type Value = GaussCPDP;
198
199 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
200 formatter.write_str("struct GaussCPDP")
201 }
202
203 fn visit_map<V>(self, mut map: V) -> Result<GaussCPDP, V::Error>
204 where
205 V: MapAccess<'de>,
206 {
207 use serde::de::Error as E;
208
209 let mut coefficients = None;
211 let mut intercept = None;
212 let mut covariance = None;
213
214 while let Some(key) = map.next_key()? {
216 match key {
217 Field::Coefficients => {
218 if coefficients.is_some() {
219 return Err(E::duplicate_field("coefficients"));
220 }
221 coefficients = Some(map.next_value()?);
222 }
223 Field::Intercept => {
224 if intercept.is_some() {
225 return Err(E::duplicate_field("intercept"));
226 }
227 intercept = Some(map.next_value()?);
228 }
229 Field::Covariance => {
230 if covariance.is_some() {
231 return Err(E::duplicate_field("covariance"));
232 }
233 covariance = Some(map.next_value()?);
234 }
235 }
236 }
237
238 let coefficients = coefficients.ok_or_else(|| E::missing_field("coefficients"))?;
240 let intercept = intercept.ok_or_else(|| E::missing_field("intercept"))?;
241 let covariance = covariance.ok_or_else(|| E::missing_field("covariance"))?;
242
243 let coefficients = {
245 let values: Vec<Vec<f64>> = coefficients;
246 let shape = (values.len(), values[0].len());
247 Array::from_iter(values.into_iter().flatten())
248 .into_shape_with_order(shape)
249 .map_err(|_| E::custom("Invalid coefficients shape"))?
250 };
251 let intercept = Array1::from_vec(intercept);
253 let covariance = {
255 let values: Vec<Vec<f64>> = covariance;
256 let shape = (values.len(), values[0].len());
257 Array::from_iter(values.into_iter().flatten())
258 .into_shape_with_order(shape)
259 .map_err(|_| E::custom("Invalid covariance shape"))?
260 };
261
262 Ok(GaussCPDP::new(coefficients, intercept, covariance))
263 }
264 }
265
266 const FIELDS: &[&str] = &["coefficients", "intercept", "covariance"];
267
268 deserializer.deserialize_struct("GaussCPDP", FIELDS, GaussCPDPVisitor)
269 }
270}
271
272#[derive(Clone, Debug)]
274pub struct GaussCPD {
275 labels: Labels,
277 conditioning_labels: Labels,
279 parameters: GaussCPDP,
281 sample_statistics: Option<GaussCPDS>,
283 sample_log_likelihood: Option<f64>,
285}
286
287impl GaussCPD {
288 pub fn new(
306 mut labels: Labels,
307 mut conditioning_labels: Labels,
308 mut parameters: GaussCPDP,
309 ) -> Self {
310 assert!(
312 labels.is_disjoint(&conditioning_labels),
313 "Labels and conditioning labels must be disjoint."
314 );
315 assert_eq!(
317 parameters.a.nrows(),
318 labels.len(),
319 "Coefficient matrix rows must match labels length."
320 );
321 assert_eq!(
322 parameters.a.ncols(),
323 conditioning_labels.len(),
324 "Coefficient matrix columns must match conditioning labels length."
325 );
326 assert_eq!(
327 parameters.b.len(),
328 labels.len(),
329 "Intercept vector size must match labels length."
330 );
331 assert_eq!(
332 parameters.s.nrows(),
333 labels.len(),
334 "Covariance matrix rows must match labels length."
335 );
336 assert_eq!(
337 parameters.s.ncols(),
338 labels.len(),
339 "Covariance matrix columns must match labels length."
340 );
341
342 if !labels.is_sorted() {
344 let mut indices: Vec<usize> = (0..labels.len()).collect();
346 indices.sort_by_key(|&i| &labels[i]);
348 labels.sort();
350 let mut new_a = parameters.a.clone();
352 let mut new_b = parameters.b.clone();
353 let mut new_s = parameters.s.clone();
354 for (i, &j) in indices.iter().enumerate() {
356 new_a.row_mut(i).assign(¶meters.a.row(j));
357 }
358 for (i, &j) in indices.iter().enumerate() {
360 new_b[i] = parameters.b[j];
361 }
362 for (i, &j) in indices.iter().enumerate() {
364 new_s.row_mut(i).assign(¶meters.s.row(j));
365 }
366 let _s = new_s.clone();
368 for (i, &j) in indices.iter().enumerate() {
370 new_s.column_mut(i).assign(&_s.column(j));
371 }
372 parameters.a = new_a;
374 parameters.b = new_b;
375 parameters.s = new_s;
376 }
377
378 if !conditioning_labels.is_sorted() {
380 let mut indices: Vec<usize> = (0..conditioning_labels.len()).collect();
382 indices.sort_by_key(|&i| &conditioning_labels[i]);
384 conditioning_labels.sort();
386 let mut new_a = parameters.a.clone();
388 for (i, &j) in indices.iter().enumerate() {
390 new_a.column_mut(i).assign(¶meters.a.column(j));
391 }
392 parameters.a = new_a;
394 }
395
396 Self {
397 labels,
398 conditioning_labels,
399 parameters,
400 sample_statistics: None,
401 sample_log_likelihood: None,
402 }
403 }
404
405 pub fn marginalize(&self, x: &Set<usize>, z: &Set<usize>) -> Self {
417 if x.is_empty() && z.is_empty() {
419 return self.clone();
420 }
421 let labels_x = self.labels();
423 let labels_z = self.conditioning_labels();
424 let not_x = (0..labels_x.len()).filter(|i| !x.contains(i)).collect();
426 let not_z = (0..labels_z.len()).filter(|i| !z.contains(i)).collect();
427 let phi = self.clone().into_phi();
429 let x = phi.indices_from(x, labels_x);
431 let z = phi.indices_from(z, labels_z);
432 let phi = phi.marginalize(&(&x | &z));
434 let not_x = phi.indices_from(¬_x, labels_x);
436 let not_z = phi.indices_from(¬_z, labels_z);
437 phi.into_cpd(¬_x, ¬_z)
439 }
440
441 pub fn with_optionals(
456 labels: Labels,
457 conditioning_labels: Labels,
458 parameters: GaussCPDP,
459 sample_statistics: Option<GaussCPDS>,
460 sample_log_likelihood: Option<f64>,
461 ) -> Self {
462 let mut cpd = Self::new(labels, conditioning_labels, parameters);
466
467 cpd.sample_statistics = sample_statistics;
471 cpd.sample_log_likelihood = sample_log_likelihood;
472
473 cpd
474 }
475
476 #[inline]
488 pub fn from_phi(phi: GaussPhi, x: &Set<usize>, z: &Set<usize>) -> Self {
489 phi.into_cpd(x, z)
490 }
491
492 #[inline]
503 pub fn into_phi(self) -> GaussPhi {
504 GaussPhi::from_cpd(self)
505 }
506}
507
508impl Labelled for GaussCPD {
509 #[inline]
510 fn labels(&self) -> &Labels {
511 &self.labels
512 }
513}
514
515impl PartialEq for GaussCPD {
516 fn eq(&self, other: &Self) -> bool {
517 self.labels.eq(&other.labels)
518 && self.conditioning_labels.eq(&other.conditioning_labels)
519 && self.parameters.eq(&other.parameters)
520 }
521}
522
523impl AbsDiffEq for GaussCPD {
524 type Epsilon = f64;
525
526 fn default_epsilon() -> Self::Epsilon {
527 Self::Epsilon::default_epsilon()
528 }
529
530 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
531 self.labels.eq(&other.labels)
532 && self.conditioning_labels.eq(&other.conditioning_labels)
533 && self.parameters.abs_diff_eq(&other.parameters, epsilon)
534 }
535}
536
537impl RelativeEq for GaussCPD {
538 fn default_max_relative() -> Self::Epsilon {
539 Self::Epsilon::default_max_relative()
540 }
541
542 fn relative_eq(
543 &self,
544 other: &Self,
545 epsilon: Self::Epsilon,
546 max_relative: Self::Epsilon,
547 ) -> bool {
548 self.labels.eq(&other.labels)
549 && self.conditioning_labels.eq(&other.conditioning_labels)
550 && self
551 .parameters
552 .relative_eq(&other.parameters, epsilon, max_relative)
553 }
554}
555
556impl CPD for GaussCPD {
557 type Support = GaussSample;
558 type Parameters = GaussCPDP;
559 type Statistics = GaussCPDS;
560
561 #[inline]
562 fn conditioning_labels(&self) -> &Labels {
563 &self.conditioning_labels
564 }
565
566 #[inline]
567 fn parameters(&self) -> &Self::Parameters {
568 &self.parameters
569 }
570
571 #[inline]
572 fn parameters_size(&self) -> usize {
573 let s = {
574 let s = self.parameters.s.nrows();
576 s * (s + 1) / 2
577 };
578
579 self.parameters.a.len() + self.parameters.b.len() + s
580 }
581
582 #[inline]
583 fn sample_statistics(&self) -> Option<&Self::Statistics> {
584 self.sample_statistics.as_ref()
585 }
586
587 #[inline]
588 fn sample_log_likelihood(&self) -> Option<f64> {
589 self.sample_log_likelihood
590 }
591
592 fn pf(&self, x: &Self::Support, z: &Self::Support) -> f64 {
593 let n = self.labels.len();
595 let m = self.conditioning_labels.len();
597
598 assert_eq!(
600 x.len(),
601 n,
602 "Vector X must match number of variables: \n\
603 \t expected: |X| == {} , \n\
604 \t found: |X| == {} .",
605 n,
606 x.len(),
607 );
608 assert_eq!(
610 z.len(),
611 m,
612 "Vector Z must match number of conditioning variables: \n\
613 \t expected: |Z| == {} , \n\
614 \t found: |Z| == {} .",
615 m,
616 z.len(),
617 );
618
619 let (a, b, s) = (
621 self.parameters.coefficients(),
622 self.parameters.intercept(),
623 self.parameters.covariance(),
624 );
625
626 if n == 0 {
628 return 1.;
629 }
630
631 if n == 1 {
633 let mu = match m {
635 0 => b[0], 1 => f64::mul_add(a[[0, 0]], z[0], b[0]), _ => (a.dot(z) + b)[0], };
642 let x_mu = x[0] - mu;
644 let k = s[[0, 0]];
646 let ln_pf = -0.5 * (LN_2_PI + f64::ln(k) + f64::powi(x_mu, 2) / k);
648 return f64::exp(ln_pf);
650 }
651
652 let mu = a.dot(z) + b;
656 let x_mu = x - mu;
658 let k = s.pinv();
660 let n_ln_2_pi = s.nrows() as f64 * LN_2_PI;
662 let (_, ln_det) = s.sln_det().expect("Failed to compute the determinant.");
663 let ln_pf = -0.5 * (n_ln_2_pi + ln_det + x_mu.dot(&k).dot(&x_mu));
664 f64::exp(ln_pf)
666 }
667
668 fn sample<R: Rng>(&self, rng: &mut R, z: &Self::Support) -> Self::Support {
669 let n = self.labels.len();
671 let m = self.conditioning_labels.len();
673
674 assert_eq!(
676 z.len(),
677 m,
678 "Vector Z must match number of conditioning variables: \n\
679 \t expected: |Z| == {} , \n\
680 \t found: |Z| == {} .",
681 m,
682 z.len(),
683 );
684
685 let (a, b, s) = (
687 self.parameters.coefficients(),
688 self.parameters.intercept(),
689 self.parameters.covariance(),
690 );
691
692 if n == 0 {
694 return array![];
695 }
696
697 if n == 1 {
699 let mu = match m {
701 0 => b[0], 1 => f64::mul_add(a[[0, 0]], z[0], b[0]), _ => (a.dot(z) + b)[0], };
708 let e: f64 = StandardNormal.sample(rng);
710 let x = f64::mul_add(s[[0, 0]].sqrt(), e, mu);
712 return array![x];
714 }
715
716 let mu = a.dot(z) + b;
720 let l = (s + EPSILON * Array::eye(s.nrows()))
722 .cholesky_into(UPLO::Lower)
723 .expect("Failed to compute Cholesky decomposition.");
724 let e = StandardNormal
726 .sample_iter(rng)
727 .take(s.nrows())
728 .collect::<Array1<_>>();
729 l.dot(&e) + mu
731 }
732}
733
734impl Serialize for GaussCPD {
735 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
736 where
737 S: Serializer,
738 {
739 let mut size = 4;
741 size += self.sample_statistics.is_some() as usize;
743 size += self.sample_log_likelihood.is_some() as usize;
744 let mut map = serializer.serialize_map(Some(size))?;
746
747 map.serialize_entry("labels", &self.labels)?;
749 map.serialize_entry("conditioning_labels", &self.conditioning_labels)?;
751 map.serialize_entry("parameters", &self.parameters)?;
753
754 if let Some(sample_statistics) = &self.sample_statistics {
756 map.serialize_entry("sample_statistics", sample_statistics)?;
757 }
758
759 if let Some(sample_log_likelihood) = &self.sample_log_likelihood {
761 map.serialize_entry("sample_log_likelihood", sample_log_likelihood)?;
762 }
763
764 map.serialize_entry("type", "gausscpd")?;
766
767 map.end()
769 }
770}
771
772impl<'de> Deserialize<'de> for GaussCPD {
773 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
774 where
775 D: Deserializer<'de>,
776 {
777 #[derive(Deserialize)]
778 #[serde(field_identifier, rename_all = "snake_case")]
779 enum Field {
780 Labels,
781 ConditioningLabels,
782 Parameters,
783 SampleStatistics,
784 SampleLogLikelihood,
785 Type,
786 }
787
788 struct GaussCPDVisitor;
789
790 impl<'de> Visitor<'de> for GaussCPDVisitor {
791 type Value = GaussCPD;
792
793 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
794 formatter.write_str("struct GaussCPD")
795 }
796
797 fn visit_map<V>(self, mut map: V) -> Result<GaussCPD, V::Error>
798 where
799 V: MapAccess<'de>,
800 {
801 use serde::de::Error as E;
802
803 let mut labels = None;
805 let mut conditioning_labels = None;
806 let mut parameters = None;
807 let mut sample_statistics = None;
808 let mut sample_log_likelihood = None;
809 let mut type_ = None;
810
811 while let Some(key) = map.next_key()? {
813 match key {
814 Field::Labels => {
815 if labels.is_some() {
816 return Err(E::duplicate_field("labels"));
817 }
818 labels = Some(map.next_value()?);
819 }
820 Field::ConditioningLabels => {
821 if conditioning_labels.is_some() {
822 return Err(E::duplicate_field("conditioning_labels"));
823 }
824 conditioning_labels = Some(map.next_value()?);
825 }
826 Field::Parameters => {
827 if parameters.is_some() {
828 return Err(E::duplicate_field("parameters"));
829 }
830 parameters = Some(map.next_value()?);
831 }
832 Field::SampleStatistics => {
833 if sample_statistics.is_some() {
834 return Err(E::duplicate_field("sample_statistics"));
835 }
836 sample_statistics = Some(map.next_value()?);
837 }
838 Field::SampleLogLikelihood => {
839 if sample_log_likelihood.is_some() {
840 return Err(E::duplicate_field("sample_log_likelihood"));
841 }
842 sample_log_likelihood = Some(map.next_value()?);
843 }
844 Field::Type => {
845 if type_.is_some() {
846 return Err(E::duplicate_field("type"));
847 }
848 type_ = Some(map.next_value()?);
849 }
850 }
851 }
852
853 let labels = labels.ok_or_else(|| E::missing_field("labels"))?;
855 let conditioning_labels =
856 conditioning_labels.ok_or_else(|| E::missing_field("conditioning_labels"))?;
857 let parameters = parameters.ok_or_else(|| E::missing_field("parameters"))?;
858
859 let type_: String = type_.ok_or_else(|| E::missing_field("type"))?;
861 assert_eq!(type_, "gausscpd", "Invalid type for GaussCPD.");
862
863 Ok(GaussCPD::with_optionals(
864 labels,
865 conditioning_labels,
866 parameters,
867 sample_statistics,
868 sample_log_likelihood,
869 ))
870 }
871 }
872
873 const FIELDS: &[&str] = &[
874 "labels",
875 "conditioning_labels",
876 "parameters",
877 "sample_statistics",
878 "sample_log_likelihood",
879 "type",
880 ];
881
882 deserializer.deserialize_struct("GaussCPD", FIELDS, GaussCPDVisitor)
883 }
884}
885
886impl_json_io!(GaussCPD);