1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
extern crate ndarray;
extern crate ndarray_linalg;

use ndarray::*;

use rand::prelude::*;
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::StandardNormal;

use crate::pseudoinverse::*;
use crate::wishart::*;
use crate::normal_inverse_wishart::*;
use crate::sqrtm::*;

///A structure that may be used to efficiently
///sample from some fixed [`NormalInverseWishart`]
///distribution.
pub struct NormalInverseWishartSampler {
    wishart : Wishart,
    mean : Array2<f32>,
    covariance_cholesky_factor : Array2<f32>,
    t : usize,
    s : usize
}

impl NormalInverseWishartSampler {
    ///Constructs a new [`NormalInverseWishartSampler`] from the
    ///given [`NormalInverseWishart`] distribution.
    pub fn new(distr : &NormalInverseWishart) -> NormalInverseWishartSampler {
        let big_v_inverse = pseudoinverse_h(&distr.big_v);
        let wishart : Wishart = Wishart::new(big_v_inverse, distr.little_v);
        let mean = distr.mean.clone();
        let covariance_cholesky_factor = sqrtm(&distr.sigma);
        let s = distr.s;
        let t = distr.t;
        NormalInverseWishartSampler {
            wishart,
            mean,
            covariance_cholesky_factor,
            t,
            s
        }
    }
    ///Draws a sample from the [`NormalInverseWishart`] distribution
    ///that this [`NormalInverseWishartSampler`] was constructed with.
    pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
        let out_chol = self.wishart.sample_inv_cholesky_factor(rng);
        let in_chol = &self.covariance_cholesky_factor;

        let X = Array::random((self.t, self.s), StandardNormal);
	let T = out_chol.dot(&X).dot(in_chol);

        let mut result = self.mean.clone();
        result += &T;
        result
    }
}