1#[cfg(feature = "serde1")]
3use serde::{Deserialize, Serialize};
4
5use rand::Rng;
6use rand_distr::Normal;
7use special::Error as _;
8use std::f64::consts::SQRT_2;
9use std::fmt;
10
11use crate::consts::{HALF_LN_2PI, HALF_LN_2PI_E};
12use crate::data::GaussianSuffStat;
13use crate::impl_display;
14use crate::traits::HasDensity;
15use crate::traits::{
16 Cdf, ContinuousDistr, Entropy, HasSuffStat, InverseCdf, KlDivergence,
17 Kurtosis, Mean, Median, Mode, Parameterized, QuadBounds, Sampleable,
18 Scalable, Shiftable, Skewness, Support, Variance,
19};
20
21#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
45#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
46#[cfg_attr(feature = "serde1", serde(try_from = "GaussianParameters"))]
47#[cfg_attr(feature = "serde1", serde(into = "GaussianParameters"))]
48pub struct Gaussian {
49 mu: f64,
51 sigma: f64,
53 ln_sigma: f64,
55}
56
57impl PartialEq for Gaussian {
58 fn eq(&self, other: &Gaussian) -> bool {
59 self.mu == other.mu && self.sigma == other.sigma
60 }
61}
62
63#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
64#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
65pub struct GaussianParameters {
66 pub mu: f64,
67 pub sigma: f64,
68}
69
70impl TryFrom<GaussianParameters> for Gaussian {
71 type Error = GaussianError;
72
73 fn try_from(params: GaussianParameters) -> Result<Self, Self::Error> {
74 Gaussian::new(params.mu, params.sigma)
75 }
76}
77
78impl From<Gaussian> for GaussianParameters {
79 fn from(gauss: Gaussian) -> Self {
80 GaussianParameters {
81 mu: gauss.mu,
82 sigma: gauss.sigma,
83 }
84 }
85}
86
87impl Parameterized for Gaussian {
88 type Parameters = GaussianParameters;
89
90 fn emit_params(&self) -> Self::Parameters {
91 Self::Parameters {
92 mu: self.mu(),
93 sigma: self.sigma(),
94 }
95 }
96
97 fn from_params(params: Self::Parameters) -> Self {
98 Self::new_unchecked(params.mu, params.sigma)
99 }
100}
101
102#[derive(Debug, Clone, PartialEq)]
103#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
104#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
105pub enum GaussianError {
106 MuNotFinite { mu: f64 },
108 SigmaTooLow { sigma: f64 },
110 SigmaNotFinite { sigma: f64 },
112}
113
114impl Gaussian {
115 pub fn new(mu: f64, sigma: f64) -> Result<Self, GaussianError> {
121 if !mu.is_finite() {
122 Err(GaussianError::MuNotFinite { mu })
123 } else if sigma <= 0.0 {
124 Err(GaussianError::SigmaTooLow { sigma })
125 } else if !sigma.is_finite() {
126 Err(GaussianError::SigmaNotFinite { sigma })
127 } else {
128 Ok(Gaussian {
129 mu,
130 sigma,
131 ln_sigma: sigma.ln(),
132 })
133 }
134 }
135
136 #[inline]
139 #[must_use]
140 pub fn new_unchecked(mu: f64, sigma: f64) -> Self {
141 Gaussian {
142 mu,
143 sigma,
144 ln_sigma: sigma.ln(),
145 }
146 }
147
148 #[inline]
159 #[must_use]
160 pub fn standard() -> Self {
161 Gaussian {
162 mu: 0.0,
163 sigma: 1.0,
164 ln_sigma: 0.0,
165 }
166 }
167
168 #[inline]
179 #[must_use]
180 pub fn mu(&self) -> f64 {
181 self.mu
182 }
183
184 #[inline]
208 pub fn set_mu(&mut self, mu: f64) -> Result<(), GaussianError> {
209 if mu.is_finite() {
210 self.set_mu_unchecked(mu);
211 Ok(())
212 } else {
213 Err(GaussianError::MuNotFinite { mu })
214 }
215 }
216
217 #[inline]
219 pub fn set_mu_unchecked(&mut self, mu: f64) {
220 self.mu = mu;
221 }
222
223 #[inline]
234 #[must_use]
235 pub fn sigma(&self) -> f64 {
236 self.sigma
237 }
238
239 #[inline]
265 pub fn set_sigma(&mut self, sigma: f64) -> Result<(), GaussianError> {
266 if sigma <= 0.0 {
267 Err(GaussianError::SigmaTooLow { sigma })
268 } else if !sigma.is_finite() {
269 Err(GaussianError::SigmaNotFinite { sigma })
270 } else {
271 self.set_sigma_unchecked(sigma);
272 Ok(())
273 }
274 }
275
276 #[inline]
278 pub fn set_sigma_unchecked(&mut self, sigma: f64) {
279 self.sigma = sigma;
280 self.ln_sigma = sigma.ln();
281 }
282}
283
284impl Default for Gaussian {
285 fn default() -> Self {
286 Gaussian::standard()
287 }
288}
289
290impl From<&Gaussian> for String {
291 fn from(gauss: &Gaussian) -> String {
292 format!("N(μ: {}, σ: {})", gauss.mu, gauss.sigma)
293 }
294}
295
296impl_display!(Gaussian);
297
298impl Shiftable for Gaussian {
299 type Output = Self;
300 type Error = GaussianError;
301
302 fn shifted(self, shift: f64) -> Result<Self::Output, Self::Error>
303 where
304 Self: Sized,
305 {
306 Self::new(self.mu() + shift, self.sigma())
307 }
308
309 fn shifted_unchecked(self, shift: f64) -> Self::Output
310 where
311 Self: Sized,
312 {
313 Self::new_unchecked(self.mu() + shift, self.sigma())
314 }
315}
316
317impl Scalable for Gaussian {
318 type Output = Self;
319 type Error = GaussianError;
320
321 fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error>
322 where
323 Self: Sized,
324 {
325 Self::new(self.mu() * scale, self.sigma() * scale)
326 }
327
328 fn scaled_unchecked(self, scale: f64) -> Self::Output
329 where
330 Self: Sized,
331 {
332 Self::new_unchecked(self.mu() * scale, self.sigma() * scale)
333 }
334}
335
336macro_rules! impl_traits {
337 ($kind:ty) => {
338 impl HasDensity<$kind> for Gaussian {
339 fn ln_f(&self, x: &$kind) -> f64 {
340 let k = (f64::from(*x) - self.mu) / self.sigma;
341 (0.5 * k).mul_add(-k, -self.ln_sigma) - HALF_LN_2PI
342 }
343 }
344
345 impl Sampleable<$kind> for Gaussian {
346 fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
347 let g = Normal::new(self.mu, self.sigma).unwrap();
348 rng.sample(g) as $kind
349 }
350
351 fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<$kind> {
352 let g = Normal::new(self.mu, self.sigma).unwrap();
353 (0..n).map(|_| rng.sample(g) as $kind).collect()
354 }
355 }
356
357 impl ContinuousDistr<$kind> for Gaussian {}
358
359 impl Support<$kind> for Gaussian {
360 fn supports(&self, x: &$kind) -> bool {
361 x.is_finite()
362 }
363 }
364
365 impl Cdf<$kind> for Gaussian {
366 fn cdf(&self, x: &$kind) -> f64 {
367 let errf =
368 ((f64::from(*x) - self.mu) / (self.sigma * SQRT_2)).error();
369 0.5 * (1.0 + errf)
370 }
371 }
372
373 impl InverseCdf<$kind> for Gaussian {
374 fn invcdf(&self, p: f64) -> $kind {
375 assert!((0.0..=1.0).contains(&p), "P out of range");
376
377 let x = (self.sigma * SQRT_2)
378 .mul_add(2.0_f64.mul_add(p, -1.0).inv_error(), self.mu);
379 x as $kind
380 }
381 }
382
383 impl Mean<$kind> for Gaussian {
384 fn mean(&self) -> Option<$kind> {
385 Some(self.mu as $kind)
386 }
387 }
388
389 impl Median<$kind> for Gaussian {
390 fn median(&self) -> Option<$kind> {
391 Some(self.mu as $kind)
392 }
393 }
394
395 impl Mode<$kind> for Gaussian {
396 fn mode(&self) -> Option<$kind> {
397 Some(self.mu as $kind)
398 }
399 }
400
401 impl HasSuffStat<$kind> for Gaussian {
402 type Stat = GaussianSuffStat;
403
404 fn empty_suffstat(&self) -> Self::Stat {
405 GaussianSuffStat::new()
406 }
407
408 fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
409 let z = (2.0 * self.sigma * self.sigma).recip();
412 let n = stat.n() as f64;
413 let expterm = stat.sum_x_sq()
414 + self
415 .mu
416 .mul_add(-2.0 * stat.sum_x(), n * self.mu * self.mu);
417 -n.mul_add(self.ln_sigma + HALF_LN_2PI, z * expterm)
418 }
419 }
420 };
421}
422
423impl Variance<f64> for Gaussian {
424 fn variance(&self) -> Option<f64> {
425 Some(self.sigma * self.sigma)
426 }
427}
428
429impl Entropy for Gaussian {
430 fn entropy(&self) -> f64 {
431 HALF_LN_2PI_E + self.ln_sigma
432 }
433}
434
435impl Skewness for Gaussian {
436 fn skewness(&self) -> Option<f64> {
437 Some(0.0)
438 }
439}
440
441impl Kurtosis for Gaussian {
442 fn kurtosis(&self) -> Option<f64> {
443 Some(0.0)
444 }
445}
446
447impl KlDivergence for Gaussian {
448 #[allow(clippy::suspicious_operation_groupings)]
449 fn kl(&self, other: &Self) -> f64 {
450 let m1 = self.mu;
451 let m2 = other.mu;
452
453 let s1 = self.sigma;
454 let s2 = other.sigma;
455
456 let term1 = s2.ln() - s1.ln();
457 let term2 = s1.mul_add(s1, (m1 - m2) * (m1 - m2)) / (2.0 * s2 * s2);
458
459 term1 + term2 - 0.5
460 }
461}
462
463impl QuadBounds for Gaussian {
464 fn quad_bounds(&self) -> (f64, f64) {
465 self.interval(0.999_999_999_999)
466 }
467}
468
469#[cfg(feature = "experimental")]
470impl_traits!(f16);
471impl_traits!(f32);
472impl_traits!(f64);
473
474impl std::error::Error for GaussianError {}
475
476#[cfg_attr(coverage_nightly, coverage(off))]
477impl fmt::Display for GaussianError {
478 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479 match self {
480 Self::MuNotFinite { mu } => write!(f, "non-finite mu: {mu}"),
481 Self::SigmaTooLow { sigma } => {
482 write!(f, "sigma ({sigma}) must be greater than zero")
483 }
484 Self::SigmaNotFinite { sigma } => {
485 write!(f, "non-finite sigma: {sigma}")
486 }
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 const TOL: f64 = 1E-12;
495
496 use crate::test_basic_impls;
497 test_basic_impls!(f64, Gaussian);
498
499 use crate::test_shiftable_cdf;
500 use crate::test_shiftable_density;
501 use crate::test_shiftable_entropy;
502 use crate::test_shiftable_invcdf;
503 use crate::test_shiftable_method;
504
505 test_shiftable_method!(Gaussian::new(2.0, 4.0).unwrap(), mean);
506 test_shiftable_method!(Gaussian::new(2.0, 4.0).unwrap(), median);
507 test_shiftable_method!(Gaussian::new(2.0, 4.0).unwrap(), variance);
508 test_shiftable_method!(Gaussian::new(2.0, 4.0).unwrap(), skewness);
509 test_shiftable_method!(Gaussian::new(2.0, 4.0).unwrap(), kurtosis);
510 test_shiftable_density!(Gaussian::new(2.0, 4.0).unwrap());
511 test_shiftable_entropy!(Gaussian::new(2.0, 4.0).unwrap());
512 test_shiftable_cdf!(Gaussian::new(2.0, 4.0).unwrap());
513 test_shiftable_invcdf!(Gaussian::new(2.0, 4.0).unwrap());
514
515 use crate::test_scalable_cdf;
516 use crate::test_scalable_density;
517 use crate::test_scalable_entropy;
518 use crate::test_scalable_invcdf;
519 use crate::test_scalable_method;
520
521 test_scalable_method!(Gaussian::new(2.0, 4.0).unwrap(), mean);
522 test_scalable_method!(Gaussian::new(2.0, 4.0).unwrap(), median);
523 test_scalable_method!(Gaussian::new(2.0, 4.0).unwrap(), variance);
524 test_scalable_method!(Gaussian::new(2.0, 4.0).unwrap(), skewness);
525 test_scalable_method!(Gaussian::new(2.0, 4.0).unwrap(), kurtosis);
526 test_scalable_density!(Gaussian::new(2.0, 4.0).unwrap());
527 test_scalable_entropy!(Gaussian::new(2.0, 4.0).unwrap());
528 test_scalable_cdf!(Gaussian::new(2.0, 4.0).unwrap());
529 test_scalable_invcdf!(Gaussian::new(2.0, 4.0).unwrap());
530
531 #[test]
532 fn new() {
533 let gauss = Gaussian::new(1.2, 3.0).unwrap();
534 assert::close(gauss.mu, 1.2, TOL);
535 assert::close(gauss.sigma, 3.0, TOL);
536 }
537
538 #[test]
539 fn standard() {
540 let gauss = Gaussian::standard();
541 assert::close(gauss.mu, 0.0, TOL);
542 assert::close(gauss.sigma, 1.0, TOL);
543 }
544
545 #[test]
546 fn standard_gaussian_mean_should_be_zero() {
547 let mean: f64 = Gaussian::standard().mean().unwrap();
548 assert::close(mean, 0.0, TOL);
549 }
550
551 #[test]
552 fn standard_gaussian_variance_should_be_one() {
553 assert::close(Gaussian::standard().variance().unwrap(), 1.0, TOL);
554 }
555
556 #[test]
557 fn mean_should_be_mu() {
558 let mu = 3.4;
559 let mean: f64 = Gaussian::new(mu, 0.5).unwrap().mean().unwrap();
560 assert::close(mean, mu, TOL);
561 }
562
563 #[test]
564 fn median_should_be_mu() {
565 let mu = 3.4;
566 let median: f64 = Gaussian::new(mu, 0.5).unwrap().median().unwrap();
567 assert::close(median, mu, TOL);
568 }
569
570 #[test]
571 fn mode_should_be_mu() {
572 let mu = 3.4;
573 let mode: f64 = Gaussian::new(mu, 0.5).unwrap().mode().unwrap();
574 assert::close(mode, mu, TOL);
575 }
576
577 #[test]
578 fn variance_should_be_sigma_squared() {
579 let sigma = 0.5;
580 let gauss = Gaussian::new(3.4, sigma).unwrap();
581 assert::close(gauss.variance().unwrap(), sigma * sigma, TOL);
582 }
583
584 #[test]
585 fn draws_should_be_finite() {
586 let mut rng = rand::rng();
587 let gauss = Gaussian::standard();
588 for _ in 0..100 {
589 let x: f64 = gauss.draw(&mut rng);
590 assert!(x.is_finite());
591 }
592 }
593
594 #[test]
595 fn sample_length() {
596 let mut rng = rand::rng();
597 let gauss = Gaussian::standard();
598 let xs: Vec<f64> = gauss.sample(10, &mut rng);
599 assert_eq!(xs.len(), 10);
600 }
601
602 #[test]
603 fn standard_ln_pdf_at_zero() {
604 let gauss = Gaussian::standard();
605 assert::close(gauss.ln_pdf(&0.0_f64), -0.918_938_533_204_672_7, TOL);
606 }
607
608 #[test]
609 fn standard_ln_pdf_off_zero() {
610 let gauss = Gaussian::standard();
611 assert::close(gauss.ln_pdf(&2.1_f64), -3.123_938_533_204_672_7, TOL);
612 }
613
614 #[test]
615 fn nonstandard_ln_pdf_on_mean() {
616 let gauss = Gaussian::new(-1.2, 0.33).unwrap();
617 assert::close(gauss.ln_pdf(&-1.2_f64), 0.189_724_091_316_938_46, TOL);
618 }
619
620 #[test]
621 fn nonstandard_ln_pdf_off_mean() {
622 let gauss = Gaussian::new(-1.2, 0.33).unwrap();
623 assert::close(gauss.ln_pdf(&0.0_f32), -6.421_846_156_616_945, TOL);
624 }
625
626 #[test]
627 fn should_contain_finite_values() {
628 let gauss = Gaussian::standard();
629 assert!(gauss.supports(&0.0_f32));
630 assert!(gauss.supports(&10E8_f64));
631 assert!(gauss.supports(&-10E8_f64));
632 }
633
634 #[test]
635 fn should_not_contain_nan() {
636 let gauss = Gaussian::standard();
637 assert!(!gauss.supports(&f64::NAN));
638 }
639
640 #[test]
641 fn should_not_contain_positive_or_negative_infinity() {
642 let gauss = Gaussian::standard();
643 assert!(!gauss.supports(&f64::INFINITY));
644 assert!(!gauss.supports(&f64::NEG_INFINITY));
645 }
646
647 #[test]
648 fn skewness_should_be_zero() {
649 let gauss = Gaussian::new(-12.3, 45.6).unwrap();
650 assert::close(gauss.skewness().unwrap(), 0.0, TOL);
651 }
652
653 #[test]
654 fn kurtosis_should_be_zero() {
655 let gauss = Gaussian::new(-12.3, 45.6).unwrap();
656 assert::close(gauss.skewness().unwrap(), 0.0, TOL);
657 }
658
659 #[test]
660 fn cdf_at_mean_should_be_one_half() {
661 let mu1: f64 = 2.3;
662 let gauss1 = Gaussian::new(mu1, 0.2).unwrap();
663 assert::close(gauss1.cdf(&mu1), 0.5, TOL);
664
665 let mu2: f32 = -8.0;
666 let gauss2 = Gaussian::new(mu2.into(), 100.0).unwrap();
667 assert::close(gauss2.cdf(&mu2), 0.5, TOL);
668 }
669
670 #[test]
671 fn cdf_value_at_one() {
672 let gauss = Gaussian::standard();
673 assert::close(gauss.cdf(&1.0_f64), 0.841_344_746_068_542_9, TOL);
674 }
675
676 #[test]
677 fn cdf_value_at_neg_two() {
678 let gauss = Gaussian::standard();
679 assert::close(gauss.cdf(&-2.0_f64), 0.022_750_131_948_179_195, TOL);
680 }
681
682 #[test]
683 fn quantile_at_one_half_should_be_mu() {
684 let mu = 1.2315;
685 let gauss = Gaussian::new(mu, 1.0).unwrap();
686 let x: f64 = gauss.quantile(0.5);
687 assert::close(x, mu, TOL);
688 }
689
690 #[test]
691 fn quantile_agree_with_cdf() {
692 let mut rng = rand::rng();
693 let gauss = Gaussian::standard();
694 let xs: Vec<f64> = gauss.sample(100, &mut rng);
695
696 for x in &xs {
697 let p = gauss.cdf(x);
698 let y: f64 = gauss.quantile(p);
699 assert::close(y, *x, TOL);
700 }
701 }
702
703 #[test]
704 fn quad_on_pdf_agrees_with_cdf_x() {
705 use peroxide::numerical::integral::{
706 Integral, gauss_kronrod_quadrature,
707 };
708 let ig = Gaussian::new(-2.3, 0.5).unwrap();
709 let pdf = |x: f64| ig.f(&x);
710 let mut rng = rand::rng();
711 for _ in 0..100 {
712 let x: f64 = ig.draw(&mut rng);
713 let res = gauss_kronrod_quadrature(
714 pdf,
715 (-10.0, x),
716 Integral::G7K15(1e-12, 100),
717 );
718 let cdf = ig.cdf(&x);
719 assert::close(res, cdf, 1e-9);
720 }
721 }
722
723 #[test]
724 fn standard_gaussian_entropy() {
725 let gauss = Gaussian::standard();
726 assert::close(gauss.entropy(), 1.418_938_533_204_672_7, TOL);
727 }
728
729 #[test]
730 fn entropy() {
731 let gauss = Gaussian::new(3.0, 12.3).unwrap();
732 assert::close(gauss.entropy(), 3.928_537_795_583_044_7, TOL);
733 }
734
735 #[test]
736 fn kl_of_identical_distributions_should_be_zero() {
737 let gauss = Gaussian::new(1.2, 3.4).unwrap();
738 assert::close(gauss.kl(&gauss), 0.0, TOL);
739 }
740
741 #[test]
742 fn kl() {
743 let g1 = Gaussian::new(1.0, 2.0).unwrap();
744 let g2 = Gaussian::new(2.0, 1.0).unwrap();
745 let kl = 0.5_f64.ln() + 5.0 / 2.0 - 0.5;
746 assert::close(g1.kl(&g2), kl, TOL);
747 }
748
749 #[test]
750 fn ln_f_after_set_mu_works() {
751 let mut gauss = Gaussian::standard();
752 assert::close(gauss.ln_pdf(&0.0_f64), -0.918_938_533_204_672_7, TOL);
753
754 gauss.set_mu(1.0).unwrap();
755 assert::close(gauss.ln_pdf(&1.0_f64), -0.918_938_533_204_672_7, TOL);
756 }
757
758 #[test]
759 fn ln_f_after_set_sigm_works() {
760 let mut gauss = Gaussian::new(-1.2, 5.0).unwrap();
761
762 gauss.set_sigma(0.33).unwrap();
763 assert::close(gauss.ln_pdf(&-1.2_f64), 0.189_724_091_316_938_46, TOL);
764 assert::close(gauss.ln_pdf(&0.0_f32), -6.421_846_156_616_945, TOL);
765 }
766
767 #[test]
768 fn ln_f_stat() {
769 use crate::traits::SuffStat;
770
771 let data: Vec<f64> = vec![0.1, 0.23, 1.4, 0.65, 0.22, 3.1];
772 let mut stat = GaussianSuffStat::new();
773 stat.observe_many(&data);
774
775 let gauss = Gaussian::new(-0.3, 2.33).unwrap();
776
777 let ln_f_base: f64 = data.iter().map(|x| gauss.ln_f(x)).sum();
778 let ln_f_stat: f64 =
779 <Gaussian as HasSuffStat<f64>>::ln_f_stat(&gauss, &stat);
780
781 assert::close(ln_f_base, ln_f_stat, TOL);
782 }
783
784 #[cfg(feature = "serde1")]
785 crate::test_serde_params!(Gaussian::new(-1.3, 2.4).unwrap(), Gaussian, f64);
786
787 #[test]
788 fn emit_and_from_params_are_identity() {
789 let dist_a = Gaussian::new(3.0, 5.0).unwrap();
790 let dist_b = Gaussian::from_params(dist_a.emit_params());
791 assert_eq!(dist_a, dist_b);
792 }
793}