use crate::error::{AprenderError, Result};
use std::f32::consts::PI;
#[derive(Debug, Clone)]
pub struct TTestResult {
pub statistic: f32,
pub pvalue: f32,
pub df: f32,
}
#[derive(Debug, Clone)]
pub struct ChiSquareResult {
pub statistic: f32,
pub pvalue: f32,
pub df: usize,
}
#[derive(Debug, Clone)]
pub struct AnovaResult {
pub statistic: f32,
pub pvalue: f32,
pub df_between: usize,
pub df_within: usize,
}
pub fn ttest_1samp(sample: &[f32], population_mean: f32) -> Result<TTestResult> {
let n = sample.len();
if n < 2 {
return Err(AprenderError::Other(
"t-test requires at least 2 samples".into(),
));
}
let sample_mean = sample.iter().sum::<f32>() / n as f32;
let variance = sample
.iter()
.map(|&x| (x - sample_mean).powi(2))
.sum::<f32>()
/ (n - 1) as f32;
let std = variance.sqrt();
let se = std / (n as f32).sqrt();
let t_stat = (sample_mean - population_mean) / se;
let df = (n - 1) as f32;
let pvalue = t_distribution_pvalue(t_stat.abs(), df);
Ok(TTestResult {
statistic: t_stat,
pvalue,
df,
})
}
pub fn ttest_ind(sample1: &[f32], sample2: &[f32], equal_var: bool) -> Result<TTestResult> {
let n1 = sample1.len();
let n2 = sample2.len();
if n1 < 2 || n2 < 2 {
return Err(AprenderError::Other(
"Each sample must have at least 2 observations".into(),
));
}
let mean1 = sample1.iter().sum::<f32>() / n1 as f32;
let mean2 = sample2.iter().sum::<f32>() / n2 as f32;
let var1 = sample1.iter().map(|&x| (x - mean1).powi(2)).sum::<f32>() / (n1 - 1) as f32;
let var2 = sample2.iter().map(|&x| (x - mean2).powi(2)).sum::<f32>() / (n2 - 1) as f32;
let (t_stat, df) = if equal_var {
let pooled_var = ((n1 - 1) as f32 * var1 + (n2 - 1) as f32 * var2) / (n1 + n2 - 2) as f32;
let se = (pooled_var * (1.0 / n1 as f32 + 1.0 / n2 as f32)).sqrt();
let t = (mean1 - mean2) / se;
let df = (n1 + n2 - 2) as f32;
(t, df)
} else {
let se = (var1 / n1 as f32 + var2 / n2 as f32).sqrt();
let t = (mean1 - mean2) / se;
let numerator = (var1 / n1 as f32 + var2 / n2 as f32).powi(2);
let denominator = (var1 / n1 as f32).powi(2) / (n1 - 1) as f32
+ (var2 / n2 as f32).powi(2) / (n2 - 1) as f32;
let df = numerator / denominator;
(t, df)
};
let pvalue = t_distribution_pvalue(t_stat.abs(), df);
Ok(TTestResult {
statistic: t_stat,
pvalue,
df,
})
}
pub fn ttest_rel(sample1: &[f32], sample2: &[f32]) -> Result<TTestResult> {
if sample1.len() != sample2.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} samples in sample1", sample1.len()),
actual: format!("{} samples in sample2", sample2.len()),
});
}
let diffs: Vec<f32> = sample1
.iter()
.zip(sample2.iter())
.map(|(&x1, &x2)| x1 - x2)
.collect();
ttest_1samp(&diffs, 0.0)
}
pub fn chisquare(observed: &[f32], expected: &[f32]) -> Result<ChiSquareResult> {
if observed.len() != expected.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{} categories in expected", expected.len()),
actual: format!("{} categories in observed", observed.len()),
});
}
let k = observed.len();
if k < 2 {
return Err(AprenderError::Other(
"Chi-square test requires at least 2 categories".into(),
));
}
for &exp in expected {
if exp <= 0.0 {
return Err(AprenderError::Other(
"Expected frequencies must be positive".into(),
));
}
}
let chi2_stat = observed
.iter()
.zip(expected.iter())
.map(|(&obs, &exp)| (obs - exp).powi(2) / exp)
.sum::<f32>();
let df = k - 1;
let pvalue = chi_square_pvalue(chi2_stat, df);
Ok(ChiSquareResult {
statistic: chi2_stat,
pvalue,
df,
})
}
pub fn f_oneway(groups: &[Vec<f32>]) -> Result<AnovaResult> {
let k = groups.len();
if k < 2 {
return Err(AprenderError::Other(
"ANOVA requires at least 2 groups".into(),
));
}
for (i, group) in groups.iter().enumerate() {
if group.is_empty() {
return Err(AprenderError::Other(format!(
"Group {i} is empty. All groups must have at least 1 observation"
)));
}
}
let group_means: Vec<f32> = groups
.iter()
.map(|g| g.iter().sum::<f32>() / g.len() as f32)
.collect();
let n_total: usize = groups.iter().map(Vec::len).sum();
let grand_mean = groups.iter().flat_map(|g| g.iter()).sum::<f32>() / n_total as f32;
let ss_between = groups
.iter()
.zip(group_means.iter())
.map(|(group, &mean)| group.len() as f32 * (mean - grand_mean).powi(2))
.sum::<f32>();
let ss_within = groups
.iter()
.zip(group_means.iter())
.map(|(group, &mean)| group.iter().map(|&val| (val - mean).powi(2)).sum::<f32>())
.sum::<f32>();
let df_between = k - 1;
let df_within = n_total - k;
if df_within == 0 {
return Err(AprenderError::Other(
"Not enough observations for within-group variance".into(),
));
}
let ms_between = ss_between / df_between as f32;
let ms_within = ss_within / df_within as f32;
let f_stat = ms_between / ms_within;
let pvalue = f_distribution_pvalue(f_stat, df_between, df_within);
Ok(AnovaResult {
statistic: f_stat,
pvalue,
df_between,
df_within,
})
}
fn t_distribution_pvalue(t: f32, df: f32) -> f32 {
if df > 30.0 {
return 2.0 * normal_cdf(-t.abs());
}
let x = df / (df + t * t);
let p_one_tail = 0.5 * incomplete_beta(df / 2.0, 0.5, x);
2.0 * p_one_tail.clamp(0.0, 1.0)
}
fn chi_square_pvalue(chi2: f32, df: usize) -> f32 {
let k = df as f32 / 2.0;
1.0 - incomplete_gamma(k, chi2 / 2.0)
}
fn f_distribution_pvalue(f: f32, df1: usize, df2: usize) -> f32 {
let x = df2 as f32 / (df2 as f32 + df1 as f32 * f);
incomplete_beta(df2 as f32 / 2.0, df1 as f32 / 2.0, x).clamp(0.0, 1.0)
}
fn normal_cdf(x: f32) -> f32 {
0.5 * (1.0 + erf(x / 2.0_f32.sqrt()))
}
fn erf(x: f32) -> f32 {
batuta_common::math::erf_f32(x)
}
fn incomplete_gamma(a: f32, x: f32) -> f32 {
if x <= 0.0 {
return 0.0;
}
if a <= 0.0 {
return 1.0;
}
let mut sum = 1.0 / a;
let mut term = 1.0 / a;
for n in 1..100 {
term *= x / (a + n as f32);
sum += term;
if term.abs() < 1e-7 {
break;
}
}
((-x).exp() * x.powf(a) * sum / gamma(a)).clamp(0.0, 1.0)
}
fn incomplete_beta(a: f32, b: f32, x: f32) -> f32 {
if x <= 0.0 {
return 0.0;
}
if x >= 1.0 {
return 1.0;
}
let bt = (x.powf(a) * (1.0 - x).powf(b)) / (a * beta_function(a, b));
if x < (a + 1.0) / (a + b + 2.0) {
bt * beta_continued_fraction(a, b, x) / a
} else {
1.0 - bt * beta_continued_fraction(b, a, 1.0 - x) / b
}
}
fn beta_function(a: f32, b: f32) -> f32 {
gamma(a) * gamma(b) / gamma(a + b)
}
include!("beta_continued_fraction.rs");
include!("hypothesis_tests.rs");
#[cfg(test)]
#[path = "tests_hypothesis_contract.rs"]
mod tests_hypothesis_contract;