numra_stats/distributions/
normal.rs1use numra_core::Scalar;
8use numra_special::{erf, erfinv};
9use rand::RngCore;
10
11use super::ContinuousDistribution;
12
13#[derive(Clone, Debug)]
15pub struct Normal<S: Scalar> {
16 pub mu: S,
17 pub sigma: S,
18}
19
20impl<S: Scalar> Normal<S> {
21 pub fn new(mu: S, sigma: S) -> Self {
22 Self { mu, sigma }
23 }
24
25 pub fn standard() -> Self {
27 Self {
28 mu: S::ZERO,
29 sigma: S::ONE,
30 }
31 }
32}
33
34impl<S: Scalar> ContinuousDistribution<S> for Normal<S> {
35 fn pdf(&self, x: S) -> S {
36 let two = S::TWO;
37 let pi2 = S::from_f64(core::f64::consts::TAU);
38 let z = (x - self.mu) / self.sigma;
39 (S::ZERO - z * z / two).exp() / (pi2.sqrt() * self.sigma)
40 }
41
42 fn cdf(&self, x: S) -> S {
43 let sqrt2 = S::from_f64(core::f64::consts::SQRT_2);
44 let half = S::HALF;
45 let z = (x - self.mu) / (self.sigma * sqrt2);
46 half * (S::ONE + erf(z))
47 }
48
49 fn quantile(&self, p: S) -> S {
50 let two = S::TWO;
51 let sqrt2 = S::from_f64(core::f64::consts::SQRT_2);
52 self.mu + self.sigma * sqrt2 * erfinv(two * p - S::ONE)
53 }
54
55 fn mean(&self) -> S {
56 self.mu
57 }
58
59 fn variance(&self) -> S {
60 self.sigma * self.sigma
61 }
62
63 fn sample(&self, rng: &mut dyn RngCore) -> S {
64 let u1 = random_uniform_01::<S>(rng);
66 let u2 = random_uniform_01::<S>(rng);
67 let two = S::TWO;
68 let pi2 = S::from_f64(core::f64::consts::TAU);
69 let z = (S::ZERO - two * u1.ln()).sqrt() * (pi2 * u2).cos();
70 self.mu + self.sigma * z
71 }
72}
73
74pub(crate) fn random_uniform_01<S: Scalar>(rng: &mut dyn RngCore) -> S {
76 let bits = rng.next_u32();
78 S::from_f64((bits as f64 + 1.0) / (u32::MAX as f64 + 2.0))
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn test_normal_pdf_at_mean() {
88 let n = Normal::new(0.0_f64, 1.0);
89 let peak = 1.0 / (2.0 * core::f64::consts::PI).sqrt();
90 assert!((n.pdf(0.0) - peak).abs() < 1e-12);
91 }
92
93 #[test]
94 fn test_normal_cdf_at_mean() {
95 let n = Normal::new(0.0_f64, 1.0);
96 assert!((n.cdf(0.0) - 0.5).abs() < 1e-12);
97 }
98
99 #[test]
100 fn test_normal_cdf_tails() {
101 let n = Normal::<f64>::standard();
102 assert!(n.cdf(-5.0) < 1e-5);
103 assert!(n.cdf(5.0) > 1.0 - 1e-5);
104 }
105
106 #[test]
107 fn test_normal_quantile_roundtrip() {
108 let n = Normal::new(2.0_f64, 3.0);
109 for &p in &[0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99] {
110 let x = n.quantile(p);
111 let p2 = n.cdf(x);
112 assert!((p - p2).abs() < 1e-10, "p={}, p2={}", p, p2);
113 }
114 }
115
116 #[test]
117 fn test_normal_mean_variance() {
118 let n = Normal::new(3.0_f64, 2.0);
119 assert!((n.mean() - 3.0).abs() < 1e-14);
120 assert!((n.variance() - 4.0).abs() < 1e-14);
121 }
122
123 #[test]
124 fn test_normal_sample() {
125 use rand::SeedableRng;
126 let n = Normal::new(0.0_f64, 1.0);
127 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
128 let samples = n.sample_n(&mut rng, 10000);
129 let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
130 assert!(mean.abs() < 0.1, "sample mean = {}", mean);
131 }
132}