nalgebra_rand_mvn/
lib.rs

1//! Random multi-variate normal generation using nalgebra
2use nalgebra::{OMatrix, OVector};
3use rand::distr::Distribution;
4use rand_distr::Normal;
5
6use nalgebra::allocator::Allocator;
7use nalgebra::{DefaultAllocator, Dim, DimName, DimSub, Dyn, RealField};
8
9/// An error
10#[derive(Debug)]
11pub struct Error {
12    kind: ErrorKind,
13}
14
15impl Error {
16    pub fn kind(&self) -> &ErrorKind {
17        &self.kind
18    }
19}
20
21impl std::error::Error for Error {}
22
23impl std::fmt::Display for Error {
24    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
25        write!(f, "{:?}", self.kind)
26    }
27}
28
29/// Kind of error
30#[derive(Debug, Clone)]
31pub enum ErrorKind {
32    NotDefinitePositive,
33}
34
35fn standard_normal<Real: RealField, R: Dim, C: Dim>(nrows: R, ncols: C) -> OMatrix<Real, R, C>
36where
37    DefaultAllocator: Allocator<R, C>,
38{
39    let normal = Normal::new(0.0, 1.0).expect("creating normal");
40    let mut rng = rand::rng();
41    OMatrix::<Real, R, C>::from_fn_generic(nrows, ncols, |_row, _col| {
42        nalgebra::convert::<f64, Real>(normal.sample(&mut rng))
43    })
44}
45
46/// Draw random samples from a multi-variate normal
47///
48/// Return `n_samples` samples from the N dimensional normal given by mean `mu` and
49/// covariance `sigma`.
50pub fn rand_mvn_generic<Real, Count, N>(
51    n_samples: Count,
52    mu: &OVector<Real, N>,
53    sigma: nalgebra::OMatrix<Real, N, N>,
54) -> Result<OMatrix<Real, Count, N>, Error>
55where
56    Real: RealField,
57    Count: Dim,
58    N: Dim + DimSub<Dyn>,
59    DefaultAllocator: Allocator<Count, N>,
60    DefaultAllocator: Allocator<N, Count>,
61    DefaultAllocator: Allocator<N, N>,
62    DefaultAllocator: Allocator<N>,
63{
64    let ncols = N::from_usize(mu.nrows());
65    let norm_data: OMatrix<Real, N, Count> = standard_normal(ncols, n_samples);
66    let sigma_chol: nalgebra::OMatrix<Real, N, N> = nalgebra::linalg::Cholesky::new(sigma)
67        .ok_or(Error {
68            kind: ErrorKind::NotDefinitePositive,
69        })?
70        .l();
71    Ok(broadcast_add(&(sigma_chol * norm_data).transpose(), mu))
72}
73
74/// Draw random samples from a multi-variate normal
75///
76/// Return `Count` samples from the N dimensional normal given by mean `mu` and
77/// covariance `sigma`.
78pub fn rand_mvn<Real, Count, N>(
79    mu: &OVector<Real, N>,
80    sigma: nalgebra::OMatrix<Real, N, N>,
81) -> Result<OMatrix<Real, Count, N>, Error>
82where
83    Real: RealField,
84    Count: DimName,
85    N: DimName,
86    DefaultAllocator: Allocator<Count, N>,
87    DefaultAllocator: Allocator<N, Count>,
88    DefaultAllocator: Allocator<N, N>,
89    DefaultAllocator: Allocator<N>,
90{
91    let nrows = Count::name();
92    rand_mvn_generic(nrows, mu, sigma)
93}
94
95/// Add `vec` to each row of `arr`, returning the result with shape of `arr`.
96///
97/// Inputs `arr` has shape R x C and `vec` is C dimensional. Result
98/// has shape R x C.
99fn broadcast_add<Real, R, C>(
100    arr: &OMatrix<Real, R, C>,
101    vec: &OVector<Real, C>,
102) -> OMatrix<Real, R, C>
103where
104    Real: RealField,
105    R: Dim,
106    C: Dim,
107    DefaultAllocator: Allocator<R, C>,
108    DefaultAllocator: Allocator<C>,
109{
110    let ndim = arr.nrows();
111    let nrows = R::from_usize(arr.nrows());
112    let ncols = C::from_usize(arr.ncols());
113
114    // TODO: remove explicit index calculation and indexing
115    OMatrix::from_iterator_generic(
116        nrows,
117        ncols,
118        arr.iter().enumerate().map(|(i, el)| {
119            let vi = i / ndim; // integer div to get index into vec
120            el.clone() + vec[vi].clone()
121        }),
122    )
123}
124
125#[cfg(test)]
126mod tests {
127    use crate::*;
128    use approx::assert_relative_eq;
129    use nalgebra as na;
130
131    /// Calculate the sample covariance
132    ///
133    /// Calculates the sample covariances among N-dimensional samples with M
134    /// observations each. Calculates N x N covariance matrix from observations
135    /// in `arr`, which is M rows of N columns used to store M vectors of
136    /// dimension N.
137    fn sample_covariance<Real: RealField, M: Dim, N: Dim>(
138        arr: &OMatrix<Real, M, N>,
139    ) -> nalgebra::OMatrix<Real, N, N>
140    where
141        DefaultAllocator: Allocator<M, N>,
142        DefaultAllocator: Allocator<N, M>,
143        DefaultAllocator: Allocator<N, N>,
144        DefaultAllocator: Allocator<N>,
145    {
146        let mu: OVector<Real, N> = mean_axis0(arr);
147        let y = broadcast_add(arr, &-mu);
148        let n: Real = nalgebra::convert(arr.nrows() as f64);
149        let sigma = (y.transpose() * y) / (n - Real::one());
150        sigma
151    }
152
153    /// Calculate the mean of R x C matrix along the rows and return C dim vector
154    fn mean_axis0<Real, R, C>(arr: &OMatrix<Real, R, C>) -> OVector<Real, C>
155    where
156        Real: RealField,
157        R: Dim,
158        C: Dim,
159        DefaultAllocator: Allocator<R, C>,
160        DefaultAllocator: Allocator<C>,
161    {
162        let vec_dim: C = C::from_usize(arr.ncols());
163        let mut mu = OVector::zeros_generic(vec_dim, na::Const);
164        let scale: Real = Real::one() / na::convert(arr.nrows() as f64);
165        for j in 0..arr.ncols() {
166            let col_sum = arr
167                .column(j)
168                .iter()
169                .fold(Real::zero(), |acc, x| acc.clone() + x.clone());
170            mu[j] = col_sum * scale.clone();
171        }
172        mu
173    }
174
175    #[test]
176    fn test_covar() {
177        use nalgebra::core::dimension::{U2, U3};
178
179        // We use the same example as https://numpy.org/doc/stable/reference/generated/numpy.cov.html
180
181        // However, our format is transposed compared to numpy. We have
182        // variables as the columns and samples as rows.
183        let arr = OMatrix::<f64, U2, U3>::new(-2.1, -1.0, 4.3, 3.0, 1.1, 0.12).transpose();
184
185        let c = sample_covariance(&arr);
186
187        let expected = nalgebra::OMatrix::<f64, U2, U2>::new(11.71, -4.286, -4.286, 2.144133);
188
189        assert_relative_eq!(c, expected, epsilon = 1e-3);
190    }
191
192    #[test]
193    fn test_mean_axis0() {
194        use nalgebra::dimension::{U2, U4};
195
196        let a1 = OMatrix::<f64, U2, U4>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
197        let actual1: OVector<f64, U4> = mean_axis0(&a1);
198        let expected1 = &[3.0, 4.0, 5.0, 6.0];
199        assert!(actual1.as_slice() == expected1);
200
201        let a2 = OMatrix::<f64, U4, U2>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
202        let actual2: OVector<f64, U2> = mean_axis0(&a2);
203        let expected2 = &[4.0, 5.0];
204        assert!(actual2.as_slice() == expected2);
205    }
206
207    #[test]
208    fn test_rand() {
209        use na::dimension::{U25, U4};
210        let mu = OVector::<f64, U4>::new(1.0, 2.0, 3.0, 4.0);
211        let sigma = nalgebra::OMatrix::<f64, U4, U4>::new(
212            2.0, 0.1, 0.0, 0.0, 0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
213        );
214        let y: OMatrix<f64, U25, U4> = rand_mvn(&mu, sigma).unwrap();
215        assert!(y.nrows() == 25);
216        assert!(y.ncols() == 4);
217
218        let mu2 = mean_axis0(&y);
219        assert_relative_eq!(mu, mu2, epsilon = 0.5); // expect occasional failures here
220
221        let sigma2 = sample_covariance(&y);
222        assert_relative_eq!(sigma, sigma2, epsilon = 1.0); // expect occasional failures here
223    }
224
225    #[test]
226    fn test_rand_dynamic() {
227        use na::dimension::{Dyn, U4};
228        let mu = OVector::<f64, U4>::new(1.0, 2.0, 3.0, 4.0);
229        let sigma = nalgebra::OMatrix::<f64, U4, U4>::new(
230            2.0, 0.1, 0.0, 0.0, 0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
231        );
232        let nrows = Dyn(1_000);
233        let y: nalgebra::OMatrix<f64, Dyn, U4> =
234            rand_mvn_generic(nrows, &mu, sigma.clone()).unwrap();
235        assert!(y.ncols() == 4);
236
237        let mu2 = mean_axis0(&y);
238        assert_relative_eq!(mu, mu2, epsilon = 0.2); // expect occasional failures here
239
240        let sigma2 = sample_covariance(&y);
241        assert_relative_eq!(sigma, sigma2, epsilon = 0.2); // expect occasional failures here
242    }
243}