fetish_lib/
normal_inverse_wishart.rs

1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use std::ops;
5use ndarray::*;
6use crate::multiple::*;
7use crate::prior_specification::*;
8use crate::input_to_schmeared_output::*;
9use crate::array_utils::*;
10use crate::params::*;
11use crate::data_points::*;
12use crate::pseudoinverse::*;
13use crate::func_scatter_tensor::*;
14use crate::func_inverse_schmear::*;
15use crate::func_schmear::*;
16use crate::schmear::*;
17use crate::data_point::*;
18use crate::sherman_morrison::*;
19use crate::linalg_utils::*;
20use crate::normal_inverse_wishart_sampler::*;
21use crate::function_space_info::*;
22
23use rand::prelude::*;
24
25
26///Matrix-normal-inverse-wishart distribution representation
27///for bayesian inference
28#[derive(Clone)]
29pub struct NormalInverseWishart {
30    pub mean : Array2<f32>,
31    ///This is always maintained to equal `mean.dot(&precision)`
32    pub precision_u : Array2<f32>,
33    pub precision : Array2<f32>,
34    ///This is always maintained to equal `pseudoinverse_h(&precision)`
35    pub sigma : Array2<f32>,
36    pub big_v : Array2<f32>,
37    pub little_v : f32,
38    ///The output dimensionality
39    pub t : usize,
40    ///The input dimensionality
41    pub s : usize
42}
43
44impl NormalInverseWishart {
45    ///Manually re-computes `self.sigma` and `self.precision_u` to
46    ///align with their definitions based on the other fields in `self`.
47    ///This method takes cubic time, so it's not recommended to call this
48    ///unless you observe issues with these fields accumulating numerical
49    ///errors away from what they should be.
50    pub fn recompute_derived(&mut self) {
51        self.sigma = pseudoinverse_h(&self.precision);
52        self.precision_u = self.mean.dot(&self.precision);
53    }
54
55    ///Gets the total dimensionality of `self.mean`.
56    pub fn get_total_dims(&self) -> usize {
57        self.s * self.t
58    }
59
60    ///Draws a sample from the represented MNIW distribution
61    pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
62        let sampler = NormalInverseWishartSampler::new(&self);
63        sampler.sample(rng)
64    }
65
66    ///The same as [`Self::sample`], but the result is flattened.
67    pub fn sample_as_vec(&self, rng : &mut ThreadRng) -> Array1<f32> {
68        let thick = self.sample(rng);
69        let total_dims = self.get_total_dims();
70        thick.into_shape((total_dims,)).unwrap()
71    }
72
73    ///Returns the mean of the represented MNIW distribution, but flattened to be a vector.
74    pub fn get_mean_as_vec(&self) -> ArrayView1::<f32> {
75        flatten_matrix(self.mean.view())
76    }
77    ///Returns the mean of the represented MNIW distribution as a linear map.
78    pub fn get_mean(&self) -> Array2<f32> {
79        self.mean.clone()
80    }
81    ///Gets the [`FuncSchmear`] over linear mappings given by this MNIW distribution
82    pub fn get_schmear(&self) -> FuncSchmear {
83        FuncSchmear {
84            mean : self.mean.clone(),
85            covariance : self.get_covariance()
86        }
87    }
88    ///Gets the [`FuncInverseSchmear`] over linear mappings given by this MNIW distribution.
89    pub fn get_inverse_schmear(&self) -> FuncInverseSchmear {
90        FuncInverseSchmear {
91            mean : self.mean.clone(),
92            precision : self.get_precision()
93        }
94    }
95    ///Gets the precision (inverse covariance) [`FuncScatterTensor`] of this MNIW distribution.
96    pub fn get_precision(&self) -> FuncScatterTensor {
97        let scale = self.little_v - (self.t as f32) - 1.0f32;
98        let mut out_precision = pseudoinverse_h(&self.big_v);
99        out_precision *= scale;
100        FuncScatterTensor {
101            in_scatter : self.precision.clone(),
102            out_scatter : out_precision
103        }
104    }
105    ///Gets the covariance [`FuncScatterTensor`] of this MNIW distribution.
106    pub fn get_covariance(&self) -> FuncScatterTensor {
107        let scale = 1.0f32 / (self.little_v - (self.t as f32) - 1.0f32);
108        let big_v_scaled = scale * &self.big_v;
109        FuncScatterTensor {
110            in_scatter : self.sigma.clone(),
111            out_scatter : big_v_scaled
112        }
113    }
114}
115
116impl NormalInverseWishart {
117    ///Constructs a [`NormalInverseWishart`] distribution from the given [`PriorSpecification`],
118    ///the given feature dimensions, and the given output dimensions.
119    pub fn from_in_out_dims(prior_specification : &dyn PriorSpecification,
120                            feat_dims : usize, out_dims : usize) -> NormalInverseWishart {
121        let mean : Array2<f32> = Array::zeros((out_dims, feat_dims));
122
123        let in_precision_multiplier = prior_specification.get_in_precision_multiplier(feat_dims);
124        let out_covariance_multiplier = prior_specification.get_out_covariance_multiplier(out_dims);
125
126        let in_precision : Array2<f32> = in_precision_multiplier * Array::eye(feat_dims);
127        let out_covariance : Array2<f32> = out_covariance_multiplier * Array::eye(out_dims);
128
129        let little_v = prior_specification.get_out_pseudo_observations(out_dims);
130
131        NormalInverseWishart::new(mean, in_precision, out_covariance, little_v)
132    }
133    ///Constructs a [`NormalInverseWishart`] distribution from the given [`PriorSpecification`]
134    ///and the given [`FunctionSpaceInfo`].
135    pub fn from_space_info(prior_specification : &dyn PriorSpecification,
136                           func_space_info : &FunctionSpaceInfo) -> NormalInverseWishart {
137        let feat_dims = func_space_info.get_feature_dimensions();
138        let out_dims = func_space_info.get_output_dimensions();
139
140        NormalInverseWishart::from_in_out_dims(prior_specification, feat_dims, out_dims)
141    }
142    ///Constructs a [`NormalInverseWishart`] distribution with the given mean, input precision,
143    ///total output error covariance, and (pseudo-)observation count.
144    pub fn new(mean : Array2<f32>, precision : Array2<f32>, big_v : Array2<f32>, little_v : f32) -> NormalInverseWishart {
145        let precision_u : Array2<f32> = mean.dot(&precision);
146        let sigma : Array2<f32> = pseudoinverse_h(&precision);
147        let t = mean.shape()[0];
148        let s = mean.shape()[1];
149
150        NormalInverseWishart {
151            mean,
152            precision_u,
153            precision,
154            sigma,
155            big_v,
156            little_v,
157            t,
158            s
159        }
160    }
161}
162
163impl ops::BitXorAssign<()> for NormalInverseWishart {
164    ///Inverts this [`NormalInverseWishart`] distribution in place 
165    ///with respect to the addition operation given by the MNIW sum. This
166    ///always will satisfy 
167    ///`
168    ///let mut other = self.clone();
169    ///other ^= ();
170    ///other += self;
171    ///self == zero_normal_inverse_wishart(self.t, self.s)
172    ///`
173    fn bitxor_assign(&mut self, _rhs: ()) {
174        self.precision_u *= -1.0;
175        self.precision *= -1.0;
176        self.sigma *= -1.0;
177        self.little_v = 2.0 * (self.t as f32) - self.little_v;
178        self.big_v *= -1.0;
179    }
180}
181
182///Constructs the zero element with respect to the MNIW sum, of the given
183///output and input dimensions, respectively.
184fn zero_normal_inverse_wishart(t : usize, s : usize) -> NormalInverseWishart {
185    NormalInverseWishart {
186        mean: Array::zeros((t, s)),
187        precision_u: Array::zeros((t, s)),
188        precision: Array::zeros((t, s)),
189        sigma: Array::zeros((t, s)),
190        big_v: Array::zeros((t, s)),
191        little_v: (t as f32),
192        t,
193        s
194    }
195}
196
197impl NormalInverseWishart {
198    ///Updates this [`NormalInverseWishart`] distribution to reflect
199    ///new data-points for linear regression.
200    pub fn update_datapoints(&mut self, data_points : &DataPoints) {
201        let n = data_points.num_points();
202        let X = &data_points.in_vecs;
203        let Y = &data_points.out_vecs;
204
205        let XTX = X.t().dot(X);
206        let YTX = Y.t().dot(X);
207
208        let mut out_precision = self.precision.clone();
209        out_precision += &XTX;
210
211        self.sigma = pseudoinverse_h(&out_precision);
212
213        self.precision_u += &YTX;
214
215        let out_mean = self.precision_u.dot(&self.sigma);
216
217        let mean_diff = &out_mean - &self.mean;
218        let mean_diff_t = mean_diff.t().clone();
219        let mean_product = mean_diff.dot(&self.precision).dot(&mean_diff_t);
220        self.big_v += &mean_product;
221        
222        let XT = X.t().clone();
223        let BNX = out_mean.dot(&XT);
224        let R = &Y.t() - &BNX;
225        let RT = R.t().clone();
226        let RTR = R.dot(&RT);
227        self.big_v += &RTR;
228
229
230        self.mean = out_mean;
231        self.precision = out_precision;
232        self.little_v += (n as f32);
233    }
234    fn update(&mut self, data_point : &DataPoint, downdate : bool) {
235        let x = &data_point.in_vec;
236        let x_norm_sq = x.dot(x);
237        
238        if (x_norm_sq < UPDATE_SQ_NORM_TRUNCATION_THRESH) {
239            return;
240        }
241
242        let y = &data_point.out_vec;
243        let s = if (downdate) {-1.0f32} else {1.0f32};
244        let w = data_point.weight * s;
245
246        let mut out_precision = self.precision.clone();
247        sherman_morrison_update(&mut out_precision, &mut self.sigma, w, x.view());
248
249        self.precision_u += &(w * outer(y.view(), x.view()));
250
251        let out_mean = self.precision_u.dot(&self.sigma);
252
253        self.little_v += w;
254
255        let update_mean = (1.0f32 / x_norm_sq) * outer(y.view(), x.view());
256        let update_precision = w * outer(x.view(), x.view());
257
258        let initial_mean_diff = &out_mean - &self.mean;
259        let update_mean_diff = &out_mean - &update_mean;
260
261        self.big_v += &update_mean_diff.dot(&update_precision).dot(&update_mean_diff.t());
262        self.big_v += &initial_mean_diff.dot(&self.precision).dot(&initial_mean_diff.t());
263
264        if (self.big_v[[0, 0]] < 0.0f32) {
265            println!("Big v became negative due to data update");
266            println!("In vec: {}", &data_point.in_vec);
267            println!("Out vec: {}", &data_point.out_vec);
268            println!("Weight: {}", &data_point.weight);
269            println!("Big v: {}", &self.big_v);
270        }
271
272        self.mean = out_mean;
273        self.precision = out_precision;
274    }
275    fn update_input_to_schmeared_output(&mut self, update : &InputToSchmearedOutput, downdate : bool) {
276        let data_point = DataPoint {
277            in_vec : update.in_vec.clone(),
278            out_vec : update.out_schmear.mean.clone(),
279            weight : 1.0f32
280        };
281        if (downdate) {
282            self.big_v -= &update.out_schmear.covariance;
283        }
284        self.update(&data_point, downdate);
285        if (!downdate) {
286            self.big_v += &update.out_schmear.covariance;
287        }
288    }
289}
290
291impl ops::AddAssign<&Multiple<InputToSchmearedOutput>> for NormalInverseWishart {
292    fn add_assign(&mut self, update : &Multiple<InputToSchmearedOutput>) {
293        //TODO: closed form expression here
294        for _ in 0..update.count {
295            self.update_input_to_schmeared_output(&update.elem, false);
296        }
297    }
298}
299
300impl ops::SubAssign<&Multiple<InputToSchmearedOutput>> for NormalInverseWishart {
301    fn sub_assign(&mut self, update : &Multiple<InputToSchmearedOutput>) {
302        for _ in 0..update.count {
303            self.update_input_to_schmeared_output(&update.elem, true);
304        }
305    }
306}
307
308impl ops::AddAssign<&InputToSchmearedOutput> for NormalInverseWishart {
309    ///Updates this [`NormalInverseWishart`] distribution to incorporate
310    ///regression information from the given [`InputToSchmearedOutput`].
311    fn add_assign(&mut self, update : &InputToSchmearedOutput) {
312        self.update_input_to_schmeared_output(update, false);
313    }
314}
315
316impl ops::SubAssign<&InputToSchmearedOutput> for NormalInverseWishart {
317    ///Updates this [`NormalInverseWishart`] distribution to remove
318    ///regression information from the given [`InputToSchmearedOutput`].
319    fn sub_assign(&mut self, update : &InputToSchmearedOutput) {
320        self.update_input_to_schmeared_output(update, true);
321    }
322}
323
324impl ops::AddAssign<&DataPoint> for NormalInverseWishart {
325    ///Updates this [`NormalInverseWishart`] distribution to incorporate
326    ///regression information from the given [`DataPoint`].
327    fn add_assign(&mut self, other: &DataPoint) {
328        self.update(other, false)
329    }
330}
331
332impl ops::SubAssign<&DataPoint> for NormalInverseWishart {
333    ///Updates this [`NormalInverseWishart`] distribution to remove
334    ///regression information from the given [`DataPoint`].
335    fn sub_assign(&mut self, other: &DataPoint) {
336        self.update(other, true)
337    }
338}
339
340impl NormalInverseWishart {
341    fn update_combine(&mut self, other : &NormalInverseWishart, downdate : bool) {
342        let s = if (downdate) {-1.0f32} else {1.0f32};
343        
344        let mut other_precision = other.precision.clone();
345        other_precision *= s;
346
347        let mut other_big_v = other.big_v.clone();
348        other_big_v *= s;
349
350        let other_mean = &other.mean;
351        let other_little_v = if (downdate) {(self.t as f32) * 2.0f32 - other.little_v} else {other.little_v};
352
353        let mut other_precision_u = other.precision_u.clone();
354        other_precision_u *= s;
355
356
357        self.precision_u += &other_precision_u;
358
359        let precision_out = &self.precision + &other_precision;
360
361        self.sigma = pseudoinverse_h(&precision_out);
362
363        let mean_out = self.precision_u.dot(&self.sigma);
364
365        let mean_one_diff = &self.mean - &mean_out;
366        let mean_two_diff = other_mean - &mean_out;
367        
368        let u_diff_l_u_diff_one = mean_one_diff.dot(&self.precision).dot(&mean_one_diff.t());
369        let u_diff_l_u_diff_two = mean_two_diff.dot(&other_precision).dot(&mean_two_diff.t());
370
371        self.little_v += other_little_v - (self.t as f32);
372        self.precision = precision_out;
373        self.mean = mean_out;
374
375        self.big_v += &other_big_v;
376        self.big_v += &u_diff_l_u_diff_one;
377        self.big_v += &u_diff_l_u_diff_two;
378
379        if (self.big_v[[0, 0]] < 0.0f32) {
380            println!("Big v became negative due to prior update");
381            println!("Big v: {}", &self.big_v);
382            println!("Other big v: {}", &other_big_v);
383        }
384    }
385}
386
387impl ops::AddAssign<&Multiple<NormalInverseWishart>> for NormalInverseWishart {
388    fn add_assign(&mut self, other : &Multiple<NormalInverseWishart>) {
389        //TODO: Use a closed-form expression, for this and for the subtraction
390        //and then just make tests verifying that it matches
391        for _ in 0..other.count {
392            self.update_combine(&other.elem, false);
393        }
394    }
395}
396
397impl ops::SubAssign<&Multiple<NormalInverseWishart>> for NormalInverseWishart {
398    fn sub_assign(&mut self, other : &Multiple<NormalInverseWishart>) {
399        for _ in 0..other.count {
400            self.update_combine(&other.elem, true);
401        }
402    }
403}
404
405impl ops::AddAssign<&NormalInverseWishart> for NormalInverseWishart {
406    ///Updates this [`NormalInverseWishart`] distribution to the 
407    ///MNIW-sum of it and the passed in [`NormalInverseWishart`] distribution.
408    fn add_assign(&mut self, other: &NormalInverseWishart) {
409        self.update_combine(other, false);
410    }
411}
412impl ops::SubAssign<&NormalInverseWishart> for NormalInverseWishart {
413    ///Updates this [`NormalInverseWishart`] distribution to the 
414    ///MNIW-sum of it and the additive inverse of the passed in [`NormalInverseWishart`] distribution.
415    fn sub_assign(&mut self, other : &NormalInverseWishart) {
416        self.update_combine(other, true);
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::test_utils::*;
424    use crate::type_id::*;
425
426    #[test]
427    fn prior_updates_undo_cleanly() {
428        let ctxt = get_test_vector_only_context();
429        let expected = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
430
431        let mut model = expected.clone();
432        let other = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
433
434        model.data += &other.data;
435        model.data -= &other.data;
436
437        assert_equal_distributions_to_within(&model.data, &expected.data, 1.0f32);
438    }
439
440    #[test]
441    fn test_model_convergence_noiseless() {
442        let num_samps = 1000;
443        let s = 5;
444        let t = 4;
445        let out_weight = 100.0f32;
446        let mut model = standard_normal_inverse_wishart(s, t);
447
448        let mat = random_matrix(t, s);
449        for _ in 0..num_samps {
450            let vec = random_vector(s);
451            let out = mat.dot(&vec);
452
453            let data_point = DataPoint {
454                in_vec : vec,
455                out_vec : out,
456                weight : out_weight
457            };
458
459            model += &data_point;
460        }
461
462        assert_equal_matrices(model.mean.view(), mat.view());
463    }
464}