use crate::distributions::f;
use crate::error::{StatsError, StatsResult};
use crate::mean;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, NumCast};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct AnovaResult<F> {
pub f_statistic: F,
pub p_value: F,
pub df_treatment: usize,
pub df_error: usize,
pub ss_treatment: F,
pub ss_error: F,
pub ms_treatment: F,
pub ms_error: F,
pub ss_total: F,
}
#[allow(dead_code)]
pub fn one_way_anova<F>(groups: &[&ArrayView1<F>]) -> StatsResult<AnovaResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ Debug
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps,
{
if groups.len() < 2 {
return Err(StatsError::InvalidArgument(
"At least two groups are required for ANOVA".to_string(),
));
}
for (i, group) in groups.iter().enumerate() {
if group.is_empty() {
return Err(StatsError::InvalidArgument(format!(
"Group {} is empty",
i + 1
)));
}
}
let n_total = groups.iter().map(|group| group.len()).sum::<usize>();
if n_total <= groups.len() {
return Err(StatsError::InvalidArgument(
"Not enough data for ANOVA (need more observations than groups)".to_string(),
));
}
let mut all_values = Array1::<F>::zeros(n_total);
let mut index = 0;
for group in groups {
for &value in group.iter() {
all_values[index] = value;
index += 1;
}
}
let grand_mean = mean(&all_values.view())?;
let mut group_means = Vec::with_capacity(groups.len());
let mut groupsizes = Vec::with_capacity(groups.len());
for group in groups {
group_means.push(mean(group)?);
groupsizes.push(group.len());
}
let mut ss_treatment = F::zero();
let mut ss_error = F::zero();
let mut ss_total = F::zero();
for (&group_mean, &groupsize) in group_means.iter().zip(groupsizes.iter()) {
let size_f = F::from(groupsize).expect("Failed to convert to float");
ss_treatment = ss_treatment + size_f * (group_mean - grand_mean).powi(2);
}
for (group, &group_mean) in groups.iter().zip(group_means.iter()) {
for &value in group.iter() {
ss_error = ss_error + (value - group_mean).powi(2);
ss_total = ss_total + (value - grand_mean).powi(2);
}
}
let df_treatment = groups.len() - 1;
let df_error = n_total - groups.len();
let ms_treatment = ss_treatment / F::from(df_treatment).expect("Failed to convert to float");
let ms_error = ss_error / F::from(df_error).expect("Failed to convert to float");
let f_statistic = ms_treatment / ms_error;
let f_dist = f(
F::from(df_treatment).expect("Failed to convert to float"),
F::from(df_error).expect("Failed to convert to float"),
F::zero(),
F::one(),
)?;
let p_value = F::one() - f_dist.cdf(f_statistic);
Ok(AnovaResult {
f_statistic,
p_value,
df_treatment,
df_error,
ss_treatment,
ss_error,
ms_treatment,
ms_error,
ss_total,
})
}
pub type TukeyHSDResult<F> = Vec<(usize, usize, F, F, bool)>;
#[allow(dead_code)]
pub fn tukey_hsd<F>(groups: &[&ArrayView1<F>], alpha: F) -> StatsResult<TukeyHSDResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ Debug
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps,
{
if groups.len() < 2 {
return Err(StatsError::InvalidArgument(
"At least two groups are required for Tukey's HSD".to_string(),
));
}
let anova_result = one_way_anova(groups)?;
let mut group_means = Vec::with_capacity(groups.len());
let mut groupsizes = Vec::with_capacity(groups.len());
for group in groups {
group_means.push(mean(group)?);
groupsizes.push(F::from(group.len()).expect("Test: operation failed"));
}
let critical_q = calculate_studentized_range_critical_value(
alpha,
F::from(groups.len()).expect("Test: operation failed"),
F::from(anova_result.df_error).expect("Failed to convert to float"),
)?;
let mut results = Vec::new();
for i in 0..groups.len() {
for j in (i + 1)..groups.len() {
let mean_diff = (group_means[i] - group_means[j]).abs();
let harmonic_mean_n = (F::from(2.0).expect("Failed to convert constant to float")
* groupsizes[i]
* groupsizes[j])
/ (groupsizes[i] + groupsizes[j]);
let std_error = (anova_result.ms_error / harmonic_mean_n).sqrt();
let q_stat = mean_diff / std_error;
let p_value = calculate_studentized_range_p_value(
q_stat,
F::from(groups.len()).expect("Test: operation failed"),
F::from(anova_result.df_error).expect("Failed to convert to float"),
);
let significant = q_stat > critical_q;
results.push((i, j, mean_diff, p_value, significant));
}
}
Ok(results)
}
#[allow(dead_code)]
fn calculate_studentized_range_critical_value<F: Float + NumCast>(
alpha: F,
k: F,
df: F,
) -> StatsResult<F> {
let q_05_values = [
[2.77, 3.31, 3.63, 3.86, 4.03], [2.66, 3.17, 3.48, 3.70, 3.86], [2.58, 3.08, 3.38, 3.60, 3.76], [2.52, 3.01, 3.31, 3.51, 3.67], [2.47, 2.95, 3.24, 3.45, 3.60], [2.33, 2.77, 3.04, 3.22, 3.37], ];
let q_01_values = [
[3.72, 4.32, 4.68, 4.93, 5.12], [3.51, 4.07, 4.41, 4.64, 4.82], [3.36, 3.89, 4.22, 4.44, 4.62], [3.25, 3.76, 4.07, 4.28, 4.45], [3.17, 3.66, 3.96, 4.17, 4.33], [2.97, 3.43, 3.71, 3.89, 4.04], ];
let alpha_f64 = <f64 as NumCast>::from(alpha).expect("Test: operation failed");
let k_f64 = <f64 as NumCast>::from(k).expect("Test: operation failed");
let df_f64 = <f64 as NumCast>::from(df).expect("Test: operation failed");
if alpha_f64 <= 0.0 || alpha_f64 >= 1.0 {
return Err(StatsError::InvalidArgument(
"Alpha must be between 0 and 1".to_string(),
));
}
if !(2.0..=6.0).contains(&k_f64) {
return Err(StatsError::InvalidArgument(
"This approximation supports only 2 to 6 groups".to_string(),
));
}
if df_f64 < 1.0 {
return Err(StatsError::InvalidArgument(
"Degrees of freedom must be positive".to_string(),
));
}
let table = if (alpha_f64 - 0.05).abs() < 0.01 {
q_05_values
} else if (alpha_f64 - 0.01).abs() < 0.01 {
q_01_values
} else {
return Err(StatsError::InvalidArgument(
"This approximation supports only alpha=0.05 or alpha=0.01".to_string(),
));
};
let df_index = if df_f64 <= 10.0 {
0
} else if df_f64 <= 20.0 {
1
} else if df_f64 <= 30.0 {
2
} else if df_f64 <= 60.0 {
3
} else if df_f64 <= 120.0 {
4
} else {
5
};
let k_index = k_f64 as usize - 2;
Ok(F::from(table[df_index][k_index]).expect("Failed to convert to float"))
}
#[allow(dead_code)]
fn calculate_studentized_range_p_value<F: Float + NumCast>(q: F, k: F, df: F) -> F {
let q_f64 = <f64 as NumCast>::from(q).expect("Test: operation failed");
let k_f64 = <f64 as NumCast>::from(k).expect("Test: operation failed");
let df_f64 = <f64 as NumCast>::from(df).expect("Test: operation failed");
let adjustment = 0.7 + 0.1 * k_f64;
let z = q_f64 / adjustment;
let p = if z < 0.0 {
0.5
} else {
let t = 1.0 / (1.0 + 0.2316419 * z);
let poly = t
* (0.319381530
+ t * (-0.356563782 + t * (1.781477937 + t * (-1.821255978 + t * 1.330274429))));
1.0 - 0.5 * 0.39894228 * (-0.5 * z * z).exp() * poly
};
let p_adjusted = 1.0 - (1.0 - p).powf(k_f64 - 1.0);
let df_adjustment = 1.0 - 10.0 / df_f64;
let final_p = p_adjusted * df_adjustment;
F::from(final_p.clamp(0.0, 1.0)).expect("Test: operation failed")
}