Skip to main content

rs_stats/distributions/
fitting.rs

1//! # Distribution Fitting
2//!
3//! High-level API for automatic distribution detection and fitting.
4//!
5//! Given a dataset, this module:
6//! 1. **Detects** whether the data is discrete or continuous (`detect_data_type`)
7//! 2. **Fits** all applicable distribution candidates using MLE or MOM
8//! 3. **Ranks** them by AIC (lower = better fit, penalised for complexity)
9//! 4. **Validates** with a Kolmogorov-Smirnov goodness-of-fit test
10//!
11//! ## Key functions
12//!
13//! | Function | Description |
14//! |----------|-------------|
15//! | `auto_fit(data)` | Auto-detect type + return single best fit |
16//! | `fit_all(data)` | All 10 continuous distributions, ranked by AIC |
17//! | `fit_best(data)` | Best continuous distribution (lowest AIC) |
18//! | `fit_all_discrete(data)` | All 4 discrete distributions, ranked by AIC |
19//! | `fit_best_discrete(data)` | Best discrete distribution |
20//! | `detect_data_type(data)` | `DataKind::Discrete` or `DataKind::Continuous` |
21//! | `ks_test(data, cdf)` | Two-sided KS test for continuous distributions |
22//! | `ks_test_discrete(data, cdf)` | KS test for discrete distributions |
23//!
24//! ## Medical example — identifying the best distribution for drug half-life
25//!
26//! ```rust
27//! use rs_stats::distributions::fitting::{fit_all, auto_fit};
28//!
29//! // Drug half-life (hours) measured in 20 patients — typically log-normal in PK studies
30//! let half_lives = vec![
31//!     4.2, 6.1, 3.8, 9.5, 5.3, 7.4, 4.9, 11.2, 3.5, 6.8,
32//!     8.1, 4.4, 5.7, 7.0, 3.9, 10.3, 5.1,  6.5, 4.7,  8.6,
33//! ];
34//!
35//! // One-call: auto-detect + return single best (lowest AIC)
36//! let best = auto_fit(&half_lives).unwrap();
37//! println!("Best fit: {} (AIC={:.2}, KS p={:.3})", best.name, best.aic, best.ks_p_value);
38//!
39//! // Full ranking for model comparison and reporting
40//! println!("{:<15} {:>8} {:>8} {:>10}", "Distribution", "AIC", "BIC", "KS p-value");
41//! for r in fit_all(&half_lives).unwrap() {
42//!     println!("{:<15} {:>8.2} {:>8.2} {:>10.4}", r.name, r.aic, r.bic, r.ks_p_value);
43//! }
44//! ```
45//!
46//! ## Medical example — discrete event counts (adverse reactions)
47//!
48//! ```rust
49//! use rs_stats::distributions::fitting::fit_all_discrete;
50//!
51//! // Adverse drug reaction counts per patient over 6 months
52//! let adr_counts = vec![0.0, 1.0, 0.0, 2.0, 0.0, 1.0, 3.0, 0.0, 1.0, 0.0,
53//!                       2.0, 1.0, 0.0, 0.0, 1.0, 4.0, 0.0, 2.0, 1.0, 0.0];
54//!
55//! for r in fit_all_discrete(&adr_counts).unwrap() {
56//!     println!("{:<20} AIC={:.2}  KS p={:.3}", r.name, r.aic, r.ks_p_value);
57//! }
58//! // Poisson usually wins when variance ≈ mean; NegativeBinomial wins if overdispersed
59//! ```
60
61use crate::distributions::{
62    beta::Beta,
63    binomial_distribution::Binomial,
64    chi_squared::ChiSquared,
65    f_distribution::FDistribution,
66    gamma_distribution::Gamma,
67    geometric::Geometric,
68    lognormal::LogNormal,
69    negative_binomial::NegativeBinomial,
70    normal_distribution::Normal,
71    poisson_distribution::Poisson,
72    student_t::StudentT,
73    traits::{DiscreteDistribution, Distribution},
74    uniform_distribution::Uniform,
75    weibull::Weibull,
76};
77use crate::error::{StatsError, StatsResult};
78
79// ── Data kind detection ────────────────────────────────────────────────────────
80
81/// Whether a dataset looks discrete or continuous.
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum DataKind {
84    /// All values are non-negative integers (whole numbers ≥ 0).
85    Discrete,
86    /// Contains non-integer or negative values — treated as continuous.
87    Continuous,
88}
89
90/// Infer whether `data` is discrete (all non-negative integers) or continuous.
91///
92/// # Examples
93/// ```
94/// use rs_stats::distributions::fitting::{detect_data_type, DataKind};
95///
96/// assert_eq!(detect_data_type(&[0.0, 1.0, 2.0, 3.0]), DataKind::Discrete);
97/// assert_eq!(detect_data_type(&[0.5, 1.5, 2.3]), DataKind::Continuous);
98/// ```
99pub fn detect_data_type(data: &[f64]) -> DataKind {
100    if data
101        .iter()
102        .all(|&x| x >= 0.0 && x.fract() == 0.0 && x.is_finite())
103    {
104        DataKind::Discrete
105    } else {
106        DataKind::Continuous
107    }
108}
109
110// ── Kolmogorov-Smirnov test ────────────────────────────────────────────────────
111
112/// Result of a Kolmogorov-Smirnov goodness-of-fit test.
113#[derive(Debug, Clone, Copy)]
114pub struct KsResult {
115    /// KS statistic D (maximum absolute deviation between empirical and theoretical CDF).
116    pub statistic: f64,
117    /// Approximate two-sided p-value.
118    pub p_value: f64,
119}
120
121/// Two-sided Kolmogorov-Smirnov test of `data` against `cdf`.
122///
123/// Uses the Kolmogorov distribution for the p-value approximation.
124pub fn ks_test(data: &[f64], cdf: impl Fn(f64) -> f64) -> KsResult {
125    let n = data.len();
126    if n == 0 {
127        return KsResult {
128            statistic: 0.0,
129            p_value: 1.0,
130        };
131    }
132    let mut sorted = data.to_vec();
133    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
134
135    let nf = n as f64;
136    let mut d = 0.0_f64;
137    for (i, &x) in sorted.iter().enumerate() {
138        let f = cdf(x);
139        let upper = (i + 1) as f64 / nf;
140        let lower = i as f64 / nf;
141        d = d.max((upper - f).abs()).max((f - lower).abs());
142    }
143
144    let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
145
146    KsResult {
147        statistic: d,
148        p_value,
149    }
150}
151
152/// KS test for discrete distributions (uses PMF-based CDF on integer grid).
153pub fn ks_test_discrete(data: &[f64], cdf: impl Fn(u64) -> f64) -> KsResult {
154    let n = data.len();
155    if n == 0 {
156        return KsResult {
157            statistic: 0.0,
158            p_value: 1.0,
159        };
160    }
161    let mut sorted = data.to_vec();
162    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
163
164    let nf = n as f64;
165    let mut d = 0.0_f64;
166    for (i, &x) in sorted.iter().enumerate() {
167        let k = x.round() as u64;
168        let f = cdf(k);
169        let upper = (i + 1) as f64 / nf;
170        let lower = i as f64 / nf;
171        d = d.max((upper - f).abs()).max((f - lower).abs());
172    }
173
174    let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
175
176    KsResult {
177        statistic: d,
178        p_value,
179    }
180}
181
182/// Approximate p-value of the Kolmogorov distribution at `x`.
183fn kolmogorov_p(x: f64) -> f64 {
184    if x <= 0.0 {
185        return 1.0;
186    }
187    // P(K > x) = 2 Σ_{j=1}^∞ (−1)^{j+1} exp(−2j²x²)
188    let mut sum = 0.0_f64;
189    for j in 1_u32..=100 {
190        let term = (-(2.0 * (j as f64).powi(2) * x * x)).exp();
191        if j % 2 == 1 {
192            sum += term;
193        } else {
194            sum -= term;
195        }
196        if term < 1e-15 {
197            break;
198        }
199    }
200    (2.0 * sum).clamp(0.0, 1.0)
201}
202
203// ── Fit result ─────────────────────────────────────────────────────────────────
204
205/// Summary of a distribution fit.
206#[derive(Debug, Clone)]
207pub struct FitResult {
208    /// Distribution name (e.g. `"Normal"`, `"Gamma"`).
209    pub name: String,
210    /// Akaike Information Criterion (lower = better).
211    pub aic: f64,
212    /// Bayesian Information Criterion (lower = better).
213    pub bic: f64,
214    /// KS test statistic D.
215    pub ks_statistic: f64,
216    /// KS test p-value (higher = better fit).
217    pub ks_p_value: f64,
218}
219
220// ── Continuous fitting ─────────────────────────────────────────────────────────
221
222/// Fit all continuous distributions to `data` and return ranked results (by AIC).
223///
224/// Distributions that fail to fit (e.g. Beta when data are not in (0,1)) are silently skipped.
225pub fn fit_all(data: &[f64]) -> StatsResult<Vec<FitResult>> {
226    if data.is_empty() {
227        return Err(StatsError::InvalidInput {
228            message: "fit_all: data must not be empty".to_string(),
229        });
230    }
231
232    let mut results: Vec<FitResult> = Vec::new();
233
234    macro_rules! try_fit {
235        ($dist_type:ty, $fit_expr:expr) => {
236            if let Ok(dist) = $fit_expr {
237                if let (Ok(aic), Ok(bic)) = (dist.aic(data), dist.bic(data)) {
238                    if aic.is_finite() && bic.is_finite() {
239                        let ks = ks_test(data, |x| dist.cdf(x).unwrap_or(0.0));
240                        results.push(FitResult {
241                            name: dist.name().to_string(),
242                            aic,
243                            bic,
244                            ks_statistic: ks.statistic,
245                            ks_p_value: ks.p_value,
246                        });
247                    }
248                }
249            }
250        };
251    }
252
253    try_fit!(Normal, Normal::fit(data));
254    try_fit!(
255        Exponential,
256        crate::distributions::exponential_distribution::Exponential::fit(data)
257    );
258    try_fit!(Uniform, Uniform::fit(data));
259    try_fit!(Gamma, Gamma::fit(data));
260    try_fit!(LogNormal, LogNormal::fit(data));
261    try_fit!(Weibull, Weibull::fit(data));
262    try_fit!(Beta, Beta::fit(data));
263    try_fit!(StudentT, StudentT::fit(data));
264    try_fit!(FDistribution, FDistribution::fit(data));
265    try_fit!(ChiSquared, ChiSquared::fit(data));
266
267    if results.is_empty() {
268        return Err(StatsError::InvalidInput {
269            message: "fit_all: no distribution could be fitted to the data".to_string(),
270        });
271    }
272
273    results.sort_by(|a, b| {
274        a.aic
275            .partial_cmp(&b.aic)
276            .unwrap_or(std::cmp::Ordering::Equal)
277    });
278    Ok(results)
279}
280
281/// Fit all continuous distributions and return the best one (lowest AIC).
282pub fn fit_best(data: &[f64]) -> StatsResult<FitResult> {
283    let mut all = fit_all(data)?;
284    Ok(all.remove(0))
285}
286
287// ── Verbose fitting (with diagnostics) ────────────────────────────────────────
288
289/// A distribution candidate that failed to fit, with a human-readable reason.
290///
291/// Returned alongside successful fits by [`fit_all_verbose`] and [`fit_all_discrete_verbose`].
292#[derive(Debug, Clone)]
293pub struct SkippedFit {
294    /// Distribution name (e.g. `"Beta"`).
295    pub name: &'static str,
296    /// Why this distribution was not included (e.g. `"fit failed: data must be in (0,1)"`).
297    pub reason: String,
298}
299
300/// Like [`fit_all`] but also reports which distributions were skipped and why.
301///
302/// # Examples
303/// ```
304/// use rs_stats::distributions::fitting::fit_all_verbose;
305///
306/// // Data outside (0,1): Beta will be skipped
307/// let data = vec![2.1, 3.5, 1.8, 4.2, 2.9];
308/// let (fitted, skipped) = fit_all_verbose(&data).unwrap();
309/// println!("{} distributions fitted, {} skipped", fitted.len(), skipped.len());
310/// for s in &skipped {
311///     println!("  Skipped {}: {}", s.name, s.reason);
312/// }
313/// ```
314pub fn fit_all_verbose(data: &[f64]) -> StatsResult<(Vec<FitResult>, Vec<SkippedFit>)> {
315    if data.is_empty() {
316        return Err(StatsError::InvalidInput {
317            message: "fit_all_verbose: data must not be empty".to_string(),
318        });
319    }
320
321    let mut results: Vec<FitResult> = Vec::new();
322    let mut skipped: Vec<SkippedFit> = Vec::new();
323
324    macro_rules! try_fit_v {
325        ($name:literal, $fit_expr:expr) => {
326            match $fit_expr {
327                Err(e) => skipped.push(SkippedFit {
328                    name: $name,
329                    reason: format!("fit failed: {e}"),
330                }),
331                Ok(dist) => match (dist.aic(data), dist.bic(data)) {
332                    (Ok(aic), Ok(bic)) if aic.is_finite() && bic.is_finite() => {
333                        let ks = ks_test(data, |x| dist.cdf(x).unwrap_or(0.0));
334                        results.push(FitResult {
335                            name: dist.name().to_string(),
336                            aic,
337                            bic,
338                            ks_statistic: ks.statistic,
339                            ks_p_value: ks.p_value,
340                        });
341                    }
342                    _ => skipped.push(SkippedFit {
343                        name: $name,
344                        reason: "non-finite AIC/BIC (log-likelihood diverged)".to_string(),
345                    }),
346                },
347            }
348        };
349    }
350
351    try_fit_v!("Normal", Normal::fit(data));
352    try_fit_v!(
353        "Exponential",
354        crate::distributions::exponential_distribution::Exponential::fit(data)
355    );
356    try_fit_v!("Uniform", Uniform::fit(data));
357    try_fit_v!("Gamma", Gamma::fit(data));
358    try_fit_v!("LogNormal", LogNormal::fit(data));
359    try_fit_v!("Weibull", Weibull::fit(data));
360    try_fit_v!("Beta", Beta::fit(data));
361    try_fit_v!("StudentT", StudentT::fit(data));
362    try_fit_v!("FDistribution", FDistribution::fit(data));
363    try_fit_v!("ChiSquared", ChiSquared::fit(data));
364
365    if results.is_empty() {
366        return Err(StatsError::InvalidInput {
367            message: "fit_all_verbose: no distribution could be fitted to the data".to_string(),
368        });
369    }
370
371    results.sort_by(|a, b| {
372        a.aic
373            .partial_cmp(&b.aic)
374            .unwrap_or(std::cmp::Ordering::Equal)
375    });
376    Ok((results, skipped))
377}
378
379/// Like [`fit_all_discrete`] but also reports which distributions were skipped and why.
380pub fn fit_all_discrete_verbose(data: &[f64]) -> StatsResult<(Vec<FitResult>, Vec<SkippedFit>)> {
381    if data.is_empty() {
382        return Err(StatsError::InvalidInput {
383            message: "fit_all_discrete_verbose: data must not be empty".to_string(),
384        });
385    }
386
387    let int_data: Vec<u64> = data.iter().map(|&x| x.round() as u64).collect();
388    let mut results: Vec<FitResult> = Vec::new();
389    let mut skipped: Vec<SkippedFit> = Vec::new();
390
391    macro_rules! try_fit_disc_v {
392        ($name:literal, $fit_expr:expr) => {
393            match $fit_expr {
394                Err(e) => skipped.push(SkippedFit {
395                    name: $name,
396                    reason: format!("fit failed: {e}"),
397                }),
398                Ok(dist) => match (dist.aic(&int_data), dist.bic(&int_data)) {
399                    (Ok(aic), Ok(bic)) if aic.is_finite() && bic.is_finite() => {
400                        let ks = ks_test_discrete(data, |k| dist.cdf(k).unwrap_or(0.0));
401                        results.push(FitResult {
402                            name: dist.name().to_string(),
403                            aic,
404                            bic,
405                            ks_statistic: ks.statistic,
406                            ks_p_value: ks.p_value,
407                        });
408                    }
409                    _ => skipped.push(SkippedFit {
410                        name: $name,
411                        reason: "non-finite AIC/BIC (log-likelihood diverged)".to_string(),
412                    }),
413                },
414            }
415        };
416    }
417
418    try_fit_disc_v!("Poisson", Poisson::fit(data));
419    try_fit_disc_v!("Geometric", Geometric::fit(data));
420    try_fit_disc_v!("NegativeBinomial", NegativeBinomial::fit(data));
421    try_fit_disc_v!("Binomial", Binomial::fit(data));
422
423    if results.is_empty() {
424        return Err(StatsError::InvalidInput {
425            message: "fit_all_discrete_verbose: no distribution could be fitted".to_string(),
426        });
427    }
428
429    results.sort_by(|a, b| {
430        a.aic
431            .partial_cmp(&b.aic)
432            .unwrap_or(std::cmp::Ordering::Equal)
433    });
434    Ok((results, skipped))
435}
436
437// ── Discrete fitting ───────────────────────────────────────────────────────────
438
439/// Fit all discrete distributions to integer `data` (passed as f64) and return ranked results.
440///
441/// Skips distributions that cannot be fitted.
442pub fn fit_all_discrete(data: &[f64]) -> StatsResult<Vec<FitResult>> {
443    if data.is_empty() {
444        return Err(StatsError::InvalidInput {
445            message: "fit_all_discrete: data must not be empty".to_string(),
446        });
447    }
448
449    // Convert to u64 for discrete distributions
450    let int_data: Vec<u64> = data.iter().map(|&x| x.round() as u64).collect();
451
452    let mut results: Vec<FitResult> = Vec::new();
453
454    macro_rules! try_fit_disc {
455        ($fit_expr:expr) => {
456            if let Ok(dist) = $fit_expr {
457                if let (Ok(aic), Ok(bic)) = (dist.aic(&int_data), dist.bic(&int_data)) {
458                    if aic.is_finite() && bic.is_finite() {
459                        let ks = ks_test_discrete(data, |k| dist.cdf(k).unwrap_or(0.0));
460                        results.push(FitResult {
461                            name: dist.name().to_string(),
462                            aic,
463                            bic,
464                            ks_statistic: ks.statistic,
465                            ks_p_value: ks.p_value,
466                        });
467                    }
468                }
469            }
470        };
471    }
472
473    try_fit_disc!(Poisson::fit(data));
474    try_fit_disc!(Geometric::fit(data));
475    try_fit_disc!(NegativeBinomial::fit(data));
476    try_fit_disc!(Binomial::fit(data));
477
478    if results.is_empty() {
479        return Err(StatsError::InvalidInput {
480            message: "fit_all_discrete: no distribution could be fitted to the data".to_string(),
481        });
482    }
483
484    results.sort_by(|a, b| {
485        a.aic
486            .partial_cmp(&b.aic)
487            .unwrap_or(std::cmp::Ordering::Equal)
488    });
489    Ok(results)
490}
491
492/// Fit discrete distributions and return the best (lowest AIC).
493pub fn fit_best_discrete(data: &[f64]) -> StatsResult<FitResult> {
494    let mut all = fit_all_discrete(data)?;
495    Ok(all.remove(0))
496}
497
498// ── Auto-detect and fit ────────────────────────────────────────────────────────
499
500/// Automatically detect whether data is discrete or continuous, then fit all applicable
501/// distributions and return the best match (lowest AIC).
502///
503/// # Examples
504/// ```
505/// use rs_stats::distributions::fitting::auto_fit;
506///
507/// let data = vec![1.2, 2.3, 1.8, 2.9, 1.5];
508/// let best = auto_fit(&data).unwrap();
509/// println!("Best fit: {}", best.name);
510/// ```
511pub fn auto_fit(data: &[f64]) -> StatsResult<FitResult> {
512    match detect_data_type(data) {
513        DataKind::Discrete => fit_best_discrete(data),
514        DataKind::Continuous => fit_best(data),
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_detect_data_type_discrete() {
524        assert_eq!(detect_data_type(&[0.0, 1.0, 2.0, 3.0]), DataKind::Discrete);
525        assert_eq!(detect_data_type(&[0.0, 0.0, 1.0]), DataKind::Discrete);
526    }
527
528    #[test]
529    fn test_detect_data_type_continuous() {
530        assert_eq!(detect_data_type(&[0.5, 1.5, 2.3]), DataKind::Continuous);
531        assert_eq!(detect_data_type(&[-1.0, 0.0, 1.0]), DataKind::Continuous);
532        assert_eq!(detect_data_type(&[1.0, 2.5, 3.0]), DataKind::Continuous);
533    }
534
535    #[test]
536    fn test_ks_test_uniform() {
537        // Data from Uniform(0,1) should give large p-value against U(0,1) CDF
538        let data: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
539        let ks = ks_test(&data, |x| x.clamp(0.0, 1.0));
540        assert!(ks.statistic < 0.15);
541    }
542
543    #[test]
544    fn test_fit_all_returns_results() {
545        let data: Vec<f64> = (0..50).map(|i| (i as f64) * 0.1 + 0.5).collect();
546        let results = fit_all(&data).unwrap();
547        assert!(!results.is_empty());
548        // Results sorted by AIC (ascending)
549        for i in 1..results.len() {
550            assert!(results[i].aic >= results[i - 1].aic);
551        }
552    }
553
554    #[test]
555    fn test_fit_best_normal_data() {
556        // Data generated from N(5, 1)
557        let data = vec![
558            4.1, 5.2, 5.8, 4.7, 5.3, 4.9, 6.1, 4.5, 5.5, 5.0, 4.8, 5.1, 4.3, 5.7, 4.6, 5.4, 4.2,
559            5.9, 5.2, 4.4,
560        ];
561        let best = fit_best(&data).unwrap();
562        // Normal should win (or be competitive)
563        assert!(best.aic.is_finite());
564    }
565
566    #[test]
567    fn test_fit_all_discrete() {
568        let data = vec![0.0, 1.0, 2.0, 3.0, 1.0, 0.0, 2.0, 1.0, 0.0, 4.0];
569        let results = fit_all_discrete(&data).unwrap();
570        assert!(!results.is_empty());
571    }
572
573    #[test]
574    fn test_auto_fit_continuous() {
575        let data = vec![1.5, 2.3, 1.8, 2.1, 2.7, 1.9, 2.4, 2.0];
576        let best = auto_fit(&data).unwrap();
577        assert!(!best.name.is_empty());
578    }
579
580    #[test]
581    fn test_auto_fit_discrete() {
582        let data = vec![0.0, 1.0, 2.0, 1.0, 0.0, 3.0, 1.0, 2.0];
583        let best = auto_fit(&data).unwrap();
584        assert!(!best.name.is_empty());
585    }
586
587    #[test]
588    fn test_fit_all_empty_data() {
589        assert!(fit_all(&[]).is_err());
590    }
591}