use crate::error::{Error, Result};
use crate::stats::{AnovaResult, ChiSquareResult, MannWhitneyResult, TTestResult};
use std::collections::HashMap;
use std::f64::consts::PI;
fn normal_cdf(z: f64) -> f64 {
const A1: f64 = 0.254829592;
const A2: f64 = -0.284496736;
const A3: f64 = 1.421413741;
const A4: f64 = -1.453152027;
const A5: f64 = 1.061405429;
const P: f64 = 0.3275911;
let sign = if z < 0.0 { -1.0 } else { 1.0 };
let x = z.abs() / (2.0_f64).sqrt();
let t = 1.0 / (1.0 + P * x);
let y = 1.0 - (((((A5 * t + A4) * t) + A3) * t + A2) * t + A1) * t * (-x * x).exp();
0.5 * (1.0 + sign * y)
}
fn t_distribution_cdf(t: f64, df: usize) -> f64 {
if df > 30 {
return normal_cdf(t);
}
let df_f64 = df as f64;
let x = df_f64 / (df_f64 + t * t);
let a = 0.5 * df_f64;
let b = 0.5;
let beta_approx = if t > 0.0 {
1.0 - 0.5 * x.powf(a)
} else {
0.5 * x.powf(a)
};
beta_approx
}
pub(crate) fn ttest_impl(
sample1: &[f64],
sample2: &[f64],
alpha: f64,
equal_var: bool,
) -> Result<TTestResult> {
if sample1.is_empty() || sample2.is_empty() {
return Err(Error::EmptyData("t-test requires data".into()));
}
let n1 = sample1.len();
let n2 = sample2.len();
if n1 < 2 || n2 < 2 {
return Err(Error::InsufficientData(
"t-test requires at least 2 data points in each group".into(),
));
}
let mean1 = sample1.iter().sum::<f64>() / n1 as f64;
let mean2 = sample2.iter().sum::<f64>() / n2 as f64;
let var1 = sample1.iter().map(|&x| (x - mean1).powi(2)).sum::<f64>() / (n1 - 1) as f64;
let var2 = sample2.iter().map(|&x| (x - mean2).powi(2)).sum::<f64>() / (n2 - 1) as f64;
let (t_stat, df) = if equal_var {
let pooled_var = ((n1 - 1) as f64 * var1 + (n2 - 1) as f64 * var2) / (n1 + n2 - 2) as f64;
let std_err = (pooled_var * (1.0 / n1 as f64 + 1.0 / n2 as f64)).sqrt();
let t_value = (mean1 - mean2) / std_err;
(t_value, n1 + n2 - 2)
} else {
let std_err = (var1 / n1 as f64 + var2 / n2 as f64).sqrt();
let t_value = (mean1 - mean2) / std_err;
let df_num = (var1 / n1 as f64 + var2 / n2 as f64).powi(2);
let df_denom = (var1 / n1 as f64).powi(2) / (n1 - 1) as f64
+ (var2 / n2 as f64).powi(2) / (n2 - 1) as f64;
let df_welch = df_num / df_denom;
(t_value, df_welch.floor() as usize)
};
let p_value = 2.0 * (1.0 - t_distribution_cdf(t_stat.abs(), df));
Ok(TTestResult {
statistic: t_stat,
pvalue: p_value,
significant: p_value < alpha,
df,
})
}
fn chi2_to_pvalue(chi2: f64, df: usize) -> f64 {
let k = df as f64 / 2.0;
let x = chi2 / 2.0;
let gamma_k = if df % 2 == 0 {
1.0 } else {
(PI * 2.0).sqrt() };
let p = if chi2 > df as f64 + 2.0 {
1.0 - gamma_k * (1.0 - x.exp() * (1.0 + x + 0.5 * x.powi(2)))
} else {
gamma_k * x.exp() * x.powf(k - 1.0)
};
1.0 - p.min(1.0).max(0.0)
}
fn f_distribution_cdf(f: f64, df1: usize, df2: usize) -> f64 {
let df1_f64 = df1 as f64;
let df2_f64 = df2 as f64;
let x = df1_f64 * f / (df1_f64 * f + df2_f64);
let a = df1_f64 / 2.0;
let b = df2_f64 / 2.0;
let beta_approx = if x > 0.5 {
1.0 - (1.0 - x).powf(b)
* (1.0 + (1.0 - x) * a / b + (1.0 - x).powi(2) * a * (a + 1.0) / (b * (b + 1.0)) / 2.0)
} else {
x.powf(a) * (1.0 + x * b / a + x.powi(2) * b * (b + 1.0) / (a * (a + 1.0)) / 2.0)
};
beta_approx.min(1.0).max(0.0)
}
pub(crate) fn anova_impl(groups: &HashMap<&str, &[f64]>, alpha: f64) -> Result<AnovaResult> {
if groups.is_empty() {
return Err(Error::EmptyData("ANOVA requires at least one group".into()));
}
if groups.len() < 2 {
return Err(Error::InsufficientData(
"ANOVA requires at least two groups".into(),
));
}
let mut total_n = 0;
let mut global_sum = 0.0;
for (_, values) in groups.iter() {
if values.is_empty() {
return Err(Error::EmptyData("There is an empty group".into()));
}
total_n += values.len();
global_sum += values.iter().sum::<f64>();
}
let global_mean = global_sum / total_n as f64;
let mut ss_between = 0.0;
let mut ss_within = 0.0;
let mut ss_total = 0.0;
for (_, values) in groups.iter() {
let group_n = values.len();
let group_mean = values.iter().sum::<f64>() / group_n as f64;
ss_between += group_n as f64 * (group_mean - global_mean).powi(2);
for &value in *values {
ss_within += (value - group_mean).powi(2);
ss_total += (value - global_mean).powi(2);
}
}
let df_between = groups.len() - 1;
let df_within = total_n - groups.len();
let df_total = total_n - 1;
let ms_between = ss_between / df_between as f64;
let ms_within = ss_within / df_within as f64;
let f_statistic = ms_between / ms_within;
let p_value = 1.0 - f_distribution_cdf(f_statistic, df_between, df_within);
Ok(AnovaResult {
f_statistic,
p_value,
ss_between,
ss_within,
ss_total,
df_between,
df_within,
df_total,
ms_between,
ms_within,
significant: p_value < alpha,
})
}
pub(crate) fn mann_whitney_u_impl(
sample1: &[f64],
sample2: &[f64],
alpha: f64,
) -> Result<MannWhitneyResult> {
if sample1.is_empty() || sample2.is_empty() {
return Err(Error::EmptyData("Mann-Whitney U test requires data".into()));
}
let n1 = sample1.len();
let n2 = sample2.len();
let mut combined: Vec<(f64, usize, usize)> = Vec::with_capacity(n1 + n2);
for (i, &val) in sample1.iter().enumerate() {
combined.push((val, 0, i)); }
for (i, &val) in sample2.iter().enumerate() {
combined.push((val, 1, i)); }
combined.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut ranks = vec![0.0; n1 + n2];
let mut i = 0;
while i < n1 + n2 {
let mut j = i;
while j < n1 + n2 - 1 && (combined[j].0 - combined[j + 1].0).abs() < f64::EPSILON {
j += 1;
}
if j > i {
let rank_avg = (i + 1 + j + 1) as f64 / 2.0;
for k in i..=j {
let (_, group, idx) = combined[k];
if group == 0 {
ranks[idx] = rank_avg;
} else {
ranks[idx + n1] = rank_avg;
}
}
} else {
let (_, group, idx) = combined[i];
if group == 0 {
ranks[idx] = (i + 1) as f64;
} else {
ranks[idx + n1] = (i + 1) as f64;
}
}
i = j + 1;
}
let r1: f64 = ranks.iter().take(n1).sum();
let u1 = r1 - (n1 * (n1 + 1)) as f64 / 2.0;
let u2 = (n1 * n2) as f64 - u1;
let u_statistic = u1.min(u2);
let mean_u = (n1 * n2) as f64 / 2.0;
let std_u = ((n1 * n2 * (n1 + n2 + 1)) as f64 / 12.0).sqrt();
let z = (u_statistic - mean_u) / std_u;
let p_value = 2.0 * normal_cdf(-z.abs());
Ok(MannWhitneyResult {
u_statistic,
p_value,
significant: p_value < alpha,
})
}
pub(crate) fn chi_square_test_impl(observed: &[Vec<f64>], alpha: f64) -> Result<ChiSquareResult> {
if observed.is_empty() {
return Err(Error::EmptyData(
"Chi-square test requires observed data".into(),
));
}
let rows = observed.len();
if rows < 2 {
return Err(Error::InsufficientData(
"Chi-square test requires at least 2 rows of data".into(),
));
}
let cols = observed[0].len();
if cols < 2 {
return Err(Error::InsufficientData(
"Chi-square test requires at least 2 columns of data".into(),
));
}
for row in observed.iter() {
if row.len() != cols {
return Err(Error::InvalidInput(
"All rows must have the same number of columns".into(),
));
}
}
let mut row_sums = vec![0.0; rows];
let mut col_sums = vec![0.0; cols];
let mut total_sum = 0.0;
for i in 0..rows {
for j in 0..cols {
let value = observed[i][j];
if value < 0.0 {
return Err(Error::InvalidInput(
"Observed values must not be negative".into(),
));
}
row_sums[i] += value;
col_sums[j] += value;
total_sum += value;
}
}
if total_sum < 1.0 {
return Err(Error::InvalidInput("Sum of observed data is zero".into()));
}
let mut expected = vec![vec![0.0; cols]; rows];
let mut chi2_statistic = 0.0;
for i in 0..rows {
for j in 0..cols {
expected[i][j] = row_sums[i] * col_sums[j] / total_sum;
if expected[i][j] < 5.0 {
}
let diff = observed[i][j] - expected[i][j];
chi2_statistic += diff * diff / expected[i][j];
}
}
let df = (rows - 1) * (cols - 1);
let p_value = chi2_to_pvalue(chi2_statistic, df);
Ok(ChiSquareResult {
chi2_statistic,
p_value,
df,
significant: p_value < alpha,
expected_freq: expected,
})
}
pub fn chi_square_test_with_statistic(
chi2: f64,
df: usize,
alpha: f64,
expected: Vec<Vec<f64>>,
) -> Result<ChiSquareResult> {
let p_value = chi2_to_pvalue(chi2, df);
Ok(ChiSquareResult {
chi2_statistic: chi2,
p_value,
df,
significant: p_value < alpha,
expected_freq: expected,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ttest_equal_means() {
let sample1 = vec![5.0, 6.0, 7.0, 8.0, 9.0];
let sample2 = vec![6.0, 7.0, 8.0, 9.0, 10.0];
let result = ttest_impl(&sample1, &sample2, 0.05, true).expect("operation should succeed");
assert!((result.statistic + 1.0).abs() < 1.0); assert!(result.pvalue > 0.05); assert!(!result.significant);
}
#[test]
fn test_ttest_different_means() {
let sample1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sample2 = vec![11.0, 12.0, 13.0, 14.0, 15.0];
let result = ttest_impl(&sample1, &sample2, 0.05, true).expect("operation should succeed");
assert!(result.statistic < -5.0); assert!(result.pvalue < 0.05); assert!(result.significant);
}
#[test]
fn test_ttest_welch() {
let sample1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sample2 = vec![11.0, 13.0, 15.0, 17.0, 19.0];
let result_equal_var =
ttest_impl(&sample1, &sample2, 0.05, true).expect("operation should succeed");
let result_welch =
ttest_impl(&sample1, &sample2, 0.05, false).expect("operation should succeed");
assert!(result_equal_var.significant);
assert!(result_welch.significant);
assert!(result_equal_var.df != result_welch.df);
}
#[test]
fn test_ttest_empty() {
let sample1 = vec![1.0, 2.0, 3.0];
let sample2: Vec<f64> = vec![];
let result = ttest_impl(&sample1, &sample2, 0.05, true);
assert!(result.is_err());
}
#[test]
fn test_anova_basic() {
let mut groups = HashMap::new();
let a_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b_values = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let c_values = vec![3.0, 4.0, 5.0, 6.0, 7.0];
groups.insert("A", a_values.as_slice());
groups.insert("B", b_values.as_slice());
groups.insert("C", c_values.as_slice());
let result = anova_impl(&groups, 0.05).expect("operation should succeed");
assert!(result.f_statistic > 0.0);
assert_eq!(result.df_between, 2);
assert_eq!(result.df_within, 12);
assert_eq!(result.df_total, 14);
}
#[test]
fn test_anova_significant_difference() {
let mut groups = HashMap::new();
let a_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b_values = vec![11.0, 12.0, 13.0, 14.0, 15.0];
let c_values = vec![21.0, 22.0, 23.0, 24.0, 25.0];
groups.insert("A", a_values.as_slice());
groups.insert("B", b_values.as_slice());
groups.insert("C", c_values.as_slice());
let result = anova_impl(&groups, 0.05).expect("operation should succeed");
assert!(result.f_statistic > 100.0);
assert!(result.p_value < 0.05);
assert!(result.significant);
}
#[test]
fn test_mann_whitney_u() {
let sample1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let sample2 = vec![6.0, 7.0, 8.0, 9.0, 10.0];
let result =
mann_whitney_u_impl(&sample1, &sample2, 0.05).expect("operation should succeed");
assert!(result.u_statistic == 0.0); assert!(result.p_value < 0.05);
assert!(result.significant);
}
#[test]
fn test_chi_square() {
let observed = vec![vec![10.0, 10.0], vec![10.0, 20.0]];
let result = chi_square_test_impl(&observed, 0.05).expect("operation should succeed");
assert!(result.chi2_statistic > 0.0);
assert_eq!(result.df, 1);
assert_eq!(result.expected_freq.len(), 2);
assert_eq!(result.expected_freq[0].len(), 2);
}
}