1#![allow(missing_docs)]
2
3mod meanvar;
4pub use meanvar::{NanHandling, col_mean, col_varm, row_mean, row_varm};
5
6pub mod prelude {
7 #[cfg(feature = "rand")]
8 pub use num_complex::ComplexDistribution;
9 #[cfg(feature = "rand")]
10 pub use rand::prelude::*;
11 #[cfg(feature = "rand")]
12 pub use rand_distr::{Standard, StandardNormal};
13
14 #[cfg(feature = "rand")]
15 pub use super::{CwiseColDistribution, CwiseMatDistribution, CwiseRowDistribution, DistributionExt, UnitaryMat};
16}
17
18#[cfg(feature = "rand")]
19pub use self::rand::*;
20
21#[cfg(feature = "rand")]
22mod rand {
23 use crate::internal_prelude::*;
24 use rand::distributions::Distribution;
25
26 pub trait DistributionExt {
27 fn rand<T>(&self, rng: &mut (impl ?Sized + rand::Rng)) -> T
28 where
29 Self: Distribution<T>,
30 {
31 self.sample(rng)
32 }
33 }
34 impl<T: ?Sized> DistributionExt for T {}
35
36 #[derive(Copy, Clone, Debug)]
37 pub struct CwiseMatDistribution<Rows: Shape, Cols: Shape, D> {
38 pub nrows: Rows,
39 pub ncols: Cols,
40 pub dist: D,
41 }
42
43 #[derive(Copy, Clone, Debug)]
44 pub struct CwiseColDistribution<Rows: Shape, D> {
45 pub nrows: Rows,
46 pub dist: D,
47 }
48
49 #[derive(Copy, Clone, Debug)]
50 pub struct CwiseRowDistribution<Cols: Shape, D> {
51 pub ncols: Cols,
52 pub dist: D,
53 }
54
55 #[derive(Copy, Clone, Debug)]
56 pub struct UnitaryMat<Dim: Shape, D> {
57 pub dim: Dim,
58 pub standard_normal: D,
59 }
60
61 impl<T, Rows: Shape, Cols: Shape, D: Distribution<T>> Distribution<Mat<T, Rows, Cols>> for CwiseMatDistribution<Rows, Cols, D> {
62 #[inline]
63 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Mat<T, Rows, Cols> {
64 Mat::from_fn(self.nrows, self.ncols, |_, _| self.dist.sample(rng))
65 }
66 }
67
68 impl<T, Rows: Shape, D: Distribution<T>> Distribution<Col<T, Rows>> for CwiseColDistribution<Rows, D> {
69 #[inline]
70 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Col<T, Rows> {
71 Col::from_fn(self.nrows, |_| self.dist.sample(rng))
72 }
73 }
74
75 impl<T, Cols: Shape, D: Distribution<T>> Distribution<Row<T, Cols>> for CwiseRowDistribution<Cols, D> {
76 #[inline]
77 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Row<T, Cols> {
78 Row::from_fn(self.ncols, |_| self.dist.sample(rng))
79 }
80 }
81
82 impl<T: ComplexField, D: Distribution<T>> Distribution<Mat<T>> for UnitaryMat<usize, D> {
83 #[math]
84 fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> Mat<T> {
85 let qr = CwiseMatDistribution {
86 nrows: self.dim,
87 ncols: self.dim,
88 dist: &self.standard_normal,
89 }
90 .sample(rng)
91 .qr();
92
93 let r_diag = qr.R().diagonal().column_vector();
94 let mut q = qr.compute_Q();
95
96 for j in 0..self.dim {
97 let r = r_diag.read(j);
98 let r = if r == zero() { one() } else { mul_real(r, recip(abs(r))) };
99
100 z!(q.as_mut().col_mut(j)).for_each(|uz!(q)| {
101 *q = *q * r;
102 });
103 }
104
105 q
106 }
107 }
108}