1use nalgebra::{allocator::Allocator, linalg, DefaultAllocator, Dim, OMatrix, OVector, RealField};
41
42#[derive(Debug)]
44pub struct Error {
45 kind: ErrorKind,
46}
47
48impl Error {
49 pub fn kind(&self) -> &ErrorKind {
50 &self.kind
51 }
52}
53
54impl std::error::Error for Error {}
55
56impl std::fmt::Display for Error {
57 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
58 write!(f, "{:?}", self.kind)
59 }
60}
61
62#[derive(Debug, Clone)]
64pub enum ErrorKind {
65 NotDefinitePositive,
66}
67
68#[derive(Debug, Clone)]
72pub struct MultivariateNormal<Real, N>
73where
74 Real: RealField,
75 N: Dim + nalgebra::DimMin<N, Output = N>,
76 DefaultAllocator: Allocator<N>,
77 DefaultAllocator: Allocator<N, N>,
78 DefaultAllocator: Allocator<nalgebra::U1, N>,
79 DefaultAllocator: Allocator<<N as nalgebra::DimMin<N>>::Output>,
80{
81 neg_mu: nalgebra::OVector<Real, N>,
83 precision: nalgebra::OMatrix<Real, N, N>,
85 fac: Real,
87}
88
89impl<Real, N> MultivariateNormal<Real, N>
90where
91 Real: RealField,
92 N: Dim + nalgebra::DimMin<N, Output = N> + nalgebra::DimSub<nalgebra::Dyn>,
93 DefaultAllocator: Allocator<N>,
94 DefaultAllocator: Allocator<N, N>,
95 DefaultAllocator: Allocator<nalgebra::U1, N>,
96 DefaultAllocator: Allocator<<N as nalgebra::DimMin<N>>::Output>,
97{
98 pub fn from_mean_and_precision(
103 mu: &nalgebra::OVector<Real, N>,
104 precision: &nalgebra::OMatrix<Real, N, N>,
105 ) -> Self {
106 let precision_det = nalgebra::linalg::LU::new(precision.clone()).determinant();
111 let det = Real::one() / precision_det;
112
113 let ndim = mu.nrows();
114 let fac: Real = Real::one() / (Real::two_pi().powi(ndim as i32) * det.abs()).sqrt();
115
116 Self {
117 neg_mu: -mu,
118 precision: precision.clone(),
119 fac,
120 }
121 }
122
123 pub fn from_mean_and_covariance(
132 mu: &nalgebra::OVector<Real, N>,
133 covariance: &nalgebra::OMatrix<Real, N, N>,
134 ) -> Result<Self, Error> {
135 let precision = linalg::Cholesky::new(covariance.clone())
137 .ok_or(Error {
138 kind: ErrorKind::NotDefinitePositive,
139 })?
140 .inverse();
141 let result = Self::from_mean_and_precision(mu, &precision);
142 Ok(result)
143 }
144
145 pub fn mean(&self) -> nalgebra::OVector<Real, N> {
147 -&self.neg_mu
148 }
149
150 pub fn precision(&self) -> &nalgebra::OMatrix<Real, N, N> {
152 &self.precision
153 }
154
155 fn inner_pdf<Count>(
156 &self,
157 xs_t: &nalgebra::OMatrix<Real, Count, N>,
158 ) -> nalgebra::OVector<Real, Count>
159 where
160 Count: Dim,
161 DefaultAllocator: Allocator<Count>,
162 DefaultAllocator: Allocator<N, Count>,
163 DefaultAllocator: Allocator<Count, N>,
164 DefaultAllocator: Allocator<Count, Count>,
165 {
166 let dvs: nalgebra::OMatrix<Real, Count, N> = broadcast_add(xs_t, &self.neg_mu);
167
168 let left: nalgebra::OMatrix<Real, Count, N> = &dvs * &self.precision;
169 let ny2_tmp: nalgebra::OMatrix<Real, Count, N> = left.component_mul(&dvs);
170 let ones = nalgebra::OMatrix::<Real, N, nalgebra::U1>::repeat_generic(
171 N::from_usize(self.neg_mu.nrows()),
172 nalgebra::Const,
173 nalgebra::convert::<f64, Real>(1.0),
174 );
175 let ny2: nalgebra::OVector<Real, Count> = ny2_tmp * ones;
176 let y: nalgebra::OVector<Real, Count> = ny2 * nalgebra::convert::<f64, Real>(-0.5);
177 y
178 }
179
180 pub fn pdf<Count>(
184 &self,
185 xs: &nalgebra::OMatrix<Real, Count, N>,
186 ) -> nalgebra::OVector<Real, Count>
187 where
188 Count: Dim,
189 DefaultAllocator: Allocator<Count>,
190 DefaultAllocator: Allocator<N, Count>,
191 DefaultAllocator: Allocator<Count, N>,
192 DefaultAllocator: Allocator<Count, Count>,
193 {
194 let y = self.inner_pdf(xs);
195 vec_exp(&y) * self.fac.clone()
196 }
197
198 pub fn logpdf<Count>(
202 &self,
203 xs: &nalgebra::OMatrix<Real, Count, N>,
204 ) -> nalgebra::OVector<Real, Count>
205 where
206 Count: Dim,
207 DefaultAllocator: Allocator<Count>,
208 DefaultAllocator: Allocator<N, Count>,
209 DefaultAllocator: Allocator<Count, N>,
210 DefaultAllocator: Allocator<Count, Count>,
211 {
212 let y = self.inner_pdf(xs);
213 vec_add(&y, self.fac.clone().ln())
214 }
215}
216
217fn vec_exp<Real, Count>(v: &nalgebra::OVector<Real, Count>) -> nalgebra::OVector<Real, Count>
218where
219 Real: RealField,
220 Count: Dim,
221 DefaultAllocator: Allocator<Count>,
222{
223 let nrows = Count::from_usize(v.nrows());
224 OVector::from_iterator_generic(nrows, nalgebra::Const, v.iter().map(|vi| vi.clone().exp()))
225}
226
227fn vec_add<Real, Count>(
228 v: &nalgebra::OVector<Real, Count>,
229 rhs: Real,
230) -> nalgebra::OVector<Real, Count>
231where
232 Real: RealField,
233 Count: Dim,
234 DefaultAllocator: Allocator<Count>,
235{
236 let nrows = Count::from_usize(v.nrows());
237 OVector::from_iterator_generic(
238 nrows,
239 nalgebra::Const,
240 v.iter().map(|vi| vi.clone() + rhs.clone()),
241 )
242}
243
244fn broadcast_add<Real, R, C>(
249 arr: &OMatrix<Real, R, C>,
250 vec: &OVector<Real, C>,
251) -> OMatrix<Real, R, C>
252where
253 Real: RealField,
254 R: Dim,
255 C: Dim,
256 DefaultAllocator: Allocator<R, C>,
257 DefaultAllocator: Allocator<C>,
258{
259 let ndim = arr.nrows();
260 let nrows = R::from_usize(arr.nrows());
261 let ncols = C::from_usize(arr.ncols());
262
263 OMatrix::from_iterator_generic(
265 nrows,
266 ncols,
267 arr.iter().enumerate().map(|(i, el)| {
268 let vi = i / ndim; el.clone() + vec[vi].clone()
270 }),
271 )
272}
273
274#[cfg(test)]
275mod tests {
276 use crate::*;
277 use approx::assert_relative_eq;
278 use nalgebra as na;
279
280 fn sample_covariance<Real: RealField, M: Dim, N: Dim>(
287 arr: &OMatrix<Real, M, N>,
288 ) -> nalgebra::OMatrix<Real, N, N>
289 where
290 DefaultAllocator: Allocator<M, N>,
291 DefaultAllocator: Allocator<N, M>,
292 DefaultAllocator: Allocator<N, N>,
293 DefaultAllocator: Allocator<N>,
294 {
295 let mu: OVector<Real, N> = mean_axis0(arr);
296 let y = broadcast_add(arr, &-mu);
297 let n: Real = nalgebra::convert(arr.nrows() as f64);
298
299 (y.transpose() * y) / (n - Real::one())
300 }
301
302 fn mean_axis0<Real, R, C>(arr: &OMatrix<Real, R, C>) -> OVector<Real, C>
304 where
305 Real: RealField,
306 R: Dim,
307 C: Dim,
308 DefaultAllocator: Allocator<R, C>,
309 DefaultAllocator: Allocator<C>,
310 {
311 let vec_dim: C = C::from_usize(arr.ncols());
312 let mut mu = OVector::<Real, C>::zeros_generic(vec_dim, nalgebra::Const);
313 let scale: Real = Real::one() / na::convert(arr.nrows() as f64);
314 for j in 0..arr.ncols() {
315 let col_sum = arr
316 .column(j)
317 .iter()
318 .fold(Real::zero(), |acc, x| acc + x.clone());
319 mu[j] = col_sum * scale.clone();
320 }
321 mu
322 }
323
324 #[test]
325 fn test_covar() {
326 use nalgebra::core::dimension::{U2, U3};
327
328 let arr = OMatrix::<f64, U2, U3>::new(-2.1, -1.0, 4.3, 3.0, 1.1, 0.12).transpose();
333
334 let c = sample_covariance(&arr);
335
336 let expected = nalgebra::OMatrix::<f64, U2, U2>::new(11.71, -4.286, -4.286, 2.144133);
337
338 assert_relative_eq!(c, expected, epsilon = 1e-3);
339 }
340
341 #[test]
342 fn test_mean_and_precision() {
343 let mu = na::Vector2::<f64>::new(0.0, 0.0);
344 let precision = na::Matrix2::<f64>::new(1.0, 0.0, 0.0, 1.0);
345
346 let mvn = MultivariateNormal::from_mean_and_precision(&mu, &precision);
347
348 assert!(mu == mvn.mean());
349 assert!(&precision == mvn.precision());
350 }
351
352 #[test]
353 fn test_mean_axis0() {
354 use nalgebra::core::dimension::{U2, U4};
355
356 let a1 = OMatrix::<f64, U2, U4>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
357 let actual1: OVector<f64, U4> = mean_axis0(&a1);
358 let expected1 = &[3.0, 4.0, 5.0, 6.0];
359 assert!(actual1.as_slice() == expected1);
360
361 let a2 = OMatrix::<f64, U4, U2>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
362 let actual2: OVector<f64, U2> = mean_axis0(&a2);
363 let expected2 = &[4.0, 5.0];
364 assert!(actual2.as_slice() == expected2);
365 }
366
367 #[test]
368 fn test_broadcast_add() {
369 use nalgebra::core::dimension::{U3, U4};
370
371 let x = OMatrix::<f64, U3, U4>::new(
372 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 200.0, 300.0, 400.0,
373 );
374 let v = OVector::<f64, U4>::new(-3.0, -4.0, -5.0, -3.0);
375 let actual = broadcast_add(&x, &v);
376
377 let expected = OMatrix::<f64, U3, U4>::new(
378 -2.0, -2.0, -2.0, 1.0, 2.0, 2.0, 2.0, 5.0, 97.0, 196.0, 295.0, 397.0,
379 );
380
381 assert!(actual == expected);
382 }
383
384 #[test]
385 fn test_density() {
386 let mu = na::Vector2::<f64>::new(0.0, 0.0);
416 let precision = na::Matrix2::<f64>::new(1.0, 0.0, 0.0, 1.0);
417
418 let mvn = MultivariateNormal::from_mean_and_precision(&mu, &precision);
419
420 let xs = na::OMatrix::<f64, na::U2, na::U3>::new(0.0, 1.0, 0.0, 0.0, 0.0, 1.0).transpose();
421
422 let results = mvn.pdf(&xs);
423
424 for i in 0..xs.nrows() {
426 let x = xs.row(i).clone_owned();
427 let di = mvn.pdf(&x)[0];
428 assert_relative_eq!(di, results[i], epsilon = 1e-10);
429 }
430
431 dbg!((results[0], 0.15915494));
432
433 let epsilon = 1e-5;
434 assert_relative_eq!(results[0], 0.15915494, epsilon = epsilon);
436 assert_relative_eq!(results[1], 0.09653235, epsilon = epsilon);
437 assert_relative_eq!(results[2], 0.09653235, epsilon = epsilon);
438 }
439}