numra_stats/distributions/
beta_dist.rs1use numra_core::Scalar;
8use numra_special::{betainc, lgamma};
9use rand::RngCore;
10
11use super::gamma_dist::GammaDist;
12use super::ContinuousDistribution;
13
14#[derive(Clone, Debug)]
16pub struct BetaDist<S: Scalar> {
17 pub alpha: S,
18 pub beta: S,
19}
20
21impl<S: Scalar> BetaDist<S> {
22 pub fn new(alpha: S, beta: S) -> Self {
23 Self { alpha, beta }
24 }
25}
26
27impl<S: Scalar> ContinuousDistribution<S> for BetaDist<S> {
28 fn pdf(&self, x: S) -> S {
29 if x < S::ZERO || x > S::ONE {
30 return S::ZERO;
31 }
32 let log_pdf = (self.alpha - S::ONE) * x.ln() + (self.beta - S::ONE) * (S::ONE - x).ln()
33 - lbeta(self.alpha, self.beta);
34 log_pdf.exp()
35 }
36
37 fn cdf(&self, x: S) -> S {
38 if x <= S::ZERO {
39 return S::ZERO;
40 }
41 if x >= S::ONE {
42 return S::ONE;
43 }
44 betainc(self.alpha, self.beta, x)
45 }
46
47 fn quantile(&self, p: S) -> S {
48 if p <= S::ZERO {
49 return S::ZERO;
50 }
51 if p >= S::ONE {
52 return S::ONE;
53 }
54 let mu = self.mean();
56 let mut x = mu;
57 let p_f64 = p.to_f64();
59 if p_f64 < 0.05 {
60 x = S::from_f64(p_f64.powf(1.0 / self.alpha.to_f64()));
61 } else if p_f64 > 0.95 {
62 x = S::ONE - S::from_f64((1.0 - p_f64).powf(1.0 / self.beta.to_f64()));
63 }
64 if x <= S::ZERO {
66 x = S::from_f64(0.01);
67 }
68 if x >= S::ONE {
69 x = S::from_f64(0.99);
70 }
71 for _ in 0..100 {
73 let f_val = self.cdf(x) - p;
74 let f_prime = self.pdf(x);
75 if f_prime.to_f64().abs() < 1e-300 {
76 break;
77 }
78 let step = f_val / f_prime;
79 let mut s = step;
81 for _ in 0..10 {
82 let xn = x - s;
83 if xn.to_f64() > 1e-10 && xn.to_f64() < 1.0 - 1e-10 {
84 break;
85 }
86 s *= S::HALF;
87 }
88 x -= s;
89 if x <= S::ZERO {
91 x = S::from_f64(1e-10);
92 }
93 if x >= S::ONE {
94 x = S::from_f64(1.0 - 1e-10);
95 }
96 if s.to_f64().abs() < 1e-12 {
97 break;
98 }
99 }
100 x
101 }
102
103 fn mean(&self) -> S {
104 self.alpha / (self.alpha + self.beta)
105 }
106
107 fn variance(&self) -> S {
108 let ab = self.alpha + self.beta;
109 self.alpha * self.beta / (ab * ab * (ab + S::ONE))
110 }
111
112 fn sample(&self, rng: &mut dyn RngCore) -> S {
113 let gx = GammaDist::new(self.alpha, S::ONE);
116 let gy = GammaDist::new(self.beta, S::ONE);
117 let x = gx.sample(rng);
118 let y = gy.sample(rng);
119 x / (x + y)
120 }
121}
122
123fn lbeta<S: Scalar>(a: S, b: S) -> S {
125 lgamma(a) + lgamma(b) - lgamma(a + b)
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_beta_uniform_case() {
134 let b = BetaDist::new(1.0_f64, 1.0);
136 assert!((b.pdf(0.5) - 1.0).abs() < 1e-12);
137 assert!((b.cdf(0.5) - 0.5).abs() < 1e-8);
138 }
139
140 #[test]
141 fn test_beta_symmetric() {
142 let b = BetaDist::new(2.0_f64, 2.0);
143 assert!((b.mean() - 0.5).abs() < 1e-14);
144 assert!((b.pdf(0.3) - b.pdf(0.7)).abs() < 1e-12);
146 }
147
148 #[test]
149 fn test_beta_quantile_roundtrip() {
150 let b = BetaDist::new(2.0_f64, 5.0);
151 for &p in &[0.1, 0.5, 0.9] {
152 let x = b.quantile(p);
153 let p2 = b.cdf(x);
154 assert!((p - p2).abs() < 1e-6, "p={}, p2={}", p, p2);
155 }
156 }
157
158 #[test]
159 fn test_beta_mean_variance() {
160 let b = BetaDist::new(2.0_f64, 3.0);
161 assert!((b.mean() - 0.4).abs() < 1e-14);
162 assert!((b.variance() - 0.04).abs() < 1e-14);
164 }
165}