Skip to main content

ferray_random/distributions/
discrete.rs

1// ferray-random: Discrete distributions
2//
3// binomial, negative_binomial, poisson, geometric, hypergeometric, logseries
4
5use ferray_core::{Array, FerrayError, Ix1};
6
7use crate::bitgen::BitGenerator;
8use crate::distributions::gamma::standard_gamma_single;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec_i64, vec_to_array1_i64};
11
12/// Generate a single Poisson variate using Knuth's algorithm for small lambda,
13/// or the transformed rejection method (Hormann) for large lambda.
14fn poisson_single<B: BitGenerator>(bg: &mut B, lam: f64) -> i64 {
15    if lam < 30.0 {
16        // Knuth's algorithm
17        let l = (-lam).exp();
18        let mut k: i64 = 0;
19        let mut p = 1.0;
20        loop {
21            k += 1;
22            p *= bg.next_f64();
23            if p <= l {
24                return k - 1;
25            }
26        }
27    } else {
28        // Transformed rejection method (PA algorithm, Ahrens & Dieter)
29        let slam = lam.sqrt();
30        let loglam = lam.ln();
31        let b = 0.931 + 2.53 * slam;
32        let a = -0.059 + 0.02483 * b;
33        let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
34        let vr = 0.9277 - 3.6224 / (b - 2.0);
35
36        loop {
37            let u = bg.next_f64() - 0.5;
38            let v = bg.next_f64();
39            let us = 0.5 - u.abs();
40            let k = ((2.0 * a / us + b) * u + lam + 0.43).floor() as i64;
41            if k < 0 {
42                continue;
43            }
44            if us >= 0.07 && v <= vr {
45                return k;
46            }
47            if k > 0
48                && us >= 0.013
49                && v <= (k as f64)
50                    .ln()
51                    .mul_add(-0.5, (k as f64) * loglam - lam - ln_factorial(k as u64))
52                    .exp()
53                    * inv_alpha
54            {
55                // Removed the fast path — fall through to full test
56            }
57            if us < 0.013 && v > us {
58                continue;
59            }
60            // Full log test
61            let kf = k as f64;
62            let log_accept = -lam + kf * loglam - ln_factorial(k as u64);
63            if v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln() <= log_accept {
64                return k;
65            }
66        }
67    }
68}
69
70/// Approximate ln(n!) using Stirling's approximation with correction terms.
71fn ln_factorial(n: u64) -> f64 {
72    if n <= 20 {
73        // Use exact values for small n
74        let mut result = 0.0_f64;
75        for i in 2..=n {
76            result += (i as f64).ln();
77        }
78        result
79    } else {
80        // Stirling's approximation
81        let nf = n as f64;
82        0.5 * (std::f64::consts::TAU).ln() + (nf + 0.5) * nf.ln() - nf + 1.0 / (12.0 * nf)
83            - 1.0 / (360.0 * nf * nf * nf)
84    }
85}
86
87/// Generate a single binomial variate using the inverse transform for small n*p
88/// or the BTPE algorithm for larger n*p.
89fn binomial_single<B: BitGenerator>(bg: &mut B, n: u64, p: f64) -> i64 {
90    if n == 0 || p == 0.0 {
91        return 0;
92    }
93    if p == 1.0 {
94        return n as i64;
95    }
96
97    // Use the smaller of p and 1-p for efficiency
98    let (pp, flipped) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
99
100    let np = n as f64 * pp;
101
102    let result = if np < 30.0 {
103        // Inverse transform (waiting time) method
104        let q = 1.0 - pp;
105        let s = pp / q;
106        let a = (n as f64 + 1.0) * s;
107        let mut r = q.powi(n as i32);
108        let mut u = bg.next_f64();
109        let mut x: i64 = 0;
110        while u > r {
111            u -= r;
112            x += 1;
113            r *= a / (x as f64) - s;
114            if r < 0.0 {
115                break;
116            }
117        }
118        x.min(n as i64)
119    } else {
120        // For large n*p, use normal approximation with correction
121        // (BTPE is complex; this is a simpler but adequate approach)
122        loop {
123            let z = standard_normal_single(bg);
124            let sigma = (np * (1.0 - pp)).sqrt();
125            let x = (np + sigma * z + 0.5).floor() as i64;
126            if x >= 0 && x <= n as i64 {
127                break x;
128            }
129        }
130    };
131
132    if flipped { n as i64 - result } else { result }
133}
134
135impl<B: BitGenerator> Generator<B> {
136    /// Generate an array of binomial-distributed variates.
137    ///
138    /// Each value is the number of successes in `n` Bernoulli trials
139    /// with success probability `p`.
140    ///
141    /// # Arguments
142    /// * `n` - Number of trials.
143    /// * `p` - Probability of success per trial, must be in [0, 1].
144    /// * `size` - Number of values to generate.
145    ///
146    /// # Errors
147    /// Returns `FerrayError::InvalidValue` for invalid parameters.
148    pub fn binomial(
149        &mut self,
150        n: u64,
151        p: f64,
152        size: usize,
153    ) -> Result<Array<i64, Ix1>, FerrayError> {
154        if size == 0 {
155            return Err(FerrayError::invalid_value("size must be > 0"));
156        }
157        if !(0.0..=1.0).contains(&p) {
158            return Err(FerrayError::invalid_value(format!(
159                "p must be in [0, 1], got {p}"
160            )));
161        }
162        let data = generate_vec_i64(self, size, |bg| binomial_single(bg, n, p));
163        vec_to_array1_i64(data)
164    }
165
166    /// Generate an array of negative binomial distributed variates.
167    ///
168    /// The number of failures before `n` successes with success probability `p`.
169    /// Uses the gamma-Poisson mixture.
170    ///
171    /// # Arguments
172    /// * `n` - Number of successes (positive).
173    /// * `p` - Probability of success, must be in (0, 1].
174    /// * `size` - Number of values to generate.
175    ///
176    /// # Errors
177    /// Returns `FerrayError::InvalidValue` for invalid parameters.
178    pub fn negative_binomial(
179        &mut self,
180        n: f64,
181        p: f64,
182        size: usize,
183    ) -> Result<Array<i64, Ix1>, FerrayError> {
184        if size == 0 {
185            return Err(FerrayError::invalid_value("size must be > 0"));
186        }
187        if n <= 0.0 {
188            return Err(FerrayError::invalid_value(format!(
189                "n must be positive, got {n}"
190            )));
191        }
192        if p <= 0.0 || p > 1.0 {
193            return Err(FerrayError::invalid_value(format!(
194                "p must be in (0, 1], got {p}"
195            )));
196        }
197        let data = generate_vec_i64(self, size, |bg| {
198            // Gamma-Poisson mixture:
199            // Y ~ Gamma(n, (1-p)/p), then X ~ Poisson(Y)
200            let y = standard_gamma_single(bg, n) * (1.0 - p) / p;
201            poisson_single(bg, y)
202        });
203        vec_to_array1_i64(data)
204    }
205
206    /// Generate an array of Poisson-distributed variates.
207    ///
208    /// # Arguments
209    /// * `lam` - Expected number of events (lambda), must be non-negative.
210    /// * `size` - Number of values to generate.
211    ///
212    /// # Errors
213    /// Returns `FerrayError::InvalidValue` if `lam < 0` or `size` is zero.
214    pub fn poisson(&mut self, lam: f64, size: usize) -> Result<Array<i64, Ix1>, FerrayError> {
215        if size == 0 {
216            return Err(FerrayError::invalid_value("size must be > 0"));
217        }
218        if lam < 0.0 {
219            return Err(FerrayError::invalid_value(format!(
220                "lam must be non-negative, got {lam}"
221            )));
222        }
223        if lam == 0.0 {
224            let data = vec![0i64; size];
225            return vec_to_array1_i64(data);
226        }
227        let data = generate_vec_i64(self, size, |bg| poisson_single(bg, lam));
228        vec_to_array1_i64(data)
229    }
230
231    /// Generate an array of geometric-distributed variates.
232    ///
233    /// The number of trials until the first success (1-based).
234    ///
235    /// # Arguments
236    /// * `p` - Probability of success, must be in (0, 1].
237    /// * `size` - Number of values to generate.
238    ///
239    /// # Errors
240    /// Returns `FerrayError::InvalidValue` if `p` not in (0, 1] or `size` is zero.
241    pub fn geometric(&mut self, p: f64, size: usize) -> Result<Array<i64, Ix1>, FerrayError> {
242        if size == 0 {
243            return Err(FerrayError::invalid_value("size must be > 0"));
244        }
245        if p <= 0.0 || p > 1.0 {
246            return Err(FerrayError::invalid_value(format!(
247                "p must be in (0, 1], got {p}"
248            )));
249        }
250        if (p - 1.0).abs() < f64::EPSILON {
251            let data = vec![1i64; size];
252            return vec_to_array1_i64(data);
253        }
254        let log_q = (1.0 - p).ln();
255        let data = generate_vec_i64(self, size, |bg| {
256            loop {
257                let u = bg.next_f64();
258                if u > f64::EPSILON {
259                    return (u.ln() / log_q).floor() as i64 + 1;
260                }
261            }
262        });
263        vec_to_array1_i64(data)
264    }
265
266    /// Generate an array of hypergeometric-distributed variates.
267    ///
268    /// Models drawing `nsample` items without replacement from a population
269    /// containing `ngood` success states and `nbad` failure states.
270    ///
271    /// # Arguments
272    /// * `ngood` - Number of success states in the population.
273    /// * `nbad` - Number of failure states in the population.
274    /// * `nsample` - Number of items drawn.
275    /// * `size` - Number of values to generate.
276    ///
277    /// # Errors
278    /// Returns `FerrayError::InvalidValue` if `nsample > ngood + nbad` or `size` is zero.
279    pub fn hypergeometric(
280        &mut self,
281        ngood: u64,
282        nbad: u64,
283        nsample: u64,
284        size: usize,
285    ) -> Result<Array<i64, Ix1>, FerrayError> {
286        if size == 0 {
287            return Err(FerrayError::invalid_value("size must be > 0"));
288        }
289        let total = ngood + nbad;
290        if nsample > total {
291            return Err(FerrayError::invalid_value(format!(
292                "nsample ({nsample}) > ngood + nbad ({total})"
293            )));
294        }
295        let data = generate_vec_i64(self, size, |bg| {
296            hypergeometric_single(bg, ngood, nbad, nsample)
297        });
298        vec_to_array1_i64(data)
299    }
300
301    /// Generate an array of logarithmic series distributed variates.
302    ///
303    /// # Arguments
304    /// * `p` - Shape parameter, must be in (0, 1).
305    /// * `size` - Number of values to generate.
306    ///
307    /// # Errors
308    /// Returns `FerrayError::InvalidValue` if `p` not in (0, 1) or `size` is zero.
309    pub fn logseries(&mut self, p: f64, size: usize) -> Result<Array<i64, Ix1>, FerrayError> {
310        if size == 0 {
311            return Err(FerrayError::invalid_value("size must be > 0"));
312        }
313        if p <= 0.0 || p >= 1.0 {
314            return Err(FerrayError::invalid_value(format!(
315                "p must be in (0, 1), got {p}"
316            )));
317        }
318        let r = (-(-p).ln_1p()).recip();
319        let data = generate_vec_i64(self, size, |bg| {
320            // Kemp's "second" algorithm for the logarithmic distribution.
321            // See Devroye, "Non-Uniform Random Variate Generation", p. 548.
322            loop {
323                let u = bg.next_f64();
324                if u <= f64::EPSILON || u >= 1.0 - f64::EPSILON {
325                    continue;
326                }
327                let v = bg.next_f64();
328                let q = 1.0 - (-r.recip() * u.ln()).exp();
329                if q <= 0.0 {
330                    return 1;
331                }
332                if v < q * q {
333                    let k = (1.0 + v.ln() / q.ln()).floor() as i64;
334                    return k.max(1);
335                }
336                if v < q {
337                    return 2;
338                }
339                return 1;
340            }
341        });
342        vec_to_array1_i64(data)
343    }
344}
345
346/// Generate a single hypergeometric variate using the direct algorithm.
347fn hypergeometric_single<B: BitGenerator>(bg: &mut B, ngood: u64, nbad: u64, nsample: u64) -> i64 {
348    // Direct simulation: draw nsample items from population
349    let mut good_remaining = ngood;
350    let mut total_remaining = ngood + nbad;
351    let mut successes: i64 = 0;
352
353    for _ in 0..nsample {
354        if total_remaining == 0 {
355            break;
356        }
357        let u = bg.next_f64();
358        if u < (good_remaining as f64) / (total_remaining as f64) {
359            successes += 1;
360            good_remaining -= 1;
361        }
362        total_remaining -= 1;
363    }
364    successes
365}
366
367#[cfg(test)]
368mod tests {
369    use crate::default_rng_seeded;
370
371    #[test]
372    fn poisson_mean() {
373        let mut rng = default_rng_seeded(42);
374        let n = 100_000;
375        let lam = 5.0;
376        let arr = rng.poisson(lam, n).unwrap();
377        let slice = arr.as_slice().unwrap();
378        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
379        // Poisson(lam): mean = lam, var = lam
380        let se = (lam / n as f64).sqrt();
381        assert!(
382            (mean - lam).abs() < 3.0 * se,
383            "poisson mean {mean} too far from {lam}"
384        );
385    }
386
387    #[test]
388    fn poisson_large_lambda() {
389        let mut rng = default_rng_seeded(42);
390        let n = 50_000;
391        let lam = 100.0;
392        let arr = rng.poisson(lam, n).unwrap();
393        let slice = arr.as_slice().unwrap();
394        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
395        let se = (lam / n as f64).sqrt();
396        assert!(
397            (mean - lam).abs() < 3.0 * se,
398            "poisson mean {mean} too far from {lam}"
399        );
400    }
401
402    #[test]
403    fn poisson_zero() {
404        let mut rng = default_rng_seeded(42);
405        let arr = rng.poisson(0.0, 100).unwrap();
406        for &v in arr.as_slice().unwrap() {
407            assert_eq!(v, 0);
408        }
409    }
410
411    #[test]
412    fn binomial_mean() {
413        let mut rng = default_rng_seeded(42);
414        let size = 100_000;
415        let n = 20u64;
416        let p = 0.3;
417        let arr = rng.binomial(n, p, size).unwrap();
418        let slice = arr.as_slice().unwrap();
419        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / size as f64;
420        // Binomial(n, p): mean = n*p
421        let expected_mean = n as f64 * p;
422        let expected_var = n as f64 * p * (1.0 - p);
423        let se = (expected_var / size as f64).sqrt();
424        assert!(
425            (mean - expected_mean).abs() < 3.0 * se,
426            "binomial mean {mean} too far from {expected_mean}"
427        );
428        // Values must be in [0, n]
429        for &v in slice {
430            assert!(
431                v >= 0 && v <= n as i64,
432                "binomial value {v} out of [0, {n}]"
433            );
434        }
435    }
436
437    #[test]
438    fn binomial_edge_cases() {
439        let mut rng = default_rng_seeded(42);
440        // p=0: always 0
441        let arr = rng.binomial(10, 0.0, 100).unwrap();
442        for &v in arr.as_slice().unwrap() {
443            assert_eq!(v, 0);
444        }
445        // p=1: always n
446        let arr = rng.binomial(10, 1.0, 100).unwrap();
447        for &v in arr.as_slice().unwrap() {
448            assert_eq!(v, 10);
449        }
450    }
451
452    #[test]
453    fn negative_binomial_positive() {
454        let mut rng = default_rng_seeded(42);
455        let arr = rng.negative_binomial(5.0, 0.5, 10_000).unwrap();
456        for &v in arr.as_slice().unwrap() {
457            assert!(v >= 0, "negative_binomial value {v} must be >= 0");
458        }
459    }
460
461    #[test]
462    fn geometric_mean() {
463        let mut rng = default_rng_seeded(42);
464        let n = 100_000;
465        let p = 0.3;
466        let arr = rng.geometric(p, n).unwrap();
467        let slice = arr.as_slice().unwrap();
468        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
469        // Geometric(p) (1-based): mean = 1/p
470        let expected_mean = 1.0 / p;
471        let expected_var = (1.0 - p) / (p * p);
472        let se = (expected_var / n as f64).sqrt();
473        assert!(
474            (mean - expected_mean).abs() < 3.0 * se,
475            "geometric mean {mean} too far from {expected_mean}"
476        );
477        for &v in slice {
478            assert!(v >= 1, "geometric value {v} must be >= 1");
479        }
480    }
481
482    #[test]
483    fn hypergeometric_range() {
484        let mut rng = default_rng_seeded(42);
485        let ngood = 20u64;
486        let nbad = 30u64;
487        let nsample = 10u64;
488        let arr = rng.hypergeometric(ngood, nbad, nsample, 10_000).unwrap();
489        let slice = arr.as_slice().unwrap();
490        for &v in slice {
491            assert!(
492                v >= 0 && v <= nsample.min(ngood) as i64,
493                "hypergeometric value {v} out of range"
494            );
495        }
496    }
497
498    #[test]
499    fn hypergeometric_mean() {
500        let mut rng = default_rng_seeded(42);
501        let n = 100_000;
502        let ngood = 20u64;
503        let nbad = 30u64;
504        let nsample = 10u64;
505        let arr = rng.hypergeometric(ngood, nbad, nsample, n).unwrap();
506        let slice = arr.as_slice().unwrap();
507        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
508        // Hypergeometric: mean = nsample * ngood / (ngood + nbad)
509        let total = (ngood + nbad) as f64;
510        let expected_mean = nsample as f64 * ngood as f64 / total;
511        let expected_var = nsample as f64
512            * (ngood as f64 / total)
513            * (nbad as f64 / total)
514            * (total - nsample as f64)
515            / (total - 1.0);
516        let se = (expected_var / n as f64).sqrt();
517        assert!(
518            (mean - expected_mean).abs() < 3.0 * se,
519            "hypergeometric mean {mean} too far from {expected_mean}"
520        );
521    }
522
523    #[test]
524    fn logseries_positive() {
525        let mut rng = default_rng_seeded(42);
526        let arr = rng.logseries(0.5, 10_000).unwrap();
527        for &v in arr.as_slice().unwrap() {
528            assert!(v >= 1, "logseries value {v} must be >= 1");
529        }
530    }
531
532    #[test]
533    fn bad_params() {
534        let mut rng = default_rng_seeded(42);
535        assert!(rng.binomial(10, -0.1, 10).is_err());
536        assert!(rng.binomial(10, 1.5, 10).is_err());
537        assert!(rng.poisson(-1.0, 10).is_err());
538        assert!(rng.geometric(0.0, 10).is_err());
539        assert!(rng.geometric(1.5, 10).is_err());
540        assert!(rng.hypergeometric(5, 5, 20, 10).is_err());
541        assert!(rng.logseries(0.0, 10).is_err());
542        assert!(rng.logseries(1.0, 10).is_err());
543        assert!(rng.negative_binomial(0.0, 0.5, 10).is_err());
544        assert!(rng.negative_binomial(5.0, 0.0, 10).is_err());
545    }
546}