use crate::error::{StatsError, StatsResult};
use num_traits::ToPrimitive;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq)]
pub struct AnovaResult {
pub f_statistic: f64,
pub df_between: usize,
pub df_within: usize,
pub p_value: f64,
pub ss_between: f64,
pub ss_within: f64,
pub ms_between: f64,
pub ms_within: f64,
}
pub fn one_way_anova<T>(groups_data: &[&[T]]) -> StatsResult<AnovaResult>
where
T: ToPrimitive + Copy + Debug,
{
if groups_data.len() < 2 {
return Err(StatsError::invalid_input(
"ANOVA requires at least 2 groups",
));
}
let mut groups: Vec<Vec<f64>> = Vec::with_capacity(groups_data.len());
for (group_idx, group) in groups_data.iter().enumerate() {
let mut converted_group = Vec::with_capacity(group.len());
for (value_idx, &value) in group.iter().enumerate() {
let f64_value = value.to_f64().ok_or_else(|| {
StatsError::conversion_error(format!(
"Failed to convert value at group {}, index {} to f64",
group_idx, value_idx
))
})?;
converted_group.push(f64_value);
}
if converted_group.len() < 2 {
return Err(StatsError::invalid_input(format!(
"Each group must have at least 2 observations (group {} has {})",
group_idx,
converted_group.len()
)));
}
groups.push(converted_group);
}
let n_total: usize = groups.iter().map(|group| group.len()).sum();
let grand_mean = groups
.iter()
.flat_map(|group| group.iter().copied())
.sum::<f64>()
/ (n_total as f64);
#[cfg(feature = "parallel")]
let group_means: Vec<f64> = groups
.par_iter()
.map(|group| group.iter().sum::<f64>() / (group.len() as f64))
.collect();
#[cfg(not(feature = "parallel"))]
let group_means: Vec<f64> = groups
.iter()
.map(|group| group.iter().sum::<f64>() / (group.len() as f64))
.collect();
let ss_between: f64 = groups
.iter()
.zip(group_means.iter())
.map(|(group, &group_mean)| (group_mean - grand_mean).powi(2) * (group.len() as f64))
.sum();
let ss_within: f64 = groups
.iter()
.zip(group_means.iter())
.map(|(group, &group_mean)| {
group
.iter()
.map(|&value| (value - group_mean).powi(2))
.sum::<f64>()
})
.sum();
let df_between = groups.len() - 1;
let df_within = n_total - groups.len();
let ms_between = ss_between / (df_between as f64);
let ms_within = ss_within / (df_within as f64);
let f_statistic = ms_between / ms_within;
let p_value = 1.0 - f_distribution_cdf(f_statistic, df_between as u32, df_within as u32);
Ok(AnovaResult {
f_statistic,
df_between,
df_within,
p_value,
ss_between,
ss_within,
ms_between,
ms_within,
})
}
fn f_distribution_cdf(f: f64, df1: u32, df2: u32) -> f64 {
if f <= 0.0 {
return 0.0;
}
if f < 1.0 {
return 1.0 - f_distribution_cdf(1.0 / f, df2, df1);
}
let denominator = df2 as f64 + df1 as f64 * f;
if denominator.abs() < 1e-15 {
return 0.0;
}
let x = df2 as f64 / denominator;
let a = df2 as f64 / 2.0;
let b = df1 as f64 / 2.0;
let cdf = regularized_incomplete_beta(x, a, b);
cdf.clamp(0.0, 1.0)
}
fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x >= 1.0 {
return 1.0;
}
let mut term = 1.0;
let mut sum = 0.0;
let max_iterations = 200;
let ln_beta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
for i in 0..max_iterations {
if i > 0 {
term *= (a + i as f64 - 1.0) * x / i as f64;
}
let denominator = a + b + i as f64 - 1.0;
if denominator.abs() > 1e-15 {
sum += term / denominator;
}
if term.abs() < 1e-15 {
break;
}
}
(x.powf(a) * (1.0 - x).powf(b) / (-ln_beta).exp()) * sum
}
fn ln_gamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::INFINITY; }
let p = [
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
let mut result = 0.999_999_999_999_809_9;
let z: f64 = x - 1.0;
for (i, &val) in p.iter().enumerate() {
result += val / (z + (i as f64) + 1.0);
}
let t = z + p.len() as f64 - 0.5;
use crate::utils::constants::LN_2PI;
LN_2PI / 2.0 + (t + 0.5) * t.ln() - t + result.ln()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_one_way_anova_basic() {
let group1 = [5, 7, 9, 8, 6];
let group2 = [2, 4, 3, 5, 4];
let group3 = [8, 9, 10, 7, 8];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic > 1.0,
"F-statistic should be greater than 1.0"
);
assert!(result.p_value < 0.05, "p-value should be less than 0.05");
assert_eq!(result.df_between, 2);
assert_eq!(result.df_within, 12);
}
#[test]
fn test_one_way_anova_equal_means() {
let group1 = [5, 7, 6, 5, 7];
let group2 = [6, 5, 7, 6, 6];
let group3 = [7, 5, 6, 7, 5];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic < 1.0,
"F-statistic should be less than 1.0 for equal means"
);
assert!(
result.p_value > 0.05,
"p-value should be greater than 0.05 for equal means"
);
}
#[test]
fn test_one_way_anova_different_group_sizes() {
let group1 = [5, 7, 9, 8];
let group2 = [2, 4, 3];
let group3 = [8, 9, 10, 7, 8, 9];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(result.df_between == 2);
assert!(result.df_within == 10);
}
#[test]
fn test_one_way_anova_float_values() {
let group1 = [5.2, 7.3, 9.1, 8.0, 6.5];
let group2 = [2.1, 4.3, 3.7, 5.0, 4.2];
let group3 = [8.1, 9.2, 10.0, 7.5, 8.3];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(result.f_statistic > 1.0);
assert!(result.p_value < 0.05);
}
#[test]
fn test_one_way_anova_invalid_input() {
let group1: [i32; 0] = [];
let group2 = [1, 2, 3];
let groups1 = [&group1[..], &group2[..]];
let result = one_way_anova(&groups1);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
let groups2 = [&group2[..]];
let result = one_way_anova(&groups2);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
let groups3: [&[i32]; 0] = [];
let result = one_way_anova(&groups3);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_anova_result_fields() {
let group1 = [5, 7, 9, 8, 6];
let group2 = [2, 4, 3, 5, 4];
let group3 = [8, 9, 10, 7, 8];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(result.ss_between > 0.0);
assert!(result.ss_within > 0.0);
assert!(result.ms_between > 0.0);
assert!(result.ms_within > 0.0);
let calculated_f = result.ms_between / result.ms_within;
assert!((calculated_f - result.f_statistic).abs() < 1e-10);
}
#[test]
fn test_f_distribution_cdf_with_df1_eq_df2_eq_1() {
let group1 = [1.0, 2.0];
let group2 = [3.0, 4.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(!result.p_value.is_nan(), "p-value should not be NaN");
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite"
);
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
assert!(
!result.f_statistic.is_nan(),
"F-statistic should not be NaN"
);
assert!(
!result.f_statistic.is_infinite(),
"F-statistic should not be infinite"
);
}
#[test]
fn test_f_distribution_cdf_edge_case_df1_1_df2_1() {
let group1 = [1.0, 2.0];
let group2 = [3.0, 4.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(!result.p_value.is_nan(), "p-value should not be NaN");
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite"
);
let clamped_p = result.p_value.max(0.0).min(1.0);
assert!(
(result.p_value - clamped_p).abs() < 1e-6
|| (result.p_value >= -1e-6 && result.p_value <= 1.0 + 1e-6),
"p-value should be approximately in [0, 1], got {}",
result.p_value
);
assert!(
!result.f_statistic.is_nan(),
"F-statistic should not be NaN"
);
assert!(
!result.f_statistic.is_infinite(),
"F-statistic should not be infinite"
);
assert!(
result.f_statistic >= 0.0,
"F-statistic should be non-negative"
);
assert_eq!(result.df_between, 1);
assert_eq!(result.df_within, 2);
}
#[test]
fn test_f_distribution_cdf_f_less_than_one() {
let group1 = [1.0, 2.0, 3.0, 4.0, 5.0];
let group2 = [1.1, 2.1, 3.1, 4.1, 5.1];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
!result.p_value.is_nan(),
"p-value should not be NaN when F < 1.0"
);
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite when F < 1.0"
);
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
if result.f_statistic < 1.0 {
assert!(result.f_statistic > 0.0, "F-statistic should be positive");
}
}
#[test]
fn test_f_distribution_cdf_f_zero() {
let group1 = [5.0, 5.0, 5.0];
let group2 = [5.0, 5.0, 5.0];
let group3 = [5.0, 5.0, 5.0];
let groups = [&group1[..], &group2[..], &group3[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic.is_nan() || result.f_statistic >= 0.0,
"F-statistic should be NaN or non-negative"
);
}
#[test]
fn test_f_distribution_cdf_f_negative() {
let group1 = [1.0, 1.0, 1.0];
let group2 = [1.0, 1.0, 1.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
result.f_statistic.is_nan() || result.f_statistic >= 0.0,
"F-statistic should be NaN or non-negative"
);
}
#[test]
fn test_regularized_incomplete_beta_i_zero_vs_i_greater_than_zero() {
let group1 = [1.0, 2.0, 3.0];
let group2 = [1.1, 2.1, 3.1];
let groups = [&group1[..], &group2[..]];
let result1 = one_way_anova(&groups).unwrap();
assert!(!result1.p_value.is_nan(), "p-value should not be NaN");
let group3 = [1.0, 2.0, 3.0];
let group4 = [10.0, 11.0, 12.0];
let groups2 = [&group3[..], &group4[..]];
let result2 = one_way_anova(&groups2).unwrap();
assert!(!result2.p_value.is_nan(), "p-value should not be NaN");
assert!(result1.p_value >= 0.0 && result1.p_value <= 1.0);
assert!(result2.p_value >= 0.0 && result2.p_value <= 1.0);
}
#[test]
fn test_regularized_incomplete_beta_division_by_zero_edge_case() {
let group1 = [1.0, 2.0];
let group2 = [3.0, 4.0];
let groups = [&group1[..], &group2[..]];
let result = one_way_anova(&groups).unwrap();
assert!(
!result.p_value.is_nan(),
"p-value should not be NaN even with edge case parameters"
);
assert!(
!result.p_value.is_infinite(),
"p-value should not be infinite"
);
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
}
}