1use crate::error::{BayesError, Result};
4use nalgebra::DVector;
5use rand_distr::Distribution as RandDistribution;
6use std::f64::consts::PI;
7
8const MAX_POISSON_SAMPLE_RATE: f64 = 1_000_000_000_000.0;
12
13pub trait Distribution {
15 fn log_pdf(&self, x: f64) -> f64;
17
18 fn pdf(&self, x: f64) -> f64 {
20 self.log_pdf(x).exp()
21 }
22}
23
24pub trait DiscreteDistribution {
30 fn log_pmf(&self, k: u64) -> f64;
32
33 fn pmf(&self, k: u64) -> f64 {
35 self.log_pmf(k).exp()
36 }
37
38 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u64;
40}
41
42pub trait MultivariateDistribution {
44 fn log_pdf(&self, x: &DVector<f64>) -> f64;
46
47 fn pdf(&self, x: &DVector<f64>) -> f64 {
49 self.log_pdf(x).exp()
50 }
51}
52
53#[derive(Debug, Clone, PartialEq)]
55pub struct Bernoulli {
56 p: f64,
57}
58
59impl Bernoulli {
60 pub fn new(p: f64) -> Result<Self> {
62 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
63 return Err(BayesError::invalid_parameter(
64 "probability must be finite and between 0 and 1",
65 ));
66 }
67
68 Ok(Self { p })
69 }
70
71 pub fn probability(&self) -> f64 {
73 self.p
74 }
75
76 pub fn mean(&self) -> f64 {
78 self.p
79 }
80
81 pub fn variance(&self) -> f64 {
83 self.p * (1.0 - self.p)
84 }
85}
86
87impl DiscreteDistribution for Bernoulli {
88 fn log_pmf(&self, k: u64) -> f64 {
89 match k {
90 0 => {
91 if self.p == 1.0 {
92 f64::NEG_INFINITY
93 } else {
94 (-self.p).ln_1p()
95 }
96 }
97 1 => {
98 if self.p == 0.0 {
99 f64::NEG_INFINITY
100 } else {
101 self.p.ln()
102 }
103 }
104 _ => f64::NEG_INFINITY,
105 }
106 }
107
108 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
109 if rng.gen_bool(self.p) {
110 1
111 } else {
112 0
113 }
114 }
115}
116
117#[derive(Debug, Clone, PartialEq)]
119pub struct Binomial {
120 n: u64,
121 p: f64,
122}
123
124impl Binomial {
125 pub fn new(n: u64, p: f64) -> Result<Self> {
127 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
128 return Err(BayesError::invalid_parameter(
129 "probability must be finite and between 0 and 1",
130 ));
131 }
132
133 Ok(Self { n, p })
134 }
135
136 pub fn trials(&self) -> u64 {
138 self.n
139 }
140
141 pub fn probability(&self) -> f64 {
143 self.p
144 }
145
146 pub fn mean(&self) -> f64 {
148 self.n as f64 * self.p
149 }
150
151 pub fn variance(&self) -> f64 {
153 self.n as f64 * self.p * (1.0 - self.p)
154 }
155}
156
157impl DiscreteDistribution for Binomial {
158 fn log_pmf(&self, k: u64) -> f64 {
159 if k > self.n {
160 return f64::NEG_INFINITY;
161 }
162
163 if self.p == 0.0 {
164 return if k == 0 { 0.0 } else { f64::NEG_INFINITY };
165 }
166 if self.p == 1.0 {
167 return if k == self.n { 0.0 } else { f64::NEG_INFINITY };
168 }
169
170 log_binomial_coefficient(self.n, k)
171 + k as f64 * self.p.ln()
172 + (self.n - k) as f64 * (-self.p).ln_1p()
173 }
174
175 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
176 if self.p == 0.0 {
177 return 0;
178 }
179 if self.p == 1.0 {
180 return self.n;
181 }
182
183 let binomial = rand_distr::Binomial::new(self.n, self.p)
184 .expect("validated binomial parameters should be accepted by rand_distr");
185 binomial.sample(rng)
186 }
187}
188
189#[derive(Debug, Clone, PartialEq)]
191pub struct Poisson {
192 lambda: f64,
193}
194
195impl Poisson {
196 pub fn new(lambda: f64) -> Result<Self> {
202 if !lambda.is_finite() || lambda <= 0.0 || lambda > MAX_POISSON_SAMPLE_RATE {
203 return Err(BayesError::invalid_parameter(format!(
204 "lambda must be finite, positive, and no greater than {MAX_POISSON_SAMPLE_RATE:e}",
205 )));
206 }
207
208 Ok(Self { lambda })
209 }
210
211 pub fn rate(&self) -> f64 {
213 self.lambda
214 }
215
216 pub fn mean(&self) -> f64 {
218 self.lambda
219 }
220
221 pub fn variance(&self) -> f64 {
223 self.lambda
224 }
225}
226
227impl DiscreteDistribution for Poisson {
228 fn log_pmf(&self, k: u64) -> f64 {
229 k as f64 * self.lambda.ln() - self.lambda - gamma_ln(k as f64 + 1.0)
230 }
231
232 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
233 let poisson = rand_distr::Poisson::new(self.lambda)
234 .expect("validated Poisson rate should be accepted by rand_distr");
235 poisson.sample(rng) as u64
237 }
238}
239
240#[derive(Debug, Clone, PartialEq)]
245pub struct Categorical {
246 probabilities: Vec<f64>,
247 cumulative_probabilities: Vec<f64>,
248}
249
250impl Categorical {
251 pub fn new(weights: Vec<f64>) -> Result<Self> {
256 if weights.is_empty() {
257 return Err(BayesError::invalid_parameter(
258 "categorical weights must not be empty",
259 ));
260 }
261
262 if weights.iter().any(|&w| !w.is_finite() || w < 0.0) {
263 return Err(BayesError::invalid_parameter(
264 "categorical weights must be finite and non-negative",
265 ));
266 }
267
268 let total_mass: f64 = weights.iter().sum();
269 if !total_mass.is_finite() || total_mass <= 0.0 {
270 return Err(BayesError::invalid_parameter(
271 "categorical weights must have positive finite total mass",
272 ));
273 }
274
275 let probabilities: Vec<f64> = weights.into_iter().map(|w| w / total_mass).collect();
276 let mut running_total = 0.0;
277 let mut cumulative_probabilities: Vec<f64> = probabilities
278 .iter()
279 .map(|&p| {
280 running_total += p;
281 running_total
282 })
283 .collect();
284
285 if let Some(last_positive) = probabilities.iter().rposition(|&p| p > 0.0) {
289 for cumulative_probability in &mut cumulative_probabilities[last_positive..] {
290 *cumulative_probability = 1.0;
291 }
292 }
293
294 Ok(Self {
295 probabilities,
296 cumulative_probabilities,
297 })
298 }
299
300 pub fn probabilities(&self) -> &[f64] {
302 &self.probabilities
303 }
304
305 pub fn category_count(&self) -> usize {
307 self.probabilities.len()
308 }
309
310 pub fn mean(&self) -> f64 {
312 self.probabilities
313 .iter()
314 .enumerate()
315 .map(|(category, &p)| category as f64 * p)
316 .sum()
317 }
318
319 pub fn variance(&self) -> f64 {
321 let mean = self.mean();
322 self.probabilities
323 .iter()
324 .enumerate()
325 .map(|(category, &p)| {
326 let diff = category as f64 - mean;
327 diff * diff * p
328 })
329 .sum()
330 }
331}
332
333impl DiscreteDistribution for Categorical {
334 fn log_pmf(&self, k: u64) -> f64 {
335 let Ok(index) = usize::try_from(k) else {
336 return f64::NEG_INFINITY;
337 };
338 let Some(&probability) = self.probabilities.get(index) else {
339 return f64::NEG_INFINITY;
340 };
341
342 if probability == 0.0 {
343 f64::NEG_INFINITY
344 } else {
345 probability.ln()
346 }
347 }
348
349 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
350 let u: f64 = rng.gen();
351 self.cumulative_probabilities
352 .iter()
353 .position(|&cdf| u < cdf)
354 .unwrap_or(self.probabilities.len() - 1) as u64
355 }
356}
357
358#[derive(Debug, Clone, PartialEq)]
360pub struct Normal {
361 mu: f64,
362 sigma: f64,
363 log_sigma: f64,
364 inv_sigma: f64,
365}
366
367impl Normal {
368 pub fn new(mu: f64, sigma: f64) -> Result<Self> {
370 if sigma <= 0.0 {
371 return Err(BayesError::invalid_parameter("sigma must be positive"));
372 }
373 if !mu.is_finite() || !sigma.is_finite() {
374 return Err(BayesError::invalid_parameter("parameters must be finite"));
375 }
376
377 Ok(Self {
378 mu,
379 sigma,
380 log_sigma: sigma.ln(),
381 inv_sigma: 1.0 / sigma,
382 })
383 }
384
385 pub fn mean(&self) -> f64 {
387 self.mu
388 }
389
390 pub fn std_dev(&self) -> f64 {
392 self.sigma
393 }
394
395 pub fn variance(&self) -> f64 {
397 self.sigma * self.sigma
398 }
399}
400
401impl Distribution for Normal {
402 fn log_pdf(&self, x: f64) -> f64 {
403 if !x.is_finite() {
404 return f64::NEG_INFINITY;
405 }
406
407 let diff = x - self.mu;
408 -0.5 * (2.0 * PI).ln()
409 - self.log_sigma
410 - 0.5 * diff * diff * self.inv_sigma * self.inv_sigma
411 }
412}
413
414#[derive(Debug, Clone)]
416pub struct MultivariateNormal {
417 mu: DVector<f64>,
418 precision: nalgebra::DMatrix<f64>, log_det_precision: f64,
420 dim: usize,
421}
422
423impl MultivariateNormal {
424 pub fn new(mu: DVector<f64>, covariance: nalgebra::DMatrix<f64>) -> Result<Self> {
426 if mu.len() != covariance.nrows() || covariance.nrows() != covariance.ncols() {
427 return Err(BayesError::dimension_mismatch(mu.len(), covariance.nrows()));
428 }
429
430 let chol = covariance.clone().cholesky().ok_or_else(|| {
431 BayesError::numerical_error("Covariance matrix is not positive definite")
432 })?;
433
434 let precision = chol.inverse();
435 let log_det_precision = -2.0 * chol.l().diagonal().iter().map(|x| x.ln()).sum::<f64>();
436
437 Ok(Self {
438 dim: mu.len(),
439 mu,
440 precision,
441 log_det_precision,
442 })
443 }
444
445 pub fn new_diagonal(mu: DVector<f64>, variances: DVector<f64>) -> Result<Self> {
447 if mu.len() != variances.len() {
448 return Err(BayesError::dimension_mismatch(mu.len(), variances.len()));
449 }
450
451 if variances.iter().any(|&v| v <= 0.0) {
452 return Err(BayesError::invalid_parameter(
453 "All variances must be positive",
454 ));
455 }
456
457 let dim = mu.len();
458 let mut covariance = nalgebra::DMatrix::zeros(dim, dim);
459 for i in 0..dim {
460 covariance[(i, i)] = variances[i];
461 }
462
463 Self::new(mu, covariance)
464 }
465
466 pub fn mean(&self) -> &DVector<f64> {
468 &self.mu
469 }
470
471 pub fn dimension(&self) -> usize {
473 self.dim
474 }
475}
476
477impl MultivariateDistribution for MultivariateNormal {
478 fn log_pdf(&self, x: &DVector<f64>) -> f64 {
479 if x.len() != self.dim {
480 return f64::NEG_INFINITY;
481 }
482
483 if !x.iter().all(|&val| val.is_finite()) {
484 return f64::NEG_INFINITY;
485 }
486
487 let diff = x - &self.mu;
488 let quadratic_form = diff.dot(&(self.precision.clone() * &diff));
489
490 -0.5 * (self.dim as f64 * (2.0 * PI).ln() - self.log_det_precision + quadratic_form)
491 }
492}
493
494#[derive(Debug, Clone, PartialEq)]
496pub struct Gamma {
497 alpha: f64,
498 beta: f64,
499 log_gamma_alpha: f64,
500}
501
502impl Gamma {
503 pub fn new(alpha: f64, beta: f64) -> Result<Self> {
505 if alpha <= 0.0 || beta <= 0.0 {
506 return Err(BayesError::invalid_parameter(
507 "alpha and beta must be positive",
508 ));
509 }
510 if !alpha.is_finite() || !beta.is_finite() {
511 return Err(BayesError::invalid_parameter("parameters must be finite"));
512 }
513
514 Ok(Self {
515 alpha,
516 beta,
517 log_gamma_alpha: gamma_ln(alpha),
518 })
519 }
520
521 pub fn shape(&self) -> f64 {
523 self.alpha
524 }
525
526 pub fn rate(&self) -> f64 {
528 self.beta
529 }
530
531 pub fn mean(&self) -> f64 {
533 self.alpha / self.beta
534 }
535
536 pub fn variance(&self) -> f64 {
538 self.alpha / (self.beta * self.beta)
539 }
540}
541
542impl Distribution for Gamma {
543 fn log_pdf(&self, x: f64) -> f64 {
544 if x <= 0.0 || !x.is_finite() {
545 return f64::NEG_INFINITY;
546 }
547
548 (self.alpha - 1.0) * x.ln() - self.beta * x + self.alpha * self.beta.ln()
549 - self.log_gamma_alpha
550 }
551}
552
553#[derive(Debug, Clone, PartialEq)]
555pub struct Beta {
556 alpha: f64,
557 beta: f64,
558 log_beta_function: f64,
559}
560
561impl Beta {
562 pub fn new(alpha: f64, beta: f64) -> Result<Self> {
564 if alpha <= 0.0 || beta <= 0.0 {
565 return Err(BayesError::invalid_parameter(
566 "alpha and beta must be positive",
567 ));
568 }
569 if !alpha.is_finite() || !beta.is_finite() {
570 return Err(BayesError::invalid_parameter("parameters must be finite"));
571 }
572
573 let log_beta_function = gamma_ln(alpha) + gamma_ln(beta) - gamma_ln(alpha + beta);
574
575 Ok(Self {
576 alpha,
577 beta,
578 log_beta_function,
579 })
580 }
581
582 pub fn alpha(&self) -> f64 {
584 self.alpha
585 }
586
587 pub fn beta(&self) -> f64 {
589 self.beta
590 }
591
592 pub fn mean(&self) -> f64 {
594 self.alpha / (self.alpha + self.beta)
595 }
596
597 pub fn variance(&self) -> f64 {
599 let ab = self.alpha + self.beta;
600 (self.alpha * self.beta) / (ab * ab * (ab + 1.0))
601 }
602}
603
604impl Distribution for Beta {
605 fn log_pdf(&self, x: f64) -> f64 {
606 if x <= 0.0 || x >= 1.0 || !x.is_finite() {
607 return f64::NEG_INFINITY;
608 }
609
610 (self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln() - self.log_beta_function
611 }
612}
613
614#[derive(Debug, Clone, PartialEq)]
616pub struct Exponential {
617 rate: f64,
618 log_rate: f64,
619}
620
621impl Exponential {
622 pub fn new(rate: f64) -> Result<Self> {
624 if rate <= 0.0 {
625 return Err(BayesError::invalid_parameter("rate must be positive"));
626 }
627 if !rate.is_finite() {
628 return Err(BayesError::invalid_parameter("rate must be finite"));
629 }
630
631 Ok(Self {
632 rate,
633 log_rate: rate.ln(),
634 })
635 }
636
637 pub fn rate(&self) -> f64 {
639 self.rate
640 }
641
642 pub fn mean(&self) -> f64 {
644 1.0 / self.rate
645 }
646
647 pub fn variance(&self) -> f64 {
649 1.0 / (self.rate * self.rate)
650 }
651}
652
653impl Distribution for Exponential {
654 fn log_pdf(&self, x: f64) -> f64 {
655 if x < 0.0 || !x.is_finite() {
656 return f64::NEG_INFINITY;
657 }
658
659 self.log_rate - self.rate * x
660 }
661}
662
663#[derive(Debug, Clone, PartialEq)]
665pub struct Uniform {
666 a: f64,
667 b: f64,
668 log_density: f64,
669}
670
671impl Uniform {
672 pub fn new(a: f64, b: f64) -> Result<Self> {
674 if a >= b {
675 return Err(BayesError::invalid_parameter("a must be less than b"));
676 }
677 if !a.is_finite() || !b.is_finite() {
678 return Err(BayesError::invalid_parameter("parameters must be finite"));
679 }
680
681 Ok(Self {
682 a,
683 b,
684 log_density: -(b - a).ln(),
685 })
686 }
687
688 pub fn lower_bound(&self) -> f64 {
690 self.a
691 }
692
693 pub fn upper_bound(&self) -> f64 {
695 self.b
696 }
697
698 pub fn mean(&self) -> f64 {
700 (self.a + self.b) / 2.0
701 }
702
703 pub fn variance(&self) -> f64 {
705 (self.b - self.a).powi(2) / 12.0
706 }
707}
708
709impl Distribution for Uniform {
710 fn log_pdf(&self, x: f64) -> f64 {
711 if x < self.a || x > self.b || !x.is_finite() {
712 return f64::NEG_INFINITY;
713 }
714
715 self.log_density
716 }
717}
718
719#[derive(Debug, Clone, PartialEq)]
721pub struct StudentT {
722 nu: f64,
723 mu: f64,
724 sigma: f64,
725 log_normalizer: f64,
726}
727
728impl StudentT {
729 pub fn new(nu: f64, mu: f64, sigma: f64) -> Result<Self> {
731 if nu <= 0.0 || sigma <= 0.0 {
732 return Err(BayesError::invalid_parameter(
733 "nu and sigma must be positive",
734 ));
735 }
736 if !nu.is_finite() || !mu.is_finite() || !sigma.is_finite() {
737 return Err(BayesError::invalid_parameter("parameters must be finite"));
738 }
739
740 let log_normalizer =
741 gamma_ln((nu + 1.0) / 2.0) - gamma_ln(nu / 2.0) - 0.5 * (nu * PI).ln() - sigma.ln();
742
743 Ok(Self {
744 nu,
745 mu,
746 sigma,
747 log_normalizer,
748 })
749 }
750
751 pub fn degrees_of_freedom(&self) -> f64 {
753 self.nu
754 }
755
756 pub fn location(&self) -> f64 {
758 self.mu
759 }
760
761 pub fn scale(&self) -> f64 {
763 self.sigma
764 }
765
766 pub fn mean(&self) -> Option<f64> {
768 if self.nu > 1.0 {
769 Some(self.mu)
770 } else {
771 None
772 }
773 }
774
775 pub fn variance(&self) -> Option<f64> {
777 if self.nu > 2.0 {
778 Some(self.sigma * self.sigma * self.nu / (self.nu - 2.0))
779 } else {
780 None
781 }
782 }
783}
784
785impl Distribution for StudentT {
786 fn log_pdf(&self, x: f64) -> f64 {
787 if !x.is_finite() {
788 return f64::NEG_INFINITY;
789 }
790
791 let standardized = (x - self.mu) / self.sigma;
792 self.log_normalizer
793 - 0.5 * (self.nu + 1.0) * (1.0 + standardized * standardized / self.nu).ln()
794 }
795}
796
797fn log_binomial_coefficient(n: u64, k: u64) -> f64 {
799 if k > n {
800 return f64::NEG_INFINITY;
801 }
802
803 if k == 0 || k == n {
804 return 0.0;
805 }
806
807 gamma_ln(n as f64 + 1.0) - gamma_ln(k as f64 + 1.0) - gamma_ln((n - k) as f64 + 1.0)
808}
809
810fn gamma_ln(x: f64) -> f64 {
815 if x <= 0.0 || !x.is_finite() {
816 return f64::NEG_INFINITY;
817 }
818
819 const LANCZOS_G: f64 = 7.0;
821 const LANCZOS_COEFFS: [f64; 9] = [
822 0.999_999_999_999_809_9,
823 676.520_368_121_885_1,
824 -1_259.139_216_722_402_8,
825 771.323_428_777_653_1,
826 -176.615_029_162_140_6,
827 12.507_343_278_686_905,
828 -0.138_571_095_265_720_12,
829 9.984_369_578_019_572e-6,
830 1.505_632_735_149_311_6e-7,
831 ];
832
833 if x < 0.5 {
834 return PI.ln() - (PI * x).sin().ln() - gamma_ln(1.0 - x);
835 }
836
837 let z = x - 1.0;
838 let mut a = LANCZOS_COEFFS[0];
839 for (i, coeff) in LANCZOS_COEFFS.iter().enumerate().skip(1) {
840 a += coeff / (z + i as f64);
841 }
842 let t = z + LANCZOS_G + 0.5;
843
844 0.5 * (2.0 * PI).ln() + (z + 0.5) * t.ln() - t + a.ln()
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use approx::assert_abs_diff_eq;
851 use rand::SeedableRng;
852
853 #[test]
854 fn test_bernoulli_distribution() {
855 let bernoulli = Bernoulli::new(0.25).unwrap();
856 assert_eq!(bernoulli.probability(), 0.25);
857 assert_eq!(bernoulli.mean(), 0.25);
858 assert_abs_diff_eq!(bernoulli.variance(), 0.1875, epsilon = 1e-12);
859
860 assert_abs_diff_eq!(bernoulli.pmf(0), 0.75, epsilon = 1e-12);
861 assert_abs_diff_eq!(bernoulli.pmf(1), 0.25, epsilon = 1e-12);
862 assert_eq!(bernoulli.pmf(2), 0.0);
863 assert_abs_diff_eq!(bernoulli.log_pmf(1), 0.25_f64.ln(), epsilon = 1e-12);
864
865 let rare_success = Bernoulli::new(1.0e-12).unwrap();
866 assert_abs_diff_eq!(
867 rare_success.log_pmf(0),
868 (-1.0e-12_f64).ln_1p(),
869 epsilon = 1e-24
870 );
871 }
872
873 #[test]
874 fn test_bernoulli_edge_cases() {
875 let always_zero = Bernoulli::new(0.0).unwrap();
876 assert_eq!(always_zero.pmf(0), 1.0);
877 assert_eq!(always_zero.pmf(1), 0.0);
878 assert_eq!(always_zero.log_pmf(1), f64::NEG_INFINITY);
879
880 let always_one = Bernoulli::new(1.0).unwrap();
881 assert_eq!(always_one.pmf(0), 0.0);
882 assert_eq!(always_one.pmf(1), 1.0);
883 assert_eq!(always_one.log_pmf(0), f64::NEG_INFINITY);
884 }
885
886 #[test]
887 fn test_bernoulli_invalid_params() {
888 assert!(Bernoulli::new(-0.1).is_err());
889 assert!(Bernoulli::new(1.1).is_err());
890 assert!(Bernoulli::new(f64::NAN).is_err());
891 assert!(Bernoulli::new(f64::INFINITY).is_err());
892 }
893
894 #[test]
895 fn test_bernoulli_sampling_seeded() {
896 let bernoulli = Bernoulli::new(0.5).unwrap();
897 let mut rng_a = rand::rngs::StdRng::seed_from_u64(42);
898 let mut rng_b = rand::rngs::StdRng::seed_from_u64(42);
899 let samples_a: Vec<_> = (0..16).map(|_| bernoulli.sample(&mut rng_a)).collect();
900 let samples_b: Vec<_> = (0..16).map(|_| bernoulli.sample(&mut rng_b)).collect();
901
902 assert_eq!(samples_a, samples_b);
903 assert!(samples_a.iter().all(|&x| x <= 1));
904 }
905
906 #[test]
907 fn test_binomial_distribution() {
908 let binomial = Binomial::new(10, 0.5).unwrap();
909 assert_eq!(binomial.trials(), 10);
910 assert_eq!(binomial.probability(), 0.5);
911 assert_abs_diff_eq!(binomial.mean(), 5.0, epsilon = 1e-12);
912 assert_abs_diff_eq!(binomial.variance(), 2.5, epsilon = 1e-12);
913
914 assert_abs_diff_eq!(binomial.pmf(0), 1.0 / 1024.0, epsilon = 1e-12);
915 assert_abs_diff_eq!(binomial.pmf(5), 252.0 / 1024.0, epsilon = 1e-12);
916 assert_eq!(binomial.pmf(11), 0.0);
917 assert_eq!(binomial.log_pmf(11), f64::NEG_INFINITY);
918 assert_abs_diff_eq!(
919 binomial.log_pmf(5),
920 (252.0_f64 / 1024.0).ln(),
921 epsilon = 1e-12
922 );
923 }
924
925 #[test]
926 fn test_binomial_edge_cases() {
927 let zero_prob = Binomial::new(4, 0.0).unwrap();
928 assert_eq!(zero_prob.pmf(0), 1.0);
929 assert_eq!(zero_prob.pmf(1), 0.0);
930 assert_eq!(zero_prob.sample(&mut rand::thread_rng()), 0);
931
932 let one_prob = Binomial::new(4, 1.0).unwrap();
933 assert_eq!(one_prob.pmf(3), 0.0);
934 assert_eq!(one_prob.pmf(4), 1.0);
935 assert_eq!(one_prob.sample(&mut rand::thread_rng()), 4);
936
937 let zero_trials = Binomial::new(0, 0.75).unwrap();
938 assert_eq!(zero_trials.pmf(0), 1.0);
939 assert_eq!(zero_trials.pmf(1), 0.0);
940 assert_eq!(zero_trials.sample(&mut rand::thread_rng()), 0);
941 }
942
943 #[test]
944 fn test_binomial_invalid_params() {
945 assert!(Binomial::new(10, -0.1).is_err());
946 assert!(Binomial::new(10, 1.1).is_err());
947 assert!(Binomial::new(10, f64::NAN).is_err());
948 assert!(Binomial::new(10, f64::INFINITY).is_err());
949 }
950
951 #[test]
952 fn test_binomial_sampling_seeded() {
953 let binomial = Binomial::new(20, 0.25).unwrap();
954 let mut rng_a = rand::rngs::StdRng::seed_from_u64(7);
955 let mut rng_b = rand::rngs::StdRng::seed_from_u64(7);
956 let samples_a: Vec<_> = (0..8).map(|_| binomial.sample(&mut rng_a)).collect();
957 let samples_b: Vec<_> = (0..8).map(|_| binomial.sample(&mut rng_b)).collect();
958
959 assert_eq!(samples_a, samples_b);
960 assert!(samples_a.iter().all(|&x| x <= binomial.trials()));
961
962 let large_n = Binomial::new(1_000_000, 0.0001).unwrap();
963 let mut rng = rand::rngs::StdRng::seed_from_u64(11);
964 assert!(large_n.sample(&mut rng) <= large_n.trials());
965 }
966
967 #[test]
968 fn test_poisson_distribution() {
969 let poisson = Poisson::new(3.0).unwrap();
970 assert_eq!(poisson.rate(), 3.0);
971 assert_eq!(poisson.mean(), 3.0);
972 assert_eq!(poisson.variance(), 3.0);
973
974 assert_abs_diff_eq!(poisson.pmf(0), (-3.0_f64).exp(), epsilon = 1e-12);
975 assert_abs_diff_eq!(poisson.log_pmf(0), -3.0, epsilon = 1e-12);
976 assert_abs_diff_eq!(poisson.pmf(2), 4.5 * (-3.0_f64).exp(), epsilon = 1e-12);
977 assert_abs_diff_eq!(
978 poisson.log_pmf(2),
979 (4.5_f64 * (-3.0_f64).exp()).ln(),
980 epsilon = 1e-12
981 );
982 }
983
984 #[test]
985 fn test_poisson_invalid_params() {
986 assert!(Poisson::new(0.0).is_err());
987 assert!(Poisson::new(-1.0).is_err());
988 assert!(Poisson::new(f64::NAN).is_err());
989 assert!(Poisson::new(f64::INFINITY).is_err());
990 assert!(Poisson::new(MAX_POISSON_SAMPLE_RATE).is_ok());
991 assert!(Poisson::new(MAX_POISSON_SAMPLE_RATE * 2.0).is_err());
992 }
993
994 #[test]
995 fn test_poisson_sampling_seeded() {
996 let poisson = Poisson::new(4.0).unwrap();
997 let mut rng_a = rand::rngs::StdRng::seed_from_u64(99);
998 let mut rng_b = rand::rngs::StdRng::seed_from_u64(99);
999 let samples_a: Vec<_> = (0..8).map(|_| poisson.sample(&mut rng_a)).collect();
1000 let samples_b: Vec<_> = (0..8).map(|_| poisson.sample(&mut rng_b)).collect();
1001
1002 assert_eq!(samples_a, samples_b);
1003 assert!(samples_a.iter().all(|&x| x < 100));
1004 }
1005
1006 #[test]
1007 fn test_categorical_distribution() {
1008 let categorical = Categorical::new(vec![1.0, 2.0, 3.0]).unwrap();
1009 assert_eq!(categorical.category_count(), 3);
1010 assert_abs_diff_eq!(categorical.probabilities()[0], 1.0 / 6.0, epsilon = 1e-12);
1011 assert_abs_diff_eq!(categorical.probabilities()[1], 2.0 / 6.0, epsilon = 1e-12);
1012 assert_abs_diff_eq!(categorical.probabilities()[2], 3.0 / 6.0, epsilon = 1e-12);
1013 assert_abs_diff_eq!(categorical.mean(), 4.0 / 3.0, epsilon = 1e-12);
1014 assert_abs_diff_eq!(categorical.variance(), 5.0 / 9.0, epsilon = 1e-12);
1015
1016 assert_abs_diff_eq!(categorical.pmf(0), 1.0 / 6.0, epsilon = 1e-12);
1017 assert_abs_diff_eq!(categorical.pmf(1), 2.0 / 6.0, epsilon = 1e-12);
1018 assert_abs_diff_eq!(categorical.pmf(2), 3.0 / 6.0, epsilon = 1e-12);
1019 assert_eq!(categorical.pmf(3), 0.0);
1020 assert_eq!(categorical.log_pmf(3), f64::NEG_INFINITY);
1021 assert_abs_diff_eq!(categorical.log_pmf(2), (0.5_f64).ln(), epsilon = 1e-12);
1022 }
1023
1024 #[test]
1025 fn test_categorical_zero_probability_categories() {
1026 let categorical = Categorical::new(vec![0.0, 5.0, 0.0]).unwrap();
1027 assert_eq!(categorical.probabilities(), &[0.0, 1.0, 0.0]);
1028 assert_eq!(categorical.pmf(0), 0.0);
1029 assert_eq!(categorical.pmf(1), 1.0);
1030 assert_eq!(categorical.pmf(2), 0.0);
1031 assert_eq!(categorical.log_pmf(0), f64::NEG_INFINITY);
1032 assert_eq!(categorical.log_pmf(2), f64::NEG_INFINITY);
1033 assert_eq!(categorical.mean(), 1.0);
1034 assert_eq!(categorical.variance(), 0.0);
1035 assert_eq!(categorical.sample(&mut rand::thread_rng()), 1);
1036 }
1037
1038 #[test]
1039 fn test_categorical_trailing_zero_weight_never_samples_zero_mass_category() {
1040 let categorical = Categorical::new(vec![1.0, 0.1, 0.0]).unwrap();
1041 let mut rng = rand::rngs::StdRng::seed_from_u64(321);
1042 let samples: Vec<_> = (0..1_000).map(|_| categorical.sample(&mut rng)).collect();
1043
1044 assert!(samples.iter().all(|&x| x < 2));
1045 assert_eq!(categorical.pmf(2), 0.0);
1046 assert_eq!(categorical.log_pmf(2), f64::NEG_INFINITY);
1047 }
1048
1049 #[test]
1050 fn test_categorical_invalid_params() {
1051 assert!(Categorical::new(vec![]).is_err());
1052 assert!(Categorical::new(vec![0.0, 0.0]).is_err());
1053 assert!(Categorical::new(vec![1.0, -0.1]).is_err());
1054 assert!(Categorical::new(vec![1.0, f64::NAN]).is_err());
1055 assert!(Categorical::new(vec![1.0, f64::INFINITY]).is_err());
1056 assert!(Categorical::new(vec![f64::MAX, f64::MAX]).is_err());
1057 }
1058
1059 #[test]
1060 fn test_categorical_sampling_seeded() {
1061 let categorical = Categorical::new(vec![0.2, 0.3, 0.5]).unwrap();
1062 let mut rng_a = rand::rngs::StdRng::seed_from_u64(123);
1063 let mut rng_b = rand::rngs::StdRng::seed_from_u64(123);
1064 let samples_a: Vec<_> = (0..16).map(|_| categorical.sample(&mut rng_a)).collect();
1065 let samples_b: Vec<_> = (0..16).map(|_| categorical.sample(&mut rng_b)).collect();
1066
1067 assert_eq!(samples_a, samples_b);
1068 assert!(samples_a
1069 .iter()
1070 .all(|&x| x < categorical.category_count() as u64));
1071 assert!(samples_a.contains(&0));
1072 assert!(samples_a.contains(&1));
1073 assert!(samples_a.contains(&2));
1074 }
1075
1076 #[test]
1077 fn test_gamma_ln_known_values() {
1078 assert_abs_diff_eq!(gamma_ln(0.5), 0.5 * PI.ln(), epsilon = 1e-12);
1079 assert_abs_diff_eq!(gamma_ln(1.0e-8), 18.42068073818021, epsilon = 1e-12);
1080 assert_abs_diff_eq!(gamma_ln(1.0), 0.0, epsilon = 1e-12);
1081 assert_abs_diff_eq!(gamma_ln(5.0), 24.0_f64.ln(), epsilon = 1e-12);
1082 assert_abs_diff_eq!(gamma_ln(10.0), 362_880.0_f64.ln(), epsilon = 1e-10);
1083 }
1084
1085 #[test]
1086 fn test_gamma_ln_regression_absolute_log_probabilities() {
1087 let exponential_gamma = Gamma::new(1.0, 1.0).unwrap();
1093 assert_abs_diff_eq!(exponential_gamma.log_pdf(1.0), -1.0, epsilon = 1e-12);
1094
1095 let non_integer_gamma = Gamma::new(2.5, 1.5).unwrap();
1096 assert_abs_diff_eq!(
1097 non_integer_gamma.log_pdf(1.2),
1098 -0.797_537_765_011_576_5,
1099 epsilon = 1e-12
1100 );
1101
1102 let uniform_beta = Beta::new(1.0, 1.0).unwrap();
1103 assert_abs_diff_eq!(uniform_beta.log_pdf(0.5), 0.0, epsilon = 1e-12);
1104
1105 let fractional_beta = Beta::new(2.5, 3.5).unwrap();
1106 assert_abs_diff_eq!(
1107 fractional_beta.log_pdf(0.4),
1108 0.650_335_112_735_843_9,
1109 epsilon = 1e-12
1110 );
1111
1112 let cauchy = StudentT::new(1.0, 0.0, 1.0).unwrap();
1113 assert_abs_diff_eq!(cauchy.log_pdf(0.0), -PI.ln(), epsilon = 1e-12);
1114
1115 let scaled_student_t = StudentT::new(5.0, 0.5, 1.25).unwrap();
1116 assert_abs_diff_eq!(
1117 scaled_student_t.log_pdf(1.75),
1118 -1.738_727_810_750_798_4,
1119 epsilon = 1e-12
1120 );
1121
1122 let poisson = Poisson::new(3.0).unwrap();
1123 assert_abs_diff_eq!(poisson.log_pmf(0), -3.0, epsilon = 1e-12);
1124 assert_abs_diff_eq!(
1125 poisson.log_pmf(4),
1126 -1.783_604_675_675_505_7,
1127 epsilon = 1e-12
1128 );
1129
1130 let fair_coin = Binomial::new(4, 0.5).unwrap();
1131 assert_abs_diff_eq!(fair_coin.log_pmf(2), (6.0_f64 / 16.0).ln(), epsilon = 1e-12);
1132 assert_abs_diff_eq!(
1133 log_binomial_coefficient(5, 2),
1134 10.0_f64.ln(),
1135 epsilon = 1e-12
1136 );
1137 assert_abs_diff_eq!(
1138 log_binomial_coefficient(10, 3),
1139 120.0_f64.ln(),
1140 epsilon = 1e-12
1141 );
1142 }
1143
1144 #[test]
1145 fn test_normal_creation() {
1146 let normal = Normal::new(0.0, 1.0).unwrap();
1147 assert_eq!(normal.mean(), 0.0);
1148 assert_eq!(normal.std_dev(), 1.0);
1149 assert_eq!(normal.variance(), 1.0);
1150 }
1151
1152 #[test]
1153 fn test_normal_invalid_params() {
1154 assert!(Normal::new(0.0, 0.0).is_err());
1155 assert!(Normal::new(0.0, -1.0).is_err());
1156 assert!(Normal::new(f64::NAN, 1.0).is_err());
1157 assert!(Normal::new(0.0, f64::INFINITY).is_err());
1158 }
1159
1160 #[test]
1161 fn test_normal_pdf() {
1162 let normal = Normal::new(0.0, 1.0).unwrap();
1163
1164 let pdf_at_mean = normal.pdf(0.0);
1166 assert_abs_diff_eq!(pdf_at_mean, 1.0 / (2.0 * PI).sqrt(), epsilon = 1e-10);
1167
1168 assert_abs_diff_eq!(normal.pdf(1.0), normal.pdf(-1.0), epsilon = 1e-10);
1170
1171 assert_abs_diff_eq!(normal.log_pdf(0.0), pdf_at_mean.ln(), epsilon = 1e-10);
1173 }
1174
1175 #[test]
1176 fn test_multivariate_normal() {
1177 let mu = DVector::from_vec(vec![0.0, 0.0]);
1178 let cov = nalgebra::DMatrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]);
1179
1180 let mvn = MultivariateNormal::new(mu, cov).unwrap();
1181 assert_eq!(mvn.dimension(), 2);
1182
1183 let x = DVector::from_vec(vec![0.0, 0.0]);
1184 let log_pdf = mvn.log_pdf(&x);
1185 assert!(log_pdf.is_finite());
1186 }
1187
1188 #[test]
1189 fn test_gamma_distribution() {
1190 let gamma = Gamma::new(2.0, 1.0).unwrap();
1191 assert_eq!(gamma.shape(), 2.0);
1192 assert_eq!(gamma.rate(), 1.0);
1193 assert_eq!(gamma.mean(), 2.0);
1194 assert_eq!(gamma.variance(), 2.0);
1195
1196 assert!(gamma.pdf(1.0) > 0.0);
1198 assert!(gamma.pdf(1.0).is_finite());
1199 assert_abs_diff_eq!(gamma.log_pdf(1.0), -1.0, epsilon = 1e-12);
1200
1201 assert_eq!(gamma.pdf(0.0), 0.0);
1203 assert_eq!(gamma.pdf(-1.0), 0.0);
1204 }
1205
1206 #[test]
1207 fn test_gamma_invalid_params() {
1208 assert!(Gamma::new(0.0, 1.0).is_err());
1209 assert!(Gamma::new(1.0, 0.0).is_err());
1210 assert!(Gamma::new(-1.0, 1.0).is_err());
1211 assert!(Gamma::new(1.0, -1.0).is_err());
1212 }
1213
1214 #[test]
1215 fn test_beta_distribution() {
1216 let beta = Beta::new(2.0, 3.0).unwrap();
1217 assert_eq!(beta.alpha(), 2.0);
1218 assert_eq!(beta.beta(), 3.0);
1219 assert_abs_diff_eq!(beta.mean(), 2.0 / 5.0, epsilon = 1e-10);
1220 assert_abs_diff_eq!(beta.variance(), 6.0 / 150.0, epsilon = 1e-10);
1221
1222 assert!(beta.pdf(0.5) > 0.0);
1224 assert!(beta.pdf(0.5).is_finite());
1225 assert_abs_diff_eq!(beta.pdf(0.5), 1.5, epsilon = 1e-12);
1226
1227 assert_eq!(beta.pdf(0.0), 0.0);
1229 assert_eq!(beta.pdf(1.0), 0.0);
1230 assert_eq!(beta.pdf(-0.1), 0.0);
1231 assert_eq!(beta.pdf(1.1), 0.0);
1232 }
1233
1234 #[test]
1235 fn test_beta_invalid_params() {
1236 assert!(Beta::new(0.0, 1.0).is_err());
1237 assert!(Beta::new(1.0, 0.0).is_err());
1238 assert!(Beta::new(-1.0, 1.0).is_err());
1239 assert!(Beta::new(1.0, -1.0).is_err());
1240 }
1241
1242 #[test]
1243 fn test_exponential_distribution() {
1244 let exp = Exponential::new(2.0).unwrap();
1245 assert_eq!(exp.rate(), 2.0);
1246 assert_abs_diff_eq!(exp.mean(), 0.5, epsilon = 1e-10);
1247 assert_abs_diff_eq!(exp.variance(), 0.25, epsilon = 1e-10);
1248
1249 assert!(exp.pdf(1.0) > 0.0);
1251 assert!(exp.pdf(1.0).is_finite());
1252
1253 assert_eq!(exp.pdf(-1.0), 0.0);
1255
1256 assert_eq!(exp.pdf(0.0), 2.0);
1258 }
1259
1260 #[test]
1261 fn test_exponential_invalid_params() {
1262 assert!(Exponential::new(0.0).is_err());
1263 assert!(Exponential::new(-1.0).is_err());
1264 assert!(Exponential::new(f64::NAN).is_err());
1265 assert!(Exponential::new(f64::INFINITY).is_err());
1266 }
1267
1268 #[test]
1269 fn test_uniform_distribution() {
1270 let uniform = Uniform::new(0.0, 1.0).unwrap();
1271 assert_eq!(uniform.lower_bound(), 0.0);
1272 assert_eq!(uniform.upper_bound(), 1.0);
1273 assert_abs_diff_eq!(uniform.mean(), 0.5, epsilon = 1e-10);
1274 assert_abs_diff_eq!(uniform.variance(), 1.0 / 12.0, epsilon = 1e-10);
1275
1276 assert_abs_diff_eq!(uniform.pdf(0.5), 1.0, epsilon = 1e-10);
1278 assert_abs_diff_eq!(uniform.pdf(0.0), 1.0, epsilon = 1e-10);
1279 assert_abs_diff_eq!(uniform.pdf(1.0), 1.0, epsilon = 1e-10);
1280
1281 assert_eq!(uniform.pdf(-0.1), 0.0);
1283 assert_eq!(uniform.pdf(1.1), 0.0);
1284 }
1285
1286 #[test]
1287 fn test_uniform_invalid_params() {
1288 assert!(Uniform::new(1.0, 0.0).is_err()); assert!(Uniform::new(1.0, 1.0).is_err()); assert!(Uniform::new(f64::NAN, 1.0).is_err());
1291 assert!(Uniform::new(0.0, f64::NAN).is_err());
1292 }
1293
1294 #[test]
1295 fn test_student_t_distribution() {
1296 let t = StudentT::new(3.0, 0.0, 1.0).unwrap();
1297 assert_eq!(t.degrees_of_freedom(), 3.0);
1298 assert_eq!(t.location(), 0.0);
1299 assert_eq!(t.scale(), 1.0);
1300 assert_eq!(t.mean(), Some(0.0));
1301 assert!(t.variance().is_some());
1302
1303 assert!(t.pdf(0.0) > 0.0);
1305 assert!(t.pdf(0.0).is_finite());
1306
1307 assert_abs_diff_eq!(t.pdf(1.0), t.pdf(-1.0), epsilon = 1e-10);
1309 }
1310
1311 #[test]
1312 fn test_student_t_invalid_params() {
1313 assert!(StudentT::new(0.0, 0.0, 1.0).is_err()); assert!(StudentT::new(1.0, 0.0, 0.0).is_err()); assert!(StudentT::new(-1.0, 0.0, 1.0).is_err()); assert!(StudentT::new(1.0, 0.0, -1.0).is_err()); }
1318
1319 #[test]
1320 fn test_student_t_moments() {
1321 let t1 = StudentT::new(0.5, 0.0, 1.0).unwrap();
1323 assert!(t1.mean().is_none());
1324
1325 let t2 = StudentT::new(1.5, 0.0, 1.0).unwrap();
1327 assert!(t2.variance().is_none());
1328 }
1329}