Skip to main content

ferray_random/distributions/
discrete.rs

1// ferray-random: Discrete distributions
2//
3// binomial, negative_binomial, poisson, geometric, hypergeometric, logseries
4
5use ferray_core::{Array, FerrayError, IxDyn};
6
7use crate::bitgen::BitGenerator;
8use crate::distributions::gamma::standard_gamma_single;
9use crate::generator::{Generator, generate_vec_i64, shape_size, vec_to_array_i64};
10use crate::shape::IntoShape;
11
12/// Generate a single Poisson variate using Knuth's algorithm for small lambda,
13/// or the transformed rejection method (Hormann) for large lambda.
14fn poisson_single<B: BitGenerator>(bg: &mut B, lam: f64) -> i64 {
15    if lam < 30.0 {
16        // Knuth's algorithm
17        let l = (-lam).exp();
18        let mut k: i64 = 0;
19        let mut p = 1.0;
20        loop {
21            k += 1;
22            p *= bg.next_f64();
23            if p <= l {
24                return k - 1;
25            }
26        }
27    } else {
28        // Transformed rejection method (PA algorithm, Ahrens & Dieter)
29        let slam = lam.sqrt();
30        let loglam = lam.ln();
31        let b = 2.53f64.mul_add(slam, 0.931);
32        let a = 0.02483f64.mul_add(b, -0.059);
33        let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
34        let vr = 0.9277 - 3.6224 / (b - 2.0);
35
36        loop {
37            let u = bg.next_f64() - 0.5;
38            let v = bg.next_f64();
39            let us = 0.5 - u.abs();
40            let k = ((2.0 * a / us + b).mul_add(u, lam) + 0.43).floor() as i64;
41            if k < 0 {
42                continue;
43            }
44            if us >= 0.07 && v <= vr {
45                return k;
46            }
47            if k > 0
48                && us >= 0.013
49                && v <= (k as f64)
50                    .ln()
51                    .mul_add(
52                        -0.5,
53                        (k as f64).mul_add(loglam, -lam) - ln_factorial(k as u64),
54                    )
55                    .exp()
56                    * inv_alpha
57            {
58                return k;
59            }
60            if us < 0.013 && v > us {
61                continue;
62            }
63            // Full log test
64            let kf = k as f64;
65            let log_accept = -lam + kf * loglam - ln_factorial(k as u64);
66            if v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln() <= log_accept {
67                return k;
68            }
69        }
70    }
71}
72
73/// Approximate ln(n!) using Stirling's approximation with correction terms.
74fn ln_factorial(n: u64) -> f64 {
75    if n <= 20 {
76        // Use exact values for small n
77        let mut result = 0.0_f64;
78        for i in 2..=n {
79            result += (i as f64).ln();
80        }
81        result
82    } else {
83        // Stirling's approximation
84        let nf = n as f64;
85        0.5f64.mul_add((std::f64::consts::TAU).ln(), (nf + 0.5) * nf.ln()) - nf + 1.0 / (12.0 * nf)
86            - 1.0 / (360.0 * nf * nf * nf)
87    }
88}
89
90/// Generate a single binomial variate using the inverse transform for small n*p
91/// or the BTPE algorithm for larger n*p.
92fn binomial_single<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> i64 {
93    if n == 0 || p == 0.0 {
94        return 0;
95    }
96    if p == 1.0 {
97        return n as i64;
98    }
99
100    // Use the smaller of p and 1-p for efficiency
101    let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
102
103    let np = n as f64 * pp;
104    let q = 1.0 - pp;
105
106    let result = if np < 30.0 {
107        // Inverse transform (waiting time) method
108        let s = pp / q;
109        let a = (n as f64 + 1.0) * s;
110        let mut r = q.powf(n as f64);
111        let mut u = bg.next_f64();
112        let mut x: i64 = 0;
113        while u > r {
114            u -= r;
115            x += 1;
116            r *= a / (x as f64) - s;
117            if r < 0.0 {
118                break;
119            }
120        }
121        x.min(n as i64)
122    } else {
123        // BTPE algorithm (Hormann 1993) for large n*p.
124        // Based on the transformed rejection method with decomposition
125        // into triangular, parallelogram, and exponential regions.
126        let fm = np + pp;
127        let m = fm.floor() as i64;
128        let mf = m as f64;
129        let p1 = 2.195f64.mul_add((np * q).sqrt(), -(4.6 * q)).floor() + 0.5;
130        let xm = mf + 0.5;
131        let xl = xm - p1;
132        let xr = xm + p1;
133        let c = 0.134 + 20.5 / (15.3 + mf);
134        let a = (fm - xl) / (fm - xl * pp);
135        let lambda_l = a * 0.5f64.mul_add(a, 1.0);
136        let a2 = (xr - fm) / (xr * q);
137        let lambda_r = a2 * 0.5f64.mul_add(a2, 1.0);
138        let p2 = p1 * 2.0f64.mul_add(c, 1.0);
139        let p3 = p2 + c / lambda_l;
140        let p4 = p3 + c / lambda_r;
141
142        loop {
143            let u = bg.next_f64() * p4;
144            let v = bg.next_f64();
145            let y: i64;
146
147            if u <= p1 {
148                // Triangular region
149                y = (xm - p1 * v + u).floor() as i64;
150            } else if u <= p2 {
151                // Parallelogram region
152                let x = xl + (u - p1) / c;
153                // BTPE acceptance test: w = v + (x - xm)^2 / p1^2.
154                // clippy::suspicious_operation_groupings would rewrite the
155                // squared denominator to `x * p1`, which is mathematically
156                // wrong here.
157                #[allow(clippy::suspicious_operation_groupings)]
158                let w = v + (x - xm) * (x - xm) / (p1 * p1);
159                if w > 1.0 {
160                    continue;
161                }
162                y = x.floor() as i64;
163            } else if u <= p3 {
164                // Left exponential tail
165                y = (xl + v.ln() / lambda_l).floor() as i64;
166                if y < 0 {
167                    continue;
168                }
169            } else {
170                // Right exponential tail
171                y = (xr - v.ln() / lambda_r).floor() as i64;
172                if y > n as i64 {
173                    continue;
174                }
175            }
176
177            // Squeeze acceptance
178            let k = (y - m).abs();
179            if k <= 20 || k as f64 >= (0.5 * np).mul_add(q, -1.0) {
180                // Full acceptance/rejection via log-factorial comparison
181                let kf = k as f64;
182                let yf = y as f64;
183                let rho =
184                    (kf / (np * q)) * (kf.mul_add(kf / 3.0 + 0.625, 1.0 / 6.0) / (np * q) + 0.5);
185                let t = -kf * kf / (2.0 * np * q);
186                let log_a = t - rho;
187                if v.ln() <= log_a {
188                    break y;
189                }
190                // Full log-factorial test
191                let log_v = v.ln();
192                let log_accept = (yf - mf).mul_add(
193                    (pp / q).ln(),
194                    ln_factorial(m as u64) - ln_factorial(y as u64) - ln_factorial(n - y as u64)
195                        + ln_factorial(n - m as u64),
196                );
197                if log_v <= log_accept {
198                    break y;
199                }
200            } else {
201                break y;
202            }
203        }
204    };
205
206    if flipped { n as i64 - result } else { result }
207}
208
209impl<B: BitGenerator> Generator<B> {
210    /// Generate an array of binomial-distributed variates.
211    ///
212    /// Each value is the number of successes in `n` Bernoulli trials
213    /// with success probability `p`.
214    ///
215    /// # Arguments
216    /// * `n` - Number of trials.
217    /// * `p` - Probability of success per trial, must be in [0, 1].
218    /// * `size` - Number of values to generate.
219    ///
220    /// # Errors
221    /// Returns `FerrayError::InvalidValue` for invalid parameters.
222    pub fn binomial(
223        &mut self,
224        n: u64,
225        p: f64,
226        size: impl IntoShape,
227    ) -> Result<Array<i64, IxDyn>, FerrayError> {
228        if !(0.0..=1.0).contains(&p) {
229            return Err(FerrayError::invalid_value(format!(
230                "p must be in [0, 1], got {p}"
231            )));
232        }
233        let shape_vec = size.into_shape()?;
234        let total = shape_size(&shape_vec);
235        let data = generate_vec_i64(self, total, |bg| binomial_single(bg, n, p));
236        vec_to_array_i64(data, &shape_vec)
237    }
238
239    /// Generate an array of negative binomial distributed variates.
240    ///
241    /// The number of failures before `n` successes with success probability `p`.
242    /// Uses the gamma-Poisson mixture.
243    ///
244    /// # Arguments
245    /// * `n` - Number of successes (positive).
246    /// * `p` - Probability of success, must be in (0, 1].
247    /// * `size` - Number of values to generate.
248    ///
249    /// # Errors
250    /// Returns `FerrayError::InvalidValue` for invalid parameters.
251    pub fn negative_binomial(
252        &mut self,
253        n: f64,
254        p: f64,
255        size: impl IntoShape,
256    ) -> Result<Array<i64, IxDyn>, FerrayError> {
257        if n <= 0.0 {
258            return Err(FerrayError::invalid_value(format!(
259                "n must be positive, got {n}"
260            )));
261        }
262        if p <= 0.0 || p > 1.0 {
263            return Err(FerrayError::invalid_value(format!(
264                "p must be in (0, 1], got {p}"
265            )));
266        }
267        let shape_vec = size.into_shape()?;
268        let total = shape_size(&shape_vec);
269        let data = generate_vec_i64(self, total, |bg| {
270            // Gamma-Poisson mixture:
271            // Y ~ Gamma(n, (1-p)/p), then X ~ Poisson(Y)
272            let y = standard_gamma_single(bg, n) * (1.0 - p) / p;
273            poisson_single(bg, y)
274        });
275        vec_to_array_i64(data, &shape_vec)
276    }
277
278    /// Generate an array of Poisson-distributed variates.
279    ///
280    /// # Arguments
281    /// * `lam` - Expected number of events (lambda), must be non-negative.
282    /// * `size` - Number of values to generate.
283    ///
284    /// # Errors
285    /// Returns `FerrayError::InvalidValue` if `lam < 0` or `size` is zero.
286    pub fn poisson(
287        &mut self,
288        lam: f64,
289        size: impl IntoShape,
290    ) -> Result<Array<i64, IxDyn>, FerrayError> {
291        if lam < 0.0 {
292            return Err(FerrayError::invalid_value(format!(
293                "lam must be non-negative, got {lam}"
294            )));
295        }
296        let shape_vec = size.into_shape()?;
297        let total = shape_size(&shape_vec);
298        if lam == 0.0 {
299            let data = vec![0i64; total];
300            return vec_to_array_i64(data, &shape_vec);
301        }
302        let data = generate_vec_i64(self, total, |bg| poisson_single(bg, lam));
303        vec_to_array_i64(data, &shape_vec)
304    }
305
306    /// Generate an array of Poisson variates with a broadcast `lam`
307    /// parameter (#449).
308    ///
309    /// `lam` is an array of expected counts; each output element is
310    /// sampled from `Poisson(lam[i])`. Output shape equals
311    /// `lam.shape()`.
312    ///
313    /// # Errors
314    /// - `FerrayError::InvalidValue` if any `lam` element is negative.
315    pub fn poisson_array(
316        &mut self,
317        lam: &Array<f64, IxDyn>,
318    ) -> Result<Array<i64, IxDyn>, FerrayError> {
319        let shape = lam.shape().to_vec();
320        let total: usize = shape.iter().product();
321        let mut out: Vec<i64> = Vec::with_capacity(total);
322        for &l in lam.iter() {
323            if l < 0.0 {
324                return Err(FerrayError::invalid_value(format!(
325                    "lam must be non-negative, got {l}"
326                )));
327            }
328            if l == 0.0 {
329                out.push(0);
330            } else {
331                out.push(poisson_single(&mut self.bg, l));
332            }
333        }
334        Array::<i64, IxDyn>::from_vec(IxDyn::new(&shape), out)
335    }
336
337    /// Generate an array of geometric-distributed variates.
338    ///
339    /// The number of trials until the first success (1-based).
340    ///
341    /// # Arguments
342    /// * `p` - Probability of success, must be in (0, 1].
343    /// * `size` - Number of values to generate.
344    ///
345    /// # Errors
346    /// Returns `FerrayError::InvalidValue` if `p` not in (0, 1] or `size` is zero.
347    pub fn geometric(
348        &mut self,
349        p: f64,
350        size: impl IntoShape,
351    ) -> Result<Array<i64, IxDyn>, FerrayError> {
352        if p <= 0.0 || p > 1.0 {
353            return Err(FerrayError::invalid_value(format!(
354                "p must be in (0, 1], got {p}"
355            )));
356        }
357        let shape_vec = size.into_shape()?;
358        let total = shape_size(&shape_vec);
359        if (p - 1.0).abs() < f64::EPSILON {
360            let data = vec![1i64; total];
361            return vec_to_array_i64(data, &shape_vec);
362        }
363        let log_q = (1.0 - p).ln();
364        let data = generate_vec_i64(self, total, |bg| {
365            loop {
366                let u = bg.next_f64();
367                if u > f64::EPSILON {
368                    return (u.ln() / log_q).floor() as i64 + 1;
369                }
370            }
371        });
372        vec_to_array_i64(data, &shape_vec)
373    }
374
375    /// Generate an array of hypergeometric-distributed variates.
376    ///
377    /// Models drawing `nsample` items without replacement from a population
378    /// containing `ngood` success states and `nbad` failure states.
379    ///
380    /// # Arguments
381    /// * `ngood` - Number of success states in the population.
382    /// * `nbad` - Number of failure states in the population.
383    /// * `nsample` - Number of items drawn.
384    /// * `size` - Number of values to generate.
385    ///
386    /// # Errors
387    /// Returns `FerrayError::InvalidValue` if `nsample > ngood + nbad` or `size` is zero.
388    pub fn hypergeometric(
389        &mut self,
390        ngood: u64,
391        nbad: u64,
392        nsample: u64,
393        size: impl IntoShape,
394    ) -> Result<Array<i64, IxDyn>, FerrayError> {
395        let total = ngood + nbad;
396        if nsample > total {
397            return Err(FerrayError::invalid_value(format!(
398                "nsample ({nsample}) > ngood + nbad ({total})"
399            )));
400        }
401        let shape_vec = size.into_shape()?;
402        let total_n = shape_size(&shape_vec);
403        let data = generate_vec_i64(self, total_n, |bg| {
404            hypergeometric_single(bg, ngood, nbad, nsample)
405        });
406        vec_to_array_i64(data, &shape_vec)
407    }
408
409    /// Generate an array of logarithmic series distributed variates.
410    ///
411    /// # Arguments
412    /// * `p` - Shape parameter, must be in (0, 1).
413    /// * `size` - Number of values to generate.
414    ///
415    /// # Errors
416    /// Returns `FerrayError::InvalidValue` if `p` not in (0, 1) or `size` is zero.
417    pub fn logseries(
418        &mut self,
419        p: f64,
420        size: impl IntoShape,
421    ) -> Result<Array<i64, IxDyn>, FerrayError> {
422        if p <= 0.0 || p >= 1.0 {
423            return Err(FerrayError::invalid_value(format!(
424                "p must be in (0, 1), got {p}"
425            )));
426        }
427        let r = (-(-p).ln_1p()).recip();
428        let shape_vec = size.into_shape()?;
429        let total = shape_size(&shape_vec);
430        let data = generate_vec_i64(self, total, |bg| {
431            // Kemp's "second" algorithm for the logarithmic distribution.
432            // See Devroye, "Non-Uniform Random Variate Generation", p. 548.
433            loop {
434                let u = bg.next_f64();
435                if u <= f64::EPSILON || u >= 1.0 - f64::EPSILON {
436                    continue;
437                }
438                let v = bg.next_f64();
439                let q = 1.0 - (-r.recip() * u.ln()).exp();
440                if q <= 0.0 {
441                    return 1;
442                }
443                if v < q * q {
444                    let k = (1.0 + v.log(q)).floor() as i64;
445                    return k.max(1);
446                }
447                if v < q {
448                    return 2;
449                }
450                return 1;
451            }
452        });
453        vec_to_array_i64(data, &shape_vec)
454    }
455
456    /// Generate an array of Zipf-distributed variates.
457    ///
458    /// Samples from the Zipf (zeta) distribution with shape parameter `a > 1`,
459    /// using Devroye's rejection algorithm (Non-Uniform Random Variate
460    /// Generation, p. 551). The PMF is `P(k) = k^(-a) / zeta(a)` for
461    /// `k = 1, 2, ...`.
462    ///
463    /// Equivalent to `numpy.random.Generator.zipf`.
464    ///
465    /// # Errors
466    /// - `FerrayError::InvalidValue` if `a <= 1` or `size` is invalid.
467    pub fn zipf(&mut self, a: f64, size: impl IntoShape) -> Result<Array<i64, IxDyn>, FerrayError> {
468        if a <= 1.0 {
469            return Err(FerrayError::invalid_value(format!(
470                "a must be > 1 for Zipf, got {a}"
471            )));
472        }
473        let am1 = a - 1.0;
474        let b = 2.0_f64.powf(am1);
475        let shape_vec = size.into_shape()?;
476        let total = shape_size(&shape_vec);
477        let data = generate_vec_i64(self, total, |bg| {
478            loop {
479                let u = 1.0 - bg.next_f64();
480                let v = bg.next_f64();
481                let x = u.powf(-1.0 / am1).floor();
482                // Guard against overflow / non-positive results.
483                if !x.is_finite() || x < 1.0 {
484                    continue;
485                }
486                let t = (1.0 + 1.0 / x).powf(am1);
487                // Devroye's acceptance: v * x * (t - 1) / (b - 1) <= t / b
488                if v * x * (t - 1.0) / (b - 1.0) <= t / b {
489                    if x > i64::MAX as f64 {
490                        continue;
491                    }
492                    return x as i64;
493                }
494            }
495        });
496        vec_to_array_i64(data, &shape_vec)
497    }
498}
499
500/// Generate a single hypergeometric variate using the direct algorithm.
501fn hypergeometric_single<B: BitGenerator>(bg: &mut B, ngood: u64, nbad: u64, nsample: u64) -> i64 {
502    // Direct simulation: draw nsample items from population
503    let mut good_remaining = ngood;
504    let mut total_remaining = ngood + nbad;
505    let mut successes: i64 = 0;
506
507    for _ in 0..nsample {
508        if total_remaining == 0 {
509            break;
510        }
511        let u = bg.next_f64();
512        if u < (good_remaining as f64) / (total_remaining as f64) {
513            successes += 1;
514            good_remaining -= 1;
515        }
516        total_remaining -= 1;
517    }
518    successes
519}
520
521#[cfg(test)]
522mod tests {
523    use crate::default_rng_seeded;
524
525    // ---- poisson_array (#449) ------------------------------------------
526
527    #[test]
528    fn poisson_array_shape_matches_lam() {
529        use crate::default_rng_seeded;
530        use ferray_core::{Array, IxDyn};
531        let mut rng = default_rng_seeded(42);
532        let lam =
533            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 2]), vec![1.0, 5.0, 50.0, 0.0]).unwrap();
534        let out = rng.poisson_array(&lam).unwrap();
535        assert_eq!(out.shape(), &[2, 2]);
536        // Last element is lam=0 → must be exactly 0.
537        let s = out.as_slice().unwrap();
538        assert_eq!(s[3], 0);
539        for &v in s {
540            assert!(v >= 0);
541        }
542    }
543
544    #[test]
545    fn poisson_array_per_element_mean() {
546        use crate::default_rng_seeded;
547        use ferray_core::{Array, IxDyn};
548        let mut rng = default_rng_seeded(11);
549        let lams = [3.0_f64, 50.0];
550        let lam = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), lams.to_vec()).unwrap();
551        let n_trials = 5_000;
552        let mut sums = [0.0_f64; 2];
553        for _ in 0..n_trials {
554            let out = rng.poisson_array(&lam).unwrap();
555            let s = out.as_slice().unwrap();
556            for j in 0..2 {
557                sums[j] += s[j] as f64;
558            }
559        }
560        for j in 0..2 {
561            let mean = sums[j] / n_trials as f64;
562            // Poisson(λ) variance = λ; SE for n_trials draws = sqrt(λ/n).
563            let se = (lams[j] / n_trials as f64).sqrt();
564            assert!(
565                (mean - lams[j]).abs() < 4.0 * se,
566                "elt {j}: mean {mean} too far from {}",
567                lams[j]
568            );
569        }
570    }
571
572    #[test]
573    fn poisson_array_negative_lam_errors() {
574        use crate::default_rng_seeded;
575        use ferray_core::{Array, IxDyn};
576        let mut rng = default_rng_seeded(0);
577        let lam = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, -2.0]).unwrap();
578        assert!(rng.poisson_array(&lam).is_err());
579    }
580
581    #[test]
582    fn poisson_mean() {
583        let mut rng = default_rng_seeded(42);
584        let n = 100_000;
585        let lam = 5.0;
586        let arr = rng.poisson(lam, n).unwrap();
587        let slice = arr.as_slice().unwrap();
588        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
589        // Poisson(lam): mean = lam, var = lam
590        let se = (lam / n as f64).sqrt();
591        assert!(
592            (mean - lam).abs() < 3.0 * se,
593            "poisson mean {mean} too far from {lam}"
594        );
595    }
596
597    #[test]
598    fn poisson_large_lambda() {
599        let mut rng = default_rng_seeded(42);
600        let n = 50_000;
601        let lam = 100.0;
602        let arr = rng.poisson(lam, n).unwrap();
603        let slice = arr.as_slice().unwrap();
604        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
605        let se = (lam / n as f64).sqrt();
606        assert!(
607            (mean - lam).abs() < 3.0 * se,
608            "poisson mean {mean} too far from {lam}"
609        );
610    }
611
612    #[test]
613    fn poisson_zero() {
614        let mut rng = default_rng_seeded(42);
615        let arr = rng.poisson(0.0, 100).unwrap();
616        for &v in arr.as_slice().unwrap() {
617            assert_eq!(v, 0);
618        }
619    }
620
621    #[test]
622    fn binomial_mean() {
623        let mut rng = default_rng_seeded(42);
624        let size = 100_000;
625        let n = 20u64;
626        let p = 0.3;
627        let arr = rng.binomial(n, p, size).unwrap();
628        let slice = arr.as_slice().unwrap();
629        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / size as f64;
630        // Binomial(n, p): mean = n*p
631        let expected_mean = n as f64 * p;
632        let expected_var = n as f64 * p * (1.0 - p);
633        let se = (expected_var / size as f64).sqrt();
634        assert!(
635            (mean - expected_mean).abs() < 3.0 * se,
636            "binomial mean {mean} too far from {expected_mean}"
637        );
638        // Values must be in [0, n]
639        for &v in slice {
640            assert!(
641                v >= 0 && v <= n as i64,
642                "binomial value {v} out of [0, {n}]"
643            );
644        }
645    }
646
647    #[test]
648    fn binomial_edge_cases() {
649        let mut rng = default_rng_seeded(42);
650        // p=0: always 0
651        let arr = rng.binomial(10, 0.0, 100).unwrap();
652        for &v in arr.as_slice().unwrap() {
653            assert_eq!(v, 0);
654        }
655        // p=1: always n
656        let arr = rng.binomial(10, 1.0, 100).unwrap();
657        for &v in arr.as_slice().unwrap() {
658            assert_eq!(v, 10);
659        }
660    }
661
662    #[test]
663    fn negative_binomial_positive() {
664        let mut rng = default_rng_seeded(42);
665        let arr = rng.negative_binomial(5.0, 0.5, 10_000).unwrap();
666        for &v in arr.as_slice().unwrap() {
667            assert!(v >= 0, "negative_binomial value {v} must be >= 0");
668        }
669    }
670
671    #[test]
672    fn geometric_mean() {
673        let mut rng = default_rng_seeded(42);
674        let n = 100_000;
675        let p = 0.3;
676        let arr = rng.geometric(p, n).unwrap();
677        let slice = arr.as_slice().unwrap();
678        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
679        // Geometric(p) (1-based): mean = 1/p
680        let expected_mean = 1.0 / p;
681        let expected_var = (1.0 - p) / (p * p);
682        let se = (expected_var / n as f64).sqrt();
683        assert!(
684            (mean - expected_mean).abs() < 3.0 * se,
685            "geometric mean {mean} too far from {expected_mean}"
686        );
687        for &v in slice {
688            assert!(v >= 1, "geometric value {v} must be >= 1");
689        }
690    }
691
692    #[test]
693    fn hypergeometric_range() {
694        let mut rng = default_rng_seeded(42);
695        let ngood = 20u64;
696        let nbad = 30u64;
697        let nsample = 10u64;
698        let arr = rng.hypergeometric(ngood, nbad, nsample, 10_000).unwrap();
699        let slice = arr.as_slice().unwrap();
700        for &v in slice {
701            assert!(
702                v >= 0 && v <= nsample.min(ngood) as i64,
703                "hypergeometric value {v} out of range"
704            );
705        }
706    }
707
708    #[test]
709    fn hypergeometric_mean() {
710        let mut rng = default_rng_seeded(42);
711        let n = 100_000;
712        let ngood = 20u64;
713        let nbad = 30u64;
714        let nsample = 10u64;
715        let arr = rng.hypergeometric(ngood, nbad, nsample, n).unwrap();
716        let slice = arr.as_slice().unwrap();
717        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
718        // Hypergeometric: mean = nsample * ngood / (ngood + nbad)
719        let total = (ngood + nbad) as f64;
720        let expected_mean = nsample as f64 * ngood as f64 / total;
721        let expected_var = nsample as f64
722            * (ngood as f64 / total)
723            * (nbad as f64 / total)
724            * (total - nsample as f64)
725            / (total - 1.0);
726        let se = (expected_var / n as f64).sqrt();
727        assert!(
728            (mean - expected_mean).abs() < 3.0 * se,
729            "hypergeometric mean {mean} too far from {expected_mean}"
730        );
731    }
732
733    #[test]
734    fn logseries_positive() {
735        let mut rng = default_rng_seeded(42);
736        let arr = rng.logseries(0.5, 10_000).unwrap();
737        for &v in arr.as_slice().unwrap() {
738            assert!(v >= 1, "logseries value {v} must be >= 1");
739        }
740    }
741
742    #[test]
743    fn bad_params() {
744        let mut rng = default_rng_seeded(42);
745        assert!(rng.binomial(10, -0.1, 10).is_err());
746        assert!(rng.binomial(10, 1.5, 10).is_err());
747        assert!(rng.poisson(-1.0, 10).is_err());
748        assert!(rng.geometric(0.0, 10).is_err());
749        assert!(rng.geometric(1.5, 10).is_err());
750        assert!(rng.hypergeometric(5, 5, 20, 10).is_err());
751        assert!(rng.logseries(0.0, 10).is_err());
752        assert!(rng.logseries(1.0, 10).is_err());
753        assert!(rng.negative_binomial(0.0, 0.5, 10).is_err());
754        assert!(rng.negative_binomial(5.0, 0.0, 10).is_err());
755    }
756
757    #[test]
758    fn zipf_positive_integers() {
759        use crate::default_rng_seeded;
760        let mut rng = default_rng_seeded(42);
761        let arr = rng.zipf(2.5, 1000).unwrap();
762        for &v in arr.as_slice().unwrap() {
763            assert!(v >= 1, "zipf output must be >= 1, got {v}");
764        }
765    }
766
767    #[test]
768    fn zipf_seed_reproducible() {
769        use crate::default_rng_seeded;
770        let mut a = default_rng_seeded(7);
771        let mut b = default_rng_seeded(7);
772        let xs = a.zipf(3.0, 200).unwrap();
773        let ys = b.zipf(3.0, 200).unwrap();
774        assert_eq!(xs.as_slice().unwrap(), ys.as_slice().unwrap());
775    }
776
777    #[test]
778    fn zipf_bad_a_errs() {
779        use crate::default_rng_seeded;
780        let mut rng = default_rng_seeded(0);
781        assert!(rng.zipf(1.0, 10).is_err());
782        assert!(rng.zipf(0.5, 10).is_err());
783        assert!(rng.zipf(-2.0, 10).is_err());
784    }
785}