rs_stats/distributions/
beta.rs1use crate::distributions::traits::Distribution;
50use crate::error::{StatsError, StatsResult};
51use crate::utils::special_functions::{bisect_inverse_cdf, ln_beta, regularized_incomplete_beta};
52
53#[derive(Debug, Clone, Copy)]
64pub struct Beta {
65 pub alpha: f64,
67 pub beta: f64,
69}
70
71impl Beta {
72 pub fn new(alpha: f64, beta: f64) -> StatsResult<Self> {
74 if alpha <= 0.0 || beta <= 0.0 {
75 return Err(StatsError::InvalidInput {
76 message: "Beta::new: alpha and beta must be positive".to_string(),
77 });
78 }
79 Ok(Self { alpha, beta })
80 }
81
82 pub fn fit(data: &[f64]) -> StatsResult<Self> {
86 if data.is_empty() {
87 return Err(StatsError::InvalidInput {
88 message: "Beta::fit: data must not be empty".to_string(),
89 });
90 }
91 if data.iter().any(|&x| x <= 0.0 || x >= 1.0) {
92 return Err(StatsError::InvalidInput {
93 message: "Beta::fit: all data values must be in (0, 1)".to_string(),
94 });
95 }
96 let n = data.len() as f64;
97 let mean = data.iter().sum::<f64>() / n;
98 let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
99
100 let common = mean * (1.0 - mean) / variance - 1.0;
102 let alpha = mean * common;
103 let beta = (1.0 - mean) * common;
104 Self::new(alpha, beta)
105 }
106}
107
108impl Distribution for Beta {
109 fn name(&self) -> &str {
110 "Beta"
111 }
112 fn num_params(&self) -> usize {
113 2
114 }
115
116 fn pdf(&self, x: f64) -> StatsResult<f64> {
117 if !(0.0..=1.0).contains(&x) {
118 return Ok(0.0);
119 }
120 if x == 0.0 {
121 return Ok(if self.alpha >= 1.0 {
122 0.0
123 } else {
124 f64::INFINITY
125 });
126 }
127 if x == 1.0 {
128 return Ok(if self.beta >= 1.0 { 0.0 } else { f64::INFINITY });
129 }
130 let log_pdf = (self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln()
131 - ln_beta(self.alpha, self.beta);
132 Ok(log_pdf.exp())
133 }
134
135 fn logpdf(&self, x: f64) -> StatsResult<f64> {
136 if x <= 0.0 || x >= 1.0 {
137 return Ok(f64::NEG_INFINITY);
138 }
139 Ok(
140 (self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln()
141 - ln_beta(self.alpha, self.beta),
142 )
143 }
144
145 fn cdf(&self, x: f64) -> StatsResult<f64> {
146 if x <= 0.0 {
147 return Ok(0.0);
148 }
149 if x >= 1.0 {
150 return Ok(1.0);
151 }
152 Ok(regularized_incomplete_beta(self.alpha, self.beta, x))
153 }
154
155 fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
156 if !(0.0..=1.0).contains(&p) {
157 return Err(StatsError::InvalidInput {
158 message: "Beta::inverse_cdf: p must be in [0, 1]".to_string(),
159 });
160 }
161 if p == 0.0 {
162 return Ok(0.0);
163 }
164 if p == 1.0 {
165 return Ok(1.0);
166 }
167 let alpha = self.alpha;
168 let beta = self.beta;
169 Ok(bisect_inverse_cdf(
170 |x| regularized_incomplete_beta(alpha, beta, x),
171 p,
172 0.0,
173 1.0,
174 ))
175 }
176
177 fn mean(&self) -> f64 {
178 self.alpha / (self.alpha + self.beta)
179 }
180
181 fn variance(&self) -> f64 {
182 let s = self.alpha + self.beta;
183 self.alpha * self.beta / (s * s * (s + 1.0))
184 }
185}
186
187#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_beta_mean_variance() {
196 let b = Beta::new(2.0, 5.0).unwrap();
197 assert!((b.mean() - 2.0 / 7.0).abs() < 1e-10);
198 let expected_var = 2.0 * 5.0 / (49.0 * 8.0);
199 assert!((b.variance() - expected_var).abs() < 1e-10);
200 }
201
202 #[test]
203 fn test_beta_pdf_at_mean() {
204 let b = Beta::new(1.0, 1.0).unwrap(); assert!((b.pdf(0.5).unwrap() - 1.0).abs() < 1e-10);
206 }
207
208 #[test]
209 fn test_beta_cdf_bounds() {
210 let b = Beta::new(2.0, 3.0).unwrap();
211 assert_eq!(b.cdf(0.0).unwrap(), 0.0);
212 assert_eq!(b.cdf(1.0).unwrap(), 1.0);
213 }
214
215 #[test]
216 fn test_beta_inverse_cdf_roundtrip() {
217 let b = Beta::new(2.0, 3.0).unwrap();
218 for p in [0.1, 0.25, 0.5, 0.75, 0.9] {
219 let x = b.inverse_cdf(p).unwrap();
220 let p_back = b.cdf(x).unwrap();
221 assert!((p - p_back).abs() < 1e-7, "p={p}: roundtrip failed");
222 }
223 }
224
225 #[test]
226 fn test_beta_fit() {
227 let data = vec![0.1, 0.2, 0.15, 0.25, 0.3, 0.18, 0.22, 0.12, 0.28, 0.16];
229 let b = Beta::fit(&data).unwrap();
230 let data_mean = data.iter().sum::<f64>() / data.len() as f64;
232 assert!((b.mean() - data_mean).abs() < 1e-10);
233 }
234
235 #[test]
236 fn test_beta_invalid_params() {
237 assert!(Beta::new(0.0, 1.0).is_err());
238 assert!(Beta::new(1.0, 0.0).is_err());
239 assert!(Beta::new(-1.0, 1.0).is_err());
240 }
241}