fetish_lib/
func_schmear.rs

1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use ndarray::*;
5
6use crate::schmear::*;
7use crate::linalg_utils::*;
8use crate::func_scatter_tensor::*;
9
10///Represents a probability distribution over linear mappings
11///whose covariance structure is separable between input and output.
12///This can be conceptualized as a specialization of [`Schmear`]
13///for a distribution over matrices where the covariance of the (vectorized)
14///random variable takes the separable form `kron(out_covariance, in_covariance)`
15///for `out_covariance` and `in_covariance` output coordinate and input coordinate
16///covariances, respectively, and `kron` referring to [`crate::linalg_utils::kron`] 
17///See also [`FuncScatterTensor`].
18pub struct FuncSchmear {
19    pub mean : Array2<f32>,
20    pub covariance : FuncScatterTensor
21}
22
23impl FuncSchmear {
24    ///Given a transformation matrix from the full, flattened dimension
25    ///of this [`FuncSchmear`] to a smaller dimension, performs a fused
26    ///[`FuncSchmear::flatten`] and [`Schmear::transform`] operation using
27    ///the specified transformation. This fused operation is written
28    ///to be much faster than manually performing the aforementioned operations.
29    pub fn compress(&self, mat : ArrayView2<f32>) -> Schmear {
30        let t = self.mean.shape()[0];
31        let s = self.mean.shape()[1];
32
33        let mean_flat = self.mean.clone().into_shape((t * s,)).unwrap();
34        let mean_transformed = mat.dot(&mean_flat);
35
36        let covariance_transformed  = self.covariance.compress(mat);
37
38        Schmear {
39            mean : mean_transformed,
40            covariance : covariance_transformed
41        }
42    }
43    
44    ///Converts this [`FuncSchmear`] over linear maps to its corresponding [`Schmear`]
45    ///over vectorized linear mappings.
46    pub fn flatten(&self) -> Schmear {
47        let t = self.mean.shape()[0];
48        let s = self.mean.shape()[1];
49
50        let mean = self.mean.clone().into_shape((t * s,)).unwrap();
51        let covariance = self.covariance.flatten();
52        Schmear {
53            mean,
54            covariance
55        }
56    }
57    ///Computes the output [`Schmear`] of this [`FuncSchmear`] applied
58    ///to a given argument [`Schmear`].
59    ///Given an input schmear, computes the output schmear which
60    ///would result from sampling `(function, input)` pairs,
61    ///computing `function(input)` for each of them, and then
62    ///obtaining the [`Schmear`] over those results.
63    pub fn apply(&self, x : &Schmear) -> Schmear {
64        let sigma_dot_u = frob_inner(self.covariance.in_scatter.view(), x.covariance.view());
65        let u_inner_product = x.mean.dot(&self.covariance.in_scatter).dot(&x.mean);
66        let v_scale = sigma_dot_u + u_inner_product;
67        let v_contrib = v_scale * &self.covariance.out_scatter;
68
69        if (v_scale < 0.0f32) {
70            println!("v scale became negative: {}", v_scale);
71            println!("components: {}, {}",  sigma_dot_u, u_inner_product);
72            if (u_inner_product < 0.0f32) {
73                println!("Non-psd in scatter: {}", &self.covariance.in_scatter);
74                println!("x mean: {}", &x.mean);
75            }
76        }
77
78        let m_sigma_m_t = self.mean.dot(&x.covariance).dot(&self.mean.t());
79
80        let result_covariance = v_contrib + &m_sigma_m_t;
81        let result_mean = self.mean.dot(&x.mean);
82        let result = Schmear {
83            mean : result_mean,
84            covariance : result_covariance
85        };
86        result
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use ndarray::*;
94    use ::rand_distr::StandardNormal;
95    use ndarray_linalg::*;
96    use crate::test_utils::*;
97    use ndarray_rand::*;
98
99    #[test]
100    fn schmear_application_accurate() {
101        let t = 3;
102        let s = 3;
103        let num_samps = 1000;
104        let arg_scale_mult = 0.01f32;
105
106        let normal_inverse_wishart = random_normal_inverse_wishart(s, t);
107
108        let func_schmear = normal_inverse_wishart.get_schmear();
109
110        let arg_mean = random_vector(s);
111        let mut arg_covariance_sqrt = random_matrix(s, s);
112        arg_covariance_sqrt *= arg_scale_mult; 
113        let arg_covariance = arg_covariance_sqrt.dot(&arg_covariance_sqrt.t());
114        let arg_schmear = Schmear {
115            mean : arg_mean.clone(),
116            covariance : arg_covariance.clone()
117        };
118
119        let actual_out_schmear = func_schmear.apply(&arg_schmear);
120
121        let mut expected_out_mean = Array::zeros((t,));
122        let mut expected_out_covariance = Array::zeros((t, t));
123
124        let mut rng = rand::thread_rng();
125
126        let scale_fac = 1.0f32 / (num_samps as f32);
127
128        for _ in 0..num_samps {
129            let func_samp = normal_inverse_wishart.sample(&mut rng);
130
131            let standard_normal_arg_vec = Array::random((s,), StandardNormal);
132            let arg_samp = &arg_mean + &arg_covariance_sqrt.dot(&standard_normal_arg_vec);
133
134            let out_samp = func_samp.dot(&arg_samp);
135
136            let out_diff = &out_samp - &actual_out_schmear.mean;
137
138            expected_out_mean += &(scale_fac * &out_samp);
139            expected_out_covariance += &(scale_fac * &outer(out_diff.view(), out_diff.view()));
140        }
141
142        assert_equal_vectors_to_within(actual_out_schmear.mean.view(), expected_out_mean.view(), 1.0f32);
143        assert_equal_matrices_to_within(actual_out_schmear.covariance.view(), expected_out_covariance.view(), 10.0f32);
144    }
145}
146