1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
extern crate ndarray;
extern crate ndarray_linalg;
use ndarray::*;
use rand::prelude::*;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::StandardNormal;
use crate::pseudoinverse::*;
use crate::wishart::*;
use crate::normal_inverse_wishart::*;
use crate::sqrtm::*;
pub struct NormalInverseWishartSampler {
wishart : Wishart,
mean : Array2<f32>,
covariance_cholesky_factor : Array2<f32>,
t : usize,
s : usize
}
impl NormalInverseWishartSampler {
pub fn new(distr : &NormalInverseWishart) -> NormalInverseWishartSampler {
let big_v_inverse = pseudoinverse_h(&distr.big_v);
let wishart : Wishart = Wishart::new(big_v_inverse, distr.little_v);
let mean = distr.mean.clone();
let covariance_cholesky_factor = sqrtm(&distr.sigma);
let s = distr.s;
let t = distr.t;
NormalInverseWishartSampler {
wishart,
mean,
covariance_cholesky_factor,
t,
s
}
}
pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
let out_chol = self.wishart.sample_inv_cholesky_factor(rng);
let in_chol = &self.covariance_cholesky_factor;
let X = Array::random((self.t, self.s), StandardNormal);
let T = out_chol.dot(&X).dot(in_chol);
let mut result = self.mean.clone();
result += &T;
result
}
}