1use core::f64::consts::PI;
8
9use cyanea_core::{CyaneaError, Result};
10
11pub fn erf(x: f64) -> f64 {
15 let sign = if x < 0.0 { -1.0 } else { 1.0 };
16 let x = x.abs();
17 let t = 1.0 / (1.0 + 0.3275911 * x);
18 let poly = t
19 * (0.254829592
20 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
21 sign * (1.0 - poly * (-x * x).exp())
22}
23
24pub fn ln_gamma(x: f64) -> f64 {
26 const COEFFS: [f64; 8] = [
27 676.5203681218851,
28 -1259.1392167224028,
29 771.32342877765313,
30 -176.61502916214059,
31 12.507343278686905,
32 -0.13857109526572012,
33 9.9843695780195716e-6,
34 1.5056327351493116e-7,
35 ];
36
37 if x < 0.5 {
38 let log_pi_over_sin = (PI / (PI * x).sin()).ln();
40 log_pi_over_sin - ln_gamma(1.0 - x)
41 } else {
42 let x = x - 1.0;
43 let mut ag = 0.99999999999980993_f64;
44 for (i, &c) in COEFFS.iter().enumerate() {
45 ag += c / (x + i as f64 + 1.0);
46 }
47 let t = x + 7.5; 0.5 * (2.0 * PI).ln() + (x + 0.5) * t.ln() - t + ag.ln()
49 }
50}
51
52pub fn betai(a: f64, b: f64, x: f64) -> Result<f64> {
57 if x < 0.0 || x > 1.0 {
58 return Err(CyaneaError::InvalidInput(
59 "betai: x must be in [0, 1]".into(),
60 ));
61 }
62 if x == 0.0 || x == 1.0 {
63 return Ok(x);
64 }
65
66 if x > (a + 1.0) / (a + b + 2.0) {
68 return Ok(1.0 - betai(b, a, 1.0 - x)?);
69 }
70
71 let ln_prefactor = ln_gamma(a + b) - ln_gamma(a) - ln_gamma(b)
72 + a * x.ln()
73 + b * (1.0 - x).ln();
74 let prefactor = ln_prefactor.exp();
75
76 let tiny = 1e-30_f64;
78 let eps = 1e-10_f64;
79 let max_iter = 200;
80
81 let mut c = 1.0_f64;
82 let mut d = (1.0 - (a + b) * x / (a + 1.0)).recip();
83 if d.abs() < tiny {
84 d = tiny;
85 }
86 let mut h = d;
87
88 for m in 1..=max_iter {
89 let m_f64 = m as f64;
90
91 let num_even = m_f64 * (b - m_f64) * x / ((a + 2.0 * m_f64 - 1.0) * (a + 2.0 * m_f64));
93 d = 1.0 + num_even * d;
94 if d.abs() < tiny {
95 d = tiny;
96 }
97 d = d.recip();
98 c = 1.0 + num_even / c;
99 if c.abs() < tiny {
100 c = tiny;
101 }
102 h *= d * c;
103
104 let num_odd = -((a + m_f64) * (a + b + m_f64) * x)
106 / ((a + 2.0 * m_f64) * (a + 2.0 * m_f64 + 1.0));
107 d = 1.0 + num_odd * d;
108 if d.abs() < tiny {
109 d = tiny;
110 }
111 d = d.recip();
112 c = 1.0 + num_odd / c;
113 if c.abs() < tiny {
114 c = tiny;
115 }
116 let delta = d * c;
117 h *= delta;
118
119 if (delta - 1.0).abs() < eps {
120 return Ok(prefactor * h / a);
121 }
122 }
123
124 Ok(prefactor * h / a)
125}
126
127pub trait Distribution {
131 fn pdf(&self, x: f64) -> f64;
133
134 fn cdf(&self, x: f64) -> f64;
136
137 fn mean(&self) -> f64;
139
140 fn variance(&self) -> f64;
142
143 fn std_dev(&self) -> f64 {
145 self.variance().sqrt()
146 }
147}
148
149#[derive(Debug, Clone, Copy)]
153pub struct Normal {
154 mu: f64,
155 sigma: f64,
156}
157
158impl Normal {
159 pub fn new(mu: f64, sigma: f64) -> Result<Self> {
161 if sigma <= 0.0 {
162 return Err(CyaneaError::InvalidInput(
163 "Normal: sigma must be positive".into(),
164 ));
165 }
166 Ok(Self { mu, sigma })
167 }
168
169 pub fn standard() -> Self {
171 Self {
172 mu: 0.0,
173 sigma: 1.0,
174 }
175 }
176}
177
178impl Distribution for Normal {
179 fn pdf(&self, x: f64) -> f64 {
180 let z = (x - self.mu) / self.sigma;
181 (-0.5 * z * z).exp() / (self.sigma * (2.0 * PI).sqrt())
182 }
183
184 fn cdf(&self, x: f64) -> f64 {
185 let z = (x - self.mu) / self.sigma;
186 0.5 * (1.0 + erf(z / core::f64::consts::SQRT_2))
187 }
188
189 fn mean(&self) -> f64 {
190 self.mu
191 }
192
193 fn variance(&self) -> f64 {
194 self.sigma * self.sigma
195 }
196}
197
198#[derive(Debug, Clone, Copy)]
202pub struct Poisson {
203 lambda: f64,
204}
205
206impl Poisson {
207 pub fn new(lambda: f64) -> Result<Self> {
209 if lambda <= 0.0 {
210 return Err(CyaneaError::InvalidInput(
211 "Poisson: lambda must be positive".into(),
212 ));
213 }
214 Ok(Self { lambda })
215 }
216}
217
218impl Distribution for Poisson {
219 fn pdf(&self, x: f64) -> f64 {
220 let k = x.round() as i64;
221 if k < 0 || (x - k as f64).abs() > 1e-9 {
222 return 0.0;
223 }
224 let k = k as f64;
225 (k * self.lambda.ln() - self.lambda - ln_gamma(k + 1.0)).exp()
227 }
228
229 fn cdf(&self, x: f64) -> f64 {
230 let k_max = x.floor() as i64;
231 if k_max < 0 {
232 return 0.0;
233 }
234 let mut sum = 0.0;
235 for k in 0..=k_max {
236 let k = k as f64;
237 sum += (k * self.lambda.ln() - self.lambda - ln_gamma(k + 1.0)).exp();
238 }
239 sum.min(1.0)
240 }
241
242 fn mean(&self) -> f64 {
243 self.lambda
244 }
245
246 fn variance(&self) -> f64 {
247 self.lambda
248 }
249}
250
251pub fn gammainc(a: f64, x: f64) -> Result<f64> {
258 if a <= 0.0 {
259 return Err(CyaneaError::InvalidInput("gammainc: a must be positive".into()));
260 }
261 if x < 0.0 {
262 return Err(CyaneaError::InvalidInput("gammainc: x must be non-negative".into()));
263 }
264 if x == 0.0 {
265 return Ok(0.0);
266 }
267
268 if x < a + 1.0 {
269 gammainc_series(a, x)
271 } else {
272 let q = gammainc_cf(a, x)?;
274 Ok(1.0 - q)
275 }
276}
277
278fn gammainc_series(a: f64, x: f64) -> Result<f64> {
280 let max_iter = 200;
281 let eps = 1e-12;
282 let ln_prefix = a * x.ln() - x - ln_gamma(a);
283
284 let mut sum = 1.0 / a;
285 let mut term = 1.0 / a;
286
287 for n in 1..=max_iter {
288 term *= x / (a + n as f64);
289 sum += term;
290 if term.abs() < sum.abs() * eps {
291 return Ok(sum * ln_prefix.exp());
292 }
293 }
294
295 Ok(sum * ln_prefix.exp())
296}
297
298fn gammainc_cf(a: f64, x: f64) -> Result<f64> {
300 let max_iter = 200;
301 let eps = 1e-12;
302 let tiny = 1e-30_f64;
303 let ln_prefix = a * x.ln() - x - ln_gamma(a);
304
305 let mut b = x + 1.0 - a;
306 let mut c = 1.0 / tiny;
307 let mut d = 1.0 / b;
308 let mut h = d;
309
310 for i in 1..=max_iter {
311 let an = -(i as f64) * (i as f64 - a);
312 b += 2.0;
313 d = an * d + b;
314 if d.abs() < tiny {
315 d = tiny;
316 }
317 c = b + an / c;
318 if c.abs() < tiny {
319 c = tiny;
320 }
321 d = 1.0 / d;
322 let delta = d * c;
323 h *= delta;
324 if (delta - 1.0).abs() < eps {
325 break;
326 }
327 }
328
329 Ok(h * ln_prefix.exp())
330}
331
332#[derive(Debug, Clone, Copy)]
336pub struct ChiSquared {
337 k: f64,
338}
339
340impl ChiSquared {
341 pub fn new(k: f64) -> Result<Self> {
343 if k <= 0.0 {
344 return Err(CyaneaError::InvalidInput(
345 "ChiSquared: k must be positive".into(),
346 ));
347 }
348 Ok(Self { k })
349 }
350
351 pub fn df(&self) -> f64 {
353 self.k
354 }
355}
356
357impl Distribution for ChiSquared {
358 fn pdf(&self, x: f64) -> f64 {
359 if x <= 0.0 {
360 return 0.0;
361 }
362 let half_k = self.k / 2.0;
363 let ln_pdf = (half_k - 1.0) * x.ln() - x / 2.0 - half_k * 2.0_f64.ln() - ln_gamma(half_k);
364 ln_pdf.exp()
365 }
366
367 fn cdf(&self, x: f64) -> f64 {
368 if x <= 0.0 {
369 return 0.0;
370 }
371 gammainc(self.k / 2.0, x / 2.0).unwrap_or(0.0)
372 }
373
374 fn mean(&self) -> f64 {
375 self.k
376 }
377
378 fn variance(&self) -> f64 {
379 2.0 * self.k
380 }
381}
382
383#[derive(Debug, Clone, Copy)]
387pub struct FDistribution {
388 d1: f64,
389 d2: f64,
390}
391
392impl FDistribution {
393 pub fn new(d1: f64, d2: f64) -> Result<Self> {
395 if d1 <= 0.0 || d2 <= 0.0 {
396 return Err(CyaneaError::InvalidInput(
397 "FDistribution: both d1 and d2 must be positive".into(),
398 ));
399 }
400 Ok(Self { d1, d2 })
401 }
402}
403
404impl Distribution for FDistribution {
405 fn pdf(&self, x: f64) -> f64 {
406 if x <= 0.0 {
407 return 0.0;
408 }
409 let d1 = self.d1;
410 let d2 = self.d2;
411 let ln_pdf = 0.5 * d1 * (d1 * x / (d1 * x + d2)).ln()
412 + 0.5 * d2 * (d2 / (d1 * x + d2)).ln()
413 - x.ln()
414 - ln_gamma(d1 / 2.0)
415 - ln_gamma(d2 / 2.0)
416 + ln_gamma((d1 + d2) / 2.0);
417 ln_pdf.exp()
418 }
419
420 fn cdf(&self, x: f64) -> f64 {
421 if x <= 0.0 {
422 return 0.0;
423 }
424 let ix = self.d1 * x / (self.d1 * x + self.d2);
425 betai(self.d1 / 2.0, self.d2 / 2.0, ix).unwrap_or(0.0)
426 }
427
428 fn mean(&self) -> f64 {
429 if self.d2 > 2.0 {
430 self.d2 / (self.d2 - 2.0)
431 } else {
432 f64::INFINITY
433 }
434 }
435
436 fn variance(&self) -> f64 {
437 if self.d2 > 4.0 {
438 let d1 = self.d1;
439 let d2 = self.d2;
440 2.0 * d2 * d2 * (d1 + d2 - 2.0)
441 / (d1 * (d2 - 2.0).powi(2) * (d2 - 4.0))
442 } else {
443 f64::INFINITY
444 }
445 }
446}
447
448#[derive(Debug, Clone, Copy)]
452pub struct Binomial {
453 n: usize,
454 p: f64,
455}
456
457impl Binomial {
458 pub fn new(n: usize, p: f64) -> Result<Self> {
460 if !(0.0..=1.0).contains(&p) {
461 return Err(CyaneaError::InvalidInput(
462 "Binomial: p must be in [0, 1]".into(),
463 ));
464 }
465 Ok(Self { n, p })
466 }
467
468 pub fn trials(&self) -> usize {
470 self.n
471 }
472
473 pub fn prob(&self) -> f64 {
475 self.p
476 }
477
478 pub fn pmf(&self, k: usize) -> f64 {
480 if k > self.n {
481 return 0.0;
482 }
483 let ln_binom = ln_gamma(self.n as f64 + 1.0)
484 - ln_gamma(k as f64 + 1.0)
485 - ln_gamma((self.n - k) as f64 + 1.0);
486 let ln_pmf = ln_binom + k as f64 * self.p.ln() + (self.n - k) as f64 * (1.0 - self.p).ln();
487 ln_pmf.exp()
488 }
489}
490
491impl Distribution for Binomial {
492 fn pdf(&self, x: f64) -> f64 {
493 let k = x.round() as i64;
494 if k < 0 || (x - k as f64).abs() > 1e-9 {
495 return 0.0;
496 }
497 self.pmf(k as usize)
498 }
499
500 fn cdf(&self, x: f64) -> f64 {
501 let k_max = x.floor() as i64;
502 if k_max < 0 {
503 return 0.0;
504 }
505 let k_max = k_max as usize;
506 if k_max >= self.n {
507 return 1.0;
508 }
509 betai((self.n - k_max) as f64, (k_max + 1) as f64, 1.0 - self.p).unwrap_or(1.0)
511 }
512
513 fn mean(&self) -> f64 {
514 self.n as f64 * self.p
515 }
516
517 fn variance(&self) -> f64 {
518 self.n as f64 * self.p * (1.0 - self.p)
519 }
520}
521
522#[derive(Debug, Clone, Copy)]
534pub struct NegativeBinomial {
535 r: f64,
536 p: f64,
537}
538
539impl NegativeBinomial {
540 pub fn new(r: f64, p: f64) -> Result<Self> {
544 if r <= 0.0 {
545 return Err(CyaneaError::InvalidInput(
546 "NegativeBinomial: r must be positive".into(),
547 ));
548 }
549 if p <= 0.0 || p > 1.0 {
550 return Err(CyaneaError::InvalidInput(
551 "NegativeBinomial: p must be in (0, 1]".into(),
552 ));
553 }
554 Ok(Self { r, p })
555 }
556
557 pub fn from_mean_dispersion(mu: f64, alpha: f64) -> Result<Self> {
564 if mu <= 0.0 {
565 return Err(CyaneaError::InvalidInput(
566 "NegativeBinomial: mu must be positive".into(),
567 ));
568 }
569 if alpha <= 0.0 {
570 return Err(CyaneaError::InvalidInput(
571 "NegativeBinomial: alpha must be positive".into(),
572 ));
573 }
574 let r = 1.0 / alpha;
575 let p = r / (r + mu);
576 Ok(Self { r, p })
577 }
578
579 pub fn r(&self) -> f64 {
581 self.r
582 }
583
584 pub fn p(&self) -> f64 {
586 self.p
587 }
588
589 pub fn ln_pmf(&self, k: usize) -> f64 {
593 let k_f = k as f64;
594 ln_gamma(k_f + self.r)
595 - ln_gamma(k_f + 1.0)
596 - ln_gamma(self.r)
597 + self.r * self.p.ln()
598 + k_f * (1.0 - self.p).ln()
599 }
600
601 pub fn pmf(&self, k: usize) -> f64 {
603 self.ln_pmf(k).exp()
604 }
605}
606
607impl Distribution for NegativeBinomial {
608 fn pdf(&self, x: f64) -> f64 {
609 let k = x.round() as i64;
610 if k < 0 || (x - k as f64).abs() > 1e-9 {
611 return 0.0;
612 }
613 self.pmf(k as usize)
614 }
615
616 fn cdf(&self, x: f64) -> f64 {
617 let k_max = x.floor() as i64;
618 if k_max < 0 {
619 return 0.0;
620 }
621 betai(self.r, (k_max + 1) as f64, self.p).unwrap_or(1.0)
623 }
624
625 fn mean(&self) -> f64 {
626 self.r * (1.0 - self.p) / self.p
627 }
628
629 fn variance(&self) -> f64 {
630 self.r * (1.0 - self.p) / (self.p * self.p)
631 }
632}
633
634#[cfg(test)]
637mod tests {
638 use super::*;
639
640 const TOL: f64 = 1e-6;
641
642 #[test]
643 fn erf_zero() {
644 assert!((erf(0.0)).abs() < TOL);
645 }
646
647 #[test]
648 fn erf_one() {
649 assert!((erf(1.0) - 0.8427007929).abs() < 1e-5);
650 }
651
652 #[test]
653 fn erf_negative_symmetry() {
654 assert!((erf(-0.5) + erf(0.5)).abs() < TOL);
655 }
656
657 #[test]
658 fn ln_gamma_integers() {
659 assert!((ln_gamma(1.0) - 0.0).abs() < TOL); assert!((ln_gamma(2.0) - 0.0).abs() < TOL); assert!((ln_gamma(5.0) - (24.0_f64).ln()).abs() < TOL); assert!((ln_gamma(7.0) - (720.0_f64).ln()).abs() < TOL); }
665
666 #[test]
667 fn ln_gamma_half() {
668 assert!((ln_gamma(0.5) - 0.5 * PI.ln()).abs() < 1e-5);
670 }
671
672 #[test]
673 fn betai_boundaries() {
674 assert_eq!(betai(1.0, 1.0, 0.0).unwrap(), 0.0);
675 assert_eq!(betai(1.0, 1.0, 1.0).unwrap(), 1.0);
676 }
677
678 #[test]
679 fn betai_uniform() {
680 assert!((betai(1.0, 1.0, 0.5).unwrap() - 0.5).abs() < TOL);
682 assert!((betai(1.0, 1.0, 0.3).unwrap() - 0.3).abs() < TOL);
683 }
684
685 #[test]
686 fn betai_symmetry() {
687 let a = 2.0;
689 let b = 3.0;
690 let x = 0.4;
691 let lhs = betai(a, b, x).unwrap();
692 let rhs = 1.0 - betai(b, a, 1.0 - x).unwrap();
693 assert!((lhs - rhs).abs() < TOL);
694 }
695
696 #[test]
697 fn betai_invalid_x() {
698 assert!(betai(1.0, 1.0, -0.1).is_err());
699 assert!(betai(1.0, 1.0, 1.1).is_err());
700 }
701
702 #[test]
703 fn normal_standard_cdf() {
704 let n = Normal::standard();
705 assert!((n.cdf(0.0) - 0.5).abs() < TOL);
706 assert!((n.cdf(1.0) - 0.8413447).abs() < 1e-5);
707 assert!((n.cdf(-1.0) - 0.1586553).abs() < 1e-5);
708 assert!((n.cdf(2.0) - 0.9772499).abs() < 1e-5);
709 }
710
711 #[test]
712 fn normal_standard_pdf_at_zero() {
713 let n = Normal::standard();
714 let expected = 1.0 / (2.0 * PI).sqrt();
715 assert!((n.pdf(0.0) - expected).abs() < TOL);
716 }
717
718 #[test]
719 fn normal_invalid_sigma() {
720 assert!(Normal::new(0.0, 0.0).is_err());
721 assert!(Normal::new(0.0, -1.0).is_err());
722 }
723
724 #[test]
725 fn poisson_pmf() {
726 let p = Poisson::new(3.0).unwrap();
727 assert!((p.pdf(0.0) - (-3.0_f64).exp()).abs() < TOL);
729 let expected = 27.0 * (-3.0_f64).exp() / 6.0;
731 assert!((p.pdf(3.0) - expected).abs() < TOL);
732 }
733
734 #[test]
735 fn poisson_cdf() {
736 let p = Poisson::new(1.0).unwrap();
737 assert!((p.cdf(0.0) - (-1.0_f64).exp()).abs() < TOL);
739 assert!((p.cdf(1.0) - 2.0 * (-1.0_f64).exp()).abs() < TOL);
741 }
742
743 #[test]
744 fn poisson_invalid_lambda() {
745 assert!(Poisson::new(0.0).is_err());
746 assert!(Poisson::new(-1.0).is_err());
747 }
748
749 #[test]
752 fn gammainc_zero() {
753 assert_eq!(gammainc(1.0, 0.0).unwrap(), 0.0);
754 }
755
756 #[test]
757 fn gammainc_exponential() {
758 let x: f64 = 2.0;
760 let expected = 1.0 - (-x).exp();
761 assert!((gammainc(1.0, x).unwrap() - expected).abs() < 1e-8);
762 }
763
764 #[test]
765 fn gammainc_half_integer() {
766 let x: f64 = 1.0;
768 let expected = erf(x.sqrt());
769 assert!((gammainc(0.5, x).unwrap() - expected).abs() < 1e-6);
770 }
771
772 #[test]
773 fn gammainc_large_x() {
774 assert!((gammainc(2.0, 50.0).unwrap() - 1.0).abs() < 1e-10);
776 }
777
778 #[test]
779 fn gammainc_invalid() {
780 assert!(gammainc(-1.0, 1.0).is_err());
781 assert!(gammainc(1.0, -1.0).is_err());
782 }
783
784 #[test]
787 fn chi_squared_cdf_known_values() {
788 let chi2 = ChiSquared::new(2.0).unwrap();
789 let x = 5.991; let p = chi2.cdf(x);
792 assert!((p - 0.95).abs() < 0.01, "p={}", p);
793 }
794
795 #[test]
796 fn chi_squared_cdf_df1() {
797 let chi2 = ChiSquared::new(1.0).unwrap();
798 assert!((chi2.cdf(3.841) - 0.95).abs() < 0.01);
800 }
801
802 #[test]
803 fn chi_squared_mean_variance() {
804 let chi2 = ChiSquared::new(5.0).unwrap();
805 assert!((chi2.mean() - 5.0).abs() < TOL);
806 assert!((chi2.variance() - 10.0).abs() < TOL);
807 }
808
809 #[test]
810 fn chi_squared_cdf_at_zero() {
811 let chi2 = ChiSquared::new(3.0).unwrap();
812 assert_eq!(chi2.cdf(0.0), 0.0);
813 }
814
815 #[test]
816 fn chi_squared_invalid() {
817 assert!(ChiSquared::new(0.0).is_err());
818 assert!(ChiSquared::new(-1.0).is_err());
819 }
820
821 #[test]
824 fn f_dist_cdf_known() {
825 let f = FDistribution::new(5.0, 10.0).unwrap();
826 let p = f.cdf(3.326);
828 assert!((p - 0.95).abs() < 0.02, "p={}", p);
829 }
830
831 #[test]
832 fn f_dist_cdf_at_zero() {
833 let f = FDistribution::new(3.0, 5.0).unwrap();
834 assert_eq!(f.cdf(0.0), 0.0);
835 }
836
837 #[test]
838 fn f_dist_mean() {
839 let f = FDistribution::new(4.0, 8.0).unwrap();
840 assert!((f.mean() - 8.0 / 6.0).abs() < TOL);
842 }
843
844 #[test]
845 fn f_dist_invalid() {
846 assert!(FDistribution::new(0.0, 5.0).is_err());
847 assert!(FDistribution::new(5.0, 0.0).is_err());
848 }
849
850 #[test]
853 fn binomial_pmf() {
854 let b = Binomial::new(10, 0.5).unwrap();
855 assert!((b.pmf(5) - 0.24609375).abs() < 1e-6);
857 }
858
859 #[test]
860 fn binomial_pmf_sum() {
861 let b = Binomial::new(8, 0.3).unwrap();
862 let sum: f64 = (0..=8).map(|k| b.pmf(k)).sum();
863 assert!((sum - 1.0).abs() < 1e-8);
864 }
865
866 #[test]
867 fn binomial_cdf() {
868 let b = Binomial::new(10, 0.5).unwrap();
869 assert!((b.cdf(5.0) - 0.623047).abs() < 0.01);
871 }
872
873 #[test]
874 fn binomial_cdf_boundaries() {
875 let b = Binomial::new(5, 0.5).unwrap();
876 assert!(b.cdf(-1.0) == 0.0);
877 assert!((b.cdf(5.0) - 1.0).abs() < 1e-10);
878 }
879
880 #[test]
881 fn binomial_mean_variance() {
882 let b = Binomial::new(20, 0.3).unwrap();
883 assert!((b.mean() - 6.0).abs() < TOL);
884 assert!((b.variance() - 4.2).abs() < TOL);
885 }
886
887 #[test]
888 fn binomial_invalid() {
889 assert!(Binomial::new(10, -0.1).is_err());
890 assert!(Binomial::new(10, 1.1).is_err());
891 }
892
893 #[test]
896 fn nb_pmf_known_values() {
897 let nb = NegativeBinomial::new(3.0, 0.5).unwrap();
899 assert!((nb.pmf(0) - 0.125).abs() < TOL);
900 assert!((nb.pmf(1) - 0.1875).abs() < TOL);
902 }
903
904 #[test]
905 fn nb_pmf_sums_near_one() {
906 let nb = NegativeBinomial::new(5.0, 0.4).unwrap();
907 let sum: f64 = (0..200).map(|k| nb.pmf(k)).sum();
908 assert!((sum - 1.0).abs() < 1e-4, "sum={sum}");
909 }
910
911 #[test]
912 fn nb_cdf_consistency() {
913 let nb = NegativeBinomial::new(3.0, 0.6).unwrap();
915 for k in [0, 1, 5, 10] {
916 let cdf_val = nb.cdf(k as f64);
917 let pmf_sum: f64 = (0..=k).map(|j| nb.pmf(j)).sum();
918 assert!(
919 (cdf_val - pmf_sum).abs() < 1e-5,
920 "k={k}: cdf={cdf_val}, pmf_sum={pmf_sum}"
921 );
922 }
923 }
924
925 #[test]
926 fn nb_mean_variance() {
927 let nb = NegativeBinomial::new(4.0, 0.5).unwrap();
928 assert!((nb.mean() - 4.0).abs() < TOL);
930 assert!((nb.variance() - 8.0).abs() < TOL);
932 }
933
934 #[test]
935 fn nb_from_mean_dispersion_roundtrip() {
936 let mu = 10.0;
937 let alpha = 0.5;
938 let nb = NegativeBinomial::from_mean_dispersion(mu, alpha).unwrap();
939 assert!((nb.mean() - mu).abs() < 1e-10);
940 assert!((nb.variance() - 60.0).abs() < 1e-8);
942 }
943
944 #[test]
945 fn nb_deseq2_parameterization() {
946 let nb = NegativeBinomial::from_mean_dispersion(5.0, 1e-6).unwrap();
948 let ratio = nb.variance() / nb.mean();
949 assert!((ratio - 1.0).abs() < 1e-3, "ratio={ratio}");
950 }
951
952 #[test]
953 fn nb_invalid_params() {
954 assert!(NegativeBinomial::new(0.0, 0.5).is_err());
955 assert!(NegativeBinomial::new(-1.0, 0.5).is_err());
956 assert!(NegativeBinomial::new(3.0, 0.0).is_err());
957 assert!(NegativeBinomial::new(3.0, 1.1).is_err());
958 assert!(NegativeBinomial::from_mean_dispersion(0.0, 0.5).is_err());
959 assert!(NegativeBinomial::from_mean_dispersion(5.0, 0.0).is_err());
960 }
961
962 #[test]
963 fn nb_pdf_non_integer_is_zero() {
964 let nb = NegativeBinomial::new(3.0, 0.5).unwrap();
965 assert_eq!(nb.pdf(1.5), 0.0);
966 assert_eq!(nb.pdf(-1.0), 0.0);
967 }
968
969 #[test]
970 fn nb_cdf_negative_is_zero() {
971 let nb = NegativeBinomial::new(3.0, 0.5).unwrap();
972 assert_eq!(nb.cdf(-1.0), 0.0);
973 }
974}