1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
// Copyright 2022 The Ferric AI Project Developers
use nalgebra::{Cholesky, DMatrix, DVector, Dyn};
use rand::Rng;
use rand_distr::Distribution as Distribution2;
use rand_distr::Normal as Normal2;
use crate::distributions::Distribution;
/// Multivariate normal (Gaussian) distribution over $\mathbb{R}^k$.
///
/// The PDF is
///
/// $$p(\mathbf{x} \mid \boldsymbol\mu, \boldsymbol\Sigma) =
/// \frac{1}{(2\pi)^{k/2} |\boldsymbol\Sigma|^{1/2}}
/// \exp\!\left(-\tfrac{1}{2}(\mathbf{x}-\boldsymbol\mu)^\top
/// \boldsymbol\Sigma^{-1}(\mathbf{x}-\boldsymbol\mu)\right)$$
///
/// where $\boldsymbol\mu \in \mathbb{R}^k$ is the mean vector and
/// $\boldsymbol\Sigma$ is a $k \times k$ symmetric positive-definite
/// covariance matrix.
///
/// Internally the constructor performs a Cholesky decomposition
/// $\boldsymbol\Sigma = LL^\top$, which is reused for both sampling and
/// log-probability evaluation.
///
/// # Examples
///
/// ```
/// use ferric::distributions::{Distribution, MultivariateNormal};
/// use nalgebra::{DMatrix, DVector};
/// use rand::thread_rng;
///
/// let mean = DVector::from_vec(vec![0.0, 0.0]);
/// let cov = DMatrix::identity(2, 2);
/// let dist = MultivariateNormal::new(mean, cov).unwrap();
/// let x: DVector<f64> = dist.sample(&mut thread_rng());
/// println!("sample = {:?}", x.as_slice());
/// ```
pub struct MultivariateNormal {
mean: DVector<f64>,
chol: Cholesky<f64, Dyn>,
}
impl MultivariateNormal {
/// Construct a MultivariateNormal with mean vector `mean` ($\boldsymbol\mu$)
/// and covariance matrix `cov` ($\boldsymbol\Sigma$).
///
/// # Errors
///
/// Returns `Err` if the dimensions of `cov` do not match the length of
/// `mean`, or if `cov` is not symmetric positive-definite.
pub fn new(mean: DVector<f64>, cov: DMatrix<f64>) -> Result<MultivariateNormal, String> {
let k = mean.len();
if cov.nrows() != k || cov.ncols() != k {
return Err(format!(
"MultivariateNormal: covariance {}×{} must match mean length {}",
cov.nrows(),
cov.ncols(),
k
));
}
match Cholesky::new(cov) {
Some(chol) => Ok(MultivariateNormal { mean, chol }),
None => Err("MultivariateNormal: covariance is not positive definite".to_string()),
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for MultivariateNormal {
type Domain = DVector<f64>;
/// Draw one sample via $\mathbf{x} = \boldsymbol\mu + L\mathbf{z}$
/// where $L$ is the lower Cholesky factor of $\boldsymbol\Sigma$ and
/// $\mathbf{z} \sim \mathcal{N}(\mathbf{0}, I_k)$.
fn sample(&self, rng: &mut R) -> DVector<f64> {
let k = self.mean.len();
let normal = Normal2::new(0.0, 1.0).unwrap();
let z = DVector::from_fn(k, |_, _| normal.sample(rng));
&self.mean + self.chol.l() * z
}
/// Returns
/// $-\tfrac{k}{2}\ln(2\pi)
/// - \tfrac{1}{2}\ln|\boldsymbol\Sigma|
/// - \tfrac{1}{2}(\mathbf{x}-\boldsymbol\mu)^\top
/// \boldsymbol\Sigma^{-1}(\mathbf{x}-\boldsymbol\mu)$.
///
/// Uses the Cholesky factor $L$ stored at construction time:
/// $\ln|\boldsymbol\Sigma| = 2\sum_i \ln L_{ii}$ and the Mahalanobis
/// term is computed via back-substitution.
fn log_prob(&self, x: &DVector<f64>) -> f64 {
let k = self.mean.len() as f64;
let diff = x - &self.mean;
let sol = self.chol.solve(&diff);
let mahal_sq = diff.dot(&sol);
let l = self.chol.l();
let log_det: f64 = (0..l.nrows()).map(|i| l[(i, i)].ln()).sum::<f64>() * 2.0;
-0.5 * (k * (2.0 * std::f64::consts::PI).ln() + log_det + mahal_sq)
}
fn is_discrete(&self) -> bool {
false
}
}
impl std::fmt::Display for MultivariateNormal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"MultivariateNormal {{ k = {}, mean = {} }}",
self.mean.len(),
self.mean.transpose()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ThreadRng;
use rand::thread_rng;
#[test]
fn mvn_sample() {
let mut rng = thread_rng();
let k = 3;
let mean = DVector::from_vec(vec![1.0, 2.0, 3.0]);
let cov = DMatrix::identity(k, k);
let dist = MultivariateNormal::new(mean.clone(), cov).unwrap();
println!("dist = {}", dist);
let trials = 10_000;
let mut sums = DVector::zeros(k);
for _ in 0..trials {
sums += dist.sample(&mut rng);
}
let empirical_mean = sums / trials as f64;
// 5-sigma error for each component
let err = 5.0 / (trials as f64).sqrt();
for i in 0..k {
let got = empirical_mean[i];
let expected = mean[i];
assert!(
(got - expected).abs() < err,
"component {} empirical mean {:.4} != expected {:.4}",
i,
got,
expected
);
}
}
#[test]
fn mvn_log_prob() {
// Standard bivariate normal: log p(0,0 | N(0,I)) = -ln(2π)
let dist =
MultivariateNormal::new(DVector::from_vec(vec![0.0, 0.0]), DMatrix::identity(2, 2))
.unwrap();
let x = DVector::from_vec(vec![0.0, 0.0]);
let lp = <MultivariateNormal as Distribution<ThreadRng>>::log_prob(&dist, &x);
let expected = -(2.0 * std::f64::consts::PI).ln();
assert!(
(lp - expected).abs() < 1e-10,
"log_prob {:.6} != expected {:.6}",
lp,
expected
);
assert!(!<MultivariateNormal as Distribution<ThreadRng>>::is_discrete(&dist));
}
#[test]
#[should_panic]
fn mvn_nrows_mismatch() {
// nrows != mean length: 3×3 against a 2-element mean.
MultivariateNormal::new(DVector::from_vec(vec![0.0, 0.0]), DMatrix::identity(3, 3))
.unwrap();
}
#[test]
#[should_panic]
fn mvn_ncols_mismatch() {
// nrows == mean length but ncols != mean length: 2×3 against a 2-element mean.
MultivariateNormal::new(DVector::from_vec(vec![0.0, 0.0]), DMatrix::zeros(2, 3)).unwrap();
}
#[test]
#[should_panic]
fn mvn_not_positive_definite() {
// [[1, 2], [2, 1]] has eigenvalues 3 and -1 — not positive definite.
MultivariateNormal::new(
DVector::from_vec(vec![0.0, 0.0]),
DMatrix::from_vec(2, 2, vec![1.0, 2.0, 2.0, 1.0]),
)
.unwrap();
}
#[test]
fn mvn_display() {
let dist =
MultivariateNormal::new(DVector::from_vec(vec![1.0, 2.0]), DMatrix::identity(2, 2))
.unwrap();
let s = format!("{}", dist);
assert!(s.contains("MultivariateNormal"), "missing type name: {}", s);
assert!(s.contains("k = 2"), "missing dimension: {}", s);
}
#[test]
fn mvn_log_prob_off_center() {
// mu = [1, 2], Sigma = 2*I_2, x = [2, 4]
// diff = [1, 2]
// Mahalanobis^2 = diff^T (2I)^{-1} diff = 0.5*(1+4) = 2.5
// log|Sigma| = log(det(2*I_2)) = 2*log(2)
// log_prob = -0.5*(2*log(2π) + 2*log(2) + 2.5)
let dist = MultivariateNormal::new(
DVector::from_vec(vec![1.0, 2.0]),
DMatrix::from_diagonal_element(2, 2, 2.0),
)
.unwrap();
let x = DVector::from_vec(vec![2.0, 4.0]);
let lp = <MultivariateNormal as Distribution<ThreadRng>>::log_prob(&dist, &x);
let expected = -0.5 * (2.0 * (2.0 * std::f64::consts::PI).ln() + 2.0 * 2.0_f64.ln() + 2.5);
assert!(
(lp - expected).abs() < 1e-10,
"log_prob {:.6} != expected {:.6}",
lp,
expected
);
}
}