Skip to main content

rs_stats/distributions/
fitting.rs

1//! # Distribution Fitting
2//!
3//! High-level API for automatic distribution detection and fitting.
4//!
5//! Given a dataset, this module:
6//! 1. **Detects** whether the data is discrete or continuous (`detect_data_type`)
7//! 2. **Fits** all applicable distribution candidates using MLE or MOM
8//! 3. **Ranks** them by AIC (lower = better fit, penalised for complexity)
9//! 4. **Validates** with a Kolmogorov-Smirnov goodness-of-fit test
10//!
11//! ## Key functions
12//!
13//! | Function | Description |
14//! |----------|-------------|
15//! | `auto_fit(data)` | Auto-detect type + return single best fit |
16//! | `fit_all(data)` | All 10 continuous distributions, ranked by AIC |
17//! | `fit_best(data)` | Best continuous distribution (lowest AIC) |
18//! | `fit_all_discrete(data)` | All 4 discrete distributions, ranked by AIC |
19//! | `fit_best_discrete(data)` | Best discrete distribution |
20//! | `detect_data_type(data)` | `DataKind::Discrete` or `DataKind::Continuous` |
21//! | `ks_test(data, cdf)` | Two-sided KS test for continuous distributions |
22//! | `ks_test_discrete(data, cdf)` | KS test for discrete distributions |
23//!
24//! ## Medical example — identifying the best distribution for drug half-life
25//!
26//! ```rust
27//! use rs_stats::distributions::fitting::{fit_all, auto_fit};
28//!
29//! // Drug half-life (hours) measured in 20 patients — typically log-normal in PK studies
30//! let half_lives = vec![
31//!     4.2, 6.1, 3.8, 9.5, 5.3, 7.4, 4.9, 11.2, 3.5, 6.8,
32//!     8.1, 4.4, 5.7, 7.0, 3.9, 10.3, 5.1,  6.5, 4.7,  8.6,
33//! ];
34//!
35//! // One-call: auto-detect + return single best (lowest AIC)
36//! let best = auto_fit(&half_lives).unwrap();
37//! println!("Best fit: {} (AIC={:.2}, KS p={:.3})", best.name, best.aic, best.ks_p_value);
38//!
39//! // Full ranking for model comparison and reporting
40//! println!("{:<15} {:>8} {:>8} {:>10}", "Distribution", "AIC", "BIC", "KS p-value");
41//! for r in fit_all(&half_lives).unwrap() {
42//!     println!("{:<15} {:>8.2} {:>8.2} {:>10.4}", r.name, r.aic, r.bic, r.ks_p_value);
43//! }
44//! ```
45//!
46//! ## Medical example — discrete event counts (adverse reactions)
47//!
48//! ```rust
49//! use rs_stats::distributions::fitting::fit_all_discrete;
50//!
51//! // Adverse drug reaction counts per patient over 6 months
52//! let adr_counts = vec![0.0, 1.0, 0.0, 2.0, 0.0, 1.0, 3.0, 0.0, 1.0, 0.0,
53//!                       2.0, 1.0, 0.0, 0.0, 1.0, 4.0, 0.0, 2.0, 1.0, 0.0];
54//!
55//! for r in fit_all_discrete(&adr_counts).unwrap() {
56//!     println!("{:<20} AIC={:.2}  KS p={:.3}", r.name, r.aic, r.ks_p_value);
57//! }
58//! // Poisson usually wins when variance ≈ mean; NegativeBinomial wins if overdispersed
59//! ```
60
61use crate::distributions::{
62    beta::Beta,
63    binomial_distribution::Binomial,
64    chi_squared::ChiSquared,
65    f_distribution::FDistribution,
66    gamma_distribution::Gamma,
67    geometric::Geometric,
68    lognormal::LogNormal,
69    negative_binomial::NegativeBinomial,
70    normal_distribution::Normal,
71    poisson_distribution::Poisson,
72    student_t::StudentT,
73    traits::{DiscreteDistribution, Distribution},
74    uniform_distribution::Uniform,
75    weibull::Weibull,
76};
77use crate::error::{StatsError, StatsResult};
78
79// ── Data kind detection ────────────────────────────────────────────────────────
80
81/// Whether a dataset looks discrete or continuous.
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum DataKind {
84    /// All values are non-negative integers (whole numbers ≥ 0).
85    Discrete,
86    /// Contains non-integer or negative values — treated as continuous.
87    Continuous,
88}
89
90/// Infer whether `data` is discrete (all non-negative integers) or continuous.
91///
92/// # Examples
93/// ```
94/// use rs_stats::distributions::fitting::{detect_data_type, DataKind};
95///
96/// assert_eq!(detect_data_type(&[0.0, 1.0, 2.0, 3.0]), DataKind::Discrete);
97/// assert_eq!(detect_data_type(&[0.5, 1.5, 2.3]), DataKind::Continuous);
98/// ```
99pub fn detect_data_type(data: &[f64]) -> DataKind {
100    if data
101        .iter()
102        .all(|&x| x >= 0.0 && x.fract() == 0.0 && x.is_finite())
103    {
104        DataKind::Discrete
105    } else {
106        DataKind::Continuous
107    }
108}
109
110// ── Kolmogorov-Smirnov test ────────────────────────────────────────────────────
111
112/// Result of a Kolmogorov-Smirnov goodness-of-fit test.
113#[derive(Debug, Clone, Copy)]
114pub struct KsResult {
115    /// KS statistic D (maximum absolute deviation between empirical and theoretical CDF).
116    pub statistic: f64,
117    /// Approximate two-sided p-value.
118    pub p_value: f64,
119}
120
121/// Two-sided Kolmogorov-Smirnov test of `data` against `cdf`.
122///
123/// Uses the Kolmogorov distribution for the p-value approximation.
124pub fn ks_test(data: &[f64], cdf: impl Fn(f64) -> f64) -> KsResult {
125    let n = data.len();
126    if n == 0 {
127        return KsResult {
128            statistic: 0.0,
129            p_value: 1.0,
130        };
131    }
132    let mut sorted = data.to_vec();
133    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
134
135    let nf = n as f64;
136    let mut d = 0.0_f64;
137    for (i, &x) in sorted.iter().enumerate() {
138        let f = cdf(x);
139        let upper = (i + 1) as f64 / nf;
140        let lower = i as f64 / nf;
141        d = d.max((upper - f).abs()).max((f - lower).abs());
142    }
143
144    let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
145
146    KsResult {
147        statistic: d,
148        p_value,
149    }
150}
151
152/// KS test for discrete distributions (uses PMF-based CDF on integer grid).
153pub fn ks_test_discrete(data: &[f64], cdf: impl Fn(u64) -> f64) -> KsResult {
154    let n = data.len();
155    if n == 0 {
156        return KsResult {
157            statistic: 0.0,
158            p_value: 1.0,
159        };
160    }
161    let mut sorted = data.to_vec();
162    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
163
164    let nf = n as f64;
165    let mut d = 0.0_f64;
166    for (i, &x) in sorted.iter().enumerate() {
167        let k = x.round() as u64;
168        let f = cdf(k);
169        let upper = (i + 1) as f64 / nf;
170        let lower = i as f64 / nf;
171        d = d.max((upper - f).abs()).max((f - lower).abs());
172    }
173
174    let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
175
176    KsResult {
177        statistic: d,
178        p_value,
179    }
180}
181
182/// Approximate p-value of the Kolmogorov distribution at `x`.
183fn kolmogorov_p(x: f64) -> f64 {
184    if x <= 0.0 {
185        return 1.0;
186    }
187    // P(K > x) = 2 Σ_{j=1}^∞ (−1)^{j+1} exp(−2j²x²)
188    let mut sum = 0.0_f64;
189    for j in 1_u32..=100 {
190        let term = (-(2.0 * (j as f64).powi(2) * x * x)).exp();
191        if j % 2 == 1 {
192            sum += term;
193        } else {
194            sum -= term;
195        }
196        if term < 1e-15 {
197            break;
198        }
199    }
200    (2.0 * sum).clamp(0.0, 1.0)
201}
202
203// ── Fit result ─────────────────────────────────────────────────────────────────
204
205/// Summary of a distribution fit.
206#[derive(Debug, Clone)]
207pub struct FitResult {
208    /// Distribution name (e.g. `"Normal"`, `"Gamma"`).
209    pub name: String,
210    /// Akaike Information Criterion (lower = better).
211    pub aic: f64,
212    /// Bayesian Information Criterion (lower = better).
213    pub bic: f64,
214    /// KS test statistic D.
215    pub ks_statistic: f64,
216    /// KS test p-value (higher = better fit).
217    pub ks_p_value: f64,
218}
219
220// ── Continuous fitting ─────────────────────────────────────────────────────────
221
222/// Fit all continuous distributions to `data` and return ranked results (by AIC).
223///
224/// Distributions that fail to fit (e.g. Beta when data are not in (0,1)) are silently skipped.
225pub fn fit_all(data: &[f64]) -> StatsResult<Vec<FitResult>> {
226    if data.is_empty() {
227        return Err(StatsError::InvalidInput {
228            message: "fit_all: data must not be empty".to_string(),
229        });
230    }
231
232    let mut results: Vec<FitResult> = Vec::new();
233
234    macro_rules! try_fit {
235        ($dist_type:ty, $fit_expr:expr) => {
236            if let Ok(dist) = $fit_expr {
237                if let (Ok(aic), Ok(bic)) = (dist.aic(data), dist.bic(data)) {
238                    if aic.is_finite() && bic.is_finite() {
239                        let ks = ks_test(data, |x| dist.cdf(x).unwrap_or(0.0));
240                        results.push(FitResult {
241                            name: dist.name().to_string(),
242                            aic,
243                            bic,
244                            ks_statistic: ks.statistic,
245                            ks_p_value: ks.p_value,
246                        });
247                    }
248                }
249            }
250        };
251    }
252
253    try_fit!(Normal, Normal::fit(data));
254    try_fit!(
255        Exponential,
256        crate::distributions::exponential_distribution::Exponential::fit(data)
257    );
258    try_fit!(Uniform, Uniform::fit(data));
259    try_fit!(Gamma, Gamma::fit(data));
260    try_fit!(LogNormal, LogNormal::fit(data));
261    try_fit!(Weibull, Weibull::fit(data));
262    try_fit!(Beta, Beta::fit(data));
263    try_fit!(StudentT, StudentT::fit(data));
264    try_fit!(FDistribution, FDistribution::fit(data));
265    try_fit!(ChiSquared, ChiSquared::fit(data));
266
267    if results.is_empty() {
268        return Err(StatsError::InvalidInput {
269            message: "fit_all: no distribution could be fitted to the data".to_string(),
270        });
271    }
272
273    results.sort_by(|a, b| {
274        a.aic
275            .partial_cmp(&b.aic)
276            .unwrap_or(std::cmp::Ordering::Equal)
277    });
278    Ok(results)
279}
280
281/// Fit all continuous distributions and return the best one (lowest AIC).
282pub fn fit_best(data: &[f64]) -> StatsResult<FitResult> {
283    let mut all = fit_all(data)?;
284    Ok(all.remove(0))
285}
286
287// ── Discrete fitting ───────────────────────────────────────────────────────────
288
289/// Fit all discrete distributions to integer `data` (passed as f64) and return ranked results.
290///
291/// Skips distributions that cannot be fitted.
292pub fn fit_all_discrete(data: &[f64]) -> StatsResult<Vec<FitResult>> {
293    if data.is_empty() {
294        return Err(StatsError::InvalidInput {
295            message: "fit_all_discrete: data must not be empty".to_string(),
296        });
297    }
298
299    // Convert to u64 for discrete distributions
300    let int_data: Vec<u64> = data.iter().map(|&x| x.round() as u64).collect();
301
302    let mut results: Vec<FitResult> = Vec::new();
303
304    macro_rules! try_fit_disc {
305        ($fit_expr:expr) => {
306            if let Ok(dist) = $fit_expr {
307                if let (Ok(aic), Ok(bic)) = (dist.aic(&int_data), dist.bic(&int_data)) {
308                    if aic.is_finite() && bic.is_finite() {
309                        let ks = ks_test_discrete(data, |k| dist.cdf(k).unwrap_or(0.0));
310                        results.push(FitResult {
311                            name: dist.name().to_string(),
312                            aic,
313                            bic,
314                            ks_statistic: ks.statistic,
315                            ks_p_value: ks.p_value,
316                        });
317                    }
318                }
319            }
320        };
321    }
322
323    try_fit_disc!(Poisson::fit(data));
324    try_fit_disc!(Geometric::fit(data));
325    try_fit_disc!(NegativeBinomial::fit(data));
326    try_fit_disc!(Binomial::fit(data));
327
328    if results.is_empty() {
329        return Err(StatsError::InvalidInput {
330            message: "fit_all_discrete: no distribution could be fitted to the data".to_string(),
331        });
332    }
333
334    results.sort_by(|a, b| {
335        a.aic
336            .partial_cmp(&b.aic)
337            .unwrap_or(std::cmp::Ordering::Equal)
338    });
339    Ok(results)
340}
341
342/// Fit discrete distributions and return the best (lowest AIC).
343pub fn fit_best_discrete(data: &[f64]) -> StatsResult<FitResult> {
344    let mut all = fit_all_discrete(data)?;
345    Ok(all.remove(0))
346}
347
348// ── Auto-detect and fit ────────────────────────────────────────────────────────
349
350/// Automatically detect whether data is discrete or continuous, then fit all applicable
351/// distributions and return the best match (lowest AIC).
352///
353/// # Examples
354/// ```
355/// use rs_stats::distributions::fitting::auto_fit;
356///
357/// let data = vec![1.2, 2.3, 1.8, 2.9, 1.5];
358/// let best = auto_fit(&data).unwrap();
359/// println!("Best fit: {}", best.name);
360/// ```
361pub fn auto_fit(data: &[f64]) -> StatsResult<FitResult> {
362    match detect_data_type(data) {
363        DataKind::Discrete => fit_best_discrete(data),
364        DataKind::Continuous => fit_best(data),
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_detect_data_type_discrete() {
374        assert_eq!(detect_data_type(&[0.0, 1.0, 2.0, 3.0]), DataKind::Discrete);
375        assert_eq!(detect_data_type(&[0.0, 0.0, 1.0]), DataKind::Discrete);
376    }
377
378    #[test]
379    fn test_detect_data_type_continuous() {
380        assert_eq!(detect_data_type(&[0.5, 1.5, 2.3]), DataKind::Continuous);
381        assert_eq!(detect_data_type(&[-1.0, 0.0, 1.0]), DataKind::Continuous);
382        assert_eq!(detect_data_type(&[1.0, 2.5, 3.0]), DataKind::Continuous);
383    }
384
385    #[test]
386    fn test_ks_test_uniform() {
387        // Data from Uniform(0,1) should give large p-value against U(0,1) CDF
388        let data: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
389        let ks = ks_test(&data, |x| x.clamp(0.0, 1.0));
390        assert!(ks.statistic < 0.15);
391    }
392
393    #[test]
394    fn test_fit_all_returns_results() {
395        let data: Vec<f64> = (0..50).map(|i| (i as f64) * 0.1 + 0.5).collect();
396        let results = fit_all(&data).unwrap();
397        assert!(!results.is_empty());
398        // Results sorted by AIC (ascending)
399        for i in 1..results.len() {
400            assert!(results[i].aic >= results[i - 1].aic);
401        }
402    }
403
404    #[test]
405    fn test_fit_best_normal_data() {
406        // Data generated from N(5, 1)
407        let data = vec![
408            4.1, 5.2, 5.8, 4.7, 5.3, 4.9, 6.1, 4.5, 5.5, 5.0, 4.8, 5.1, 4.3, 5.7, 4.6, 5.4, 4.2,
409            5.9, 5.2, 4.4,
410        ];
411        let best = fit_best(&data).unwrap();
412        // Normal should win (or be competitive)
413        assert!(best.aic.is_finite());
414    }
415
416    #[test]
417    fn test_fit_all_discrete() {
418        let data = vec![0.0, 1.0, 2.0, 3.0, 1.0, 0.0, 2.0, 1.0, 0.0, 4.0];
419        let results = fit_all_discrete(&data).unwrap();
420        assert!(!results.is_empty());
421    }
422
423    #[test]
424    fn test_auto_fit_continuous() {
425        let data = vec![1.5, 2.3, 1.8, 2.1, 2.7, 1.9, 2.4, 2.0];
426        let best = auto_fit(&data).unwrap();
427        assert!(!best.name.is_empty());
428    }
429
430    #[test]
431    fn test_auto_fit_discrete() {
432        let data = vec![0.0, 1.0, 2.0, 1.0, 0.0, 3.0, 1.0, 2.0];
433        let best = auto_fit(&data).unwrap();
434        assert!(!best.name.is_empty());
435    }
436
437    #[test]
438    fn test_fit_all_empty_data() {
439        assert!(fit_all(&[]).is_err());
440    }
441}