use crate::distributions::{
beta::Beta,
binomial_distribution::Binomial,
chi_squared::ChiSquared,
f_distribution::FDistribution,
gamma_distribution::Gamma,
geometric::Geometric,
lognormal::LogNormal,
negative_binomial::NegativeBinomial,
normal_distribution::Normal,
poisson_distribution::Poisson,
student_t::StudentT,
traits::{DiscreteDistribution, Distribution},
uniform_distribution::Uniform,
weibull::Weibull,
};
use crate::error::{StatsError, StatsResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DataKind {
Discrete,
Continuous,
}
pub fn detect_data_type(data: &[f64]) -> DataKind {
if data
.iter()
.all(|&x| x >= 0.0 && x.fract() == 0.0 && x.is_finite())
{
DataKind::Discrete
} else {
DataKind::Continuous
}
}
#[derive(Debug, Clone, Copy)]
pub struct KsResult {
pub statistic: f64,
pub p_value: f64,
}
pub fn ks_test(data: &[f64], cdf: impl Fn(f64) -> f64) -> KsResult {
let n = data.len();
if n == 0 {
return KsResult {
statistic: 0.0,
p_value: 1.0,
};
}
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let nf = n as f64;
let mut d = 0.0_f64;
for (i, &x) in sorted.iter().enumerate() {
let f = cdf(x);
let upper = (i + 1) as f64 / nf;
let lower = i as f64 / nf;
d = d.max((upper - f).abs()).max((f - lower).abs());
}
let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
KsResult {
statistic: d,
p_value,
}
}
pub fn ks_test_discrete(data: &[f64], cdf: impl Fn(u64) -> f64) -> KsResult {
let n = data.len();
if n == 0 {
return KsResult {
statistic: 0.0,
p_value: 1.0,
};
}
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let nf = n as f64;
let mut d = 0.0_f64;
for (i, &x) in sorted.iter().enumerate() {
let k = x.round() as u64;
let f = cdf(k);
let upper = (i + 1) as f64 / nf;
let lower = i as f64 / nf;
d = d.max((upper - f).abs()).max((f - lower).abs());
}
let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
KsResult {
statistic: d,
p_value,
}
}
fn kolmogorov_p(x: f64) -> f64 {
if x <= 0.0 {
return 1.0;
}
let mut sum = 0.0_f64;
for j in 1_u32..=100 {
let term = (-(2.0 * (j as f64).powi(2) * x * x)).exp();
if j % 2 == 1 {
sum += term;
} else {
sum -= term;
}
if term < 1e-15 {
break;
}
}
(2.0 * sum).clamp(0.0, 1.0)
}
#[derive(Debug, Clone)]
pub struct FitResult {
pub name: String,
pub aic: f64,
pub bic: f64,
pub ks_statistic: f64,
pub ks_p_value: f64,
}
pub fn fit_all(data: &[f64]) -> StatsResult<Vec<FitResult>> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all: data must not be empty".to_string(),
});
}
let mut results: Vec<FitResult> = Vec::new();
macro_rules! try_fit {
($dist_type:ty, $fit_expr:expr) => {
if let Ok(dist) = $fit_expr {
if let (Ok(aic), Ok(bic)) = (dist.aic(data), dist.bic(data)) {
if aic.is_finite() && bic.is_finite() {
let ks = ks_test(data, |x| dist.cdf(x).unwrap_or(0.0));
results.push(FitResult {
name: dist.name().to_string(),
aic,
bic,
ks_statistic: ks.statistic,
ks_p_value: ks.p_value,
});
}
}
}
};
}
try_fit!(Normal, Normal::fit(data));
try_fit!(
Exponential,
crate::distributions::exponential_distribution::Exponential::fit(data)
);
try_fit!(Uniform, Uniform::fit(data));
try_fit!(Gamma, Gamma::fit(data));
try_fit!(LogNormal, LogNormal::fit(data));
try_fit!(Weibull, Weibull::fit(data));
try_fit!(Beta, Beta::fit(data));
try_fit!(StudentT, StudentT::fit(data));
try_fit!(FDistribution, FDistribution::fit(data));
try_fit!(ChiSquared, ChiSquared::fit(data));
if results.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all: no distribution could be fitted to the data".to_string(),
});
}
results.sort_by(|a, b| {
a.aic
.partial_cmp(&b.aic)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
pub fn fit_best(data: &[f64]) -> StatsResult<FitResult> {
let mut all = fit_all(data)?;
Ok(all.remove(0))
}
#[derive(Debug, Clone)]
pub struct SkippedFit {
pub name: &'static str,
pub reason: String,
}
pub fn fit_all_verbose(data: &[f64]) -> StatsResult<(Vec<FitResult>, Vec<SkippedFit>)> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all_verbose: data must not be empty".to_string(),
});
}
let mut results: Vec<FitResult> = Vec::new();
let mut skipped: Vec<SkippedFit> = Vec::new();
macro_rules! try_fit_v {
($name:literal, $fit_expr:expr) => {
match $fit_expr {
Err(e) => skipped.push(SkippedFit {
name: $name,
reason: format!("fit failed: {e}"),
}),
Ok(dist) => match (dist.aic(data), dist.bic(data)) {
(Ok(aic), Ok(bic)) if aic.is_finite() && bic.is_finite() => {
let ks = ks_test(data, |x| dist.cdf(x).unwrap_or(0.0));
results.push(FitResult {
name: dist.name().to_string(),
aic,
bic,
ks_statistic: ks.statistic,
ks_p_value: ks.p_value,
});
}
_ => skipped.push(SkippedFit {
name: $name,
reason: "non-finite AIC/BIC (log-likelihood diverged)".to_string(),
}),
},
}
};
}
try_fit_v!("Normal", Normal::fit(data));
try_fit_v!(
"Exponential",
crate::distributions::exponential_distribution::Exponential::fit(data)
);
try_fit_v!("Uniform", Uniform::fit(data));
try_fit_v!("Gamma", Gamma::fit(data));
try_fit_v!("LogNormal", LogNormal::fit(data));
try_fit_v!("Weibull", Weibull::fit(data));
try_fit_v!("Beta", Beta::fit(data));
try_fit_v!("StudentT", StudentT::fit(data));
try_fit_v!("FDistribution", FDistribution::fit(data));
try_fit_v!("ChiSquared", ChiSquared::fit(data));
if results.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all_verbose: no distribution could be fitted to the data".to_string(),
});
}
results.sort_by(|a, b| {
a.aic
.partial_cmp(&b.aic)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok((results, skipped))
}
pub fn fit_all_discrete_verbose(data: &[f64]) -> StatsResult<(Vec<FitResult>, Vec<SkippedFit>)> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all_discrete_verbose: data must not be empty".to_string(),
});
}
let int_data: Vec<u64> = data.iter().map(|&x| x.round() as u64).collect();
let mut results: Vec<FitResult> = Vec::new();
let mut skipped: Vec<SkippedFit> = Vec::new();
macro_rules! try_fit_disc_v {
($name:literal, $fit_expr:expr) => {
match $fit_expr {
Err(e) => skipped.push(SkippedFit {
name: $name,
reason: format!("fit failed: {e}"),
}),
Ok(dist) => match (dist.aic(&int_data), dist.bic(&int_data)) {
(Ok(aic), Ok(bic)) if aic.is_finite() && bic.is_finite() => {
let ks = ks_test_discrete(data, |k| dist.cdf(k).unwrap_or(0.0));
results.push(FitResult {
name: dist.name().to_string(),
aic,
bic,
ks_statistic: ks.statistic,
ks_p_value: ks.p_value,
});
}
_ => skipped.push(SkippedFit {
name: $name,
reason: "non-finite AIC/BIC (log-likelihood diverged)".to_string(),
}),
},
}
};
}
try_fit_disc_v!("Poisson", Poisson::fit(data));
try_fit_disc_v!("Geometric", Geometric::fit(data));
try_fit_disc_v!("NegativeBinomial", NegativeBinomial::fit(data));
try_fit_disc_v!("Binomial", Binomial::fit(data));
if results.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all_discrete_verbose: no distribution could be fitted".to_string(),
});
}
results.sort_by(|a, b| {
a.aic
.partial_cmp(&b.aic)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok((results, skipped))
}
pub fn fit_all_discrete(data: &[f64]) -> StatsResult<Vec<FitResult>> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all_discrete: data must not be empty".to_string(),
});
}
let int_data: Vec<u64> = data.iter().map(|&x| x.round() as u64).collect();
let mut results: Vec<FitResult> = Vec::new();
macro_rules! try_fit_disc {
($fit_expr:expr) => {
if let Ok(dist) = $fit_expr {
if let (Ok(aic), Ok(bic)) = (dist.aic(&int_data), dist.bic(&int_data)) {
if aic.is_finite() && bic.is_finite() {
let ks = ks_test_discrete(data, |k| dist.cdf(k).unwrap_or(0.0));
results.push(FitResult {
name: dist.name().to_string(),
aic,
bic,
ks_statistic: ks.statistic,
ks_p_value: ks.p_value,
});
}
}
}
};
}
try_fit_disc!(Poisson::fit(data));
try_fit_disc!(Geometric::fit(data));
try_fit_disc!(NegativeBinomial::fit(data));
try_fit_disc!(Binomial::fit(data));
if results.is_empty() {
return Err(StatsError::InvalidInput {
message: "fit_all_discrete: no distribution could be fitted to the data".to_string(),
});
}
results.sort_by(|a, b| {
a.aic
.partial_cmp(&b.aic)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
pub fn fit_best_discrete(data: &[f64]) -> StatsResult<FitResult> {
let mut all = fit_all_discrete(data)?;
Ok(all.remove(0))
}
pub fn auto_fit(data: &[f64]) -> StatsResult<FitResult> {
match detect_data_type(data) {
DataKind::Discrete => fit_best_discrete(data),
DataKind::Continuous => fit_best(data),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_data_type_discrete() {
assert_eq!(detect_data_type(&[0.0, 1.0, 2.0, 3.0]), DataKind::Discrete);
assert_eq!(detect_data_type(&[0.0, 0.0, 1.0]), DataKind::Discrete);
}
#[test]
fn test_detect_data_type_continuous() {
assert_eq!(detect_data_type(&[0.5, 1.5, 2.3]), DataKind::Continuous);
assert_eq!(detect_data_type(&[-1.0, 0.0, 1.0]), DataKind::Continuous);
assert_eq!(detect_data_type(&[1.0, 2.5, 3.0]), DataKind::Continuous);
}
#[test]
fn test_ks_test_uniform() {
let data: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
let ks = ks_test(&data, |x| x.clamp(0.0, 1.0));
assert!(ks.statistic < 0.15);
}
#[test]
fn test_fit_all_returns_results() {
let data: Vec<f64> = (0..50).map(|i| (i as f64) * 0.1 + 0.5).collect();
let results = fit_all(&data).unwrap();
assert!(!results.is_empty());
for i in 1..results.len() {
assert!(results[i].aic >= results[i - 1].aic);
}
}
#[test]
fn test_fit_best_normal_data() {
let data = vec![
4.1, 5.2, 5.8, 4.7, 5.3, 4.9, 6.1, 4.5, 5.5, 5.0, 4.8, 5.1, 4.3, 5.7, 4.6, 5.4, 4.2,
5.9, 5.2, 4.4,
];
let best = fit_best(&data).unwrap();
assert!(best.aic.is_finite());
}
#[test]
fn test_fit_all_discrete() {
let data = vec![0.0, 1.0, 2.0, 3.0, 1.0, 0.0, 2.0, 1.0, 0.0, 4.0];
let results = fit_all_discrete(&data).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_auto_fit_continuous() {
let data = vec![1.5, 2.3, 1.8, 2.1, 2.7, 1.9, 2.4, 2.0];
let best = auto_fit(&data).unwrap();
assert!(!best.name.is_empty());
}
#[test]
fn test_auto_fit_discrete() {
let data = vec![0.0, 1.0, 2.0, 1.0, 0.0, 3.0, 1.0, 2.0];
let best = auto_fit(&data).unwrap();
assert!(!best.name.is_empty());
}
#[test]
fn test_fit_all_empty_data() {
assert!(fit_all(&[]).is_err());
}
}