Skip to main content

rs_stats/distributions/
traits.rs

1//! # Distribution Traits
2//!
3//! Defines the [`Distribution`] and [`DiscreteDistribution`] traits that provide
4//! a unified interface for all statistical distributions in this crate.
5//!
6//! ## Usage
7//!
8//! ```
9//! use rs_stats::distributions::traits::Distribution;
10//! use rs_stats::distributions::normal_distribution::Normal;
11//!
12//! let n = Normal::new(0.0, 1.0).unwrap();
13//! let pdf = n.pdf(0.0).unwrap();
14//! assert!((pdf - 0.398_942_280_4).abs() < 1e-8);
15//! ```
16
17use crate::error::{StatsError, StatsResult};
18
19// ── Continuous distributions ───────────────────────────────────────────────────
20
21/// Unified interface for continuous probability distributions.
22///
23/// All methods return `StatsResult` to propagate domain errors (e.g. `p ∉ [0,1]`).
24///
25/// The trait is **object-safe**: `Box<dyn Distribution>` works at runtime.
26/// The `fit` associated function is intentionally *not* part of the trait to preserve
27/// object safety; each concrete type exposes `Dist::fit(data)` directly.
28pub trait Distribution {
29    /// Human-readable distribution name, e.g. `"Normal"`.
30    fn name(&self) -> &str;
31
32    /// Number of free parameters (used when computing AIC / BIC).
33    fn num_params(&self) -> usize;
34
35    /// Probability density function f(x).
36    fn pdf(&self, x: f64) -> StatsResult<f64>;
37
38    /// Natural logarithm of the PDF: ln f(x).
39    ///
40    /// Default implementation delegates to `pdf`; override for numerical stability.
41    fn logpdf(&self, x: f64) -> StatsResult<f64> {
42        self.pdf(x).map(|p| p.ln())
43    }
44
45    /// Cumulative distribution function F(x) = P(X ≤ x).
46    fn cdf(&self, x: f64) -> StatsResult<f64>;
47
48    /// Quantile (inverse CDF): find x such that F(x) = p.
49    fn inverse_cdf(&self, p: f64) -> StatsResult<f64>;
50
51    /// Mean (expected value) μ.
52    fn mean(&self) -> f64;
53
54    /// Variance σ².
55    fn variance(&self) -> f64;
56
57    /// Standard deviation σ = √(variance).
58    fn std_dev(&self) -> f64 {
59        self.variance().sqrt()
60    }
61
62    /// Sum of log-likelihoods: Σ ln f(xᵢ).
63    fn log_likelihood(&self, data: &[f64]) -> StatsResult<f64> {
64        let mut ll = 0.0_f64;
65        for &x in data {
66            ll += self.logpdf(x)?;
67        }
68        Ok(ll)
69    }
70
71    /// Akaike Information Criterion: AIC = 2k − 2·ln(L̂).
72    fn aic(&self, data: &[f64]) -> StatsResult<f64> {
73        let ll = self.log_likelihood(data)?;
74        Ok(2.0 * self.num_params() as f64 - 2.0 * ll)
75    }
76
77    /// Bayesian Information Criterion: BIC = k·ln(n) − 2·ln(L̂).
78    fn bic(&self, data: &[f64]) -> StatsResult<f64> {
79        let ll = self.log_likelihood(data)?;
80        let n = data.len() as f64;
81        Ok(self.num_params() as f64 * n.ln() - 2.0 * ll)
82    }
83}
84
85// ── Discrete distributions ─────────────────────────────────────────────────────
86
87/// Unified interface for discrete probability distributions.
88///
89/// Works with non-negative integer observations represented as `u64`.
90///
91/// Object-safe: `Box<dyn DiscreteDistribution>` is valid.
92pub trait DiscreteDistribution {
93    /// Human-readable distribution name.
94    fn name(&self) -> &str;
95
96    /// Number of free parameters (used for AIC / BIC).
97    fn num_params(&self) -> usize;
98
99    /// Probability mass function P(X = k).
100    fn pmf(&self, k: u64) -> StatsResult<f64>;
101
102    /// Natural logarithm of the PMF: ln P(X = k).
103    ///
104    /// Default: delegates to `pmf`; override for stability when p is tiny.
105    fn logpmf(&self, k: u64) -> StatsResult<f64> {
106        self.pmf(k).map(|p| p.ln())
107    }
108
109    /// Cumulative distribution function P(X ≤ k).
110    fn cdf(&self, k: u64) -> StatsResult<f64>;
111
112    /// Quantile function: smallest k ≥ 0 such that CDF(k) ≥ p.
113    ///
114    /// Returns an error if `p ∉ [0, 1]`.
115    ///
116    /// The default implementation performs an **exponential search** followed by
117    /// **binary search** on the CDF, which is correct for any monotone CDF but may
118    /// be slow for distributions with very large quantiles.
119    /// Override with a closed-form formula when available.
120    ///
121    /// # Examples
122    /// ```
123    /// use rs_stats::distributions::poisson_distribution::Poisson;
124    /// use rs_stats::DiscreteDistribution;
125    ///
126    /// let p = Poisson::new(3.0).unwrap();
127    /// // Median of Poisson(3) should be 3
128    /// let median = p.inverse_cdf(0.5).unwrap();
129    /// assert!(median == 2 || median == 3);
130    /// ```
131    fn inverse_cdf(&self, p: f64) -> StatsResult<u64> {
132        if !(0.0..=1.0).contains(&p) {
133            return Err(StatsError::InvalidInput {
134                message: format!("inverse_cdf: p must be in [0, 1], got {p}"),
135            });
136        }
137        if p == 0.0 {
138            return Ok(0);
139        }
140        // Phase 1 — exponential search to bracket the answer.
141        let mut hi: u64 = 1;
142        while self.cdf(hi)? < p {
143            hi = hi.saturating_mul(2);
144            if hi == u64::MAX {
145                return Err(StatsError::NumericalError {
146                    message: "inverse_cdf: quantile exceeds u64::MAX".to_string(),
147                });
148            }
149        }
150        // Phase 2 — binary search in [0, hi].
151        let mut lo: u64 = 0;
152        while lo < hi {
153            let mid = lo + (hi - lo) / 2;
154            if self.cdf(mid)? < p {
155                lo = mid + 1;
156            } else {
157                hi = mid;
158            }
159        }
160        Ok(lo)
161    }
162
163    /// Mean (expected value) μ.
164    fn mean(&self) -> f64;
165
166    /// Variance σ².
167    fn variance(&self) -> f64;
168
169    /// Standard deviation σ = √(variance).
170    fn std_dev(&self) -> f64 {
171        self.variance().sqrt()
172    }
173
174    /// Sum of log-PMFs: Σ ln P(X = kᵢ).
175    fn log_likelihood(&self, data: &[u64]) -> StatsResult<f64> {
176        let mut ll = 0.0_f64;
177        for &k in data {
178            ll += self.logpmf(k)?;
179        }
180        Ok(ll)
181    }
182
183    /// AIC = 2k − 2·ln(L̂).
184    fn aic(&self, data: &[u64]) -> StatsResult<f64> {
185        let ll = self.log_likelihood(data)?;
186        Ok(2.0 * self.num_params() as f64 - 2.0 * ll)
187    }
188
189    /// BIC = k·ln(n) − 2·ln(L̂).
190    fn bic(&self, data: &[u64]) -> StatsResult<f64> {
191        let ll = self.log_likelihood(data)?;
192        let n = data.len() as f64;
193        Ok(self.num_params() as f64 * n.ln() - 2.0 * ll)
194    }
195}