1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Already finished the implementation of "sampleable distribution". The implement has commented out.

use crate::{
    DependentJoint, Distribution, ExactMultivariateNormalParams, IndependentJoint, InverseWishart,
    InverseWishartParams, MultivariateNormal, RandomVariable, SampleableDistribution,
};
use crate::{DistributionError, NormalInverseWishartParams};
use opensrdk_linear_algebra::pp::trf::PPTRF;
use opensrdk_linear_algebra::{SymmetricPackedMatrix, Vector};
use rand::prelude::*;
use std::{ops::BitAnd, ops::Mul};

/// Normal inverse Wishart distribution
#[derive(Clone, Debug)]
pub struct NormalInverseWishart;

#[derive(thiserror::Error, Debug)]
pub enum NormalInverseWishartError {
    #[error("Dimension mismatch")]
    DimensionMismatch,
    #[error("'λ' must be positive")]
    LambdaMustBePositive,
    #[error("'ν' must be >= dimension")]
    NuMustBeGTEDimension,
    #[error("Unknown error")]
    Unknown,
}

impl Distribution for NormalInverseWishart {
    type Value = ExactMultivariateNormalParams;
    type Condition = NormalInverseWishartParams;

    fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
        let mu0 = theta.mu0().clone();
        let lambda = theta.lambda();
        let lpsi = theta.lpsi().clone();
        let nu = theta.nu();
        let dim = mu0.len();

        let mu = x.mu();
        let lsigma = x.lsigma();

        let n = MultivariateNormal::new();
        let w_inv = InverseWishart;

        Ok(n.p_kernel(
            mu,
            &ExactMultivariateNormalParams::new(
                mu0,
                PPTRF(
                    SymmetricPackedMatrix::from(
                        dim,
                        ((1.0 / lambda).sqrt() * lsigma.0.elems().to_vec().col_mat()).vec(),
                    )
                    .unwrap(),
                ),
            )?,
        )? * w_inv.p_kernel(lsigma, &InverseWishartParams::new(lpsi, nu)?)?)
    }
}

impl<Rhs, TRhs> Mul<Rhs> for NormalInverseWishart
where
    Rhs: Distribution<Value = TRhs, Condition = NormalInverseWishartParams>,
    TRhs: RandomVariable,
{
    type Output = IndependentJoint<
        Self,
        Rhs,
        ExactMultivariateNormalParams,
        TRhs,
        NormalInverseWishartParams,
    >;

    fn mul(self, rhs: Rhs) -> Self::Output {
        IndependentJoint::new(self, rhs)
    }
}

impl<Rhs, URhs> BitAnd<Rhs> for NormalInverseWishart
where
    Rhs: Distribution<Value = NormalInverseWishartParams, Condition = URhs>,
    URhs: RandomVariable,
{
    type Output =
        DependentJoint<Self, Rhs, ExactMultivariateNormalParams, NormalInverseWishartParams, URhs>;

    fn bitand(self, rhs: Rhs) -> Self::Output {
        DependentJoint::new(self, rhs)
    }
}

impl SampleableDistribution for NormalInverseWishart {
    fn sample(
        &self,
        theta: &Self::Condition,
        rng: &mut dyn RngCore,
    ) -> Result<Self::Value, DistributionError> {
        let mu0 = theta.mu0().clone();
        let lambda = theta.lambda();
        let lpsi = theta.lpsi().clone();
        let nu = theta.nu();
        let dim = mu0.len();

        let n = MultivariateNormal::new();
        let winv = InverseWishart;

        let lsigma = winv.sample(&InverseWishartParams::new(lpsi, nu)?, rng)?;
        let mu = n.sample(
            &ExactMultivariateNormalParams::new(
                mu0,
                PPTRF(
                    SymmetricPackedMatrix::from(
                        dim,
                        ((1.0 / lambda).sqrt() * lsigma.0.elems().to_vec().col_mat()).vec(),
                    )
                    .unwrap(),
                ),
            )?,
            rng,
        )?;

        Ok(ExactMultivariateNormalParams::new(mu, lsigma)?)
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn it_works() {
        assert_eq!(2 + 2, 4);
    }
}