Skip to main content

ferray_random/distributions/
normal.rs

1// ferray-random: Normal distribution sampling — standard_normal, normal, lognormal
2
3use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::distributions::ziggurat::{standard_normal_ziggurat, standard_normal_ziggurat_f32};
7use crate::generator::{
8    Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
9};
10use crate::shape::IntoShape;
11
12/// Generate a single standard normal variate via the Ziggurat algorithm.
13///
14/// Ziggurat is ~3× faster than Box-Muller because ~98% of calls take a fast
15/// path that uses only one `next_u64`, a multiplication and a comparison.
16/// See [`ziggurat`](super::ziggurat) for the layer-table construction and
17/// the slow-path rejection logic.
18pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
19    standard_normal_ziggurat(bg)
20}
21
22/// Generate a single f32 standard normal variate via Ziggurat.
23///
24/// The sampling is performed in f64 (Ziggurat tables are f64) and cast to
25/// f32. This costs essentially nothing on modern CPUs and preserves the full
26/// tail accuracy of the f64 path.
27pub(crate) fn standard_normal_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
28    standard_normal_ziggurat_f32(bg)
29}
30
31impl<B: BitGenerator> Generator<B> {
32    /// Generate an array of standard normal (mean=0, std=1) variates.
33    ///
34    /// Uses the Marsaglia–Tsang Ziggurat algorithm (256 layers), which is
35    /// roughly 3× faster than Box–Muller for large draws. `shape` accepts
36    /// `usize`, `[usize; N]`, `&[usize]`, or `Vec<usize>` via [`IntoShape`].
37    ///
38    /// # Errors
39    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
40    pub fn standard_normal(
41        &mut self,
42        size: impl IntoShape,
43    ) -> Result<Array<f64, IxDyn>, FerrayError> {
44        let shape = size.into_shape()?;
45        let n = shape_size(&shape);
46        let data = generate_vec(self, n, standard_normal_single);
47        vec_to_array_f64(data, &shape)
48    }
49
50    /// Generate an array of normal (Gaussian) variates with given mean and
51    /// standard deviation. Equivalent to `numpy.random.Generator.normal`.
52    ///
53    /// # Errors
54    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
55    pub fn normal(
56        &mut self,
57        loc: f64,
58        scale: f64,
59        size: impl IntoShape,
60    ) -> Result<Array<f64, IxDyn>, FerrayError> {
61        if scale <= 0.0 {
62            return Err(FerrayError::invalid_value(format!(
63                "scale must be positive, got {scale}"
64            )));
65        }
66        let shape = size.into_shape()?;
67        let n = shape_size(&shape);
68        let data = generate_vec(self, n, |bg| loc + scale * standard_normal_single(bg));
69        vec_to_array_f64(data, &shape)
70    }
71
72    /// Generate an array of standard normal (mean=0, std=1) `f32` variates.
73    ///
74    /// The f32 analogue of [`standard_normal`](Self::standard_normal). Equivalent
75    /// to NumPy's `Generator.standard_normal(size, dtype=np.float32)`. Uses the
76    /// same Ziggurat f64 tables, then casts to f32 to preserve tail accuracy.
77    ///
78    /// # Errors
79    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
80    pub fn standard_normal_f32(
81        &mut self,
82        size: impl IntoShape,
83    ) -> Result<Array<f32, IxDyn>, FerrayError> {
84        let shape = size.into_shape()?;
85        let n = shape_size(&shape);
86        let data = generate_vec_f32(self, n, standard_normal_single_f32);
87        vec_to_array_f32(data, &shape)
88    }
89
90    /// Generate an array of `f32` normal (Gaussian) variates with given mean
91    /// and standard deviation. The f32 analogue of [`normal`](Self::normal).
92    ///
93    /// # Errors
94    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
95    pub fn normal_f32(
96        &mut self,
97        loc: f32,
98        scale: f32,
99        size: impl IntoShape,
100    ) -> Result<Array<f32, IxDyn>, FerrayError> {
101        if scale <= 0.0 {
102            return Err(FerrayError::invalid_value(format!(
103                "scale must be positive, got {scale}"
104            )));
105        }
106        let shape = size.into_shape()?;
107        let n = shape_size(&shape);
108        let data = generate_vec_f32(self, n, |bg| loc + scale * standard_normal_single_f32(bg));
109        vec_to_array_f32(data, &shape)
110    }
111
112    /// Generate an array of `f32` log-normal variates. The f32 analogue of
113    /// [`lognormal`](Self::lognormal).
114    ///
115    /// # Errors
116    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `shape` is invalid.
117    pub fn lognormal_f32(
118        &mut self,
119        mean: f32,
120        sigma: f32,
121        size: impl IntoShape,
122    ) -> Result<Array<f32, IxDyn>, FerrayError> {
123        if sigma <= 0.0 {
124            return Err(FerrayError::invalid_value(format!(
125                "sigma must be positive, got {sigma}"
126            )));
127        }
128        let shape = size.into_shape()?;
129        let n = shape_size(&shape);
130        let data = generate_vec_f32(self, n, |bg| {
131            (mean + sigma * standard_normal_single_f32(bg)).exp()
132        });
133        vec_to_array_f32(data, &shape)
134    }
135
136    /// Generate an array of log-normal variates.
137    ///
138    /// If X ~ Normal(mean, sigma), then exp(X) ~ LogNormal(mean, sigma).
139    ///
140    /// # Errors
141    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `shape` is invalid.
142    pub fn lognormal(
143        &mut self,
144        mean: f64,
145        sigma: f64,
146        size: impl IntoShape,
147    ) -> Result<Array<f64, IxDyn>, FerrayError> {
148        if sigma <= 0.0 {
149            return Err(FerrayError::invalid_value(format!(
150                "sigma must be positive, got {sigma}"
151            )));
152        }
153        let shape = size.into_shape()?;
154        let n = shape_size(&shape);
155        let data = generate_vec(self, n, |bg| {
156            (mean + sigma * standard_normal_single(bg)).exp()
157        });
158        vec_to_array_f64(data, &shape)
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use crate::default_rng_seeded;
165
166    #[test]
167    fn standard_normal_deterministic() {
168        let mut rng1 = default_rng_seeded(42);
169        let mut rng2 = default_rng_seeded(42);
170        let a = rng1.standard_normal(1000).unwrap();
171        let b = rng2.standard_normal(1000).unwrap();
172        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
173    }
174
175    #[test]
176    fn standard_normal_mean_variance() {
177        let mut rng = default_rng_seeded(42);
178        let n = 100_000;
179        let arr = rng.standard_normal(n).unwrap();
180        let slice = arr.as_slice().unwrap();
181        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
182        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
183        let se = (1.0 / n as f64).sqrt();
184        assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
185        assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
186    }
187
188    #[test]
189    fn normal_mean_variance() {
190        let mut rng = default_rng_seeded(42);
191        let n = 100_000;
192        let loc = 5.0;
193        let scale = 2.0;
194        let arr = rng.normal(loc, scale, n).unwrap();
195        let slice = arr.as_slice().unwrap();
196        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
197        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
198        let se = (scale * scale / n as f64).sqrt();
199        assert!(
200            (mean - loc).abs() < 3.0 * se,
201            "mean {mean} too far from {loc}"
202        );
203        assert!(
204            (var - scale * scale).abs() < 0.2,
205            "variance {var} too far from {}",
206            scale * scale
207        );
208    }
209
210    #[test]
211    fn normal_bad_scale() {
212        let mut rng = default_rng_seeded(42);
213        assert!(rng.normal(0.0, 0.0, 100).is_err());
214        assert!(rng.normal(0.0, -1.0, 100).is_err());
215    }
216
217    #[test]
218    fn lognormal_positive() {
219        let mut rng = default_rng_seeded(42);
220        let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
221        let slice = arr.as_slice().unwrap();
222        for &v in slice {
223            assert!(v > 0.0, "lognormal produced non-positive value: {v}");
224        }
225    }
226
227    #[test]
228    fn lognormal_mean() {
229        let mut rng = default_rng_seeded(42);
230        let n = 100_000;
231        let mu = 0.0;
232        let sigma = 0.5;
233        let arr = rng.lognormal(mu, sigma, n).unwrap();
234        let slice = arr.as_slice().unwrap();
235        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
236        // E[X] = exp(mu + sigma^2 / 2)
237        let expected_mean = (mu + sigma * sigma / 2.0).exp();
238        let expected_var = ((sigma * sigma).exp() - 1.0) * (2.0 * mu + sigma * sigma).exp();
239        let se = (expected_var / n as f64).sqrt();
240        assert!(
241            (mean - expected_mean).abs() < 3.0 * se,
242            "lognormal mean {mean} too far from {expected_mean}"
243        );
244    }
245
246    #[test]
247    fn standard_normal_variance() {
248        let mut rng = default_rng_seeded(42);
249        let n = 100_000;
250        let arr = rng.standard_normal(n).unwrap();
251        let s = arr.as_slice().unwrap();
252        let mean: f64 = s.iter().sum::<f64>() / n as f64;
253        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
254        // Variance should be ~1.0
255        assert!(
256            (var - 1.0).abs() < 0.05,
257            "standard_normal variance {var} too far from 1.0"
258        );
259    }
260
261    #[test]
262    fn normal_mean_and_variance() {
263        let mut rng = default_rng_seeded(42);
264        let n = 100_000;
265        let loc = 5.0;
266        let scale = 2.0;
267        let arr = rng.normal(loc, scale, n).unwrap();
268        let s: Vec<f64> = arr.iter().copied().collect();
269        let mean: f64 = s.iter().sum::<f64>() / n as f64;
270        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
271        assert!(
272            (mean - loc).abs() < 0.05,
273            "normal mean {mean} too far from {loc}"
274        );
275        assert!(
276            (var - scale * scale).abs() < 0.2,
277            "normal variance {var} too far from {}",
278            scale * scale
279        );
280    }
281
282    // ----- N-D shape tests (issue #440) -----
283
284    #[test]
285    fn standard_normal_nd_shape() {
286        let mut rng = crate::default_rng_seeded(42);
287        let arr = rng.standard_normal([3, 4]).unwrap();
288        assert_eq!(arr.shape(), &[3, 4]);
289    }
290
291    #[test]
292    fn normal_nd_shape() {
293        let mut rng = crate::default_rng_seeded(42);
294        let arr = rng.normal(10.0, 2.0, [2, 3, 4]).unwrap();
295        assert_eq!(arr.shape(), &[2, 3, 4]);
296    }
297
298    #[test]
299    fn lognormal_nd_shape() {
300        let mut rng = crate::default_rng_seeded(42);
301        let arr = rng.lognormal(0.0, 1.0, [5, 5]).unwrap();
302        assert_eq!(arr.shape(), &[5, 5]);
303        for &v in arr.iter() {
304            assert!(v > 0.0);
305        }
306    }
307
308    // ---------------------------------------------------------------
309    // f32 variants (issue #441)
310    // ---------------------------------------------------------------
311
312    #[test]
313    fn standard_normal_f32_deterministic() {
314        let mut rng1 = default_rng_seeded(42);
315        let mut rng2 = default_rng_seeded(42);
316        let a = rng1.standard_normal_f32(1000).unwrap();
317        let b = rng2.standard_normal_f32(1000).unwrap();
318        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
319    }
320
321    #[test]
322    fn standard_normal_f32_mean_variance() {
323        let mut rng = default_rng_seeded(42);
324        let n = 100_000usize;
325        let arr = rng.standard_normal_f32(n).unwrap();
326        let slice = arr.as_slice().unwrap();
327        // Accumulate in f64 to avoid compounding f32 error.
328        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
329        let var: f64 = slice
330            .iter()
331            .map(|&x| {
332                let d = x as f64 - mean;
333                d * d
334            })
335            .sum::<f64>()
336            / n as f64;
337        let se = (1.0 / n as f64).sqrt();
338        assert!(mean.abs() < 5.0 * se, "f32 mean {mean} too far from 0");
339        assert!(
340            (var - 1.0).abs() < 0.05,
341            "f32 variance {var} too far from 1"
342        );
343    }
344
345    #[test]
346    fn standard_normal_f32_nd_shape() {
347        let mut rng = default_rng_seeded(42);
348        let arr = rng.standard_normal_f32([3, 4]).unwrap();
349        assert_eq!(arr.shape(), &[3, 4]);
350    }
351
352    #[test]
353    fn normal_f32_mean() {
354        let mut rng = default_rng_seeded(42);
355        let n = 100_000usize;
356        let loc = 5.0f32;
357        let scale = 2.0f32;
358        let arr = rng.normal_f32(loc, scale, n).unwrap();
359        let slice = arr.as_slice().unwrap();
360        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
361        assert!(
362            (mean - loc as f64).abs() < 0.05,
363            "f32 normal mean {mean} too far from {loc}"
364        );
365    }
366
367    #[test]
368    fn normal_f32_bad_scale() {
369        let mut rng = default_rng_seeded(42);
370        assert!(rng.normal_f32(0.0, 0.0, 100).is_err());
371        assert!(rng.normal_f32(0.0, -1.0, 100).is_err());
372    }
373
374    #[test]
375    fn lognormal_f32_positive() {
376        let mut rng = default_rng_seeded(42);
377        let arr = rng.lognormal_f32(0.0, 1.0, 10_000).unwrap();
378        for &v in arr.as_slice().unwrap() {
379            assert!(v > 0.0, "lognormal_f32 produced non-positive value: {v}");
380        }
381    }
382
383    #[test]
384    fn lognormal_f32_bad_sigma() {
385        let mut rng = default_rng_seeded(42);
386        assert!(rng.lognormal_f32(0.0, 0.0, 100).is_err());
387        assert!(rng.lognormal_f32(0.0, -0.5, 100).is_err());
388    }
389
390    // ----- NaN/Inf parameter input tests (#263) -----
391
392    #[test]
393    fn normal_nan_loc_produces_nan_output() {
394        // NumPy: np.random.default_rng(42).normal(np.nan, 1.0, 5) → all NaN
395        let mut rng = default_rng_seeded(42);
396        let arr = rng.normal(f64::NAN, 1.0, 5).unwrap();
397        for &v in arr.as_slice().unwrap() {
398            assert!(v.is_nan(), "expected NaN, got {v}");
399        }
400    }
401
402    #[test]
403    fn normal_inf_scale_produces_inf_output() {
404        // Infinite scale → every sample is ±inf.
405        let mut rng = default_rng_seeded(42);
406        let arr = rng.normal(0.0, f64::INFINITY, 5).unwrap();
407        for &v in arr.as_slice().unwrap() {
408            assert!(v.is_infinite() || v.is_nan(), "expected Inf/NaN, got {v}");
409        }
410    }
411
412    #[test]
413    fn normal_nan_scale_rejected() {
414        // NaN scale should propagate — the output is meaningless but
415        // at minimum must not panic.
416        let mut rng = default_rng_seeded(42);
417        // scale <= 0 is rejected by parameter validation; NaN is
418        // neither > 0 nor <= 0, so the check may or may not catch it
419        // depending on the implementation. Just assert no panic.
420        let _ = rng.normal(0.0, f64::NAN, 5);
421    }
422}