use crate::error::{StatsError, StatsResult};
use crate::prob::erf;
use crate::utils::constants::{LN_2PI, SQRT_2};
use num_traits::ToPrimitive;
use std::f64;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct TTestResult {
pub t_statistic: f64,
pub degrees_of_freedom: f64,
pub p_value: f64,
pub mean_values: Vec<f64>,
pub std_devs: Vec<f64>,
pub std_error: f64,
}
pub fn one_sample_t_test<T>(data: &[T], population_mean: T) -> StatsResult<TTestResult>
where
T: ToPrimitive + Debug + Copy,
{
if data.is_empty() {
return Err(StatsError::empty_data(
"Cannot perform t-test on empty data",
));
}
if data.len() < 2 {
return Err(StatsError::invalid_input(
"Need at least 2 data points for t-test",
));
}
let pop_mean = population_mean
.to_f64()
.ok_or_else(|| StatsError::conversion_error("Failed to convert population mean to f64"))?;
let n = data.len() as f64;
let mean = calculate_mean(data)?;
let variance = calculate_variance(data, mean)?;
let std_dev = variance.sqrt();
let std_error = std_dev / n.sqrt();
let t_statistic = (mean - pop_mean) / std_error;
let df = n - 1.0;
let p_value = calculate_p_value(t_statistic.abs(), df);
Ok(TTestResult {
t_statistic,
degrees_of_freedom: df,
p_value,
mean_values: vec![mean],
std_devs: vec![std_dev],
std_error,
})
}
pub fn two_sample_t_test<T>(
data1: &[T],
data2: &[T],
equal_variances: bool,
) -> StatsResult<TTestResult>
where
T: ToPrimitive + Debug + Copy,
{
if data1.is_empty() || data2.is_empty() {
return Err(StatsError::empty_data(
"Cannot perform t-test on empty data",
));
}
if data1.len() < 2 || data2.len() < 2 {
return Err(StatsError::invalid_input(
"Need at least 2 data points in each group for t-test",
));
}
let n1 = data1.len() as f64;
let n2 = data2.len() as f64;
let mean1 = calculate_mean(data1)?;
let mean2 = calculate_mean(data2)?;
let var1 = calculate_variance(data1, mean1)?;
let var2 = calculate_variance(data2, mean2)?;
let std_dev1 = var1.sqrt();
let std_dev2 = var2.sqrt();
let t_statistic: f64;
let degrees_of_freedom: f64;
let std_error: f64;
if equal_variances {
let pooled_variance = ((n1 - 1.0) * var1 + (n2 - 1.0) * var2) / (n1 + n2 - 2.0);
std_error = (pooled_variance * (1.0 / n1 + 1.0 / n2)).sqrt();
t_statistic = (mean1 - mean2) / std_error;
degrees_of_freedom = n1 + n2 - 2.0;
} else {
let var1_n1 = var1 / n1;
let var2_n2 = var2 / n2;
std_error = (var1_n1 + var2_n2).sqrt();
t_statistic = (mean1 - mean2) / std_error;
let numerator = (var1_n1 + var2_n2).powi(2);
let denominator = (var1_n1.powi(2) / (n1 - 1.0)) + (var2_n2.powi(2) / (n2 - 1.0));
degrees_of_freedom = numerator / denominator;
}
let p_value = calculate_p_value(t_statistic.abs(), degrees_of_freedom);
Ok(TTestResult {
t_statistic,
degrees_of_freedom,
p_value,
mean_values: vec![mean1, mean2],
std_devs: vec![std_dev1, std_dev2],
std_error,
})
}
pub fn paired_t_test<T>(data1: &[T], data2: &[T]) -> StatsResult<TTestResult>
where
T: ToPrimitive + Debug + Copy,
{
if data1.is_empty() || data2.is_empty() {
return Err(StatsError::empty_data(
"Cannot perform paired t-test on empty data",
));
}
if data1.len() != data2.len() {
return Err(StatsError::dimension_mismatch(format!(
"Paired t-test requires equal sample sizes (got {} and {})",
data1.len(),
data2.len()
)));
}
if data1.len() < 2 {
return Err(StatsError::invalid_input(
"Need at least 2 pairs for paired t-test",
));
}
let n = data1.len() as f64;
let mut sum1 = 0.0_f64;
let mut sum2 = 0.0_f64;
let mut m2_1 = 0.0_f64; let mut m2_2 = 0.0_f64; let mut diff_mean = 0.0_f64; let mut diff_m2 = 0.0_f64;
for i in 0..data1.len() {
let val1 = data1[i].to_f64().ok_or_else(|| {
StatsError::conversion_error(format!(
"Failed to convert data1 value at index {} to f64",
i
))
})?;
let val2 = data2[i].to_f64().ok_or_else(|| {
StatsError::conversion_error(format!(
"Failed to convert data2 value at index {} to f64",
i
))
})?;
let count = (i + 1) as f64;
let delta1 = val1 - sum1 / count.max(1.0);
sum1 += val1;
let delta1_post = val1 - sum1 / count;
m2_1 += delta1 * delta1_post;
let delta2 = val2 - sum2 / count.max(1.0);
sum2 += val2;
let delta2_post = val2 - sum2 / count;
m2_2 += delta2 * delta2_post;
let d = val1 - val2;
let delta_d = d - diff_mean;
diff_mean += delta_d / count;
let delta_d2 = d - diff_mean;
diff_m2 += delta_d * delta_d2;
}
let mean1 = sum1 / n;
let mean2 = sum2 / n;
let std_dev1 = (m2_1 / (n - 1.0)).sqrt();
let std_dev2 = (m2_2 / (n - 1.0)).sqrt();
let mean_diff = diff_mean;
let variance = diff_m2 / (n - 1.0);
let std_dev = variance.sqrt();
let std_error = std_dev / n.sqrt();
let t_statistic = mean_diff / std_error;
let degrees_of_freedom = n - 1.0;
let p_value = calculate_p_value(t_statistic.abs(), degrees_of_freedom);
Ok(TTestResult {
t_statistic,
degrees_of_freedom,
p_value,
mean_values: vec![mean_diff, mean1, mean2],
std_devs: vec![std_dev, std_dev1, std_dev2],
std_error,
})
}
#[inline]
fn calculate_mean<T>(data: &[T]) -> StatsResult<f64>
where
T: ToPrimitive + Debug,
{
if data.is_empty() {
return Err(StatsError::empty_data(
"Cannot calculate mean of empty data",
));
}
let mut sum = 0.0;
let n = data.len() as f64;
for (i, value) in data.iter().enumerate() {
let v = value.to_f64().ok_or_else(|| {
StatsError::conversion_error(format!("Failed to convert value at index {} to f64", i))
})?;
sum += v;
}
Ok(sum / n)
}
#[inline]
fn calculate_variance<T>(data: &[T], mean: f64) -> StatsResult<f64>
where
T: ToPrimitive + Debug,
{
if data.is_empty() {
return Err(StatsError::empty_data(
"Cannot calculate variance of empty data",
));
}
if data.len() < 2 {
return Err(StatsError::invalid_input(
"Need at least 2 data points to calculate variance",
));
}
let mut sum_squared_diff = 0.0;
let n = data.len() as f64;
for (i, value) in data.iter().enumerate() {
let v = value.to_f64().ok_or_else(|| {
StatsError::conversion_error(format!("Failed to convert value at index {} to f64", i))
})?;
sum_squared_diff += (v - mean).powi(2);
}
Ok(sum_squared_diff / (n - 1.0))
}
#[inline]
fn calculate_p_value(t_stat: f64, df: f64) -> f64 {
if df > 1000.0 {
let z = t_stat;
return 2.0 * (1.0 - standard_normal_cdf(z));
}
let a = df / (df + t_stat * t_stat);
let ix = incomplete_beta(0.5 * df, 0.5, a);
(2.0 * (1.0 - ix)).clamp(0.0, 1.0)
}
#[inline]
fn standard_normal_cdf(x: f64) -> f64 {
0.5 * (1.0 + erf(x / SQRT_2).unwrap())
}
#[inline]
fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
if x == 0.0 || x == 1.0 {
return x;
}
let symmetry_point = x > (a / (a + b));
let (a_calc, b_calc, x_calc) = if symmetry_point {
(b, a, 1.0 - x)
} else {
(a, b, x)
};
let max_iterations = 200;
let epsilon = 1e-10;
let front_factor = x_calc.powf(a_calc) * (1.0 - x_calc).powf(b_calc) / beta(a_calc, b_calc);
let mut h = 1.0;
let mut d = 1.0;
let mut result = 0.0;
for m in 1..max_iterations {
let m = m as f64;
let m2 = 2.0 * m;
let numerator = (m * (b_calc - m) * x_calc) / ((a_calc + m2 - 1.0) * (a_calc + m2));
d = 1.0 + numerator * d;
if d.abs() < epsilon {
d = epsilon;
}
d = 1.0 / d;
h = 1.0 + numerator / h;
if h.abs() < epsilon {
h = epsilon;
}
result *= h * d;
let numerator = -((a_calc + m) * (a_calc + b_calc + m) * x_calc)
/ ((a_calc + m2) * (a_calc + m2 + 1.0));
d = 1.0 + numerator * d;
if d.abs() < epsilon {
d = epsilon;
}
d = 1.0 / d;
h = 1.0 + numerator / h;
if h.abs() < epsilon {
h = epsilon;
}
let delta = h * d;
result *= delta;
if (delta - 1.0).abs() < epsilon {
break;
}
}
result *= front_factor;
if symmetry_point { 1.0 - result } else { result }
}
#[inline]
fn beta(a: f64, b: f64) -> f64 {
let log_gamma_a = ln_gamma(a);
let log_gamma_b = ln_gamma(b);
let log_gamma_ab = ln_gamma(a + b);
(log_gamma_a + log_gamma_b - log_gamma_ab).exp()
}
#[inline]
fn ln_gamma(x: f64) -> f64 {
let p = [
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
if x < 0.5 {
crate::utils::constants::PI.ln()
- (crate::utils::constants::PI * x).sin().ln()
- ln_gamma(1.0 - x)
} else {
let mut sum = p[0];
for (i, &value) in p.iter().enumerate().skip(1) {
sum += value / (x + i as f64);
}
let t = x + 7.5;
(x + 0.5) * t.ln() - t + LN_2PI * 0.5 + sum.ln() / x
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_p_value_range_one_sample() {
let data = vec![5.2, 6.4, 6.9, 7.3, 7.5, 7.8, 8.1, 8.4, 9.2, 9.5];
let population_mean = 7.0;
let result = one_sample_t_test(&data, population_mean).unwrap();
assert!(
result.p_value >= 0.0,
"p-value should be >= 0.0, got {}",
result.p_value
);
assert!(
result.p_value <= 1.0,
"p-value should be <= 1.0, got {}",
result.p_value
);
}
#[test]
fn test_p_value_range_two_sample() {
let group1 = vec![5.2, 6.4, 6.9, 7.3, 7.5, 7.8, 8.1, 8.4, 9.2, 9.5];
let group2 = vec![4.1, 5.0, 5.5, 6.2, 6.3, 6.5, 6.8, 7.1, 7.4, 7.5];
let result = two_sample_t_test(&group1, &group2, false).unwrap();
assert!(
result.p_value >= 0.0,
"p-value should be >= 0.0, got {}",
result.p_value
);
assert!(
result.p_value <= 1.0,
"p-value should be <= 1.0, got {}",
result.p_value
);
}
#[test]
fn test_p_value_range_paired() {
let before = vec![12.1, 11.3, 13.7, 14.2, 13.8, 12.5, 11.9, 12.8, 14.0, 13.5];
let after = vec![12.9, 13.0, 14.3, 15.0, 14.8, 13.9, 12.7, 13.5, 15.2, 14.1];
let result = paired_t_test(&before, &after).unwrap();
assert!(
result.p_value >= 0.0,
"p-value should be >= 0.0, got {}",
result.p_value
);
assert!(
result.p_value <= 1.0,
"p-value should be <= 1.0, got {}",
result.p_value
);
}
#[test]
fn test_p_value_edge_cases() {
let test_cases = vec![
(0.0, 5.0), (1.0, 10.0), (2.0, 20.0), (5.0, 30.0), ];
for (t_stat, df) in test_cases {
let p_value = calculate_p_value(t_stat, df);
assert!(
p_value >= 0.0,
"p-value should be >= 0.0 for t={}, df={}, got {}",
t_stat,
df,
p_value
);
assert!(
p_value <= 1.0,
"p-value should be <= 1.0 for t={}, df={}, got {}",
t_stat,
df,
p_value
);
}
}
#[test]
fn test_two_sample_t_test_equal_variances_true() {
let group1 = vec![5.2, 6.4, 6.9, 7.3, 7.5];
let group2 = vec![4.1, 5.0, 5.5, 6.2, 6.3];
let result = two_sample_t_test(&group1, &group2, true).unwrap();
assert!(
!result.t_statistic.is_nan(),
"t-statistic should not be NaN"
);
assert!(!result.p_value.is_nan(), "p-value should not be NaN");
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
let expected_df = (group1.len() + group2.len() - 2) as f64;
assert!(
(result.degrees_of_freedom - expected_df).abs() < 1e-10,
"Degrees of freedom should be n1 + n2 - 2 for equal variances"
);
}
#[test]
fn test_two_sample_t_test_equal_variances_false() {
let group1 = vec![5.2, 6.4, 6.9, 7.3, 7.5];
let group2 = vec![4.1, 5.0, 5.5, 6.2, 6.3];
let result = two_sample_t_test(&group1, &group2, false).unwrap();
assert!(
!result.t_statistic.is_nan(),
"t-statistic should not be NaN"
);
assert!(!result.p_value.is_nan(), "p-value should not be NaN");
assert!(
result.p_value >= 0.0 && result.p_value <= 1.0,
"p-value should be in [0, 1]"
);
let expected_df_min = (group1.len() + group2.len() - 2) as f64;
assert!(
result.degrees_of_freedom <= expected_df_min + 1e-10,
"Welch's df should be <= n1 + n2 - 2"
);
}
#[test]
fn test_two_sample_t_test_equal_vs_unequal_variances() {
let group1 = vec![5.2, 6.4, 6.9, 7.3, 7.5, 7.8, 8.1, 8.4, 9.2, 9.5];
let group2 = vec![4.1, 5.0, 5.5, 6.2, 6.3, 6.5, 6.8, 7.1, 7.4, 7.5];
let result_equal = two_sample_t_test(&group1, &group2, true).unwrap();
let result_unequal = two_sample_t_test(&group1, &group2, false).unwrap();
assert!(!result_equal.p_value.is_nan());
assert!(!result_unequal.p_value.is_nan());
assert_ne!(
result_equal.degrees_of_freedom, result_unequal.degrees_of_freedom,
"Degrees of freedom should differ between equal and unequal variance tests"
);
}
#[test]
fn test_one_sample_t_test_single_data_point() {
let data = vec![5.0];
let result = one_sample_t_test(&data, 5.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_two_sample_t_test_single_data_point() {
let data1 = vec![5.0];
let data2 = vec![4.0, 5.0];
let result = two_sample_t_test(&data1, &data2, false);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
let data1 = vec![4.0, 5.0];
let data2 = vec![5.0];
let result = two_sample_t_test(&data1, &data2, false);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_paired_t_test_length_mismatch() {
let data1 = vec![1.0, 2.0, 3.0];
let data2 = vec![2.0, 3.0]; let result = paired_t_test(&data1, &data2);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::DimensionMismatch { .. }
));
}
#[test]
fn test_paired_t_test_single_data_point() {
let data1 = vec![1.0];
let data2 = vec![2.0];
let result = paired_t_test(&data1, &data2);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
}