Skip to main content

rs_stats/distributions/
binomial_distribution.rs

1//! # Binomial Distribution
2//!
3//! This module implements the Binomial distribution, a discrete probability distribution
4//! that models the number of successes in a sequence of independent experiments.
5//!
6//! ## Key Characteristics
7//! - Models the number of successes in `n` independent trials
8//! - Each trial has success probability `p`
9//! - Discrete probability distribution
10//!
11//! ## Common Applications
12//! - Quality control testing
13//! - A/B testing
14//! - Risk analysis
15//! - Genetics (Mendelian inheritance)
16//!
17//! ## Mathematical Formulation
18//! The probability mass function (PMF) is given by:
19//!
20//! P(X = k) = C(n,k) * p^k * (1-p)^(n-k)
21//!
22//! where:
23//! - n is the number of trials
24//! - k is the number of successes
25//! - p is the probability of success
26//! - C(n,k) is the binomial coefficient (n choose k)
27
28use crate::error::{StatsError, StatsResult};
29use crate::utils::special_functions::ln_gamma;
30use num_traits::ToPrimitive;
31use serde::{Deserialize, Serialize};
32
33/// Configuration for the Binomial distribution.
34///
35/// # Fields
36/// * `n` - The number of trials (must be positive)
37/// * `p` - The probability of success (must be between 0 and 1)
38///
39/// # Examples
40/// ```
41/// use rs_stats::distributions::binomial_distribution::BinomialConfig;
42///
43/// let config = BinomialConfig { n: 10, p: 0.5 };
44/// assert!(config.n > 0);
45/// assert!(config.p >= 0.0 && config.p <= 1.0);
46/// ```
47#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
48pub struct BinomialConfig<T>
49where
50    T: ToPrimitive,
51{
52    /// The number of trials.
53    pub n: u64,
54    /// The probability of success in a single trial.
55    pub p: T,
56}
57
58impl<T> BinomialConfig<T>
59where
60    T: ToPrimitive,
61{
62    /// Creates a new BinomialConfig with validation
63    ///
64    /// # Arguments
65    /// * `n` - The number of trials
66    /// * `p` - The probability of success
67    ///
68    /// # Returns
69    /// `Some(BinomialConfig)` if parameters are valid, `None` otherwise
70    pub fn new(n: u64, p: T) -> StatsResult<Self> {
71        let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
72            message: "BinomialConfig::new: Failed to convert p to f64".to_string(),
73        })?;
74
75        if n == 0 {
76            return Err(StatsError::InvalidInput {
77                message: "BinomialConfig::new: n must be positive".to_string(),
78            });
79        }
80        if !((0.0..=1.0).contains(&p_64)) {
81            return Err(StatsError::InvalidInput {
82                message: "BinomialConfig::new: p must be between 0 and 1".to_string(),
83            });
84        }
85        Ok(Self { n, p })
86    }
87}
88
89/// Probability mass function (PMF) for the Binomial distribution.
90///
91/// Calculates the probability of observing exactly `k` successes in `n` trials
92/// with success probability `p`.
93///
94/// # Arguments
95/// * `k` - The number of successes (must be ≤ n)
96/// * `n` - The total number of trials (must be positive)
97/// * `p` - The probability of success in a single trial (must be between 0 and 1)
98///
99/// # Returns
100/// The probability of exactly `k` successes occurring.
101///
102/// # Errors
103/// Returns an error if:
104/// - n is zero
105/// - p is not between 0 and 1
106/// - k > n
107/// - Type conversion to f64 fails
108///
109/// # Examples
110/// ```
111/// use rs_stats::distributions::binomial_distribution::pmf;
112///
113/// // Calculate probability of 3 successes in 10 trials with p=0.5
114/// let prob = pmf(3, 10, 0.5).unwrap();
115/// assert!((prob - 0.1171875).abs() < 1e-10);
116/// ```
117#[inline]
118pub fn pmf<T>(k: u64, n: u64, p: T) -> StatsResult<f64>
119where
120    T: ToPrimitive,
121{
122    let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
123        message: "binomial_distribution::pmf: Failed to convert p to f64".to_string(),
124    })?;
125    if n == 0 {
126        return Err(StatsError::InvalidInput {
127            message: "binomial_distribution::pmf: n must be positive".to_string(),
128        });
129    }
130    if !((0.0..=1.0).contains(&p_64)) {
131        return Err(StatsError::InvalidInput {
132            message: "binomial_distribution::pmf: p must be between 0 and 1".to_string(),
133        });
134    }
135    let combinations = combination(n, k)?;
136
137    // Use log-space calculation to avoid:
138    // 1. Casting u64 to i32 (information loss)
139    // 2. Numerical underflow/overflow with large exponents
140    // 3. Better numerical stability
141    // Formula: p^k * (1-p)^(n-k) = exp(k * ln(p) + (n-k) * ln(1-p))
142
143    // Handle edge cases explicitly for correctness
144    if p_64 == 0.0 {
145        // If p = 0, then p^k = 0 for k > 0, and 1 for k = 0
146        return Ok(if k == 0 { combinations } else { 0.0 });
147    }
148    if p_64 == 1.0 {
149        // If p = 1, then (1-p)^(n-k) = 0 for k < n, and 1 for k = n
150        return Ok(if k == n { combinations } else { 0.0 });
151    }
152
153    // Convert to f64 (no information loss for reasonable values)
154    let k_f64 = k as f64;
155    let n_minus_k_f64 = (n - k) as f64;
156
157    // Calculate in log space: k * ln(p) + (n-k) * ln(1-p)
158    // Both p and (1-p) are guaranteed to be in (0, 1) here
159    let log_prob = k_f64 * p_64.ln() + n_minus_k_f64 * (1.0 - p_64).ln();
160
161    // Convert back from log space
162    let prob = log_prob.exp();
163
164    Ok(combinations * prob)
165}
166
167/// Cumulative distribution function (CDF) for the Binomial distribution.
168///
169/// Calculates the probability of observing `k` or fewer successes in `n` trials
170/// with success probability `p`.
171///
172/// # Arguments
173/// * `k` - The maximum number of successes (must be ≤ n)
174/// * `n` - The total number of trials (must be positive)
175/// * `p` - The probability of success in a single trial (must be between 0 and 1)
176///
177/// # Returns
178/// The cumulative probability of `k` or fewer successes occurring.
179///
180/// # Errors
181/// Returns an error if:
182/// - n is zero
183/// - p is not between 0 and 1
184/// - k > n
185/// - Type conversion to f64 fails
186///
187/// # Examples
188/// ```
189/// use rs_stats::distributions::binomial_distribution::cdf;
190///
191/// // Calculate probability of 3 or fewer successes in 10 trials with p=0.5
192/// let prob = cdf(3, 10, 0.5).unwrap();
193/// assert!((prob - 0.171875).abs() < 1e-10);
194/// ```
195#[inline]
196pub fn cdf(k: u64, n: u64, p: f64) -> StatsResult<f64> {
197    if n == 0 {
198        return Err(StatsError::InvalidInput {
199            message: "binomial_distribution::cdf: n must be positive".to_string(),
200        });
201    }
202    if !((0.0..=1.0).contains(&p)) {
203        return Err(StatsError::InvalidInput {
204            message: "binomial_distribution::cdf: p must be between 0 and 1".to_string(),
205        });
206    }
207    if k > n {
208        return Err(StatsError::InvalidInput {
209            message: "binomial_distribution::cdf: k must be less than or equal to n".to_string(),
210        });
211    }
212    // Use PMF recurrence relation: P(X=i+1) = P(X=i) * (n-i)/(i+1) * p/(1-p)
213    // This avoids recomputing factorials at each step: O(k) total, O(1) per step
214    if p == 0.0 {
215        return Ok(1.0); // P(X <= k) = 1 when p = 0 (all mass at k=0)
216    }
217    if p == 1.0 {
218        return Ok(if k >= n { 1.0 } else { 0.0 });
219    }
220
221    // Start with pmf(0) = (1-p)^n
222    let q = 1.0 - p;
223    let mut pmf_i = q.powi(n as i32);
224    // For very large n where powi overflows, fall back to log-space
225    if pmf_i == 0.0 && n > 0 {
226        let log_pmf_0 = (n as f64) * q.ln();
227        pmf_i = log_pmf_0.exp();
228    }
229    let mut cdf_sum = pmf_i;
230    let ratio = p / q;
231
232    for i in 0..k {
233        pmf_i *= ((n - i) as f64 / (i + 1) as f64) * ratio;
234        cdf_sum += pmf_i;
235    }
236
237    Ok(cdf_sum.clamp(0.0, 1.0))
238}
239
240/// Calculate the binomial coefficient (n choose k).
241#[inline]
242fn combination(n: u64, k: u64) -> StatsResult<f64> {
243    if k > n {
244        return Err(StatsError::InvalidInput {
245            message: "binomial_distribution::combination: k must be less than or equal to n"
246                .to_string(),
247        });
248    }
249
250    // Use a more numerically stable algorithm
251    if k > n / 2 {
252        return combination(n, n - k);
253    }
254
255    Ok((1..=k).fold(1.0_f64, |acc, i| acc * (n - i + 1) as f64 / i as f64))
256}
257
258// ── Typed struct + DiscreteDistribution impl ───────────────────────────────────
259
260/// Binomial distribution Binomial(n, p) as a typed struct.
261///
262/// # Examples
263/// ```
264/// use rs_stats::distributions::binomial_distribution::Binomial;
265/// use rs_stats::distributions::traits::DiscreteDistribution;
266///
267/// let b = Binomial::new(10, 0.5).unwrap();
268/// assert!((b.mean() - 5.0).abs() < 1e-10);
269/// ```
270#[derive(Debug, Clone, Copy)]
271pub struct Binomial {
272    /// Number of trials n (must be ≥ 1)
273    pub n: u64,
274    /// Success probability p ∈ [0, 1]
275    pub p: f64,
276}
277
278impl Binomial {
279    /// Creates a `Binomial` distribution with validation.
280    pub fn new(n: u64, p: f64) -> StatsResult<Self> {
281        if n == 0 {
282            return Err(StatsError::InvalidInput {
283                message: "Binomial::new: n must be at least 1".to_string(),
284            });
285        }
286        if !(0.0..=1.0).contains(&p) {
287            return Err(StatsError::InvalidInput {
288                message: "Binomial::new: p must be in [0, 1]".to_string(),
289            });
290        }
291        Ok(Self { n, p })
292    }
293
294    /// MLE: assume n = max(data), p = mean(data) / n.
295    pub fn fit(data: &[f64]) -> StatsResult<Self> {
296        if data.is_empty() {
297            return Err(StatsError::InvalidInput {
298                message: "Binomial::fit: data must not be empty".to_string(),
299            });
300        }
301        let n = data
302            .iter()
303            .cloned()
304            .fold(f64::NEG_INFINITY, f64::max)
305            .round() as u64;
306        let mean = data.iter().sum::<f64>() / data.len() as f64;
307        let p = if n == 0 { 0.5 } else { mean / n as f64 };
308        Self::new(n.max(1), p.clamp(0.0, 1.0))
309    }
310}
311
312impl crate::distributions::traits::DiscreteDistribution for Binomial {
313    fn name(&self) -> &str {
314        "Binomial"
315    }
316    fn num_params(&self) -> usize {
317        2
318    }
319    fn pmf(&self, k: u64) -> StatsResult<f64> {
320        pmf(k, self.n, self.p)
321    }
322    /// Log-space PMF for numerical stability with large n or k.
323    ///
324    /// ln P(X=k) = ln Γ(n+1) − ln Γ(k+1) − ln Γ(n−k+1) + k·ln(p) + (n−k)·ln(1−p)
325    fn logpmf(&self, k: u64) -> StatsResult<f64> {
326        let n = self.n;
327        if k > n {
328            return Ok(f64::NEG_INFINITY);
329        }
330        // ln C(n,k) via ln_gamma — exact and stable for any n, k.
331        let log_binom =
332            ln_gamma((n + 1) as f64) - ln_gamma((k + 1) as f64) - ln_gamma((n - k + 1) as f64);
333        let log_p = match (self.p, k) {
334            (0.0, 0) => 0.0,
335            (0.0, _) => return Ok(f64::NEG_INFINITY),
336            (_, _) => k as f64 * self.p.ln(),
337        };
338        let log_q = match (self.p, n - k) {
339            (1.0, 0) => 0.0,
340            (1.0, _) => return Ok(f64::NEG_INFINITY),
341            (_, nk) => nk as f64 * (1.0 - self.p).ln(),
342        };
343        Ok(log_binom + log_p + log_q)
344    }
345    fn cdf(&self, k: u64) -> StatsResult<f64> {
346        cdf(k, self.n, self.p)
347    }
348    fn mean(&self) -> f64 {
349        self.n as f64 * self.p
350    }
351    fn variance(&self) -> f64 {
352        self.n as f64 * self.p * (1.0 - self.p)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_binomial_pmf() {
362        let n = 10;
363        let p = 0.5;
364        let k = 5;
365        let result = pmf(k, n, p).unwrap();
366        assert!(
367            !result.is_nan(),
368            "PMF returned NaN for k={}, n={}, p={}",
369            k,
370            n,
371            p
372        );
373    }
374
375    #[test]
376    fn test_binomial_cdf() {
377        let n = 10;
378        let p = 0.5;
379        let k = 5;
380        let result = cdf(k, n, p).unwrap();
381        assert!(
382            !result.is_nan(),
383            "CDF returned NaN for k={}, n={}, p={}",
384            k,
385            n,
386            p
387        );
388    }
389
390    #[test]
391    fn test_binomial_pmf_large_values_n() {
392        // Test with values that exceed i32::MAX to verify overflow protection
393        // Using values just above i32::MAX (2,147,483,647)
394        let n = 2_200_000_000u64;
395        let k = 5u64;
396        let p = 0.5;
397
398        // This should not panic or truncate - should use powf() path
399        let result = pmf(k, n, p);
400
401        // Result might be very small or NaN due to numerical precision, but shouldn't panic
402        match result {
403            Ok(val) => {
404                // Value should be valid (might be very small due to large n)
405                assert!(
406                    !val.is_infinite(),
407                    "PMF should not be infinite for large values"
408                );
409            }
410            Err(_) => {
411                // Error is acceptable for very large values (numerical precision limits)
412            }
413        }
414    }
415
416    #[test]
417    fn test_binomial_pmf_large_values_k() {
418        // Test with values that exceed i32::MAX to verify overflow protection
419        // Using values just above i32::MAX (2,147,483,647)
420        let n = 2u64;
421        let k = 2_200_000_000_000u64;
422        let p = 0.5;
423
424        // This should not panic or truncate - should use powf() path
425        let result = pmf(k, n, p);
426
427        // Result might be very small or NaN due to numerical precision, but shouldn't panic
428        match result {
429            Ok(val) => {
430                // Value should be valid (might be very small due to large n)
431                assert!(
432                    !val.is_infinite(),
433                    "PMF should not be infinite for large values"
434                );
435            }
436            Err(_) => {
437                // Error is acceptable for very large values (numerical precision limits)
438            }
439        }
440    }
441
442    #[test]
443    fn test_binomial_config_new_valid() {
444        let config = BinomialConfig::new(10, 0.5);
445        assert!(config.is_ok());
446        let config = config.unwrap();
447        assert_eq!(config.n, 10);
448    }
449
450    #[test]
451    fn test_binomial_config_new_n_zero() {
452        let result = BinomialConfig::new(0, 0.5);
453        assert!(result.is_err());
454        assert!(matches!(
455            result.unwrap_err(),
456            StatsError::InvalidInput { .. }
457        ));
458    }
459
460    #[test]
461    fn test_binomial_config_new_p_out_of_range_negative() {
462        let result = BinomialConfig::new(10, -0.1);
463        assert!(result.is_err());
464        assert!(matches!(
465            result.unwrap_err(),
466            StatsError::InvalidInput { .. }
467        ));
468    }
469
470    #[test]
471    fn test_binomial_config_new_p_out_of_range_above_one() {
472        let result = BinomialConfig::new(10, 1.1);
473        assert!(result.is_err());
474        assert!(matches!(
475            result.unwrap_err(),
476            StatsError::InvalidInput { .. }
477        ));
478    }
479
480    #[test]
481    fn test_binomial_config_new_p_zero() {
482        let config = BinomialConfig::new(10, 0.0);
483        assert!(config.is_ok());
484    }
485
486    #[test]
487    fn test_binomial_config_new_p_one() {
488        let config = BinomialConfig::new(10, 1.0);
489        assert!(config.is_ok());
490    }
491
492    #[test]
493    fn test_binomial_pmf_p_zero_k_zero() {
494        // When p=0.0 and k=0, PMF should return combinations (which is 1 for k=0)
495        let result = pmf(0, 10, 0.0).unwrap();
496        assert_eq!(result, 1.0);
497    }
498
499    #[test]
500    fn test_binomial_pmf_p_zero_k_greater_than_zero() {
501        // When p=0.0 and k>0, PMF should return 0.0
502        let result = pmf(5, 10, 0.0).unwrap();
503        assert_eq!(result, 0.0);
504    }
505
506    #[test]
507    fn test_binomial_pmf_p_one_k_equals_n() {
508        // When p=1.0 and k=n, PMF should return combinations (which is 1 for k=n)
509        let result = pmf(10, 10, 1.0).unwrap();
510        assert_eq!(result, 1.0);
511    }
512
513    #[test]
514    fn test_binomial_pmf_p_one_k_less_than_n() {
515        // When p=1.0 and k<n, PMF should return 0.0
516        let result = pmf(5, 10, 1.0).unwrap();
517        assert_eq!(result, 0.0);
518    }
519
520    #[test]
521    fn test_binomial_pmf_n_zero() {
522        let result = pmf(0, 0, 0.5);
523        assert!(result.is_err());
524        assert!(matches!(
525            result.unwrap_err(),
526            StatsError::InvalidInput { .. }
527        ));
528    }
529
530    #[test]
531    fn test_binomial_pmf_p_out_of_range() {
532        let result = pmf(5, 10, 1.5);
533        assert!(result.is_err());
534        assert!(matches!(
535            result.unwrap_err(),
536            StatsError::InvalidInput { .. }
537        ));
538    }
539
540    #[test]
541    fn test_binomial_cdf_k_greater_than_n() {
542        let result = cdf(15, 10, 0.5);
543        assert!(result.is_err());
544        assert!(matches!(
545            result.unwrap_err(),
546            StatsError::InvalidInput { .. }
547        ));
548    }
549
550    #[test]
551    fn test_binomial_combination_symmetry() {
552        // Test that combination(n, k) == combination(n, n-k) when k > n/2
553        // This tests the symmetry optimization path
554        let n = 10u64;
555        let k = 8u64; // k > n/2, so should use symmetry
556
557        // Direct call should use symmetry path
558        let result1 = combination(n, k).unwrap();
559        // Should be same as combination(n, n-k)
560        let result2 = combination(n, n - k).unwrap();
561        assert_eq!(result1, result2);
562
563        // Verify it's correct: C(10, 8) = C(10, 2) = 45
564        assert_eq!(result1, 45.0);
565    }
566
567    #[test]
568    fn test_binomial_combination_k_greater_than_n() {
569        let result = combination(10, 15);
570        assert!(result.is_err());
571        assert!(matches!(
572            result.unwrap_err(),
573            StatsError::InvalidInput { .. }
574        ));
575    }
576
577    #[test]
578    fn test_binomial_combination_k_equals_n() {
579        // C(n, n) = 1
580        let result = combination(10, 10).unwrap();
581        assert_eq!(result, 1.0);
582    }
583
584    #[test]
585    fn test_binomial_combination_k_zero() {
586        // C(n, 0) = 1
587        let result = combination(10, 0).unwrap();
588        assert_eq!(result, 1.0);
589    }
590
591    #[test]
592    fn test_binomial_config_new_n_one() {
593        // Test edge case: n = 1 (minimum valid value)
594        let config = BinomialConfig::new(1, 0.5);
595        assert!(config.is_ok());
596        let config = config.unwrap();
597        assert_eq!(config.n, 1);
598    }
599
600    #[test]
601    fn test_binomial_pmf_k_greater_than_n() {
602        // When k > n, combination() should return an error
603        let result = pmf(15, 10, 0.5);
604        assert!(result.is_err());
605        assert!(matches!(
606            result.unwrap_err(),
607            StatsError::InvalidInput { .. }
608        ));
609    }
610
611    #[test]
612    fn test_binomial_cdf_n_zero() {
613        let result = cdf(5, 0, 0.5);
614        assert!(result.is_err());
615        assert!(matches!(
616            result.unwrap_err(),
617            StatsError::InvalidInput { .. }
618        ));
619    }
620
621    #[test]
622    fn test_binomial_cdf_p_out_of_range() {
623        let result = cdf(5, 10, 1.5);
624        assert!(result.is_err());
625        assert!(matches!(
626            result.unwrap_err(),
627            StatsError::InvalidInput { .. }
628        ));
629    }
630
631    #[test]
632    fn test_binomial_combination_k_exactly_n_over_2() {
633        // Test boundary case: k = n/2 (should not use symmetry)
634        let n = 10u64;
635        let k = 5u64; // k = n/2, should not use symmetry
636        let result = combination(n, k).unwrap();
637        // C(10, 5) = 252
638        assert_eq!(result, 252.0);
639    }
640
641    #[test]
642    fn test_binomial_combination_k_just_over_n_over_2() {
643        // Test k = n/2 + 1 (should use symmetry)
644        let n = 10u64;
645        let k = 6u64; // k > n/2, should use symmetry
646        let result1 = combination(n, k).unwrap();
647        let result2 = combination(n, n - k).unwrap();
648        assert_eq!(result1, result2);
649        // C(10, 6) = C(10, 4) = 210
650        assert_eq!(result1, 210.0);
651    }
652}