1use crate::error::{StatsError, StatsResult};
29use crate::utils::special_functions::ln_gamma;
30use num_traits::ToPrimitive;
31use serde::{Deserialize, Serialize};
32
33#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
48pub struct BinomialConfig<T>
49where
50 T: ToPrimitive,
51{
52 pub n: u64,
54 pub p: T,
56}
57
58impl<T> BinomialConfig<T>
59where
60 T: ToPrimitive,
61{
62 pub fn new(n: u64, p: T) -> StatsResult<Self> {
71 let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
72 message: "BinomialConfig::new: Failed to convert p to f64".to_string(),
73 })?;
74
75 if n == 0 {
76 return Err(StatsError::InvalidInput {
77 message: "BinomialConfig::new: n must be positive".to_string(),
78 });
79 }
80 if !((0.0..=1.0).contains(&p_64)) {
81 return Err(StatsError::InvalidInput {
82 message: "BinomialConfig::new: p must be between 0 and 1".to_string(),
83 });
84 }
85 Ok(Self { n, p })
86 }
87}
88
89#[inline]
118pub fn pmf<T>(k: u64, n: u64, p: T) -> StatsResult<f64>
119where
120 T: ToPrimitive,
121{
122 let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
123 message: "binomial_distribution::pmf: Failed to convert p to f64".to_string(),
124 })?;
125 if n == 0 {
126 return Err(StatsError::InvalidInput {
127 message: "binomial_distribution::pmf: n must be positive".to_string(),
128 });
129 }
130 if !((0.0..=1.0).contains(&p_64)) {
131 return Err(StatsError::InvalidInput {
132 message: "binomial_distribution::pmf: p must be between 0 and 1".to_string(),
133 });
134 }
135 let combinations = combination(n, k)?;
136
137 if p_64 == 0.0 {
145 return Ok(if k == 0 { combinations } else { 0.0 });
147 }
148 if p_64 == 1.0 {
149 return Ok(if k == n { combinations } else { 0.0 });
151 }
152
153 let k_f64 = k as f64;
155 let n_minus_k_f64 = (n - k) as f64;
156
157 let log_prob = k_f64 * p_64.ln() + n_minus_k_f64 * (1.0 - p_64).ln();
160
161 let prob = log_prob.exp();
163
164 Ok(combinations * prob)
165}
166
167#[inline]
196pub fn cdf(k: u64, n: u64, p: f64) -> StatsResult<f64> {
197 if n == 0 {
198 return Err(StatsError::InvalidInput {
199 message: "binomial_distribution::cdf: n must be positive".to_string(),
200 });
201 }
202 if !((0.0..=1.0).contains(&p)) {
203 return Err(StatsError::InvalidInput {
204 message: "binomial_distribution::cdf: p must be between 0 and 1".to_string(),
205 });
206 }
207 if k > n {
208 return Err(StatsError::InvalidInput {
209 message: "binomial_distribution::cdf: k must be less than or equal to n".to_string(),
210 });
211 }
212 if p == 0.0 {
215 return Ok(1.0); }
217 if p == 1.0 {
218 return Ok(if k >= n { 1.0 } else { 0.0 });
219 }
220
221 let q = 1.0 - p;
223 let mut pmf_i = q.powi(n as i32);
224 if pmf_i == 0.0 && n > 0 {
226 let log_pmf_0 = (n as f64) * q.ln();
227 pmf_i = log_pmf_0.exp();
228 }
229 let mut cdf_sum = pmf_i;
230 let ratio = p / q;
231
232 for i in 0..k {
233 pmf_i *= ((n - i) as f64 / (i + 1) as f64) * ratio;
234 cdf_sum += pmf_i;
235 }
236
237 Ok(cdf_sum.clamp(0.0, 1.0))
238}
239
240#[inline]
242fn combination(n: u64, k: u64) -> StatsResult<f64> {
243 if k > n {
244 return Err(StatsError::InvalidInput {
245 message: "binomial_distribution::combination: k must be less than or equal to n"
246 .to_string(),
247 });
248 }
249
250 if k > n / 2 {
252 return combination(n, n - k);
253 }
254
255 Ok((1..=k).fold(1.0_f64, |acc, i| acc * (n - i + 1) as f64 / i as f64))
256}
257
258#[derive(Debug, Clone, Copy)]
271pub struct Binomial {
272 pub n: u64,
274 pub p: f64,
276}
277
278impl Binomial {
279 pub fn new(n: u64, p: f64) -> StatsResult<Self> {
281 if n == 0 {
282 return Err(StatsError::InvalidInput {
283 message: "Binomial::new: n must be at least 1".to_string(),
284 });
285 }
286 if !(0.0..=1.0).contains(&p) {
287 return Err(StatsError::InvalidInput {
288 message: "Binomial::new: p must be in [0, 1]".to_string(),
289 });
290 }
291 Ok(Self { n, p })
292 }
293
294 pub fn fit(data: &[f64]) -> StatsResult<Self> {
296 if data.is_empty() {
297 return Err(StatsError::InvalidInput {
298 message: "Binomial::fit: data must not be empty".to_string(),
299 });
300 }
301 let n = data
302 .iter()
303 .cloned()
304 .fold(f64::NEG_INFINITY, f64::max)
305 .round() as u64;
306 let mean = data.iter().sum::<f64>() / data.len() as f64;
307 let p = if n == 0 { 0.5 } else { mean / n as f64 };
308 Self::new(n.max(1), p.clamp(0.0, 1.0))
309 }
310}
311
312impl crate::distributions::traits::DiscreteDistribution for Binomial {
313 fn name(&self) -> &str {
314 "Binomial"
315 }
316 fn num_params(&self) -> usize {
317 2
318 }
319 fn pmf(&self, k: u64) -> StatsResult<f64> {
320 pmf(k, self.n, self.p)
321 }
322 fn logpmf(&self, k: u64) -> StatsResult<f64> {
326 let n = self.n;
327 if k > n {
328 return Ok(f64::NEG_INFINITY);
329 }
330 let log_binom =
332 ln_gamma((n + 1) as f64) - ln_gamma((k + 1) as f64) - ln_gamma((n - k + 1) as f64);
333 let log_p = match (self.p, k) {
334 (0.0, 0) => 0.0,
335 (0.0, _) => return Ok(f64::NEG_INFINITY),
336 (_, _) => k as f64 * self.p.ln(),
337 };
338 let log_q = match (self.p, n - k) {
339 (1.0, 0) => 0.0,
340 (1.0, _) => return Ok(f64::NEG_INFINITY),
341 (_, nk) => nk as f64 * (1.0 - self.p).ln(),
342 };
343 Ok(log_binom + log_p + log_q)
344 }
345 fn cdf(&self, k: u64) -> StatsResult<f64> {
346 cdf(k, self.n, self.p)
347 }
348 fn mean(&self) -> f64 {
349 self.n as f64 * self.p
350 }
351 fn variance(&self) -> f64 {
352 self.n as f64 * self.p * (1.0 - self.p)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_binomial_pmf() {
362 let n = 10;
363 let p = 0.5;
364 let k = 5;
365 let result = pmf(k, n, p).unwrap();
366 assert!(
367 !result.is_nan(),
368 "PMF returned NaN for k={}, n={}, p={}",
369 k,
370 n,
371 p
372 );
373 }
374
375 #[test]
376 fn test_binomial_cdf() {
377 let n = 10;
378 let p = 0.5;
379 let k = 5;
380 let result = cdf(k, n, p).unwrap();
381 assert!(
382 !result.is_nan(),
383 "CDF returned NaN for k={}, n={}, p={}",
384 k,
385 n,
386 p
387 );
388 }
389
390 #[test]
391 fn test_binomial_pmf_large_values_n() {
392 let n = 2_200_000_000u64;
395 let k = 5u64;
396 let p = 0.5;
397
398 let result = pmf(k, n, p);
400
401 match result {
403 Ok(val) => {
404 assert!(
406 !val.is_infinite(),
407 "PMF should not be infinite for large values"
408 );
409 }
410 Err(_) => {
411 }
413 }
414 }
415
416 #[test]
417 fn test_binomial_pmf_large_values_k() {
418 let n = 2u64;
421 let k = 2_200_000_000_000u64;
422 let p = 0.5;
423
424 let result = pmf(k, n, p);
426
427 match result {
429 Ok(val) => {
430 assert!(
432 !val.is_infinite(),
433 "PMF should not be infinite for large values"
434 );
435 }
436 Err(_) => {
437 }
439 }
440 }
441
442 #[test]
443 fn test_binomial_config_new_valid() {
444 let config = BinomialConfig::new(10, 0.5);
445 assert!(config.is_ok());
446 let config = config.unwrap();
447 assert_eq!(config.n, 10);
448 }
449
450 #[test]
451 fn test_binomial_config_new_n_zero() {
452 let result = BinomialConfig::new(0, 0.5);
453 assert!(result.is_err());
454 assert!(matches!(
455 result.unwrap_err(),
456 StatsError::InvalidInput { .. }
457 ));
458 }
459
460 #[test]
461 fn test_binomial_config_new_p_out_of_range_negative() {
462 let result = BinomialConfig::new(10, -0.1);
463 assert!(result.is_err());
464 assert!(matches!(
465 result.unwrap_err(),
466 StatsError::InvalidInput { .. }
467 ));
468 }
469
470 #[test]
471 fn test_binomial_config_new_p_out_of_range_above_one() {
472 let result = BinomialConfig::new(10, 1.1);
473 assert!(result.is_err());
474 assert!(matches!(
475 result.unwrap_err(),
476 StatsError::InvalidInput { .. }
477 ));
478 }
479
480 #[test]
481 fn test_binomial_config_new_p_zero() {
482 let config = BinomialConfig::new(10, 0.0);
483 assert!(config.is_ok());
484 }
485
486 #[test]
487 fn test_binomial_config_new_p_one() {
488 let config = BinomialConfig::new(10, 1.0);
489 assert!(config.is_ok());
490 }
491
492 #[test]
493 fn test_binomial_pmf_p_zero_k_zero() {
494 let result = pmf(0, 10, 0.0).unwrap();
496 assert_eq!(result, 1.0);
497 }
498
499 #[test]
500 fn test_binomial_pmf_p_zero_k_greater_than_zero() {
501 let result = pmf(5, 10, 0.0).unwrap();
503 assert_eq!(result, 0.0);
504 }
505
506 #[test]
507 fn test_binomial_pmf_p_one_k_equals_n() {
508 let result = pmf(10, 10, 1.0).unwrap();
510 assert_eq!(result, 1.0);
511 }
512
513 #[test]
514 fn test_binomial_pmf_p_one_k_less_than_n() {
515 let result = pmf(5, 10, 1.0).unwrap();
517 assert_eq!(result, 0.0);
518 }
519
520 #[test]
521 fn test_binomial_pmf_n_zero() {
522 let result = pmf(0, 0, 0.5);
523 assert!(result.is_err());
524 assert!(matches!(
525 result.unwrap_err(),
526 StatsError::InvalidInput { .. }
527 ));
528 }
529
530 #[test]
531 fn test_binomial_pmf_p_out_of_range() {
532 let result = pmf(5, 10, 1.5);
533 assert!(result.is_err());
534 assert!(matches!(
535 result.unwrap_err(),
536 StatsError::InvalidInput { .. }
537 ));
538 }
539
540 #[test]
541 fn test_binomial_cdf_k_greater_than_n() {
542 let result = cdf(15, 10, 0.5);
543 assert!(result.is_err());
544 assert!(matches!(
545 result.unwrap_err(),
546 StatsError::InvalidInput { .. }
547 ));
548 }
549
550 #[test]
551 fn test_binomial_combination_symmetry() {
552 let n = 10u64;
555 let k = 8u64; let result1 = combination(n, k).unwrap();
559 let result2 = combination(n, n - k).unwrap();
561 assert_eq!(result1, result2);
562
563 assert_eq!(result1, 45.0);
565 }
566
567 #[test]
568 fn test_binomial_combination_k_greater_than_n() {
569 let result = combination(10, 15);
570 assert!(result.is_err());
571 assert!(matches!(
572 result.unwrap_err(),
573 StatsError::InvalidInput { .. }
574 ));
575 }
576
577 #[test]
578 fn test_binomial_combination_k_equals_n() {
579 let result = combination(10, 10).unwrap();
581 assert_eq!(result, 1.0);
582 }
583
584 #[test]
585 fn test_binomial_combination_k_zero() {
586 let result = combination(10, 0).unwrap();
588 assert_eq!(result, 1.0);
589 }
590
591 #[test]
592 fn test_binomial_config_new_n_one() {
593 let config = BinomialConfig::new(1, 0.5);
595 assert!(config.is_ok());
596 let config = config.unwrap();
597 assert_eq!(config.n, 1);
598 }
599
600 #[test]
601 fn test_binomial_pmf_k_greater_than_n() {
602 let result = pmf(15, 10, 0.5);
604 assert!(result.is_err());
605 assert!(matches!(
606 result.unwrap_err(),
607 StatsError::InvalidInput { .. }
608 ));
609 }
610
611 #[test]
612 fn test_binomial_cdf_n_zero() {
613 let result = cdf(5, 0, 0.5);
614 assert!(result.is_err());
615 assert!(matches!(
616 result.unwrap_err(),
617 StatsError::InvalidInput { .. }
618 ));
619 }
620
621 #[test]
622 fn test_binomial_cdf_p_out_of_range() {
623 let result = cdf(5, 10, 1.5);
624 assert!(result.is_err());
625 assert!(matches!(
626 result.unwrap_err(),
627 StatsError::InvalidInput { .. }
628 ));
629 }
630
631 #[test]
632 fn test_binomial_combination_k_exactly_n_over_2() {
633 let n = 10u64;
635 let k = 5u64; let result = combination(n, k).unwrap();
637 assert_eq!(result, 252.0);
639 }
640
641 #[test]
642 fn test_binomial_combination_k_just_over_n_over_2() {
643 let n = 10u64;
645 let k = 6u64; let result1 = combination(n, k).unwrap();
647 let result2 = combination(n, n - k).unwrap();
648 assert_eq!(result1, result2);
649 assert_eq!(result1, 210.0);
651 }
652}