Skip to main content

ferray_random/distributions/
normal.rs

1// ferray-random: Normal distribution sampling — standard_normal, normal, lognormal
2
3use ferray_core::dimension::broadcast::broadcast_shapes;
4use ferray_core::{Array, FerrayError, IxDyn};
5
6use crate::bitgen::BitGenerator;
7use crate::distributions::ziggurat::{standard_normal_ziggurat, standard_normal_ziggurat_f32};
8use crate::generator::{
9    Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
10};
11use crate::shape::IntoShape;
12
13/// Generate a single standard normal variate via the Ziggurat algorithm.
14///
15/// Ziggurat is ~3× faster than Box-Muller because ~98% of calls take a fast
16/// path that uses only one `next_u64`, a multiplication and a comparison.
17/// See [`ziggurat`](super::ziggurat) for the layer-table construction and
18/// the slow-path rejection logic.
19pub(crate) fn standard_normal_single<B: BitGenerator>(bg: &mut B) -> f64 {
20    standard_normal_ziggurat(bg)
21}
22
23/// Generate a single f32 standard normal variate via Ziggurat.
24///
25/// The sampling is performed in f64 (Ziggurat tables are f64) and cast to
26/// f32. This costs essentially nothing on modern CPUs and preserves the full
27/// tail accuracy of the f64 path.
28pub(crate) fn standard_normal_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
29    standard_normal_ziggurat_f32(bg)
30}
31
32impl<B: BitGenerator> Generator<B> {
33    /// Generate an array of standard normal (mean=0, std=1) variates.
34    ///
35    /// Uses the Marsaglia–Tsang Ziggurat algorithm (256 layers), which is
36    /// roughly 3× faster than Box–Muller for large draws. `shape` accepts
37    /// `usize`, `[usize; N]`, `&[usize]`, or `Vec<usize>` via [`IntoShape`].
38    ///
39    /// # Errors
40    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
41    pub fn standard_normal(
42        &mut self,
43        size: impl IntoShape,
44    ) -> Result<Array<f64, IxDyn>, FerrayError> {
45        let shape = size.into_shape()?;
46        let n = shape_size(&shape);
47        let data = generate_vec(self, n, standard_normal_single);
48        vec_to_array_f64(data, &shape)
49    }
50
51    /// Generate an array of normal (Gaussian) variates with given mean and
52    /// standard deviation. Equivalent to `numpy.random.Generator.normal`.
53    ///
54    /// # Errors
55    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
56    pub fn normal(
57        &mut self,
58        loc: f64,
59        scale: f64,
60        size: impl IntoShape,
61    ) -> Result<Array<f64, IxDyn>, FerrayError> {
62        if scale <= 0.0 {
63            return Err(FerrayError::invalid_value(format!(
64                "scale must be positive, got {scale}"
65            )));
66        }
67        let shape = size.into_shape()?;
68        let n = shape_size(&shape);
69        let data = generate_vec(self, n, |bg| scale.mul_add(standard_normal_single(bg), loc));
70        vec_to_array_f64(data, &shape)
71    }
72
73    /// Fill a pre-allocated `out` buffer with standard normal
74    /// (mean=0, std=1) variates (#454).
75    ///
76    /// Equivalent to `numpy.random.Generator.standard_normal(out=buffer)`.
77    /// Each slot is overwritten with a fresh Ziggurat draw — no heap
78    /// allocation.
79    ///
80    /// # Errors
81    /// `FerrayError::InvalidValue` if `out` is non-contiguous.
82    pub fn standard_normal_into(&mut self, out: &mut Array<f64, IxDyn>) -> Result<(), FerrayError> {
83        let slice = out.as_slice_mut().ok_or_else(|| {
84            FerrayError::invalid_value("standard_normal_into requires a contiguous out buffer")
85        })?;
86        for v in slice.iter_mut() {
87            *v = standard_normal_single(&mut self.bg);
88        }
89        Ok(())
90    }
91
92    /// Generate an array of normal variates with broadcast array
93    /// parameters (#449).
94    ///
95    /// `loc` and `scale` are arrays that NumPy-broadcast against each
96    /// other to produce the output shape. Each output element is
97    /// `loc[i] + scale[i] * Z` where `Z` is a fresh standard-normal
98    /// draw. Equivalent to `numpy.random.Generator.normal(loc, scale)`
99    /// when both are arrays.
100    ///
101    /// For scalar parameters, prefer [`normal`](Self::normal) — it
102    /// avoids the broadcast view and is faster.
103    ///
104    /// # Errors
105    /// - `FerrayError::ShapeMismatch` if the two shapes are not
106    ///   broadcast-compatible.
107    /// - `FerrayError::InvalidValue` if any `scale` element is `<= 0`.
108    pub fn normal_array(
109        &mut self,
110        loc: &Array<f64, IxDyn>,
111        scale: &Array<f64, IxDyn>,
112    ) -> Result<Array<f64, IxDyn>, FerrayError> {
113        let target = broadcast_shapes(loc.shape(), scale.shape())?;
114        let loc_v = loc.broadcast_to(&target)?;
115        let scale_v = scale.broadcast_to(&target)?;
116        let total: usize = target.iter().product();
117        let mut out: Vec<f64> = Vec::with_capacity(total);
118        for (&l, &s) in loc_v.iter().zip(scale_v.iter()) {
119            if s <= 0.0 {
120                return Err(FerrayError::invalid_value(format!(
121                    "scale must be positive, got {s}"
122                )));
123            }
124            out.push(s.mul_add(standard_normal_single(&mut self.bg), l));
125        }
126        Array::<f64, IxDyn>::from_vec(IxDyn::new(&target), out)
127    }
128
129    /// Generate an array of standard normal (mean=0, std=1) `f32` variates.
130    ///
131    /// The f32 analogue of [`standard_normal`](Self::standard_normal). Equivalent
132    /// to `NumPy`'s `Generator.standard_normal(size, dtype=np.float32)`. Uses the
133    /// same Ziggurat f64 tables, then casts to f32 to preserve tail accuracy.
134    ///
135    /// # Errors
136    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
137    pub fn standard_normal_f32(
138        &mut self,
139        size: impl IntoShape,
140    ) -> Result<Array<f32, IxDyn>, FerrayError> {
141        let shape = size.into_shape()?;
142        let n = shape_size(&shape);
143        let data = generate_vec_f32(self, n, standard_normal_single_f32);
144        vec_to_array_f32(data, &shape)
145    }
146
147    /// Generate an array of `f32` normal (Gaussian) variates with given mean
148    /// and standard deviation. The f32 analogue of [`normal`](Self::normal).
149    ///
150    /// # Errors
151    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
152    pub fn normal_f32(
153        &mut self,
154        loc: f32,
155        scale: f32,
156        size: impl IntoShape,
157    ) -> Result<Array<f32, IxDyn>, FerrayError> {
158        if scale <= 0.0 {
159            return Err(FerrayError::invalid_value(format!(
160                "scale must be positive, got {scale}"
161            )));
162        }
163        let shape = size.into_shape()?;
164        let n = shape_size(&shape);
165        let data = generate_vec_f32(self, n, |bg| {
166            scale.mul_add(standard_normal_single_f32(bg), loc)
167        });
168        vec_to_array_f32(data, &shape)
169    }
170
171    /// Generate an array of `f32` log-normal variates. The f32 analogue of
172    /// [`lognormal`](Self::lognormal).
173    ///
174    /// # Errors
175    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `shape` is invalid.
176    pub fn lognormal_f32(
177        &mut self,
178        mean: f32,
179        sigma: f32,
180        size: impl IntoShape,
181    ) -> Result<Array<f32, IxDyn>, FerrayError> {
182        if sigma <= 0.0 {
183            return Err(FerrayError::invalid_value(format!(
184                "sigma must be positive, got {sigma}"
185            )));
186        }
187        let shape = size.into_shape()?;
188        let n = shape_size(&shape);
189        let data = generate_vec_f32(self, n, |bg| {
190            sigma.mul_add(standard_normal_single_f32(bg), mean).exp()
191        });
192        vec_to_array_f32(data, &shape)
193    }
194
195    /// Generate an array of log-normal variates.
196    ///
197    /// If X ~ Normal(mean, sigma), then exp(X) ~ LogNormal(mean, sigma).
198    ///
199    /// # Errors
200    /// Returns `FerrayError::InvalidValue` if `sigma <= 0` or `shape` is invalid.
201    pub fn lognormal(
202        &mut self,
203        mean: f64,
204        sigma: f64,
205        size: impl IntoShape,
206    ) -> Result<Array<f64, IxDyn>, FerrayError> {
207        if sigma <= 0.0 {
208            return Err(FerrayError::invalid_value(format!(
209                "sigma must be positive, got {sigma}"
210            )));
211        }
212        let shape = size.into_shape()?;
213        let n = shape_size(&shape);
214        let data = generate_vec(self, n, |bg| {
215            sigma.mul_add(standard_normal_single(bg), mean).exp()
216        });
217        vec_to_array_f64(data, &shape)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use crate::default_rng_seeded;
224
225    #[test]
226    fn standard_normal_deterministic() {
227        let mut rng1 = default_rng_seeded(42);
228        let mut rng2 = default_rng_seeded(42);
229        let a = rng1.standard_normal(1000).unwrap();
230        let b = rng2.standard_normal(1000).unwrap();
231        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
232    }
233
234    #[test]
235    fn standard_normal_mean_variance() {
236        let mut rng = default_rng_seeded(42);
237        let n = 100_000;
238        let arr = rng.standard_normal(n).unwrap();
239        let slice = arr.as_slice().unwrap();
240        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
241        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
242        let se = (1.0 / n as f64).sqrt();
243        assert!(mean.abs() < 3.0 * se, "mean {mean} too far from 0");
244        assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
245    }
246
247    #[test]
248    fn normal_mean_variance() {
249        let mut rng = default_rng_seeded(42);
250        let n = 100_000;
251        let loc = 5.0;
252        let scale = 2.0;
253        let arr = rng.normal(loc, scale, n).unwrap();
254        let slice = arr.as_slice().unwrap();
255        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
256        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
257        let se = (scale * scale / n as f64).sqrt();
258        assert!(
259            (mean - loc).abs() < 3.0 * se,
260            "mean {mean} too far from {loc}"
261        );
262        assert!(
263            (var - scale * scale).abs() < 0.2,
264            "variance {var} too far from {}",
265            scale * scale
266        );
267    }
268
269    #[test]
270    fn normal_bad_scale() {
271        let mut rng = default_rng_seeded(42);
272        assert!(rng.normal(0.0, 0.0, 100).is_err());
273        assert!(rng.normal(0.0, -1.0, 100).is_err());
274    }
275
276    #[test]
277    fn standard_normal_into_matches_allocating_version() {
278        use ferray_core::{Array, IxDyn};
279        let mut a = default_rng_seeded(42);
280        let mut b = default_rng_seeded(42);
281        let allocated = a.standard_normal([4, 5]).unwrap();
282        let mut buf = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[4, 5]), vec![0.0; 20]).unwrap();
283        b.standard_normal_into(&mut buf).unwrap();
284        assert_eq!(allocated.as_slice().unwrap(), buf.as_slice().unwrap());
285    }
286
287    #[test]
288    fn normal_array_broadcast_scalar_x_vector() {
289        use ferray_core::IxDyn;
290        let mut rng = default_rng_seeded(42);
291        // loc shape (3,), scale shape (1,) — broadcast to (3,).
292        let loc =
293            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0, 10.0, -5.0])
294                .unwrap();
295        let scale =
296            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[1]), vec![1.0]).unwrap();
297        let out = rng.normal_array(&loc, &scale).unwrap();
298        assert_eq!(out.shape(), &[3]);
299    }
300
301    #[test]
302    fn normal_array_2d_broadcast_means_match_loc() {
303        use ferray_core::IxDyn;
304        // loc shape (3, 1), scale shape (1, 4) → output (3, 4) where
305        // every row j shares loc[j] and every column shares scale.
306        // With many draws per element the per-row mean → loc[j].
307        let mut rng = default_rng_seeded(7);
308        let loc =
309            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3, 1]), vec![0.0, 5.0, -3.0])
310                .unwrap();
311        let scale = ferray_core::Array::<f64, IxDyn>::from_vec(
312            IxDyn::new(&[1, 4]),
313            vec![1.0, 0.5, 2.0, 0.1],
314        )
315        .unwrap();
316
317        let n_trials = 5_000;
318        let mut row_sums = [0.0_f64; 3];
319        for _ in 0..n_trials {
320            let out = rng.normal_array(&loc, &scale).unwrap();
321            assert_eq!(out.shape(), &[3, 4]);
322            let s = out.as_slice().unwrap();
323            for r in 0..3 {
324                for c in 0..4 {
325                    row_sums[r] += s[r * 4 + c];
326                }
327            }
328        }
329        // Row r averages over 4 columns × n_trials draws with mean loc[r].
330        let denom = (n_trials * 4) as f64;
331        let expected = [0.0, 5.0, -3.0];
332        for r in 0..3 {
333            let m = row_sums[r] / denom;
334            assert!(
335                (m - expected[r]).abs() < 0.05,
336                "row {r} mean {m} too far from {}",
337                expected[r]
338            );
339        }
340    }
341
342    #[test]
343    fn normal_array_bad_scale_errors() {
344        use ferray_core::IxDyn;
345        let mut rng = default_rng_seeded(0);
346        let loc =
347            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![0.0, 0.0]).unwrap();
348        let scale =
349            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, -0.5]).unwrap();
350        assert!(rng.normal_array(&loc, &scale).is_err());
351    }
352
353    #[test]
354    fn normal_array_shape_mismatch_errors() {
355        use ferray_core::IxDyn;
356        let mut rng = default_rng_seeded(0);
357        let loc =
358            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0.0; 3]).unwrap();
359        let scale =
360            ferray_core::Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0; 2]).unwrap();
361        assert!(rng.normal_array(&loc, &scale).is_err());
362    }
363
364    #[test]
365    fn lognormal_positive() {
366        let mut rng = default_rng_seeded(42);
367        let arr = rng.lognormal(0.0, 1.0, 10_000).unwrap();
368        let slice = arr.as_slice().unwrap();
369        for &v in slice {
370            assert!(v > 0.0, "lognormal produced non-positive value: {v}");
371        }
372    }
373
374    #[test]
375    fn lognormal_mean() {
376        let mut rng = default_rng_seeded(42);
377        let n = 100_000;
378        let mu = 0.0;
379        let sigma = 0.5;
380        let arr = rng.lognormal(mu, sigma, n).unwrap();
381        let slice = arr.as_slice().unwrap();
382        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
383        // E[X] = exp(mu + sigma^2 / 2)
384        let expected_mean = (mu + sigma * sigma / 2.0).exp();
385        let expected_var = (sigma * sigma).exp_m1() * 2.0f64.mul_add(mu, sigma * sigma).exp();
386        let se = (expected_var / n as f64).sqrt();
387        assert!(
388            (mean - expected_mean).abs() < 3.0 * se,
389            "lognormal mean {mean} too far from {expected_mean}"
390        );
391    }
392
393    #[test]
394    fn standard_normal_variance() {
395        let mut rng = default_rng_seeded(42);
396        let n = 100_000;
397        let arr = rng.standard_normal(n).unwrap();
398        let s = arr.as_slice().unwrap();
399        let mean: f64 = s.iter().sum::<f64>() / n as f64;
400        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
401        // Variance should be ~1.0
402        assert!(
403            (var - 1.0).abs() < 0.05,
404            "standard_normal variance {var} too far from 1.0"
405        );
406    }
407
408    #[test]
409    fn normal_mean_and_variance() {
410        let mut rng = default_rng_seeded(42);
411        let n = 100_000;
412        let loc = 5.0;
413        let scale = 2.0;
414        let arr = rng.normal(loc, scale, n).unwrap();
415        let s: Vec<f64> = arr.iter().copied().collect();
416        let mean: f64 = s.iter().sum::<f64>() / n as f64;
417        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
418        assert!(
419            (mean - loc).abs() < 0.05,
420            "normal mean {mean} too far from {loc}"
421        );
422        assert!(
423            (var - scale * scale).abs() < 0.2,
424            "normal variance {var} too far from {}",
425            scale * scale
426        );
427    }
428
429    // ----- N-D shape tests (issue #440) -----
430
431    #[test]
432    fn standard_normal_nd_shape() {
433        let mut rng = crate::default_rng_seeded(42);
434        let arr = rng.standard_normal([3, 4]).unwrap();
435        assert_eq!(arr.shape(), &[3, 4]);
436    }
437
438    #[test]
439    fn normal_nd_shape() {
440        let mut rng = crate::default_rng_seeded(42);
441        let arr = rng.normal(10.0, 2.0, [2, 3, 4]).unwrap();
442        assert_eq!(arr.shape(), &[2, 3, 4]);
443    }
444
445    #[test]
446    fn lognormal_nd_shape() {
447        let mut rng = crate::default_rng_seeded(42);
448        let arr = rng.lognormal(0.0, 1.0, [5, 5]).unwrap();
449        assert_eq!(arr.shape(), &[5, 5]);
450        for &v in arr.iter() {
451            assert!(v > 0.0);
452        }
453    }
454
455    // ---------------------------------------------------------------
456    // f32 variants (issue #441)
457    // ---------------------------------------------------------------
458
459    #[test]
460    fn standard_normal_f32_deterministic() {
461        let mut rng1 = default_rng_seeded(42);
462        let mut rng2 = default_rng_seeded(42);
463        let a = rng1.standard_normal_f32(1000).unwrap();
464        let b = rng2.standard_normal_f32(1000).unwrap();
465        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
466    }
467
468    #[test]
469    fn standard_normal_f32_mean_variance() {
470        let mut rng = default_rng_seeded(42);
471        let n = 100_000usize;
472        let arr = rng.standard_normal_f32(n).unwrap();
473        let slice = arr.as_slice().unwrap();
474        // Accumulate in f64 to avoid compounding f32 error.
475        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
476        let var: f64 = slice
477            .iter()
478            .map(|&x| {
479                let d = x as f64 - mean;
480                d * d
481            })
482            .sum::<f64>()
483            / n as f64;
484        let se = (1.0 / n as f64).sqrt();
485        assert!(mean.abs() < 5.0 * se, "f32 mean {mean} too far from 0");
486        assert!(
487            (var - 1.0).abs() < 0.05,
488            "f32 variance {var} too far from 1"
489        );
490    }
491
492    #[test]
493    fn standard_normal_f32_nd_shape() {
494        let mut rng = default_rng_seeded(42);
495        let arr = rng.standard_normal_f32([3, 4]).unwrap();
496        assert_eq!(arr.shape(), &[3, 4]);
497    }
498
499    #[test]
500    fn normal_f32_mean() {
501        let mut rng = default_rng_seeded(42);
502        let n = 100_000usize;
503        let loc = 5.0f32;
504        let scale = 2.0f32;
505        let arr = rng.normal_f32(loc, scale, n).unwrap();
506        let slice = arr.as_slice().unwrap();
507        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
508        assert!(
509            (mean - loc as f64).abs() < 0.05,
510            "f32 normal mean {mean} too far from {loc}"
511        );
512    }
513
514    #[test]
515    fn normal_f32_bad_scale() {
516        let mut rng = default_rng_seeded(42);
517        assert!(rng.normal_f32(0.0, 0.0, 100).is_err());
518        assert!(rng.normal_f32(0.0, -1.0, 100).is_err());
519    }
520
521    #[test]
522    fn lognormal_f32_positive() {
523        let mut rng = default_rng_seeded(42);
524        let arr = rng.lognormal_f32(0.0, 1.0, 10_000).unwrap();
525        for &v in arr.as_slice().unwrap() {
526            assert!(v > 0.0, "lognormal_f32 produced non-positive value: {v}");
527        }
528    }
529
530    #[test]
531    fn lognormal_f32_bad_sigma() {
532        let mut rng = default_rng_seeded(42);
533        assert!(rng.lognormal_f32(0.0, 0.0, 100).is_err());
534        assert!(rng.lognormal_f32(0.0, -0.5, 100).is_err());
535    }
536
537    // ----- NaN/Inf parameter input tests (#263) -----
538
539    #[test]
540    fn normal_nan_loc_produces_nan_output() {
541        // NumPy: np.random.default_rng(42).normal(np.nan, 1.0, 5) → all NaN
542        let mut rng = default_rng_seeded(42);
543        let arr = rng.normal(f64::NAN, 1.0, 5).unwrap();
544        for &v in arr.as_slice().unwrap() {
545            assert!(v.is_nan(), "expected NaN, got {v}");
546        }
547    }
548
549    #[test]
550    fn normal_inf_scale_produces_inf_output() {
551        // Infinite scale → every sample is ±inf.
552        let mut rng = default_rng_seeded(42);
553        let arr = rng.normal(0.0, f64::INFINITY, 5).unwrap();
554        for &v in arr.as_slice().unwrap() {
555            assert!(v.is_infinite() || v.is_nan(), "expected Inf/NaN, got {v}");
556        }
557    }
558
559    #[test]
560    fn normal_nan_scale_rejected() {
561        // NaN scale should propagate — the output is meaningless but
562        // at minimum must not panic.
563        let mut rng = default_rng_seeded(42);
564        // scale <= 0 is rejected by parameter validation; NaN is
565        // neither > 0 nor <= 0, so the check may or may not catch it
566        // depending on the implementation. Just assert no panic.
567        let _ = rng.normal(0.0, f64::NAN, 5);
568    }
569}