use crate::error::{StatsError, StatsResult};
use crate::utils::special_functions::regularized_incomplete_beta as canonical_inc_beta;
use num_traits::ToPrimitive;
use rayon::prelude::*;
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq)]
pub struct AnovaResult {
pub f_statistic: f64,
pub df_between: usize,
pub df_within: usize,
pub p_value: f64,
pub ss_between: f64,
pub ss_within: f64,
pub ms_between: f64,
pub ms_within: f64,
}
pub fn one_way_anova<T>(groups_data: &[&[T]]) -> StatsResult<AnovaResult>
where
T: ToPrimitive + Copy + Debug + Send + Sync,
{
if groups_data.len() < 2 {
return Err(StatsError::invalid_input(
"ANOVA requires at least 2 groups",
));
}
let n_groups = groups_data.len();
let triples: Vec<Result<(f64, f64, f64), StatsError>> = groups_data
.par_iter()
.enumerate()
.map(|(group_idx, group)| {
let mut count = 0.0_f64;
let mut mean = 0.0_f64;
let mut m2 = 0.0_f64;
for (value_idx, &value) in group.iter().enumerate() {
let v = value.to_f64().ok_or_else(|| {
StatsError::conversion_error(format!(
"Failed to convert value at group {}, index {} to f64",
group_idx, value_idx
))
})?;
count += 1.0;
let delta = v - mean;
mean += delta / count;
m2 += delta * (v - mean);
}
if count < 2.0 {
return Err(StatsError::invalid_input(format!(
"Each group must have at least 2 observations (group {} has {})",
group_idx, count as usize
)));
}
Ok((count, mean, m2))
})
.collect();
let mut counts: Vec<f64> = Vec::with_capacity(n_groups);
let mut means: Vec<f64> = Vec::with_capacity(n_groups);
let mut m2s: Vec<f64> = Vec::with_capacity(n_groups);
for triple in triples {
let (c, m, m2) = triple?;
counts.push(c);
means.push(m);
m2s.push(m2);
}
let n_total_f = counts.iter().sum::<f64>();
let n_total = n_total_f as usize;
let grand_mean = counts
.iter()
.zip(means.iter())
.map(|(&n, &m)| n * m)
.sum::<f64>()
/ n_total_f;
let ss_between: f64 = counts
.iter()
.zip(means.iter())
.map(|(&n, &m)| {
let d = m - grand_mean;
d * d * n
})
.sum();
let ss_within: f64 = m2s.iter().sum();
let df_between = n_groups - 1;
let df_within = n_total - n_groups;
let ms_between = ss_between / (df_between as f64);
let ms_within = ss_within / (df_within as f64);
let f_statistic = ms_between / ms_within;
let p_value = 1.0 - f_distribution_cdf(f_statistic, df_between as u32, df_within as u32);
Ok(AnovaResult {
f_statistic,
df_between,
df_within,
p_value,
ss_between,
ss_within,
ms_between,
ms_within,
})
}
fn f_distribution_cdf(f: f64, df1: u32, df2: u32) -> f64 {
if f <= 0.0 {
return 0.0;
}
let df1f = df1 as f64;
let df2f = df2 as f64;
let denom = df2f + df1f * f;
if denom.abs() < 1e-15 {
return 1.0;
}
let x = (df1f * f) / denom;
canonical_inc_beta(0.5 * df1f, 0.5 * df2f, x).clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_one_way_anova_basic() {
let group1 = [5, 7, 9, 8, 6];
let group2 = [2, 4, 3, 5, 4];
let group3 = [8, 9, 10, 7, 8];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic > 1.0,
"F-statistic should be greater than 1.0"
);
assert!(result.p_value < 0.05, "p-value should be less than 0.05");
assert_eq!(result.df_between, 2);
assert_eq!(result.df_within, 12);
}
#[test]
fn test_one_way_anova_equal_means() {
let group1 = [5, 7, 6, 5, 7];
let group2 = [6, 5, 7, 6, 6];
let group3 = [7, 5, 6, 7, 5];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic < 1.0,
"F-statistic should be less than 1.0 for equal means"
);
assert!(
result.p_value > 0.05,
"p-value should be greater than 0.05 for equal means"
);
}
#[test]
fn test_one_way_anova_different_group_sizes() {
let group1 = [5, 7, 9, 8];
let group2 = [2, 4, 3];
let group3 = [8, 9, 10, 7, 8, 9];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(result.df_between == 2);
assert!(result.df_within == 10);
}
#[test]
fn test_one_way_anova_float_values() {
let group1 = [5.2, 7.3, 9.1, 8.0, 6.5];
let group2 = [2.1, 4.3, 3.7, 5.0, 4.2];
let group3 = [8.1, 9.2, 10.0, 7.5, 8.3];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(result.f_statistic > 1.0);
assert!(result.p_value < 0.05);
}
#[test]
fn test_one_way_anova_invalid_input() {
let group1: [i32; 0] = [];
let group2 = [1, 2, 3];
let groups1 = [&group1[..], &group2[..]];
let result = one_way_anova(&groups1);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
let groups2 = [&group2[..]];
let result = one_way_anova(&groups2);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
let groups3: [&[i32]; 0] = [];
let result = one_way_anova(&groups3);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_anova_result_fields() {
let group1 = [5, 7, 9, 8, 6];
let group2 = [2, 4, 3, 5, 4];
let group3 = [8, 9, 10, 7, 8];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(result.ss_between > 0.0);
assert!(result.ss_within > 0.0);
assert!(result.ms_between > 0.0);
assert!(result.ms_within > 0.0);
let calculated_f = result.ms_between / result.ms_within;
assert!((calculated_f - result.f_statistic).abs() < 1e-10);
}
#[test]
fn test_f_distribution_cdf_with_df1_eq_df2_eq_1() {
let group1 = [1.0, 2.0];
let group2 = [3.0, 4.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(!result.p_value.is_nan(), "p-value should not be NaN");
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite"
);
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
assert!(
!result.f_statistic.is_nan(),
"F-statistic should not be NaN"
);
assert!(
!result.f_statistic.is_infinite(),
"F-statistic should not be infinite"
);
}
#[test]
fn test_f_distribution_cdf_edge_case_df1_1_df2_1() {
let group1 = [1.0, 2.0];
let group2 = [3.0, 4.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(!result.p_value.is_nan(), "p-value should not be NaN");
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite"
);
let clamped_p = result.p_value.max(0.0).min(1.0);
assert!(
(result.p_value - clamped_p).abs() < 1e-6
|| (result.p_value >= -1e-6 && result.p_value <= 1.0 + 1e-6),
"p-value should be approximately in [0, 1], got {}",
result.p_value
);
assert!(
!result.f_statistic.is_nan(),
"F-statistic should not be NaN"
);
assert!(
!result.f_statistic.is_infinite(),
"F-statistic should not be infinite"
);
assert!(
result.f_statistic >= 0.0,
"F-statistic should be non-negative"
);
assert_eq!(result.df_between, 1);
assert_eq!(result.df_within, 2);
}
#[test]
fn test_f_distribution_cdf_f_less_than_one() {
let group1 = [1.0, 2.0, 3.0, 4.0, 5.0];
let group2 = [1.1, 2.1, 3.1, 4.1, 5.1];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
!result.p_value.is_nan(),
"p-value should not be NaN when F < 1.0"
);
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite when F < 1.0"
);
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
if result.f_statistic < 1.0 {
assert!(result.f_statistic > 0.0, "F-statistic should be positive");
}
}
#[test]
fn test_f_distribution_cdf_f_zero() {
let group1 = [5.0, 5.0, 5.0];
let group2 = [5.0, 5.0, 5.0];
let group3 = [5.0, 5.0, 5.0];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic.is_nan() || result.f_statistic >= 0.0,
"F-statistic should be NaN or non-negative"
);
}
#[test]
fn test_f_distribution_cdf_f_negative() {
let group1 = [1.0, 1.0, 1.0];
let group2 = [1.0, 1.0, 1.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic.is_nan() || result.f_statistic >= 0.0,
"F-statistic should be NaN or non-negative"
);
}
#[test]
fn test_regularized_incomplete_beta_i_zero_vs_i_greater_than_zero() {
let group1 = [1.0, 2.0, 3.0];
let group2 = [1.1, 2.1, 3.1];
let groups = [&group1[..], &group2[..]];
let result1 = one_way_anova(&groups).unwrap();
assert!(!result1.p_value.is_nan(), "p-value should not be NaN");
let group3 = [1.0, 2.0, 3.0];
let group4 = [10.0, 11.0, 12.0];
let groups2 = [&group3[..], &group4[..]];
let result2 = one_way_anova(&groups2).unwrap();
assert!(!result2.p_value.is_nan(), "p-value should not be NaN");
assert!(result1.p_value >= 0.0 && result1.p_value <= 1.0);
assert!(result2.p_value >= 0.0 && result2.p_value <= 1.0);
}
#[test]
fn test_regularized_incomplete_beta_division_by_zero_edge_case() {
let group1 = [1.0, 2.0];
let group2 = [3.0, 4.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
!result.p_value.is_nan(),
"p-value should not be NaN even with edge case parameters"
);
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite"
);
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
}
}