1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use crate::{
DependentJoint, Distribution, ExactMultivariateNormalParams, IndependentJoint, InverseWishart,
InverseWishartParams, MultivariateNormal, RandomVariable, SampleableDistribution,
};
use crate::{DistributionError, NormalInverseWishartParams};
use opensrdk_linear_algebra::pp::trf::PPTRF;
use opensrdk_linear_algebra::{SymmetricPackedMatrix, Vector};
use rand::prelude::*;
use std::{ops::BitAnd, ops::Mul};
#[derive(Clone, Debug)]
pub struct NormalInverseWishart;
#[derive(thiserror::Error, Debug)]
pub enum NormalInverseWishartError {
#[error("Dimension mismatch")]
DimensionMismatch,
#[error("'λ' must be positive")]
LambdaMustBePositive,
#[error("'ν' must be >= dimension")]
NuMustBeGTEDimension,
#[error("Unknown error")]
Unknown,
}
impl Distribution for NormalInverseWishart {
type Value = ExactMultivariateNormalParams;
type Condition = NormalInverseWishartParams;
fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
let mu0 = theta.mu0().clone();
let lambda = theta.lambda();
let lpsi = theta.lpsi().clone();
let nu = theta.nu();
let dim = mu0.len();
let mu = x.mu();
let lsigma = x.lsigma();
let n = MultivariateNormal::new();
let w_inv = InverseWishart;
Ok(n.p_kernel(
mu,
&ExactMultivariateNormalParams::new(
mu0,
PPTRF(
SymmetricPackedMatrix::from(
dim,
((1.0 / lambda).sqrt() * lsigma.0.elems().to_vec().col_mat()).vec(),
)
.unwrap(),
),
)?,
)? * w_inv.p_kernel(lsigma, &InverseWishartParams::new(lpsi, nu)?)?)
}
}
impl<Rhs, TRhs> Mul<Rhs> for NormalInverseWishart
where
Rhs: Distribution<Value = TRhs, Condition = NormalInverseWishartParams>,
TRhs: RandomVariable,
{
type Output = IndependentJoint<
Self,
Rhs,
ExactMultivariateNormalParams,
TRhs,
NormalInverseWishartParams,
>;
fn mul(self, rhs: Rhs) -> Self::Output {
IndependentJoint::new(self, rhs)
}
}
impl<Rhs, URhs> BitAnd<Rhs> for NormalInverseWishart
where
Rhs: Distribution<Value = NormalInverseWishartParams, Condition = URhs>,
URhs: RandomVariable,
{
type Output =
DependentJoint<Self, Rhs, ExactMultivariateNormalParams, NormalInverseWishartParams, URhs>;
fn bitand(self, rhs: Rhs) -> Self::Output {
DependentJoint::new(self, rhs)
}
}
impl SampleableDistribution for NormalInverseWishart {
fn sample(
&self,
theta: &Self::Condition,
rng: &mut dyn RngCore,
) -> Result<Self::Value, DistributionError> {
let mu0 = theta.mu0().clone();
let lambda = theta.lambda();
let lpsi = theta.lpsi().clone();
let nu = theta.nu();
let dim = mu0.len();
let n = MultivariateNormal::new();
let winv = InverseWishart;
let lsigma = winv.sample(&InverseWishartParams::new(lpsi, nu)?, rng)?;
let mu = n.sample(
&ExactMultivariateNormalParams::new(
mu0,
PPTRF(
SymmetricPackedMatrix::from(
dim,
((1.0 / lambda).sqrt() * lsigma.0.elems().to_vec().col_mat()).vec(),
)
.unwrap(),
),
)?,
rng,
)?;
Ok(ExactMultivariateNormalParams::new(mu, lsigma)?)
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}