1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use std::ops;
5use ndarray::*;
6use crate::multiple::*;
7use crate::prior_specification::*;
8use crate::input_to_schmeared_output::*;
9use crate::array_utils::*;
10use crate::params::*;
11use crate::data_points::*;
12use crate::pseudoinverse::*;
13use crate::func_scatter_tensor::*;
14use crate::func_inverse_schmear::*;
15use crate::func_schmear::*;
16use crate::schmear::*;
17use crate::data_point::*;
18use crate::sherman_morrison::*;
19use crate::linalg_utils::*;
20use crate::normal_inverse_wishart_sampler::*;
21use crate::function_space_info::*;
22
23use rand::prelude::*;
24
25
26#[derive(Clone)]
29pub struct NormalInverseWishart {
30 pub mean : Array2<f32>,
31 pub precision_u : Array2<f32>,
33 pub precision : Array2<f32>,
34 pub sigma : Array2<f32>,
36 pub big_v : Array2<f32>,
37 pub little_v : f32,
38 pub t : usize,
40 pub s : usize
42}
43
44impl NormalInverseWishart {
45 pub fn recompute_derived(&mut self) {
51 self.sigma = pseudoinverse_h(&self.precision);
52 self.precision_u = self.mean.dot(&self.precision);
53 }
54
55 pub fn get_total_dims(&self) -> usize {
57 self.s * self.t
58 }
59
60 pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
62 let sampler = NormalInverseWishartSampler::new(&self);
63 sampler.sample(rng)
64 }
65
66 pub fn sample_as_vec(&self, rng : &mut ThreadRng) -> Array1<f32> {
68 let thick = self.sample(rng);
69 let total_dims = self.get_total_dims();
70 thick.into_shape((total_dims,)).unwrap()
71 }
72
73 pub fn get_mean_as_vec(&self) -> ArrayView1::<f32> {
75 flatten_matrix(self.mean.view())
76 }
77 pub fn get_mean(&self) -> Array2<f32> {
79 self.mean.clone()
80 }
81 pub fn get_schmear(&self) -> FuncSchmear {
83 FuncSchmear {
84 mean : self.mean.clone(),
85 covariance : self.get_covariance()
86 }
87 }
88 pub fn get_inverse_schmear(&self) -> FuncInverseSchmear {
90 FuncInverseSchmear {
91 mean : self.mean.clone(),
92 precision : self.get_precision()
93 }
94 }
95 pub fn get_precision(&self) -> FuncScatterTensor {
97 let scale = self.little_v - (self.t as f32) - 1.0f32;
98 let mut out_precision = pseudoinverse_h(&self.big_v);
99 out_precision *= scale;
100 FuncScatterTensor {
101 in_scatter : self.precision.clone(),
102 out_scatter : out_precision
103 }
104 }
105 pub fn get_covariance(&self) -> FuncScatterTensor {
107 let scale = 1.0f32 / (self.little_v - (self.t as f32) - 1.0f32);
108 let big_v_scaled = scale * &self.big_v;
109 FuncScatterTensor {
110 in_scatter : self.sigma.clone(),
111 out_scatter : big_v_scaled
112 }
113 }
114}
115
116impl NormalInverseWishart {
117 pub fn from_in_out_dims(prior_specification : &dyn PriorSpecification,
120 feat_dims : usize, out_dims : usize) -> NormalInverseWishart {
121 let mean : Array2<f32> = Array::zeros((out_dims, feat_dims));
122
123 let in_precision_multiplier = prior_specification.get_in_precision_multiplier(feat_dims);
124 let out_covariance_multiplier = prior_specification.get_out_covariance_multiplier(out_dims);
125
126 let in_precision : Array2<f32> = in_precision_multiplier * Array::eye(feat_dims);
127 let out_covariance : Array2<f32> = out_covariance_multiplier * Array::eye(out_dims);
128
129 let little_v = prior_specification.get_out_pseudo_observations(out_dims);
130
131 NormalInverseWishart::new(mean, in_precision, out_covariance, little_v)
132 }
133 pub fn from_space_info(prior_specification : &dyn PriorSpecification,
136 func_space_info : &FunctionSpaceInfo) -> NormalInverseWishart {
137 let feat_dims = func_space_info.get_feature_dimensions();
138 let out_dims = func_space_info.get_output_dimensions();
139
140 NormalInverseWishart::from_in_out_dims(prior_specification, feat_dims, out_dims)
141 }
142 pub fn new(mean : Array2<f32>, precision : Array2<f32>, big_v : Array2<f32>, little_v : f32) -> NormalInverseWishart {
145 let precision_u : Array2<f32> = mean.dot(&precision);
146 let sigma : Array2<f32> = pseudoinverse_h(&precision);
147 let t = mean.shape()[0];
148 let s = mean.shape()[1];
149
150 NormalInverseWishart {
151 mean,
152 precision_u,
153 precision,
154 sigma,
155 big_v,
156 little_v,
157 t,
158 s
159 }
160 }
161}
162
163impl ops::BitXorAssign<()> for NormalInverseWishart {
164 fn bitxor_assign(&mut self, _rhs: ()) {
174 self.precision_u *= -1.0;
175 self.precision *= -1.0;
176 self.sigma *= -1.0;
177 self.little_v = 2.0 * (self.t as f32) - self.little_v;
178 self.big_v *= -1.0;
179 }
180}
181
182fn zero_normal_inverse_wishart(t : usize, s : usize) -> NormalInverseWishart {
185 NormalInverseWishart {
186 mean: Array::zeros((t, s)),
187 precision_u: Array::zeros((t, s)),
188 precision: Array::zeros((t, s)),
189 sigma: Array::zeros((t, s)),
190 big_v: Array::zeros((t, s)),
191 little_v: (t as f32),
192 t,
193 s
194 }
195}
196
197impl NormalInverseWishart {
198 pub fn update_datapoints(&mut self, data_points : &DataPoints) {
201 let n = data_points.num_points();
202 let X = &data_points.in_vecs;
203 let Y = &data_points.out_vecs;
204
205 let XTX = X.t().dot(X);
206 let YTX = Y.t().dot(X);
207
208 let mut out_precision = self.precision.clone();
209 out_precision += &XTX;
210
211 self.sigma = pseudoinverse_h(&out_precision);
212
213 self.precision_u += &YTX;
214
215 let out_mean = self.precision_u.dot(&self.sigma);
216
217 let mean_diff = &out_mean - &self.mean;
218 let mean_diff_t = mean_diff.t().clone();
219 let mean_product = mean_diff.dot(&self.precision).dot(&mean_diff_t);
220 self.big_v += &mean_product;
221
222 let XT = X.t().clone();
223 let BNX = out_mean.dot(&XT);
224 let R = &Y.t() - &BNX;
225 let RT = R.t().clone();
226 let RTR = R.dot(&RT);
227 self.big_v += &RTR;
228
229
230 self.mean = out_mean;
231 self.precision = out_precision;
232 self.little_v += (n as f32);
233 }
234 fn update(&mut self, data_point : &DataPoint, downdate : bool) {
235 let x = &data_point.in_vec;
236 let x_norm_sq = x.dot(x);
237
238 if (x_norm_sq < UPDATE_SQ_NORM_TRUNCATION_THRESH) {
239 return;
240 }
241
242 let y = &data_point.out_vec;
243 let s = if (downdate) {-1.0f32} else {1.0f32};
244 let w = data_point.weight * s;
245
246 let mut out_precision = self.precision.clone();
247 sherman_morrison_update(&mut out_precision, &mut self.sigma, w, x.view());
248
249 self.precision_u += &(w * outer(y.view(), x.view()));
250
251 let out_mean = self.precision_u.dot(&self.sigma);
252
253 self.little_v += w;
254
255 let update_mean = (1.0f32 / x_norm_sq) * outer(y.view(), x.view());
256 let update_precision = w * outer(x.view(), x.view());
257
258 let initial_mean_diff = &out_mean - &self.mean;
259 let update_mean_diff = &out_mean - &update_mean;
260
261 self.big_v += &update_mean_diff.dot(&update_precision).dot(&update_mean_diff.t());
262 self.big_v += &initial_mean_diff.dot(&self.precision).dot(&initial_mean_diff.t());
263
264 if (self.big_v[[0, 0]] < 0.0f32) {
265 println!("Big v became negative due to data update");
266 println!("In vec: {}", &data_point.in_vec);
267 println!("Out vec: {}", &data_point.out_vec);
268 println!("Weight: {}", &data_point.weight);
269 println!("Big v: {}", &self.big_v);
270 }
271
272 self.mean = out_mean;
273 self.precision = out_precision;
274 }
275 fn update_input_to_schmeared_output(&mut self, update : &InputToSchmearedOutput, downdate : bool) {
276 let data_point = DataPoint {
277 in_vec : update.in_vec.clone(),
278 out_vec : update.out_schmear.mean.clone(),
279 weight : 1.0f32
280 };
281 if (downdate) {
282 self.big_v -= &update.out_schmear.covariance;
283 }
284 self.update(&data_point, downdate);
285 if (!downdate) {
286 self.big_v += &update.out_schmear.covariance;
287 }
288 }
289}
290
291impl ops::AddAssign<&Multiple<InputToSchmearedOutput>> for NormalInverseWishart {
292 fn add_assign(&mut self, update : &Multiple<InputToSchmearedOutput>) {
293 for _ in 0..update.count {
295 self.update_input_to_schmeared_output(&update.elem, false);
296 }
297 }
298}
299
300impl ops::SubAssign<&Multiple<InputToSchmearedOutput>> for NormalInverseWishart {
301 fn sub_assign(&mut self, update : &Multiple<InputToSchmearedOutput>) {
302 for _ in 0..update.count {
303 self.update_input_to_schmeared_output(&update.elem, true);
304 }
305 }
306}
307
308impl ops::AddAssign<&InputToSchmearedOutput> for NormalInverseWishart {
309 fn add_assign(&mut self, update : &InputToSchmearedOutput) {
312 self.update_input_to_schmeared_output(update, false);
313 }
314}
315
316impl ops::SubAssign<&InputToSchmearedOutput> for NormalInverseWishart {
317 fn sub_assign(&mut self, update : &InputToSchmearedOutput) {
320 self.update_input_to_schmeared_output(update, true);
321 }
322}
323
324impl ops::AddAssign<&DataPoint> for NormalInverseWishart {
325 fn add_assign(&mut self, other: &DataPoint) {
328 self.update(other, false)
329 }
330}
331
332impl ops::SubAssign<&DataPoint> for NormalInverseWishart {
333 fn sub_assign(&mut self, other: &DataPoint) {
336 self.update(other, true)
337 }
338}
339
340impl NormalInverseWishart {
341 fn update_combine(&mut self, other : &NormalInverseWishart, downdate : bool) {
342 let s = if (downdate) {-1.0f32} else {1.0f32};
343
344 let mut other_precision = other.precision.clone();
345 other_precision *= s;
346
347 let mut other_big_v = other.big_v.clone();
348 other_big_v *= s;
349
350 let other_mean = &other.mean;
351 let other_little_v = if (downdate) {(self.t as f32) * 2.0f32 - other.little_v} else {other.little_v};
352
353 let mut other_precision_u = other.precision_u.clone();
354 other_precision_u *= s;
355
356
357 self.precision_u += &other_precision_u;
358
359 let precision_out = &self.precision + &other_precision;
360
361 self.sigma = pseudoinverse_h(&precision_out);
362
363 let mean_out = self.precision_u.dot(&self.sigma);
364
365 let mean_one_diff = &self.mean - &mean_out;
366 let mean_two_diff = other_mean - &mean_out;
367
368 let u_diff_l_u_diff_one = mean_one_diff.dot(&self.precision).dot(&mean_one_diff.t());
369 let u_diff_l_u_diff_two = mean_two_diff.dot(&other_precision).dot(&mean_two_diff.t());
370
371 self.little_v += other_little_v - (self.t as f32);
372 self.precision = precision_out;
373 self.mean = mean_out;
374
375 self.big_v += &other_big_v;
376 self.big_v += &u_diff_l_u_diff_one;
377 self.big_v += &u_diff_l_u_diff_two;
378
379 if (self.big_v[[0, 0]] < 0.0f32) {
380 println!("Big v became negative due to prior update");
381 println!("Big v: {}", &self.big_v);
382 println!("Other big v: {}", &other_big_v);
383 }
384 }
385}
386
387impl ops::AddAssign<&Multiple<NormalInverseWishart>> for NormalInverseWishart {
388 fn add_assign(&mut self, other : &Multiple<NormalInverseWishart>) {
389 for _ in 0..other.count {
392 self.update_combine(&other.elem, false);
393 }
394 }
395}
396
397impl ops::SubAssign<&Multiple<NormalInverseWishart>> for NormalInverseWishart {
398 fn sub_assign(&mut self, other : &Multiple<NormalInverseWishart>) {
399 for _ in 0..other.count {
400 self.update_combine(&other.elem, true);
401 }
402 }
403}
404
405impl ops::AddAssign<&NormalInverseWishart> for NormalInverseWishart {
406 fn add_assign(&mut self, other: &NormalInverseWishart) {
409 self.update_combine(other, false);
410 }
411}
412impl ops::SubAssign<&NormalInverseWishart> for NormalInverseWishart {
413 fn sub_assign(&mut self, other : &NormalInverseWishart) {
416 self.update_combine(other, true);
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::test_utils::*;
424 use crate::type_id::*;
425
426 #[test]
427 fn prior_updates_undo_cleanly() {
428 let ctxt = get_test_vector_only_context();
429 let expected = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
430
431 let mut model = expected.clone();
432 let other = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
433
434 model.data += &other.data;
435 model.data -= &other.data;
436
437 assert_equal_distributions_to_within(&model.data, &expected.data, 1.0f32);
438 }
439
440 #[test]
441 fn test_model_convergence_noiseless() {
442 let num_samps = 1000;
443 let s = 5;
444 let t = 4;
445 let out_weight = 100.0f32;
446 let mut model = standard_normal_inverse_wishart(s, t);
447
448 let mat = random_matrix(t, s);
449 for _ in 0..num_samps {
450 let vec = random_vector(s);
451 let out = mat.dot(&vec);
452
453 let data_point = DataPoint {
454 in_vec : vec,
455 out_vec : out,
456 weight : out_weight
457 };
458
459 model += &data_point;
460 }
461
462 assert_equal_matrices(model.mean.view(), mat.view());
463 }
464}