Skip to main content

ferray_random/distributions/
normal.rs

1// ferray-random: Normal distribution sampling — standard_normal, normal, lognormal
2
3use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::ziggurat::{standard_normal_ziggurat, standard_normal_ziggurat_f32};
7use crate::generator::{
8    Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
9};
10use crate::shape::IntoShape;
11
12/// Generate a single standard normal variate via the Ziggurat algorithm.
13///
14/// Ziggurat is ~3× faster than Box-Muller because ~98% of calls take a fast
15/// path that uses only one `next_u64`, a multiplication and a comparison.
16/// See [`ziggurat`](super::ziggurat) for the layer-table construction and
17/// the slow-path rejection logic.
18pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
19    standard_normal_ziggurat(bg)
20}
21
22/// Generate a single f32 standard normal variate via Ziggurat.
23///
24/// The sampling is performed in f64 (Ziggurat tables are f64) and cast to
25/// f32. This costs essentially nothing on modern CPUs and preserves the full
26/// tail accuracy of the f64 path.
27pub(crate) fn standard_normal_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
28    standard_normal_ziggurat_f32(bg)
29}
30
31impl<B: BitGenerator> Generator<B> {
32    /// Generate an array of standard normal (mean=0, std=1) variates.
33    ///
34    /// Uses the Marsaglia–Tsang Ziggurat algorithm (256 layers), which is
35    /// roughly 3× faster than Box–Muller for large draws. `shape` accepts
36    /// `usize`, `[usize; N]`, `&[usize]`, or `Vec<usize>` via [`IntoShape`].
37    ///
38    /// # Errors
39    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
40    pub fn standard_normal(
41        &mut self,
42        size: impl IntoShape,
43    ) -> Result<Array<f64, IxDyn>, FerrayError> {
44        let shape = size.into_shape()?;
45        let n = shape_size(&shape);
46        let data = generate_vec(self, n, standard_normal_single);
47        vec_to_array_f64(data, &shape)
48    }
49
50    /// Generate an array of normal (Gaussian) variates with given mean and
51    /// standard deviation. Equivalent to `numpy.random.Generator.normal`.
52    ///
53    /// # Errors
54    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
55    pub fn normal(
56        &mut self,
57        loc: f64,
58        scale: f64,
59        size: impl IntoShape,
60    ) -> Result<Array<f64, IxDyn>, FerrayError> {
61        if scale <= 0.0 {
62            return Err(FerrayError::invalid_value(format!(
63                "scale must be positive, got {scale}"
64            )));
65        }
66        let shape = size.into_shape()?;
67        let n = shape_size(&shape);
68        let data = generate_vec(self, n, |bg| scale.mul_add(standard_normal_single(bg), loc));
69        vec_to_array_f64(data, &shape)
70    }
71
72    /// Generate an array of standard normal (mean=0, std=1) `f32` variates.
73    ///
74    /// The f32 analogue of [`standard_normal`](Self::standard_normal). Equivalent
75    /// to `NumPy`'s `Generator.standard_normal(size, dtype=np.float32)`. Uses the
76    /// same Ziggurat f64 tables, then casts to f32 to preserve tail accuracy.
77    ///
78    /// # Errors
79    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
80    pub fn standard_normal_f32(
81        &mut self,
82        size: impl IntoShape,
83    ) -> Result<Array<f32, IxDyn>, FerrayError> {
84        let shape = size.into_shape()?;
85        let n = shape_size(&shape);
86        let data = generate_vec_f32(self, n, standard_normal_single_f32);
87        vec_to_array_f32(data, &shape)
88    }
89
90    /// Generate an array of `f32` normal (Gaussian) variates with given mean
91    /// and standard deviation. The f32 analogue of [`normal`](Self::normal).
92    ///
93    /// # Errors
94    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
95    pub fn normal_f32(
96        &mut self,
97        loc: f32,
98        scale: f32,
99        size: impl IntoShape,
100    ) -> Result<Array<f32, IxDyn>, FerrayError> {
101        if scale <= 0.0 {
102            return Err(FerrayError::invalid_value(format!(
103                "scale must be positive, got {scale}"
104            )));
105        }
106        let shape = size.into_shape()?;
107        let n = shape_size(&shape);
108        let data = generate_vec_f32(self, n, |bg| {
109            scale.mul_add(standard_normal_single_f32(bg), loc)
110        });
111        vec_to_array_f32(data, &shape)
112    }
113
114    /// Generate an array of `f32` log-normal variates. The f32 analogue of
115    /// [`lognormal`](Self::lognormal).
116    ///
117    /// # Errors
118    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `shape` is invalid.
119    pub fn lognormal_f32(
120        &mut self,
121        mean: f32,
122        sigma: f32,
123        size: impl IntoShape,
124    ) -> Result<Array<f32, IxDyn>, FerrayError> {
125        if sigma <= 0.0 {
126            return Err(FerrayError::invalid_value(format!(
127                "sigma must be positive, got {sigma}"
128            )));
129        }
130        let shape = size.into_shape()?;
131        let n = shape_size(&shape);
132        let data = generate_vec_f32(self, n, |bg| {
133            sigma.mul_add(standard_normal_single_f32(bg), mean).exp()
134        });
135        vec_to_array_f32(data, &shape)
136    }
137
138    /// Generate an array of log-normal variates.
139    ///
140    /// If X ~ Normal(mean, sigma), then exp(X) ~ LogNormal(mean, sigma).
141    ///
142    /// # Errors
143    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `shape` is invalid.
144    pub fn lognormal(
145        &mut self,
146        mean: f64,
147        sigma: f64,
148        size: impl IntoShape,
149    ) -> Result<Array<f64, IxDyn>, FerrayError> {
150        if sigma <= 0.0 {
151            return Err(FerrayError::invalid_value(format!(
152                "sigma must be positive, got {sigma}"
153            )));
154        }
155        let shape = size.into_shape()?;
156        let n = shape_size(&shape);
157        let data = generate_vec(self, n, |bg| {
158            sigma.mul_add(standard_normal_single(bg), mean).exp()
159        });
160        vec_to_array_f64(data, &shape)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use crate::default_rng_seeded;
167
168    #[test]
169    fn standard_normal_deterministic() {
170        let mut rng1 = default_rng_seeded(42);
171        let mut rng2 = default_rng_seeded(42);
172        let a = rng1.standard_normal(1000).unwrap();
173        let b = rng2.standard_normal(1000).unwrap();
174        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
175    }
176
177    #[test]
178    fn standard_normal_mean_variance() {
179        let mut rng = default_rng_seeded(42);
180        let n = 100_000;
181        let arr = rng.standard_normal(n).unwrap();
182        let slice = arr.as_slice().unwrap();
183        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
184        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
185        let se = (1.0 / n as f64).sqrt();
186        assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
187        assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
188    }
189
190    #[test]
191    fn normal_mean_variance() {
192        let mut rng = default_rng_seeded(42);
193        let n = 100_000;
194        let loc = 5.0;
195        let scale = 2.0;
196        let arr = rng.normal(loc, scale, n).unwrap();
197        let slice = arr.as_slice().unwrap();
198        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
199        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
200        let se = (scale * scale / n as f64).sqrt();
201        assert!(
202            (mean - loc).abs() < 3.0 * se,
203            "mean {mean} too far from {loc}"
204        );
205        assert!(
206            (var - scale * scale).abs() < 0.2,
207            "variance {var} too far from {}",
208            scale * scale
209        );
210    }
211
212    #[test]
213    fn normal_bad_scale() {
214        let mut rng = default_rng_seeded(42);
215        assert!(rng.normal(0.0, 0.0, 100).is_err());
216        assert!(rng.normal(0.0, -1.0, 100).is_err());
217    }
218
219    #[test]
220    fn lognormal_positive() {
221        let mut rng = default_rng_seeded(42);
222        let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
223        let slice = arr.as_slice().unwrap();
224        for &v in slice {
225            assert!(v > 0.0, "lognormal produced non-positive value: {v}");
226        }
227    }
228
229    #[test]
230    fn lognormal_mean() {
231        let mut rng = default_rng_seeded(42);
232        let n = 100_000;
233        let mu = 0.0;
234        let sigma = 0.5;
235        let arr = rng.lognormal(mu, sigma, n).unwrap();
236        let slice = arr.as_slice().unwrap();
237        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
238        // E[X] = exp(mu + sigma^2 / 2)
239        let expected_mean = (mu + sigma * sigma / 2.0).exp();
240        let expected_var = (sigma * sigma).exp_m1() * 2.0f64.mul_add(mu, sigma * sigma).exp();
241        let se = (expected_var / n as f64).sqrt();
242        assert!(
243            (mean - expected_mean).abs() < 3.0 * se,
244            "lognormal mean {mean} too far from {expected_mean}"
245        );
246    }
247
248    #[test]
249    fn standard_normal_variance() {
250        let mut rng = default_rng_seeded(42);
251        let n = 100_000;
252        let arr = rng.standard_normal(n).unwrap();
253        let s = arr.as_slice().unwrap();
254        let mean: f64 = s.iter().sum::<f64>() / n as f64;
255        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
256        // Variance should be ~1.0
257        assert!(
258            (var - 1.0).abs() < 0.05,
259            "standard_normal variance {var} too far from 1.0"
260        );
261    }
262
263    #[test]
264    fn normal_mean_and_variance() {
265        let mut rng = default_rng_seeded(42);
266        let n = 100_000;
267        let loc = 5.0;
268        let scale = 2.0;
269        let arr = rng.normal(loc, scale, n).unwrap();
270        let s: Vec<f64> = arr.iter().copied().collect();
271        let mean: f64 = s.iter().sum::<f64>() / n as f64;
272        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
273        assert!(
274            (mean - loc).abs() < 0.05,
275            "normal mean {mean} too far from {loc}"
276        );
277        assert!(
278            (var - scale * scale).abs() < 0.2,
279            "normal variance {var} too far from {}",
280            scale * scale
281        );
282    }
283
284    // ----- N-D shape tests (issue #440) -----
285
286    #[test]
287    fn standard_normal_nd_shape() {
288        let mut rng = crate::default_rng_seeded(42);
289        let arr = rng.standard_normal([3, 4]).unwrap();
290        assert_eq!(arr.shape(), &[3, 4]);
291    }
292
293    #[test]
294    fn normal_nd_shape() {
295        let mut rng = crate::default_rng_seeded(42);
296        let arr = rng.normal(10.0, 2.0, [2, 3, 4]).unwrap();
297        assert_eq!(arr.shape(), &[2, 3, 4]);
298    }
299
300    #[test]
301    fn lognormal_nd_shape() {
302        let mut rng = crate::default_rng_seeded(42);
303        let arr = rng.lognormal(0.0, 1.0, [5, 5]).unwrap();
304        assert_eq!(arr.shape(), &[5, 5]);
305        for &v in arr.iter() {
306            assert!(v > 0.0);
307        }
308    }
309
310    // ---------------------------------------------------------------
311    // f32 variants (issue #441)
312    // ---------------------------------------------------------------
313
314    #[test]
315    fn standard_normal_f32_deterministic() {
316        let mut rng1 = default_rng_seeded(42);
317        let mut rng2 = default_rng_seeded(42);
318        let a = rng1.standard_normal_f32(1000).unwrap();
319        let b = rng2.standard_normal_f32(1000).unwrap();
320        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
321    }
322
323    #[test]
324    fn standard_normal_f32_mean_variance() {
325        let mut rng = default_rng_seeded(42);
326        let n = 100_000usize;
327        let arr = rng.standard_normal_f32(n).unwrap();
328        let slice = arr.as_slice().unwrap();
329        // Accumulate in f64 to avoid compounding f32 error.
330        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
331        let var: f64 = slice
332            .iter()
333            .map(|&x| {
334                let d = x as f64 - mean;
335                d * d
336            })
337            .sum::<f64>()
338            / n as f64;
339        let se = (1.0 / n as f64).sqrt();
340        assert!(mean.abs() < 5.0 * se, "f32 mean {mean} too far from 0");
341        assert!(
342            (var - 1.0).abs() < 0.05,
343            "f32 variance {var} too far from 1"
344        );
345    }
346
347    #[test]
348    fn standard_normal_f32_nd_shape() {
349        let mut rng = default_rng_seeded(42);
350        let arr = rng.standard_normal_f32([3, 4]).unwrap();
351        assert_eq!(arr.shape(), &[3, 4]);
352    }
353
354    #[test]
355    fn normal_f32_mean() {
356        let mut rng = default_rng_seeded(42);
357        let n = 100_000usize;
358        let loc = 5.0f32;
359        let scale = 2.0f32;
360        let arr = rng.normal_f32(loc, scale, n).unwrap();
361        let slice = arr.as_slice().unwrap();
362        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
363        assert!(
364            (mean - loc as f64).abs() < 0.05,
365            "f32 normal mean {mean} too far from {loc}"
366        );
367    }
368
369    #[test]
370    fn normal_f32_bad_scale() {
371        let mut rng = default_rng_seeded(42);
372        assert!(rng.normal_f32(0.0, 0.0, 100).is_err());
373        assert!(rng.normal_f32(0.0, -1.0, 100).is_err());
374    }
375
376    #[test]
377    fn lognormal_f32_positive() {
378        let mut rng = default_rng_seeded(42);
379        let arr = rng.lognormal_f32(0.0, 1.0, 10_000).unwrap();
380        for &v in arr.as_slice().unwrap() {
381            assert!(v > 0.0, "lognormal_f32 produced non-positive value: {v}");
382        }
383    }
384
385    #[test]
386    fn lognormal_f32_bad_sigma() {
387        let mut rng = default_rng_seeded(42);
388        assert!(rng.lognormal_f32(0.0, 0.0, 100).is_err());
389        assert!(rng.lognormal_f32(0.0, -0.5, 100).is_err());
390    }
391
392    // ----- NaN/Inf parameter input tests (#263) -----
393
394    #[test]
395    fn normal_nan_loc_produces_nan_output() {
396        // NumPy: np.random.default_rng(42).normal(np.nan, 1.0, 5) → all NaN
397        let mut rng = default_rng_seeded(42);
398        let arr = rng.normal(f64::NAN, 1.0, 5).unwrap();
399        for &v in arr.as_slice().unwrap() {
400            assert!(v.is_nan(), "expected NaN, got {v}");
401        }
402    }
403
404    #[test]
405    fn normal_inf_scale_produces_inf_output() {
406        // Infinite scale → every sample is ±inf.
407        let mut rng = default_rng_seeded(42);
408        let arr = rng.normal(0.0, f64::INFINITY, 5).unwrap();
409        for &v in arr.as_slice().unwrap() {
410            assert!(v.is_infinite() || v.is_nan(), "expected Inf/NaN, got {v}");
411        }
412    }
413
414    #[test]
415    fn normal_nan_scale_rejected() {
416        // NaN scale should propagate — the output is meaningless but
417        // at minimum must not panic.
418        let mut rng = default_rng_seeded(42);
419        // scale <= 0 is rejected by parameter validation; NaN is
420        // neither > 0 nor <= 0, so the check may or may not catch it
421        // depending on the implementation. Just assert no panic.
422        let _ = rng.normal(0.0, f64::NAN, 5);
423    }
424}