Skip to main content

numra_stats/distributions/
beta_dist.rs

1//! Beta distribution.
2//!
3//! Author: Moussa Leblouba
4//! Date: 9 February 2026
5//! Modified: 2 May 2026
6
7use numra_core::Scalar;
8use numra_special::{betainc, lgamma};
9use rand::RngCore;
10
11use super::gamma_dist::GammaDist;
12use super::ContinuousDistribution;
13
14/// Beta distribution on [0, 1] with parameters alpha and beta.
15#[derive(Clone, Debug)]
16pub struct BetaDist<S: Scalar> {
17    pub alpha: S,
18    pub beta: S,
19}
20
21impl<S: Scalar> BetaDist<S> {
22    pub fn new(alpha: S, beta: S) -> Self {
23        Self { alpha, beta }
24    }
25}
26
27impl<S: Scalar> ContinuousDistribution<S> for BetaDist<S> {
28    fn pdf(&self, x: S) -> S {
29        if x < S::ZERO || x > S::ONE {
30            return S::ZERO;
31        }
32        let log_pdf = (self.alpha - S::ONE) * x.ln() + (self.beta - S::ONE) * (S::ONE - x).ln()
33            - lbeta(self.alpha, self.beta);
34        log_pdf.exp()
35    }
36
37    fn cdf(&self, x: S) -> S {
38        if x <= S::ZERO {
39            return S::ZERO;
40        }
41        if x >= S::ONE {
42            return S::ONE;
43        }
44        betainc(self.alpha, self.beta, x)
45    }
46
47    fn quantile(&self, p: S) -> S {
48        if p <= S::ZERO {
49            return S::ZERO;
50        }
51        if p >= S::ONE {
52            return S::ONE;
53        }
54        // Initial guess: use mean of the distribution, adjusted toward p
55        let mu = self.mean();
56        let mut x = mu;
57        // Refine initial guess: blend mean with p-based guess
58        let p_f64 = p.to_f64();
59        if p_f64 < 0.05 {
60            x = S::from_f64(p_f64.powf(1.0 / self.alpha.to_f64()));
61        } else if p_f64 > 0.95 {
62            x = S::ONE - S::from_f64((1.0 - p_f64).powf(1.0 / self.beta.to_f64()));
63        }
64        // Clamp initial guess
65        if x <= S::ZERO {
66            x = S::from_f64(0.01);
67        }
68        if x >= S::ONE {
69            x = S::from_f64(0.99);
70        }
71        // Newton iteration on CDF
72        for _ in 0..100 {
73            let f_val = self.cdf(x) - p;
74            let f_prime = self.pdf(x);
75            if f_prime.to_f64().abs() < 1e-300 {
76                break;
77            }
78            let step = f_val / f_prime;
79            // Halve step if it would move outside (0, 1)
80            let mut s = step;
81            for _ in 0..10 {
82                let xn = x - s;
83                if xn.to_f64() > 1e-10 && xn.to_f64() < 1.0 - 1e-10 {
84                    break;
85                }
86                s *= S::HALF;
87            }
88            x -= s;
89            // Hard clamp
90            if x <= S::ZERO {
91                x = S::from_f64(1e-10);
92            }
93            if x >= S::ONE {
94                x = S::from_f64(1.0 - 1e-10);
95            }
96            if s.to_f64().abs() < 1e-12 {
97                break;
98            }
99        }
100        x
101    }
102
103    fn mean(&self) -> S {
104        self.alpha / (self.alpha + self.beta)
105    }
106
107    fn variance(&self) -> S {
108        let ab = self.alpha + self.beta;
109        self.alpha * self.beta / (ab * ab * (ab + S::ONE))
110    }
111
112    fn sample(&self, rng: &mut dyn RngCore) -> S {
113        // Sample via gamma: if X ~ Gamma(alpha, 1) and Y ~ Gamma(beta, 1),
114        // then X/(X+Y) ~ Beta(alpha, beta)
115        let gx = GammaDist::new(self.alpha, S::ONE);
116        let gy = GammaDist::new(self.beta, S::ONE);
117        let x = gx.sample(rng);
118        let y = gy.sample(rng);
119        x / (x + y)
120    }
121}
122
123/// Log of the Beta function: ln(B(a, b)) = ln(Gamma(a)) + ln(Gamma(b)) - ln(Gamma(a+b))
124fn lbeta<S: Scalar>(a: S, b: S) -> S {
125    lgamma(a) + lgamma(b) - lgamma(a + b)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_beta_uniform_case() {
134        // Beta(1, 1) = Uniform(0, 1)
135        let b = BetaDist::new(1.0_f64, 1.0);
136        assert!((b.pdf(0.5) - 1.0).abs() < 1e-12);
137        assert!((b.cdf(0.5) - 0.5).abs() < 1e-8);
138    }
139
140    #[test]
141    fn test_beta_symmetric() {
142        let b = BetaDist::new(2.0_f64, 2.0);
143        assert!((b.mean() - 0.5).abs() < 1e-14);
144        // PDF should be symmetric about 0.5
145        assert!((b.pdf(0.3) - b.pdf(0.7)).abs() < 1e-12);
146    }
147
148    #[test]
149    fn test_beta_quantile_roundtrip() {
150        let b = BetaDist::new(2.0_f64, 5.0);
151        for &p in &[0.1, 0.5, 0.9] {
152            let x = b.quantile(p);
153            let p2 = b.cdf(x);
154            assert!((p - p2).abs() < 1e-6, "p={}, p2={}", p, p2);
155        }
156    }
157
158    #[test]
159    fn test_beta_mean_variance() {
160        let b = BetaDist::new(2.0_f64, 3.0);
161        assert!((b.mean() - 0.4).abs() < 1e-14);
162        // Var = 2*3 / (5*5*6) = 6/150 = 0.04
163        assert!((b.variance() - 0.04).abs() < 1e-14);
164    }
165}