1use nalgebra::{OMatrix, OVector};
3use rand::distr::Distribution;
4use rand_distr::Normal;
5
6use nalgebra::allocator::Allocator;
7use nalgebra::{DefaultAllocator, Dim, DimName, DimSub, Dyn, RealField};
8
9#[derive(Debug)]
11pub struct Error {
12 kind: ErrorKind,
13}
14
15impl Error {
16 pub fn kind(&self) -> &ErrorKind {
17 &self.kind
18 }
19}
20
21impl std::error::Error for Error {}
22
23impl std::fmt::Display for Error {
24 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
25 write!(f, "{:?}", self.kind)
26 }
27}
28
29#[derive(Debug, Clone)]
31pub enum ErrorKind {
32 NotDefinitePositive,
33}
34
35fn standard_normal<Real: RealField, R: Dim, C: Dim>(nrows: R, ncols: C) -> OMatrix<Real, R, C>
36where
37 DefaultAllocator: Allocator<R, C>,
38{
39 let normal = Normal::new(0.0, 1.0).expect("creating normal");
40 let mut rng = rand::rng();
41 OMatrix::<Real, R, C>::from_fn_generic(nrows, ncols, |_row, _col| {
42 nalgebra::convert::<f64, Real>(normal.sample(&mut rng))
43 })
44}
45
46pub fn rand_mvn_generic<Real, Count, N>(
51 n_samples: Count,
52 mu: &OVector<Real, N>,
53 sigma: nalgebra::OMatrix<Real, N, N>,
54) -> Result<OMatrix<Real, Count, N>, Error>
55where
56 Real: RealField,
57 Count: Dim,
58 N: Dim + DimSub<Dyn>,
59 DefaultAllocator: Allocator<Count, N>,
60 DefaultAllocator: Allocator<N, Count>,
61 DefaultAllocator: Allocator<N, N>,
62 DefaultAllocator: Allocator<N>,
63{
64 let ncols = N::from_usize(mu.nrows());
65 let norm_data: OMatrix<Real, N, Count> = standard_normal(ncols, n_samples);
66 let sigma_chol: nalgebra::OMatrix<Real, N, N> = nalgebra::linalg::Cholesky::new(sigma)
67 .ok_or(Error {
68 kind: ErrorKind::NotDefinitePositive,
69 })?
70 .l();
71 Ok(broadcast_add(&(sigma_chol * norm_data).transpose(), mu))
72}
73
74pub fn rand_mvn<Real, Count, N>(
79 mu: &OVector<Real, N>,
80 sigma: nalgebra::OMatrix<Real, N, N>,
81) -> Result<OMatrix<Real, Count, N>, Error>
82where
83 Real: RealField,
84 Count: DimName,
85 N: DimName,
86 DefaultAllocator: Allocator<Count, N>,
87 DefaultAllocator: Allocator<N, Count>,
88 DefaultAllocator: Allocator<N, N>,
89 DefaultAllocator: Allocator<N>,
90{
91 let nrows = Count::name();
92 rand_mvn_generic(nrows, mu, sigma)
93}
94
95fn broadcast_add<Real, R, C>(
100 arr: &OMatrix<Real, R, C>,
101 vec: &OVector<Real, C>,
102) -> OMatrix<Real, R, C>
103where
104 Real: RealField,
105 R: Dim,
106 C: Dim,
107 DefaultAllocator: Allocator<R, C>,
108 DefaultAllocator: Allocator<C>,
109{
110 let ndim = arr.nrows();
111 let nrows = R::from_usize(arr.nrows());
112 let ncols = C::from_usize(arr.ncols());
113
114 OMatrix::from_iterator_generic(
116 nrows,
117 ncols,
118 arr.iter().enumerate().map(|(i, el)| {
119 let vi = i / ndim; el.clone() + vec[vi].clone()
121 }),
122 )
123}
124
125#[cfg(test)]
126mod tests {
127 use crate::*;
128 use approx::assert_relative_eq;
129 use nalgebra as na;
130
131 fn sample_covariance<Real: RealField, M: Dim, N: Dim>(
138 arr: &OMatrix<Real, M, N>,
139 ) -> nalgebra::OMatrix<Real, N, N>
140 where
141 DefaultAllocator: Allocator<M, N>,
142 DefaultAllocator: Allocator<N, M>,
143 DefaultAllocator: Allocator<N, N>,
144 DefaultAllocator: Allocator<N>,
145 {
146 let mu: OVector<Real, N> = mean_axis0(arr);
147 let y = broadcast_add(arr, &-mu);
148 let n: Real = nalgebra::convert(arr.nrows() as f64);
149 let sigma = (y.transpose() * y) / (n - Real::one());
150 sigma
151 }
152
153 fn mean_axis0<Real, R, C>(arr: &OMatrix<Real, R, C>) -> OVector<Real, C>
155 where
156 Real: RealField,
157 R: Dim,
158 C: Dim,
159 DefaultAllocator: Allocator<R, C>,
160 DefaultAllocator: Allocator<C>,
161 {
162 let vec_dim: C = C::from_usize(arr.ncols());
163 let mut mu = OVector::zeros_generic(vec_dim, na::Const);
164 let scale: Real = Real::one() / na::convert(arr.nrows() as f64);
165 for j in 0..arr.ncols() {
166 let col_sum = arr
167 .column(j)
168 .iter()
169 .fold(Real::zero(), |acc, x| acc.clone() + x.clone());
170 mu[j] = col_sum * scale.clone();
171 }
172 mu
173 }
174
175 #[test]
176 fn test_covar() {
177 use nalgebra::core::dimension::{U2, U3};
178
179 let arr = OMatrix::<f64, U2, U3>::new(-2.1, -1.0, 4.3, 3.0, 1.1, 0.12).transpose();
184
185 let c = sample_covariance(&arr);
186
187 let expected = nalgebra::OMatrix::<f64, U2, U2>::new(11.71, -4.286, -4.286, 2.144133);
188
189 assert_relative_eq!(c, expected, epsilon = 1e-3);
190 }
191
192 #[test]
193 fn test_mean_axis0() {
194 use nalgebra::dimension::{U2, U4};
195
196 let a1 = OMatrix::<f64, U2, U4>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
197 let actual1: OVector<f64, U4> = mean_axis0(&a1);
198 let expected1 = &[3.0, 4.0, 5.0, 6.0];
199 assert!(actual1.as_slice() == expected1);
200
201 let a2 = OMatrix::<f64, U4, U2>::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
202 let actual2: OVector<f64, U2> = mean_axis0(&a2);
203 let expected2 = &[4.0, 5.0];
204 assert!(actual2.as_slice() == expected2);
205 }
206
207 #[test]
208 fn test_rand() {
209 use na::dimension::{U25, U4};
210 let mu = OVector::<f64, U4>::new(1.0, 2.0, 3.0, 4.0);
211 let sigma = nalgebra::OMatrix::<f64, U4, U4>::new(
212 2.0, 0.1, 0.0, 0.0, 0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
213 );
214 let y: OMatrix<f64, U25, U4> = rand_mvn(&mu, sigma).unwrap();
215 assert!(y.nrows() == 25);
216 assert!(y.ncols() == 4);
217
218 let mu2 = mean_axis0(&y);
219 assert_relative_eq!(mu, mu2, epsilon = 0.5); let sigma2 = sample_covariance(&y);
222 assert_relative_eq!(sigma, sigma2, epsilon = 1.0); }
224
225 #[test]
226 fn test_rand_dynamic() {
227 use na::dimension::{Dyn, U4};
228 let mu = OVector::<f64, U4>::new(1.0, 2.0, 3.0, 4.0);
229 let sigma = nalgebra::OMatrix::<f64, U4, U4>::new(
230 2.0, 0.1, 0.0, 0.0, 0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
231 );
232 let nrows = Dyn(1_000);
233 let y: nalgebra::OMatrix<f64, Dyn, U4> =
234 rand_mvn_generic(nrows, &mu, sigma.clone()).unwrap();
235 assert!(y.ncols() == 4);
236
237 let mu2 = mean_axis0(&y);
238 assert_relative_eq!(mu, mu2, epsilon = 0.2); let sigma2 = sample_covariance(&y);
241 assert_relative_eq!(sigma, sigma2, epsilon = 0.2); }
243}