Skip to main content

ferray_random/distributions/
normal.rs

1// ferray-random: Normal distribution sampling — standard_normal, normal, lognormal
2
3use ferray_core::{Array, FerrayError, Ix1};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::{Generator, generate_vec, vec_to_array1};
7
8/// Generate a single standard normal variate using the Box-Muller transform.
9///
10/// Consumes two uniform [0,1) variates and produces two normal variates.
11/// We use both and cache the second, but for simplicity here we just use
12/// the Ziggurat-free approach generating one pair at a time.
13pub(crate) fn standard_normal_pair<B: BitGenerator>(bg: &mut B) -> (f64, f64) {
14    loop {
15        let u1 = bg.next_f64();
16        let u2 = bg.next_f64();
17        // Avoid log(0)
18        if u1 < f64::EPSILON {
19            continue;
20        }
21        let r = (-2.0 * u1.ln()).sqrt();
22        let theta = std::f64::consts::TAU * u2;
23        return (r * theta.cos(), r * theta.sin());
24    }
25}
26
27/// Generate a single standard normal variate.
28pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
29    standard_normal_pair(bg).0
30}
31
32impl<B: BitGenerator> Generator<B> {
33    /// Generate an array of standard normal (mean=0, std=1) variates.
34    ///
35    /// Uses the Box-Muller transform for generation.
36    ///
37    /// # Arguments
38    /// * `size` - Number of values to generate.
39    ///
40    /// # Errors
41    /// Returns `FerrayError::InvalidValue` if `size` is zero.
42    pub fn standard_normal(&mut self, size: usize) -> Result<Array<f64, Ix1>, FerrayError> {
43        if size == 0 {
44            return Err(FerrayError::invalid_value("size must be > 0"));
45        }
46        let mut data = Vec::with_capacity(size);
47        while data.len() < size {
48            let (a, b) = standard_normal_pair(&mut self.bg);
49            data.push(a);
50            if data.len() < size {
51                data.push(b);
52            }
53        }
54        vec_to_array1(data)
55    }
56
57    /// Generate an array of normal (Gaussian) variates with given mean and standard deviation.
58    ///
59    /// # Arguments
60    /// * `loc` - Mean of the distribution.
61    /// * `scale` - Standard deviation (must be positive).
62    /// * `size` - Number of values to generate.
63    ///
64    /// # Errors
65    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `size` is zero.
66    pub fn normal(
67        &mut self,
68        loc: f64,
69        scale: f64,
70        size: usize,
71    ) -> Result<Array<f64, Ix1>, FerrayError> {
72        if size == 0 {
73            return Err(FerrayError::invalid_value("size must be > 0"));
74        }
75        if scale <= 0.0 {
76            return Err(FerrayError::invalid_value(format!(
77                "scale must be positive, got {scale}"
78            )));
79        }
80        let data = generate_vec(self, size, |bg| loc + scale * standard_normal_single(bg));
81        vec_to_array1(data)
82    }
83
84    /// Generate an array of log-normal variates.
85    ///
86    /// If X ~ Normal(mean, sigma), then exp(X) ~ LogNormal(mean, sigma).
87    ///
88    /// # Arguments
89    /// * `mean` - Mean of the underlying normal distribution.
90    /// * `sigma` - Standard deviation of the underlying normal distribution (must be positive).
91    /// * `size` - Number of values to generate.
92    ///
93    /// # Errors
94    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `size` is zero.
95    pub fn lognormal(
96        &mut self,
97        mean: f64,
98        sigma: f64,
99        size: usize,
100    ) -> Result<Array<f64, Ix1>, FerrayError> {
101        if size == 0 {
102            return Err(FerrayError::invalid_value("size must be > 0"));
103        }
104        if sigma <= 0.0 {
105            return Err(FerrayError::invalid_value(format!(
106                "sigma must be positive, got {sigma}"
107            )));
108        }
109        let data = generate_vec(self, size, |bg| {
110            (mean + sigma * standard_normal_single(bg)).exp()
111        });
112        vec_to_array1(data)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use crate::default_rng_seeded;
119
120    #[test]
121    fn standard_normal_deterministic() {
122        let mut rng1 = default_rng_seeded(42);
123        let mut rng2 = default_rng_seeded(42);
124        let a = rng1.standard_normal(1000).unwrap();
125        let b = rng2.standard_normal(1000).unwrap();
126        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
127    }
128
129    #[test]
130    fn standard_normal_mean_variance() {
131        let mut rng = default_rng_seeded(42);
132        let n = 100_000;
133        let arr = rng.standard_normal(n).unwrap();
134        let slice = arr.as_slice().unwrap();
135        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
136        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
137        let se = (1.0 / n as f64).sqrt();
138        assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
139        assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
140    }
141
142    #[test]
143    fn normal_mean_variance() {
144        let mut rng = default_rng_seeded(42);
145        let n = 100_000;
146        let loc = 5.0;
147        let scale = 2.0;
148        let arr = rng.normal(loc, scale, n).unwrap();
149        let slice = arr.as_slice().unwrap();
150        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
151        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
152        let se = (scale * scale / n as f64).sqrt();
153        assert!(
154            (mean - loc).abs() < 3.0 * se,
155            "mean {mean} too far from {loc}"
156        );
157        assert!(
158            (var - scale * scale).abs() < 0.2,
159            "variance {var} too far from {}",
160            scale * scale
161        );
162    }
163
164    #[test]
165    fn normal_bad_scale() {
166        let mut rng = default_rng_seeded(42);
167        assert!(rng.normal(0.0, 0.0, 100).is_err());
168        assert!(rng.normal(0.0, -1.0, 100).is_err());
169    }
170
171    #[test]
172    fn lognormal_positive() {
173        let mut rng = default_rng_seeded(42);
174        let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
175        let slice = arr.as_slice().unwrap();
176        for &v in slice {
177            assert!(v > 0.0, "lognormal produced non-positive value: {v}");
178        }
179    }
180
181    #[test]
182    fn lognormal_mean() {
183        let mut rng = default_rng_seeded(42);
184        let n = 100_000;
185        let mu = 0.0;
186        let sigma = 0.5;
187        let arr = rng.lognormal(mu, sigma, n).unwrap();
188        let slice = arr.as_slice().unwrap();
189        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
190        // E[X] = exp(mu + sigma^2 / 2)
191        let expected_mean = (mu + sigma * sigma / 2.0).exp();
192        let expected_var = ((sigma * sigma).exp() - 1.0) * (2.0 * mu + sigma * sigma).exp();
193        let se = (expected_var / n as f64).sqrt();
194        assert!(
195            (mean - expected_mean).abs() < 3.0 * se,
196            "lognormal mean {mean} too far from {expected_mean}"
197        );
198    }
199}