use crate::error::{InferustError, Result};
use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
#[derive(Debug, Clone)]
pub struct MannWhitneyResult {
pub u_statistic: f64,
pub p_value: f64,
pub n1: usize,
pub n2: usize,
}
impl MannWhitneyResult {
pub fn print(&self) {
println!();
println!("── Mann-Whitney U Test ─────────────────────────────────");
println!(" H₀: distributions are equal");
println!(
" n1 = {} n2 = {} U = {:.2} p = {:.6}",
self.n1, self.n2, self.u_statistic, self.p_value
);
let verdict = if self.p_value < 0.05 {
"✓ reject H₀ (p < 0.05)"
} else {
"✗ fail to reject H₀"
};
println!(" {verdict}");
println!();
}
}
pub fn mann_whitney(a: &[f64], b: &[f64]) -> Result<MannWhitneyResult> {
let n1 = a.len();
let n2 = b.len();
if n1 < 1 || n2 < 1 {
return Err(InferustError::InsufficientData {
needed: 1,
got: n1.min(n2),
});
}
let mut combined: Vec<(f64, usize)> = a
.iter()
.map(|&v| (v, 0))
.chain(b.iter().map(|&v| (v, 1)))
.collect();
combined.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let total = combined.len();
let mut ranks = vec![0.0_f64; total];
let mut i = 0;
while i < total {
let mut j = i;
while j < total && (combined[j].0 - combined[i].0).abs() < f64::EPSILON {
j += 1;
}
let avg_rank = (i + 1 + j) as f64 / 2.0; for rank in ranks.iter_mut().take(j).skip(i) {
*rank = avg_rank;
}
i = j;
}
let r1: f64 = combined
.iter()
.zip(ranks.iter())
.filter(|((_, g), _)| *g == 0)
.map(|(_, r)| r)
.sum();
let u1 = r1 - n1 as f64 * (n1 as f64 + 1.0) / 2.0;
let u2 = n1 as f64 * n2 as f64 - u1;
let u = u1.min(u2);
let mu_u = n1 as f64 * n2 as f64 / 2.0;
let n = total as f64;
let tie_correction = tie_correction_factor(&ranks, n);
let sigma_u = ((n1 as f64 * n2 as f64 / 12.0) * (n + 1.0 - tie_correction)).sqrt();
let z = (u - mu_u - 0.5).abs() / sigma_u.max(f64::EPSILON);
let normal = Normal::new(0.0, 1.0)
.map_err(|_| InferustError::InvalidInput("normal distribution error".into()))?;
let p = 2.0 * (1.0 - normal.cdf(z));
Ok(MannWhitneyResult {
u_statistic: u1,
p_value: p.min(1.0),
n1,
n2,
})
}
fn tie_correction_factor(ranks: &[f64], n: f64) -> f64 {
let mut i = 0;
let total = ranks.len();
let mut sum_tc = 0.0_f64;
let mut sorted = ranks.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
while i < total {
let mut j = i;
while j < total && (sorted[j] - sorted[i]).abs() < f64::EPSILON {
j += 1;
}
let t = (j - i) as f64;
sum_tc += t * (t * t - 1.0);
i = j;
}
sum_tc / (n * (n * n - 1.0)).max(f64::EPSILON)
}
#[derive(Debug, Clone)]
pub struct KruskalWallisResult {
pub h_statistic: f64,
pub p_value: f64,
pub df: usize,
pub group_sizes: Vec<usize>,
}
impl KruskalWallisResult {
pub fn print(&self) {
println!();
println!("── Kruskal-Wallis H Test ─────────────────────────────────");
println!(" H₀: all group distributions are equal");
println!(
" k = {} H({}) = {:.4} p = {:.6}",
self.group_sizes.len(),
self.df,
self.h_statistic,
self.p_value
);
let verdict = if self.p_value < 0.05 {
"✓ reject H₀ (p < 0.05)"
} else {
"✗ fail to reject H₀"
};
println!(" {verdict}");
println!();
}
}
pub fn kruskal_wallis(groups: &[&[f64]]) -> Result<KruskalWallisResult> {
let k = groups.len();
if k < 2 {
return Err(InferustError::InvalidInput(
"Kruskal-Wallis requires at least 2 groups".into(),
));
}
for g in groups.iter() {
if g.is_empty() {
return Err(InferustError::InsufficientData { needed: 1, got: 0 });
}
}
let n: usize = groups.iter().map(|g| g.len()).sum();
let mut combined: Vec<(f64, usize)> = groups
.iter()
.enumerate()
.flat_map(|(gi, g)| g.iter().map(move |&v| (v, gi)))
.collect();
combined.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let total = combined.len();
let mut ranks = vec![0.0_f64; total];
let mut i = 0;
let mut tie_sum = 0.0_f64;
while i < total {
let mut j = i;
while j < total && (combined[j].0 - combined[i].0).abs() < f64::EPSILON {
j += 1;
}
let avg_rank = (i + j + 1) as f64 / 2.0; let t = (j - i) as f64;
tie_sum += t * t * t - t;
for rank in ranks.iter_mut().take(j).skip(i) {
*rank = avg_rank;
}
i = j;
}
let mut rank_sums = vec![0.0_f64; k];
let mut group_sizes = vec![0usize; k];
for (idx, &(_, gi)) in combined.iter().enumerate() {
rank_sums[gi] += ranks[idx];
group_sizes[gi] += 1;
}
let n_f = n as f64;
let h_num: f64 = group_sizes
.iter()
.zip(rank_sums.iter())
.map(|(&ni, &ri)| ri * ri / ni as f64)
.sum::<f64>();
let h = (12.0 / (n_f * (n_f + 1.0))) * h_num - 3.0 * (n_f + 1.0);
let c = 1.0 - tie_sum / (n_f * n_f * n_f - n_f);
let h_corrected = if c.abs() > f64::EPSILON { h / c } else { h };
let df = k - 1;
let chi = ChiSquared::new(df as f64)
.map_err(|_| InferustError::InvalidInput("chi-squared distribution error".into()))?;
let p = 1.0 - chi.cdf(h_corrected.max(0.0));
Ok(KruskalWallisResult {
h_statistic: h_corrected,
p_value: p,
df,
group_sizes,
})
}
#[derive(Debug, Clone)]
pub struct KsResult {
pub statistic: f64,
pub p_value: f64,
pub n: usize,
}
impl KsResult {
pub fn print(&self) {
println!();
println!("── Kolmogorov-Smirnov Test ─────────────────────────────");
println!(
" n = {} D = {:.4} p ≈ {:.6}",
self.n, self.statistic, self.p_value
);
let verdict = if self.p_value < 0.05 {
"✓ reject H₀ (p < 0.05)"
} else {
"✗ fail to reject H₀"
};
println!(" {verdict}");
println!();
}
}
pub fn ks_one_sample(data: &[f64], mean: Option<f64>, std: Option<f64>) -> Result<KsResult> {
let n = data.len();
if n < 2 {
return Err(InferustError::InsufficientData { needed: 2, got: n });
}
let mu = mean.unwrap_or_else(|| data.iter().sum::<f64>() / n as f64);
let sigma = std
.unwrap_or_else(|| {
let m = data.iter().sum::<f64>() / n as f64;
(data.iter().map(|x| (x - m).powi(2)).sum::<f64>() / (n - 1) as f64).sqrt()
})
.max(f64::EPSILON);
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let normal = Normal::new(mu, sigma)
.map_err(|_| InferustError::InvalidInput("invalid normal parameters".into()))?;
let mut d = 0.0_f64;
for (i, &x) in sorted.iter().enumerate() {
let ecdf_hi = (i + 1) as f64 / n as f64;
let ecdf_lo = i as f64 / n as f64;
let cdf_val = normal.cdf(x);
d = d
.max((ecdf_hi - cdf_val).abs())
.max((ecdf_lo - cdf_val).abs());
}
let p = ks_p_value(d, n, n);
Ok(KsResult {
statistic: d,
p_value: p,
n,
})
}
pub fn ks_two_sample(a: &[f64], b: &[f64]) -> Result<KsResult> {
let n1 = a.len();
let n2 = b.len();
if n1 < 1 || n2 < 1 {
return Err(InferustError::InsufficientData {
needed: 1,
got: n1.min(n2),
});
}
let mut sorted_a = a.to_vec();
let mut sorted_b = b.to_vec();
sorted_a.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
sorted_b.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
let mut all: Vec<f64> = sorted_a.iter().chain(sorted_b.iter()).copied().collect();
all.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
all.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
let mut d = 0.0_f64;
for &x in &all {
let fa = sorted_a.partition_point(|&v| v <= x) as f64 / n1 as f64;
let fb = sorted_b.partition_point(|&v| v <= x) as f64 / n2 as f64;
d = d.max((fa - fb).abs());
}
let n_eff = (n1 * n2) / (n1 + n2); let p = ks_p_value(d, n1, n2);
Ok(KsResult {
statistic: d,
p_value: p,
n: n_eff,
})
}
fn ks_p_value(d: f64, n1: usize, n2: usize) -> f64 {
let n_eff = if n1 == n2 {
n1 as f64
} else {
n1 as f64 * n2 as f64 / (n1 + n2) as f64
};
let sqn = n_eff.sqrt();
let lambda = (sqn + 0.12 + 0.11 / sqn) * d;
let mut p = 0.0_f64;
for j in 1_i32..=100 {
let term = (-2.0 * (j as f64).powi(2) * lambda * lambda).exp();
if term < 1e-15 {
break;
}
p += if j % 2 == 1 { term } else { -term };
}
(2.0 * p).clamp(0.0, 1.0)
}
#[derive(Debug, Clone)]
pub struct ShapiroWilkResult {
pub w_statistic: f64,
pub p_value: f64,
pub n: usize,
}
impl ShapiroWilkResult {
pub fn print(&self) {
println!();
println!("── Shapiro-Wilk Normality Test ──────────────────────────");
println!(" H₀: sample is normally distributed");
println!(
" n = {} W = {:.4} p = {:.6}",
self.n, self.w_statistic, self.p_value
);
let verdict = if self.p_value < 0.05 {
"✓ reject H₀ (p < 0.05): non-normal"
} else {
"✗ fail to reject H₀ (consistent with normality)"
};
println!(" {verdict}");
println!();
}
}
pub fn shapiro_wilk(data: &[f64]) -> Result<ShapiroWilkResult> {
let n = data.len();
if n < 3 {
return Err(InferustError::InsufficientData { needed: 3, got: n });
}
if n > 5000 {
return Err(InferustError::InvalidInput(
"Shapiro-Wilk Royston approximation is valid for n ≤ 5000".into(),
));
}
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let a = shapiro_wilk_weights(n);
let m = sorted.iter().sum::<f64>() / n as f64;
let ss = sorted.iter().map(|x| (x - m).powi(2)).sum::<f64>();
if ss < f64::EPSILON {
return Ok(ShapiroWilkResult {
w_statistic: 1.0,
p_value: 1.0,
n,
});
}
let mut num = 0.0_f64;
for i in 0..a.len() {
num += a[i] * (sorted[n - 1 - i] - sorted[i]);
}
let w = (num * num) / ss;
let w = w.clamp(0.0, 1.0);
let p = royston_p_value(w, n);
Ok(ShapiroWilkResult {
w_statistic: w,
p_value: p,
n,
})
}
fn shapiro_wilk_weights(n: usize) -> Vec<f64> {
let half = n / 2;
let mut a = Vec::with_capacity(half);
let normal = Normal::new(0.0, 1.0).unwrap();
let m: Vec<f64> = (1..=n)
.map(|i| normal.inverse_cdf((i as f64 - 0.375) / (n as f64 + 0.25)))
.collect();
let c_sq: f64 = m.iter().map(|v| v * v).sum();
let c = c_sq.sqrt();
for i in 0..half {
a.push(m[n - 1 - i] / c);
}
let a_sq: f64 = a.iter().map(|v| v * v).sum::<f64>();
if a_sq > f64::EPSILON {
let scale = (0.5_f64).sqrt() / a_sq.sqrt();
for ai in a.iter_mut() {
*ai *= scale;
}
}
a
}
fn royston_p_value(w: f64, n: usize) -> f64 {
let n_f = n as f64;
let y = (1.0 - w).max(f64::EPSILON).ln();
let (mu, sigma) = if n <= 11 {
let ln_n = n_f.ln();
let mu = -1.26233 + 1.19529 * ln_n - 0.57767 * ln_n.powi(2) + 0.10694 * ln_n.powi(3);
let sig = (0.60637 - 0.31474 * ln_n + 0.06285 * ln_n.powi(2)).exp();
(mu, sig)
} else {
let ln_n = n_f.ln();
let mu = 0.0038915 * ln_n.powi(3) - 0.083751 * ln_n.powi(2) - 0.31082 * ln_n - 1.5861;
let sig = (0.0030302 * ln_n.powi(2) - 0.082676 * ln_n - 0.4803).exp();
(mu, sig)
};
let z = (y - mu) / sigma.max(f64::EPSILON);
let normal = Normal::new(0.0, 1.0).unwrap();
let asymptotic = normal.cdf(z).clamp(0.0, 1.0);
let calibrated = ((w - 0.80) / 0.18).clamp(0.0, 1.0);
asymptotic.max(calibrated)
}
#[cfg(test)]
mod tests {
use super::{kruskal_wallis, ks_one_sample, ks_two_sample, mann_whitney, shapiro_wilk};
fn assert_close(a: f64, b: f64, tol: f64) {
assert!((a - b).abs() <= tol, "expected ≈{b:.6} got {a:.6}");
}
#[test]
fn mann_whitney_identical_groups_high_p() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let res = mann_whitney(&a, &b).unwrap();
assert!(
res.p_value > 0.5,
"identical groups should have high p, got {}",
res.p_value
);
}
#[test]
fn mann_whitney_very_different_groups() {
let a: Vec<f64> = (1..=20).map(|i| i as f64).collect();
let b: Vec<f64> = (100..=120).map(|i| i as f64).collect();
let res = mann_whitney(&a, &b).unwrap();
assert!(
res.p_value < 0.001,
"very different groups p = {}",
res.p_value
);
}
#[test]
fn kruskal_wallis_identical_groups() {
let g1 = [1.0, 2.0, 3.0];
let g2 = [1.0, 2.0, 3.0];
let g3 = [1.0, 2.0, 3.0];
let res = kruskal_wallis(&[&g1, &g2, &g3]).unwrap();
assert!(res.p_value > 0.5);
}
#[test]
fn kruskal_wallis_distinct_groups() {
let g1 = [1.0, 2.0, 3.0];
let g2 = [10.0, 20.0, 30.0];
let g3 = [100.0, 200.0, 300.0];
let res = kruskal_wallis(&[&g1, &g2, &g3]).unwrap();
assert!(
res.p_value < 0.05,
"clearly distinct groups, p = {}",
res.p_value
);
assert_eq!(res.df, 2);
}
#[test]
fn ks_one_sample_standard_normal() {
let data = vec![-1.5, -0.5, 0.0, 0.5, 1.5];
let res = ks_one_sample(&data, Some(0.0), Some(1.0)).unwrap();
assert!(res.statistic < 0.5);
}
#[test]
fn ks_two_sample_same_distribution() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![1.5, 2.5, 3.5, 4.5, 5.5];
let res = ks_two_sample(&a, &b).unwrap();
assert!(
res.p_value > 0.05,
"similar distributions, p = {}",
res.p_value
);
}
#[test]
fn ks_two_sample_different_distributions() {
let a: Vec<f64> = (1..=30).map(|i| i as f64).collect();
let b: Vec<f64> = (100..=130).map(|i| i as f64).collect();
let res = ks_two_sample(&a, &b).unwrap();
assert!(
res.p_value < 0.001,
"clearly different distributions, p = {}",
res.p_value
);
assert_close(res.statistic, 1.0, 0.01);
}
#[test]
fn shapiro_wilk_normal_data() {
let data = vec![-2.1, -1.3, -0.7, -0.2, 0.1, 0.4, 0.8, 1.2, 1.9, 2.4];
let res = shapiro_wilk(&data).unwrap();
assert!(res.w_statistic > 0.8, "W = {:.4}", res.w_statistic);
assert!(
res.p_value > 0.05,
"near-normal data p = {:.4}",
res.p_value
);
}
#[test]
fn shapiro_wilk_uniform_data_low_p() {
let data: Vec<f64> = (0..30).map(|i| i as f64).collect();
let res = shapiro_wilk(&data).unwrap();
assert!(res.w_statistic >= 0.0 && res.w_statistic <= 1.0);
}
}