Skip to main content

ferray_random/distributions/
discrete.rs

1// ferray-random: Discrete distributions
2//
3// binomial, negative_binomial, poisson, geometric, hypergeometric, logseries
4
5use ferray_core::{Array, FerrayError, IxDyn};
6
7use crate::bitgen::BitGenerator;
8use crate::distributions::gamma::standard_gamma_single;
9use crate::generator::{Generator, generate_vec_i64, shape_size, vec_to_array_i64};
10use crate::shape::IntoShape;
11
12/// Generate a single Poisson variate using Knuth's algorithm for small lambda,
13/// or the transformed rejection method (Hormann) for large lambda.
14fn poisson_single<B: BitGenerator>(bg: &mut B, lam: f64) -> i64 {
15    if lam < 30.0 {
16        // Knuth's algorithm
17        let l = (-lam).exp();
18        let mut k: i64 = 0;
19        let mut p = 1.0;
20        loop {
21            k += 1;
22            p *= bg.next_f64();
23            if p <= l {
24                return k - 1;
25            }
26        }
27    } else {
28        // Transformed rejection method (PA algorithm, Ahrens & Dieter)
29        let slam = lam.sqrt();
30        let loglam = lam.ln();
31        let b = 0.931 + 2.53 * slam;
32        let a = -0.059 + 0.02483 * b;
33        let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
34        let vr = 0.9277 - 3.6224 / (b - 2.0);
35
36        loop {
37            let u = bg.next_f64() - 0.5;
38            let v = bg.next_f64();
39            let us = 0.5 - u.abs();
40            let k = ((2.0 * a / us + b) * u + lam + 0.43).floor() as i64;
41            if k < 0 {
42                continue;
43            }
44            if us >= 0.07 && v <= vr {
45                return k;
46            }
47            if k > 0
48                && us >= 0.013
49                && v <= (k as f64)
50                    .ln()
51                    .mul_add(-0.5, (k as f64) * loglam - lam - ln_factorial(k as u64))
52                    .exp()
53                    * inv_alpha
54            {
55                return k;
56            }
57            if us < 0.013 && v > us {
58                continue;
59            }
60            // Full log test
61            let kf = k as f64;
62            let log_accept = -lam + kf * loglam - ln_factorial(k as u64);
63            if v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln() <= log_accept {
64                return k;
65            }
66        }
67    }
68}
69
70/// Approximate ln(n!) using Stirling's approximation with correction terms.
71fn ln_factorial(n: u64) -> f64 {
72    if n <= 20 {
73        // Use exact values for small n
74        let mut result = 0.0_f64;
75        for i in 2..=n {
76            result += (i as f64).ln();
77        }
78        result
79    } else {
80        // Stirling's approximation
81        let nf = n as f64;
82        0.5 * (std::f64::consts::TAU).ln() + (nf + 0.5) * nf.ln() - nf + 1.0 / (12.0 * nf)
83            - 1.0 / (360.0 * nf * nf * nf)
84    }
85}
86
87/// Generate a single binomial variate using the inverse transform for small n*p
88/// or the BTPE algorithm for larger n*p.
89fn binomial_single<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> i64 {
90    if n == 0 || p == 0.0 {
91        return 0;
92    }
93    if p == 1.0 {
94        return n as i64;
95    }
96
97    // Use the smaller of p and 1-p for efficiency
98    let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
99
100    let np = n as f64 * pp;
101
102    let result = if np < 30.0 {
103        // Inverse transform (waiting time) method
104        let q = 1.0 - pp;
105        let s = pp / q;
106        let a = (n as f64 + 1.0) * s;
107        let mut r = q.powf(n as f64);
108        let mut u = bg.next_f64();
109        let mut x: i64 = 0;
110        while u > r {
111            u -= r;
112            x += 1;
113            r *= a / (x as f64) - s;
114            if r < 0.0 {
115                break;
116            }
117        }
118        x.min(n as i64)
119    } else {
120        // BTPE algorithm (Hormann 1993) for large n*p.
121        // Based on the transformed rejection method with decomposition
122        // into triangular, parallelogram, and exponential regions.
123        let q = 1.0 - pp;
124        let fm = np + pp;
125        let m = fm.floor() as i64;
126        let mf = m as f64;
127        let p1 = (2.195 * (np * q).sqrt() - 4.6 * q).floor() + 0.5;
128        let xm = mf + 0.5;
129        let xl = xm - p1;
130        let xr = xm + p1;
131        let c = 0.134 + 20.5 / (15.3 + mf);
132        let a = (fm - xl) / (fm - xl * pp);
133        let lambda_l = a * (1.0 + 0.5 * a);
134        let a2 = (xr - fm) / (xr * q);
135        let lambda_r = a2 * (1.0 + 0.5 * a2);
136        let p2 = p1 * (1.0 + 2.0 * c);
137        let p3 = p2 + c / lambda_l;
138        let p4 = p3 + c / lambda_r;
139
140        loop {
141            let u = bg.next_f64() * p4;
142            let v = bg.next_f64();
143            let y: i64;
144
145            if u <= p1 {
146                // Triangular region
147                y = (xm - p1 * v + u).floor() as i64;
148            } else if u <= p2 {
149                // Parallelogram region
150                let x = xl + (u - p1) / c;
151                let w = v + (x - xm) * (x - xm) / (p1 * p1);
152                if w > 1.0 {
153                    continue;
154                }
155                y = x.floor() as i64;
156            } else if u <= p3 {
157                // Left exponential tail
158                y = (xl + v.ln() / lambda_l).floor() as i64;
159                if y < 0 {
160                    continue;
161                }
162            } else {
163                // Right exponential tail
164                y = (xr - v.ln() / lambda_r).floor() as i64;
165                if y > n as i64 {
166                    continue;
167                }
168            }
169
170            // Squeeze acceptance
171            let k = (y - m).abs();
172            if k <= 20 || k as f64 >= 0.5 * np * q - 1.0 {
173                // Full acceptance/rejection via log-factorial comparison
174                let kf = k as f64;
175                let yf = y as f64;
176                let rho =
177                    (kf / (np * q)) * ((kf * (kf / 3.0 + 0.625) + 1.0 / 6.0) / (np * q) + 0.5);
178                let t = -kf * kf / (2.0 * np * q);
179                let log_a = t - rho;
180                if v.ln() <= log_a {
181                    break y;
182                }
183                // Full log-factorial test
184                let log_v = v.ln();
185                let log_accept =
186                    ln_factorial(m as u64) - ln_factorial(y as u64) - ln_factorial(n - y as u64)
187                        + ln_factorial(n - m as u64)
188                        + (yf - mf) * (pp / q).ln();
189                if log_v <= log_accept {
190                    break y;
191                }
192            } else {
193                break y;
194            }
195        }
196    };
197
198    if flipped { n as i64 - result } else { result }
199}
200
201impl<B: BitGenerator> Generator<B> {
202    /// Generate an array of binomial-distributed variates.
203    ///
204    /// Each value is the number of successes in `n` Bernoulli trials
205    /// with success probability `p`.
206    ///
207    /// # Arguments
208    /// * `n` - Number of trials.
209    /// * `p` - Probability of success per trial, must be in [0, 1].
210    /// * `size` - Number of values to generate.
211    ///
212    /// # Errors
213    /// Returns `FerrayError::InvalidValue` for invalid parameters.
214    pub fn binomial(
215        &mut self,
216        n: u64,
217        p: f64,
218        size: impl IntoShape,
219    ) -> Result<Array<i64, IxDyn>, FerrayError> {
220        if !(0.0..=1.0).contains(&p) {
221            return Err(FerrayError::invalid_value(format!(
222                "p must be in [0, 1], got {p}"
223            )));
224        }
225        let shape_vec = size.into_shape()?;
226        let total = shape_size(&shape_vec);
227        let data = generate_vec_i64(self, total, |bg| binomial_single(bg, n, p));
228        vec_to_array_i64(data, &shape_vec)
229    }
230
231    /// Generate an array of negative binomial distributed variates.
232    ///
233    /// The number of failures before `n` successes with success probability `p`.
234    /// Uses the gamma-Poisson mixture.
235    ///
236    /// # Arguments
237    /// * `n` - Number of successes (positive).
238    /// * `p` - Probability of success, must be in (0, 1].
239    /// * `size` - Number of values to generate.
240    ///
241    /// # Errors
242    /// Returns `FerrayError::InvalidValue` for invalid parameters.
243    pub fn negative_binomial(
244        &mut self,
245        n: f64,
246        p: f64,
247        size: impl IntoShape,
248    ) -> Result<Array<i64, IxDyn>, FerrayError> {
249        if n <= 0.0 {
250            return Err(FerrayError::invalid_value(format!(
251                "n must be positive, got {n}"
252            )));
253        }
254        if p <= 0.0 || p > 1.0 {
255            return Err(FerrayError::invalid_value(format!(
256                "p must be in (0, 1], got {p}"
257            )));
258        }
259        let shape_vec = size.into_shape()?;
260        let total = shape_size(&shape_vec);
261        let data = generate_vec_i64(self, total, |bg| {
262            // Gamma-Poisson mixture:
263            // Y ~ Gamma(n, (1-p)/p), then X ~ Poisson(Y)
264            let y = standard_gamma_single(bg, n) * (1.0 - p) / p;
265            poisson_single(bg, y)
266        });
267        vec_to_array_i64(data, &shape_vec)
268    }
269
270    /// Generate an array of Poisson-distributed variates.
271    ///
272    /// # Arguments
273    /// * `lam` - Expected number of events (lambda), must be non-negative.
274    /// * `size` - Number of values to generate.
275    ///
276    /// # Errors
277    /// Returns `FerrayError::InvalidValue` if `lam < 0` or `size` is zero.
278    pub fn poisson(
279        &mut self,
280        lam: f64,
281        size: impl IntoShape,
282    ) -> Result<Array<i64, IxDyn>, FerrayError> {
283        if lam < 0.0 {
284            return Err(FerrayError::invalid_value(format!(
285                "lam must be non-negative, got {lam}"
286            )));
287        }
288        let shape_vec = size.into_shape()?;
289        let total = shape_size(&shape_vec);
290        if lam == 0.0 {
291            let data = vec![0i64; total];
292            return vec_to_array_i64(data, &shape_vec);
293        }
294        let data = generate_vec_i64(self, total, |bg| poisson_single(bg, lam));
295        vec_to_array_i64(data, &shape_vec)
296    }
297
298    /// Generate an array of geometric-distributed variates.
299    ///
300    /// The number of trials until the first success (1-based).
301    ///
302    /// # Arguments
303    /// * `p` - Probability of success, must be in (0, 1].
304    /// * `size` - Number of values to generate.
305    ///
306    /// # Errors
307    /// Returns `FerrayError::InvalidValue` if `p` not in (0, 1] or `size` is zero.
308    pub fn geometric(
309        &mut self,
310        p: f64,
311        size: impl IntoShape,
312    ) -> Result<Array<i64, IxDyn>, FerrayError> {
313        if p <= 0.0 || p > 1.0 {
314            return Err(FerrayError::invalid_value(format!(
315                "p must be in (0, 1], got {p}"
316            )));
317        }
318        let shape_vec = size.into_shape()?;
319        let total = shape_size(&shape_vec);
320        if (p - 1.0).abs() < f64::EPSILON {
321            let data = vec![1i64; total];
322            return vec_to_array_i64(data, &shape_vec);
323        }
324        let log_q = (1.0 - p).ln();
325        let data = generate_vec_i64(self, total, |bg| {
326            loop {
327                let u = bg.next_f64();
328                if u > f64::EPSILON {
329                    return (u.ln() / log_q).floor() as i64 + 1;
330                }
331            }
332        });
333        vec_to_array_i64(data, &shape_vec)
334    }
335
336    /// Generate an array of hypergeometric-distributed variates.
337    ///
338    /// Models drawing `nsample` items without replacement from a population
339    /// containing `ngood` success states and `nbad` failure states.
340    ///
341    /// # Arguments
342    /// * `ngood` - Number of success states in the population.
343    /// * `nbad` - Number of failure states in the population.
344    /// * `nsample` - Number of items drawn.
345    /// * `size` - Number of values to generate.
346    ///
347    /// # Errors
348    /// Returns `FerrayError::InvalidValue` if `nsample > ngood + nbad` or `size` is zero.
349    pub fn hypergeometric(
350        &mut self,
351        ngood: u64,
352        nbad: u64,
353        nsample: u64,
354        size: impl IntoShape,
355    ) -> Result<Array<i64, IxDyn>, FerrayError> {
356        let total = ngood + nbad;
357        if nsample > total {
358            return Err(FerrayError::invalid_value(format!(
359                "nsample ({nsample}) > ngood + nbad ({total})"
360            )));
361        }
362        let shape_vec = size.into_shape()?;
363        let total_n = shape_size(&shape_vec);
364        let data = generate_vec_i64(self, total_n, |bg| {
365            hypergeometric_single(bg, ngood, nbad, nsample)
366        });
367        vec_to_array_i64(data, &shape_vec)
368    }
369
370    /// Generate an array of logarithmic series distributed variates.
371    ///
372    /// # Arguments
373    /// * `p` - Shape parameter, must be in (0, 1).
374    /// * `size` - Number of values to generate.
375    ///
376    /// # Errors
377    /// Returns `FerrayError::InvalidValue` if `p` not in (0, 1) or `size` is zero.
378    pub fn logseries(
379        &mut self,
380        p: f64,
381        size: impl IntoShape,
382    ) -> Result<Array<i64, IxDyn>, FerrayError> {
383        if p <= 0.0 || p >= 1.0 {
384            return Err(FerrayError::invalid_value(format!(
385                "p must be in (0, 1), got {p}"
386            )));
387        }
388        let r = (-(-p).ln_1p()).recip();
389        let shape_vec = size.into_shape()?;
390        let total = shape_size(&shape_vec);
391        let data = generate_vec_i64(self, total, |bg| {
392            // Kemp's "second" algorithm for the logarithmic distribution.
393            // See Devroye, "Non-Uniform Random Variate Generation", p. 548.
394            loop {
395                let u = bg.next_f64();
396                if u <= f64::EPSILON || u >= 1.0 - f64::EPSILON {
397                    continue;
398                }
399                let v = bg.next_f64();
400                let q = 1.0 - (-r.recip() * u.ln()).exp();
401                if q <= 0.0 {
402                    return 1;
403                }
404                if v < q * q {
405                    let k = (1.0 + v.ln() / q.ln()).floor() as i64;
406                    return k.max(1);
407                }
408                if v < q {
409                    return 2;
410                }
411                return 1;
412            }
413        });
414        vec_to_array_i64(data, &shape_vec)
415    }
416}
417
418/// Generate a single hypergeometric variate using the direct algorithm.
419fn hypergeometric_single<B: BitGenerator>(bg: &mut B, ngood: u64, nbad: u64, nsample: u64) -> i64 {
420    // Direct simulation: draw nsample items from population
421    let mut good_remaining = ngood;
422    let mut total_remaining = ngood + nbad;
423    let mut successes: i64 = 0;
424
425    for _ in 0..nsample {
426        if total_remaining == 0 {
427            break;
428        }
429        let u = bg.next_f64();
430        if u < (good_remaining as f64) / (total_remaining as f64) {
431            successes += 1;
432            good_remaining -= 1;
433        }
434        total_remaining -= 1;
435    }
436    successes
437}
438
439#[cfg(test)]
440mod tests {
441    use crate::default_rng_seeded;
442
443    #[test]
444    fn poisson_mean() {
445        let mut rng = default_rng_seeded(42);
446        let n = 100_000;
447        let lam = 5.0;
448        let arr = rng.poisson(lam, n).unwrap();
449        let slice = arr.as_slice().unwrap();
450        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
451        // Poisson(lam): mean = lam, var = lam
452        let se = (lam / n as f64).sqrt();
453        assert!(
454            (mean - lam).abs() < 3.0 * se,
455            "poisson mean {mean} too far from {lam}"
456        );
457    }
458
459    #[test]
460    fn poisson_large_lambda() {
461        let mut rng = default_rng_seeded(42);
462        let n = 50_000;
463        let lam = 100.0;
464        let arr = rng.poisson(lam, n).unwrap();
465        let slice = arr.as_slice().unwrap();
466        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
467        let se = (lam / n as f64).sqrt();
468        assert!(
469            (mean - lam).abs() < 3.0 * se,
470            "poisson mean {mean} too far from {lam}"
471        );
472    }
473
474    #[test]
475    fn poisson_zero() {
476        let mut rng = default_rng_seeded(42);
477        let arr = rng.poisson(0.0, 100).unwrap();
478        for &v in arr.as_slice().unwrap() {
479            assert_eq!(v, 0);
480        }
481    }
482
483    #[test]
484    fn binomial_mean() {
485        let mut rng = default_rng_seeded(42);
486        let size = 100_000;
487        let n = 20u64;
488        let p = 0.3;
489        let arr = rng.binomial(n, p, size).unwrap();
490        let slice = arr.as_slice().unwrap();
491        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / size as f64;
492        // Binomial(n, p): mean = n*p
493        let expected_mean = n as f64 * p;
494        let expected_var = n as f64 * p * (1.0 - p);
495        let se = (expected_var / size as f64).sqrt();
496        assert!(
497            (mean - expected_mean).abs() < 3.0 * se,
498            "binomial mean {mean} too far from {expected_mean}"
499        );
500        // Values must be in [0, n]
501        for &v in slice {
502            assert!(
503                v >= 0 && v <= n as i64,
504                "binomial value {v} out of [0, {n}]"
505            );
506        }
507    }
508
509    #[test]
510    fn binomial_edge_cases() {
511        let mut rng = default_rng_seeded(42);
512        // p=0: always 0
513        let arr = rng.binomial(10, 0.0, 100).unwrap();
514        for &v in arr.as_slice().unwrap() {
515            assert_eq!(v, 0);
516        }
517        // p=1: always n
518        let arr = rng.binomial(10, 1.0, 100).unwrap();
519        for &v in arr.as_slice().unwrap() {
520            assert_eq!(v, 10);
521        }
522    }
523
524    #[test]
525    fn negative_binomial_positive() {
526        let mut rng = default_rng_seeded(42);
527        let arr = rng.negative_binomial(5.0, 0.5, 10_000).unwrap();
528        for &v in arr.as_slice().unwrap() {
529            assert!(v >= 0, "negative_binomial value {v} must be >= 0");
530        }
531    }
532
533    #[test]
534    fn geometric_mean() {
535        let mut rng = default_rng_seeded(42);
536        let n = 100_000;
537        let p = 0.3;
538        let arr = rng.geometric(p, n).unwrap();
539        let slice = arr.as_slice().unwrap();
540        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
541        // Geometric(p) (1-based): mean = 1/p
542        let expected_mean = 1.0 / p;
543        let expected_var = (1.0 - p) / (p * p);
544        let se = (expected_var / n as f64).sqrt();
545        assert!(
546            (mean - expected_mean).abs() < 3.0 * se,
547            "geometric mean {mean} too far from {expected_mean}"
548        );
549        for &v in slice {
550            assert!(v >= 1, "geometric value {v} must be >= 1");
551        }
552    }
553
554    #[test]
555    fn hypergeometric_range() {
556        let mut rng = default_rng_seeded(42);
557        let ngood = 20u64;
558        let nbad = 30u64;
559        let nsample = 10u64;
560        let arr = rng.hypergeometric(ngood, nbad, nsample, 10_000).unwrap();
561        let slice = arr.as_slice().unwrap();
562        for &v in slice {
563            assert!(
564                v >= 0 && v <= nsample.min(ngood) as i64,
565                "hypergeometric value {v} out of range"
566            );
567        }
568    }
569
570    #[test]
571    fn hypergeometric_mean() {
572        let mut rng = default_rng_seeded(42);
573        let n = 100_000;
574        let ngood = 20u64;
575        let nbad = 30u64;
576        let nsample = 10u64;
577        let arr = rng.hypergeometric(ngood, nbad, nsample, n).unwrap();
578        let slice = arr.as_slice().unwrap();
579        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
580        // Hypergeometric: mean = nsample * ngood / (ngood + nbad)
581        let total = (ngood + nbad) as f64;
582        let expected_mean = nsample as f64 * ngood as f64 / total;
583        let expected_var = nsample as f64
584            * (ngood as f64 / total)
585            * (nbad as f64 / total)
586            * (total - nsample as f64)
587            / (total - 1.0);
588        let se = (expected_var / n as f64).sqrt();
589        assert!(
590            (mean - expected_mean).abs() < 3.0 * se,
591            "hypergeometric mean {mean} too far from {expected_mean}"
592        );
593    }
594
595    #[test]
596    fn logseries_positive() {
597        let mut rng = default_rng_seeded(42);
598        let arr = rng.logseries(0.5, 10_000).unwrap();
599        for &v in arr.as_slice().unwrap() {
600            assert!(v >= 1, "logseries value {v} must be >= 1");
601        }
602    }
603
604    #[test]
605    fn bad_params() {
606        let mut rng = default_rng_seeded(42);
607        assert!(rng.binomial(10, -0.1, 10).is_err());
608        assert!(rng.binomial(10, 1.5, 10).is_err());
609        assert!(rng.poisson(-1.0, 10).is_err());
610        assert!(rng.geometric(0.0, 10).is_err());
611        assert!(rng.geometric(1.5, 10).is_err());
612        assert!(rng.hypergeometric(5, 5, 20, 10).is_err());
613        assert!(rng.logseries(0.0, 10).is_err());
614        assert!(rng.logseries(1.0, 10).is_err());
615        assert!(rng.negative_binomial(0.0, 0.5, 10).is_err());
616        assert!(rng.negative_binomial(5.0, 0.0, 10).is_err());
617    }
618}