Skip to main content

rs_stats/distributions/
beta.rs

1//! # Beta Distribution
2//!
3//! The Beta(α, β) distribution is a continuous distribution on [0, 1], making it
4//! the natural model for proportions, rates, and probabilities.
5//!
6//! **PDF**: f(x; α, β) = x^(α−1) · (1−x)^(β−1) / B(α, β),  x ∈ [0, 1]
7//!
8//! **Fit**: method-of-moments — estimates μ̂ and σ̂² from data, then solves for α̂, β̂.
9//!
10//! **Mean**: α / (α + β)   **Variance**: αβ / [(α+β)²(α+β+1)]
11//!
12//! ## When to use
13//!
14//! Use Beta whenever your outcome is a **proportion** constrained to (0, 1):
15//! it can be symmetric (α=β), right-skewed (α<β), left-skewed (α>β), or U-shaped (α,β<1).
16//!
17//! ## Medical applications
18//!
19//! | Proportion | Description |
20//! |------------|-------------|
21//! | **Sensitivity / Specificity** | Diagnostic test performance across studies (meta-analysis) |
22//! | **Time-in-therapeutic range (TTR)** | Anticoagulation quality (warfarin, DOAC) |
23//! | **Medication adherence rate** | Fraction of prescribed doses taken |
24//! | **Tumour response rate** | Proportion of patients achieving response |
25//! | **Prevalence** | Bayesian prior / posterior for disease frequency |
26//!
27//! ## Example — warfarin time-in-therapeutic range (TTR)
28//!
29//! ```rust
30//! use rs_stats::distributions::beta::Beta;
31//! use rs_stats::distributions::traits::Distribution;
32//!
33//! // TTR values (0–1) for anticoagulated patients
34//! // TTR ≥ 0.70 is the recommended target for well-controlled anticoagulation
35//! let ttr = vec![
36//!     0.72, 0.65, 0.88, 0.55, 0.91, 0.78, 0.62, 0.84,
37//!     0.70, 0.58, 0.79, 0.93, 0.67, 0.75, 0.48, 0.82,
38//! ];
39//! let b = Beta::fit(&ttr).unwrap();
40//! println!("Beta(α={:.2}, β={:.2})", b.alpha, b.beta);
41//!
42//! let p_controlled = 1.0 - b.cdf(0.70).unwrap();
43//! println!("P(TTR ≥ 70%) = {:.1}%", p_controlled * 100.0);
44//!
45//! let median_ttr = b.inverse_cdf(0.5).unwrap();
46//! println!("Median TTR   = {:.1}%", median_ttr * 100.0);
47//! ```
48
49use crate::distributions::traits::Distribution;
50use crate::error::{StatsError, StatsResult};
51use crate::utils::special_functions::{bisect_inverse_cdf, ln_beta, regularized_incomplete_beta};
52
53/// Beta distribution Beta(α, β).
54///
55/// # Examples
56/// ```
57/// use rs_stats::distributions::beta::Beta;
58/// use rs_stats::distributions::traits::Distribution;
59///
60/// let b = Beta::new(2.0, 5.0).unwrap();
61/// assert!((b.mean() - 2.0 / 7.0).abs() < 1e-10);
62/// ```
63#[derive(Debug, Clone, Copy)]
64pub struct Beta {
65    /// Shape parameter α > 0
66    pub alpha: f64,
67    /// Shape parameter β > 0
68    pub beta: f64,
69}
70
71impl Beta {
72    /// Creates a `Beta(α, β)` distribution. Both parameters must be positive.
73    pub fn new(alpha: f64, beta: f64) -> StatsResult<Self> {
74        if alpha <= 0.0 || beta <= 0.0 {
75            return Err(StatsError::InvalidInput {
76                message: "Beta::new: alpha and beta must be positive".to_string(),
77            });
78        }
79        Ok(Self { alpha, beta })
80    }
81
82    /// MLE via method of moments from data in [0, 1].
83    ///
84    /// Requires all data in (0, 1). Estimates α and β from sample mean and variance.
85    pub fn fit(data: &[f64]) -> StatsResult<Self> {
86        if data.is_empty() {
87            return Err(StatsError::InvalidInput {
88                message: "Beta::fit: data must not be empty".to_string(),
89            });
90        }
91        if data.iter().any(|&x| x <= 0.0 || x >= 1.0) {
92            return Err(StatsError::InvalidInput {
93                message: "Beta::fit: all data values must be in (0, 1)".to_string(),
94            });
95        }
96        let n = data.len() as f64;
97        let mean = data.iter().sum::<f64>() / n;
98        let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
99
100        // Method of moments: α = mean*(mean*(1-mean)/var - 1), β = (1-mean)*...
101        let common = mean * (1.0 - mean) / variance - 1.0;
102        let alpha = mean * common;
103        let beta = (1.0 - mean) * common;
104        Self::new(alpha, beta)
105    }
106}
107
108impl Distribution for Beta {
109    fn name(&self) -> &str {
110        "Beta"
111    }
112    fn num_params(&self) -> usize {
113        2
114    }
115
116    fn pdf(&self, x: f64) -> StatsResult<f64> {
117        if !(0.0..=1.0).contains(&x) {
118            return Ok(0.0);
119        }
120        if x == 0.0 {
121            return Ok(if self.alpha >= 1.0 {
122                0.0
123            } else {
124                f64::INFINITY
125            });
126        }
127        if x == 1.0 {
128            return Ok(if self.beta >= 1.0 { 0.0 } else { f64::INFINITY });
129        }
130        let log_pdf = (self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln()
131            - ln_beta(self.alpha, self.beta);
132        Ok(log_pdf.exp())
133    }
134
135    fn logpdf(&self, x: f64) -> StatsResult<f64> {
136        if x <= 0.0 || x >= 1.0 {
137            return Ok(f64::NEG_INFINITY);
138        }
139        Ok(
140            (self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln()
141                - ln_beta(self.alpha, self.beta),
142        )
143    }
144
145    fn cdf(&self, x: f64) -> StatsResult<f64> {
146        if x <= 0.0 {
147            return Ok(0.0);
148        }
149        if x >= 1.0 {
150            return Ok(1.0);
151        }
152        Ok(regularized_incomplete_beta(self.alpha, self.beta, x))
153    }
154
155    fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
156        if !(0.0..=1.0).contains(&p) {
157            return Err(StatsError::InvalidInput {
158                message: "Beta::inverse_cdf: p must be in [0, 1]".to_string(),
159            });
160        }
161        if p == 0.0 {
162            return Ok(0.0);
163        }
164        if p == 1.0 {
165            return Ok(1.0);
166        }
167        let alpha = self.alpha;
168        let beta = self.beta;
169        Ok(bisect_inverse_cdf(
170            |x| regularized_incomplete_beta(alpha, beta, x),
171            p,
172            0.0,
173            1.0,
174        ))
175    }
176
177    fn mean(&self) -> f64 {
178        self.alpha / (self.alpha + self.beta)
179    }
180
181    fn variance(&self) -> f64 {
182        let s = self.alpha + self.beta;
183        self.alpha * self.beta / (s * s * (s + 1.0))
184    }
185}
186
187// ── Log-likelihood with analytically stable logpdf ────────────────────────────
188// (default impl already calls logpdf, which is overridden above)
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_beta_mean_variance() {
196        let b = Beta::new(2.0, 5.0).unwrap();
197        assert!((b.mean() - 2.0 / 7.0).abs() < 1e-10);
198        let expected_var = 2.0 * 5.0 / (49.0 * 8.0);
199        assert!((b.variance() - expected_var).abs() < 1e-10);
200    }
201
202    #[test]
203    fn test_beta_pdf_at_mean() {
204        let b = Beta::new(1.0, 1.0).unwrap(); // Uniform on [0,1]
205        assert!((b.pdf(0.5).unwrap() - 1.0).abs() < 1e-10);
206    }
207
208    #[test]
209    fn test_beta_cdf_bounds() {
210        let b = Beta::new(2.0, 3.0).unwrap();
211        assert_eq!(b.cdf(0.0).unwrap(), 0.0);
212        assert_eq!(b.cdf(1.0).unwrap(), 1.0);
213    }
214
215    #[test]
216    fn test_beta_inverse_cdf_roundtrip() {
217        let b = Beta::new(2.0, 3.0).unwrap();
218        for p in [0.1, 0.25, 0.5, 0.75, 0.9] {
219            let x = b.inverse_cdf(p).unwrap();
220            let p_back = b.cdf(x).unwrap();
221            assert!((p - p_back).abs() < 1e-7, "p={p}: roundtrip failed");
222        }
223    }
224
225    #[test]
226    fn test_beta_fit() {
227        // Fit from data generated from Beta(2, 5)
228        let data = vec![0.1, 0.2, 0.15, 0.25, 0.3, 0.18, 0.22, 0.12, 0.28, 0.16];
229        let b = Beta::fit(&data).unwrap();
230        // Mean of the data ≈ mean of Beta(α, β)
231        let data_mean = data.iter().sum::<f64>() / data.len() as f64;
232        assert!((b.mean() - data_mean).abs() < 1e-10);
233    }
234
235    #[test]
236    fn test_beta_invalid_params() {
237        assert!(Beta::new(0.0, 1.0).is_err());
238        assert!(Beta::new(1.0, 0.0).is_err());
239        assert!(Beta::new(-1.0, 1.0).is_err());
240    }
241}