fetish_lib/
func_schmear.rs1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use ndarray::*;
5
6use crate::schmear::*;
7use crate::linalg_utils::*;
8use crate::func_scatter_tensor::*;
9
10pub struct FuncSchmear {
19 pub mean : Array2<f32>,
20 pub covariance : FuncScatterTensor
21}
22
23impl FuncSchmear {
24 pub fn compress(&self, mat : ArrayView2<f32>) -> Schmear {
30 let t = self.mean.shape()[0];
31 let s = self.mean.shape()[1];
32
33 let mean_flat = self.mean.clone().into_shape((t * s,)).unwrap();
34 let mean_transformed = mat.dot(&mean_flat);
35
36 let covariance_transformed = self.covariance.compress(mat);
37
38 Schmear {
39 mean : mean_transformed,
40 covariance : covariance_transformed
41 }
42 }
43
44 pub fn flatten(&self) -> Schmear {
47 let t = self.mean.shape()[0];
48 let s = self.mean.shape()[1];
49
50 let mean = self.mean.clone().into_shape((t * s,)).unwrap();
51 let covariance = self.covariance.flatten();
52 Schmear {
53 mean,
54 covariance
55 }
56 }
57 pub fn apply(&self, x : &Schmear) -> Schmear {
64 let sigma_dot_u = frob_inner(self.covariance.in_scatter.view(), x.covariance.view());
65 let u_inner_product = x.mean.dot(&self.covariance.in_scatter).dot(&x.mean);
66 let v_scale = sigma_dot_u + u_inner_product;
67 let v_contrib = v_scale * &self.covariance.out_scatter;
68
69 if (v_scale < 0.0f32) {
70 println!("v scale became negative: {}", v_scale);
71 println!("components: {}, {}", sigma_dot_u, u_inner_product);
72 if (u_inner_product < 0.0f32) {
73 println!("Non-psd in scatter: {}", &self.covariance.in_scatter);
74 println!("x mean: {}", &x.mean);
75 }
76 }
77
78 let m_sigma_m_t = self.mean.dot(&x.covariance).dot(&self.mean.t());
79
80 let result_covariance = v_contrib + &m_sigma_m_t;
81 let result_mean = self.mean.dot(&x.mean);
82 let result = Schmear {
83 mean : result_mean,
84 covariance : result_covariance
85 };
86 result
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use ndarray::*;
94 use ::rand_distr::StandardNormal;
95 use ndarray_linalg::*;
96 use crate::test_utils::*;
97 use ndarray_rand::*;
98
99 #[test]
100 fn schmear_application_accurate() {
101 let t = 3;
102 let s = 3;
103 let num_samps = 1000;
104 let arg_scale_mult = 0.01f32;
105
106 let normal_inverse_wishart = random_normal_inverse_wishart(s, t);
107
108 let func_schmear = normal_inverse_wishart.get_schmear();
109
110 let arg_mean = random_vector(s);
111 let mut arg_covariance_sqrt = random_matrix(s, s);
112 arg_covariance_sqrt *= arg_scale_mult;
113 let arg_covariance = arg_covariance_sqrt.dot(&arg_covariance_sqrt.t());
114 let arg_schmear = Schmear {
115 mean : arg_mean.clone(),
116 covariance : arg_covariance.clone()
117 };
118
119 let actual_out_schmear = func_schmear.apply(&arg_schmear);
120
121 let mut expected_out_mean = Array::zeros((t,));
122 let mut expected_out_covariance = Array::zeros((t, t));
123
124 let mut rng = rand::thread_rng();
125
126 let scale_fac = 1.0f32 / (num_samps as f32);
127
128 for _ in 0..num_samps {
129 let func_samp = normal_inverse_wishart.sample(&mut rng);
130
131 let standard_normal_arg_vec = Array::random((s,), StandardNormal);
132 let arg_samp = &arg_mean + &arg_covariance_sqrt.dot(&standard_normal_arg_vec);
133
134 let out_samp = func_samp.dot(&arg_samp);
135
136 let out_diff = &out_samp - &actual_out_schmear.mean;
137
138 expected_out_mean += &(scale_fac * &out_samp);
139 expected_out_covariance += &(scale_fac * &outer(out_diff.view(), out_diff.view()));
140 }
141
142 assert_equal_vectors_to_within(actual_out_schmear.mean.view(), expected_out_mean.view(), 1.0f32);
143 assert_equal_matrices_to_within(actual_out_schmear.covariance.view(), expected_out_covariance.view(), 10.0f32);
144 }
145}
146