use crate::distributions;
use crate::error::{StatsError, StatsResult};
use crate::{mean, std};
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, NumCast};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Alternative {
TwoSided,
Less,
Greater,
}
#[derive(Debug, Clone)]
pub struct TTestResult<F: Float + std::fmt::Display> {
pub statistic: F,
pub pvalue: F,
pub df: F,
pub alternative: Alternative,
pub info: Option<String>,
}
#[allow(dead_code)]
pub fn ttest_1samp<F>(
a: &ArrayView1<F>,
popmean: F,
alternative: Alternative,
nan_policy: &str,
) -> StatsResult<TTestResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::marker::Send
+ std::marker::Sync
+ std::fmt::Display
+ 'static
+ scirs2_core::simd_ops::SimdUnifiedOps,
{
let data = match nan_policy {
"propagate" => a.to_owned(),
"raise" => {
if a.iter().any(|x| x.is_nan()) {
return Err(StatsError::InvalidArgument(
"Input array contains NaN values".to_string(),
));
}
a.to_owned()
}
"omit" => {
let validdata: Vec<F> = a.iter().filter(|&&x| !x.is_nan()).copied().collect();
Array1::from(validdata)
}
_ => {
return Err(StatsError::InvalidArgument(format!(
"Invalid nan_policy: {}. Use 'propagate', 'raise', or 'omit'",
nan_policy
)));
}
};
if data.is_empty() {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
let sample_mean = mean(&data.view())?;
let sample_std = std(&data.view(), 1, None)?;
let n = F::from(data.len()).expect("Operation failed");
let se = sample_std / n.sqrt();
let t_stat = (sample_mean - popmean) / se;
let df = F::from(data.len() - 1).expect("Operation failed");
let t_dist = distributions::t(df, F::zero(), F::one())?;
let p_value = match alternative {
Alternative::TwoSided => {
let abs_t = t_stat.abs();
F::from(2.0).expect("Failed to convert constant to float")
* (F::one() - t_dist.cdf(abs_t))
}
Alternative::Less => t_dist.cdf(t_stat),
Alternative::Greater => F::one() - t_dist.cdf(t_stat),
};
let info = format!("mean={}, std_err={}, n={}", sample_mean, se, data.len());
Ok(TTestResult {
statistic: t_stat,
pvalue: p_value,
df,
alternative,
info: Some(info),
})
}
#[allow(dead_code)]
pub fn ttest_ind<F>(
a: &ArrayView1<F>,
b: &ArrayView1<F>,
equal_var: bool,
alternative: Alternative,
nan_policy: &str,
) -> StatsResult<TTestResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::marker::Send
+ std::marker::Sync
+ std::fmt::Display
+ 'static
+ scirs2_core::simd_ops::SimdUnifiedOps,
{
let (data_a, data_b) = match nan_policy {
"propagate" => (a.to_owned(), b.to_owned()),
"raise" => {
if a.iter().any(|x| x.is_nan()) || b.iter().any(|x| x.is_nan()) {
return Err(StatsError::InvalidArgument(
"Input arrays contain NaN values".to_string(),
));
}
(a.to_owned(), b.to_owned())
}
"omit" => {
let valid_a: Vec<F> = a.iter().filter(|&&x| !x.is_nan()).copied().collect();
let valid_b: Vec<F> = b.iter().filter(|&&x| !x.is_nan()).copied().collect();
(Array1::from(valid_a), Array1::from(valid_b))
}
_ => {
return Err(StatsError::InvalidArgument(format!(
"Invalid nan_policy: {}. Use 'propagate', 'raise', or 'omit'",
nan_policy
)));
}
};
if data_a.is_empty() || data_b.is_empty() {
return Err(StatsError::InvalidArgument(
"Input arrays cannot be empty".to_string(),
));
}
let mean_a = mean(&data_a.view())?;
let mean_b = mean(&data_b.view())?;
let n_a = F::from(data_a.len()).expect("Operation failed");
let n_b = F::from(data_b.len()).expect("Operation failed");
let std_a = std(&data_a.view(), 1, None)?;
let std_b = std(&data_b.view(), 1, None)?;
let t_stat: F;
let df: F;
let variance_a = std_a * std_a;
let variance_b = std_b * std_b;
let test_description: String;
if equal_var {
let pooled_var = ((n_a - F::one()) * variance_a + (n_b - F::one()) * variance_b)
/ (n_a + n_b - F::from(2.0).expect("Failed to convert constant to float"));
let se = (pooled_var * (F::one() / n_a + F::one() / n_b)).sqrt();
t_stat = (mean_a - mean_b) / se;
df = n_a + n_b - F::from(2.0).expect("Failed to convert constant to float");
test_description = "Student's t-test".to_string();
} else {
let var_a_over_n_a = variance_a / n_a;
let var_b_over_n_b = variance_b / n_b;
let se = (var_a_over_n_a + var_b_over_n_b).sqrt();
t_stat = (mean_a - mean_b) / se;
let numerator = (var_a_over_n_a + var_b_over_n_b).powi(2);
let denominator = (var_a_over_n_a.powi(2) / (n_a - F::one()))
+ (var_b_over_n_b.powi(2) / (n_b - F::one()));
df = numerator / denominator;
test_description = "Welch's t-test".to_string();
}
let t_dist = distributions::t(df, F::zero(), F::one())?;
let p_value = match alternative {
Alternative::TwoSided => {
let abs_t = t_stat.abs();
F::from(2.0).expect("Failed to convert constant to float")
* (F::one() - t_dist.cdf(abs_t))
}
Alternative::Less => t_dist.cdf(t_stat),
Alternative::Greater => F::one() - t_dist.cdf(t_stat),
};
let info = format!(
"{}: mean_a={}, mean_b={}, std_a={}, std_b={}, n_a={}, n_b={}",
test_description,
mean_a,
mean_b,
std_a,
std_b,
data_a.len(),
data_b.len()
);
Ok(TTestResult {
statistic: t_stat,
pvalue: p_value,
df,
alternative,
info: Some(info),
})
}
#[allow(dead_code)]
pub fn ttest_rel<F>(
a: &ArrayView1<F>,
b: &ArrayView1<F>,
alternative: Alternative,
nan_policy: &str,
) -> StatsResult<TTestResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::marker::Send
+ std::marker::Sync
+ std::fmt::Display
+ 'static
+ scirs2_core::simd_ops::SimdUnifiedOps,
{
let paireddata = match nan_policy {
"propagate" => {
if a.len() != b.len() {
return Err(StatsError::DimensionMismatch(
"Input arrays must have the same length for paired t-test".to_string(),
));
}
let mut pairs = Vec::with_capacity(a.len());
for i in 0..a.len() {
pairs.push(a[i] - b[i]);
}
Array1::from(pairs)
}
"raise" => {
if a.iter().any(|x| x.is_nan()) || b.iter().any(|x| x.is_nan()) {
return Err(StatsError::InvalidArgument(
"Input arrays contain NaN values".to_string(),
));
}
if a.len() != b.len() {
return Err(StatsError::DimensionMismatch(
"Input arrays must have the same length for paired t-test".to_string(),
));
}
let mut pairs = Vec::with_capacity(a.len());
for i in 0..a.len() {
pairs.push(a[i] - b[i]);
}
Array1::from(pairs)
}
"omit" => {
if a.len() != b.len() {
return Err(StatsError::DimensionMismatch(
"Input arrays must have the same length for paired t-test".to_string(),
));
}
let mut pairs = Vec::new();
for i in 0..a.len() {
if !a[i].is_nan() && !b[i].is_nan() {
pairs.push(a[i] - b[i]);
}
}
Array1::from(pairs)
}
_ => {
return Err(StatsError::InvalidArgument(format!(
"Invalid nan_policy: {}. Use 'propagate', 'raise', or 'omit'",
nan_policy
)));
}
};
if paireddata.is_empty() {
return Err(StatsError::InvalidArgument(
"No valid paired data after NaN removal".to_string(),
));
}
let one_sample_result = ttest_1samp(&paireddata.view(), F::zero(), alternative, "omit")?;
let valid_a: Vec<F> = a.iter().filter(|&&x| !x.is_nan()).copied().collect();
let valid_b: Vec<F> = b.iter().filter(|&&x| !x.is_nan()).copied().collect();
let mean_a = if !valid_a.is_empty() {
valid_a.iter().cloned().sum::<F>() / F::from(valid_a.len()).expect("Operation failed")
} else {
F::nan()
};
let mean_b = if !valid_b.is_empty() {
valid_b.iter().cloned().sum::<F>() / F::from(valid_b.len()).expect("Operation failed")
} else {
F::nan()
};
let info = format!(
"Paired t-test: mean_a={}, mean_b={}, mean_diff={}, n_pairs={}",
mean_a,
mean_b,
one_sample_result.statistic * one_sample_result.statistic.signum(),
paireddata.len()
);
Ok(TTestResult {
statistic: one_sample_result.statistic,
pvalue: one_sample_result.pvalue,
df: one_sample_result.df,
alternative,
info: Some(info),
})
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn ttest_ind_from_stats<F>(
mean1: F,
std1: F,
nobs1: usize,
mean2: F,
std2: F,
nobs2: usize,
equal_var: bool,
alternative: Alternative,
) -> StatsResult<TTestResult<F>>
where
F: Float + NumCast + std::marker::Send + std::marker::Sync + std::fmt::Display + 'static,
{
if nobs1 == 0 || nobs2 == 0 {
return Err(StatsError::InvalidArgument(
"Sample sizes must be positive".to_string(),
));
}
if std1.is_nan() || std2.is_nan() || mean1.is_nan() || mean2.is_nan() {
return Err(StatsError::InvalidArgument(
"Means and standard deviations must not be NaN".to_string(),
));
}
if std1 < F::zero() || std2 < F::zero() {
return Err(StatsError::InvalidArgument(
"Standard deviations must be non-negative".to_string(),
));
}
let n1 = F::from(nobs1).expect("Failed to convert to float");
let n2 = F::from(nobs2).expect("Failed to convert to float");
let t_stat: F;
let df: F;
let variance1 = std1 * std1;
let variance2 = std2 * std2;
let test_description: String;
if equal_var {
let pooled_var = ((n1 - F::one()) * variance1 + (n2 - F::one()) * variance2)
/ (n1 + n2 - F::from(2.0).expect("Failed to convert constant to float"));
let se = (pooled_var * (F::one() / n1 + F::one() / n2)).sqrt();
t_stat = (mean1 - mean2) / se;
df = n1 + n2 - F::from(2.0).expect("Failed to convert constant to float");
test_description = "Student's t-test".to_string();
} else {
let var1_over_n1 = variance1 / n1;
let var2_over_n2 = variance2 / n2;
let se = (var1_over_n1 + var2_over_n2).sqrt();
t_stat = (mean1 - mean2) / se;
let numerator = (var1_over_n1 + var2_over_n2).powi(2);
let denominator =
(var1_over_n1.powi(2) / (n1 - F::one())) + (var2_over_n2.powi(2) / (n2 - F::one()));
df = numerator / denominator;
test_description = "Welch's t-test".to_string();
}
let t_dist = distributions::t(df, F::zero(), F::one())?;
let p_value = match alternative {
Alternative::TwoSided => {
let abs_t = t_stat.abs();
F::from(2.0).expect("Failed to convert constant to float")
* (F::one() - t_dist.cdf(abs_t))
}
Alternative::Less => t_dist.cdf(t_stat),
Alternative::Greater => F::one() - t_dist.cdf(t_stat),
};
let info = format!(
"{}: mean1={}, mean2={}, std1={}, std2={}, n1={}, n2={}",
test_description, mean1, mean2, std1, std2, nobs1, nobs2
);
Ok(TTestResult {
statistic: t_stat,
pvalue: p_value,
df,
alternative,
info: Some(info),
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
fn array_with_nan<F: Float + Copy>() -> Array1<F> {
let mut data = array![
F::from(5.1).expect("Failed to convert constant to float"),
F::from(4.9).expect("Failed to convert constant to float"),
F::from(6.2).expect("Failed to convert constant to float"),
F::from(5.7).expect("Failed to convert constant to float"),
F::from(5.5).expect("Failed to convert constant to float")
];
data[2] = F::nan(); data
}
#[test]
fn test_ttest_1samp() {
let data = array![5.1f64, 4.9, 6.2, 5.7, 5.5, 5.1, 5.2, 5.0];
let null_mean = 5.0;
let result = ttest_1samp(&data.view(), null_mean, Alternative::TwoSided, "omit")
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.183, epsilon = 0.1);
assert!(result.pvalue < 1.0);
let result = ttest_1samp(&data.view(), null_mean, Alternative::Greater, "omit")
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.183, epsilon = 0.1);
assert!(result.pvalue < 1.0);
let result = ttest_1samp(&data.view(), null_mean, Alternative::Less, "omit")
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.183, epsilon = 0.1);
assert!(result.pvalue > 0.5); }
#[test]
fn test_ttest_1samp_nan_handling() {
let data = array_with_nan::<f64>();
let null_mean = 5.0;
let result = ttest_1samp(&data.view(), null_mean, Alternative::TwoSided, "omit")
.expect("Operation failed");
assert!(!result.statistic.is_nan());
assert!(!result.pvalue.is_nan());
let result = ttest_1samp(&data.view(), null_mean, Alternative::TwoSided, "raise");
assert!(result.is_err());
let result = ttest_1samp(&data.view(), null_mean, Alternative::TwoSided, "propagate")
.expect("Operation failed");
assert!(result.statistic.is_nan() || result.pvalue.is_nan());
}
#[test]
fn test_ttest_ind() {
let group1 = array![5.1f64, 4.9, 6.2, 5.7, 5.5];
let group2 = array![4.8f64, 5.2, 5.1, 4.7, 4.9];
let result = ttest_ind(
&group1.view(),
&group2.view(),
true,
Alternative::TwoSided,
"omit",
)
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.186, epsilon = 0.2);
assert!(result.pvalue < 1.0);
let result = ttest_ind(
&group1.view(),
&group2.view(),
false,
Alternative::TwoSided,
"omit",
)
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.186, epsilon = 0.5);
assert!(result.pvalue < 1.0);
let result = ttest_ind(
&group1.view(),
&group2.view(),
true,
Alternative::Greater,
"omit",
)
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.186, epsilon = 0.5);
assert!(result.pvalue < 1.0);
}
#[test]
fn test_ttest_ind_nan_handling() {
let group1 = array_with_nan::<f64>();
let group2 = array![4.8f64, 5.2, 5.1, 4.7, 4.9];
let result = ttest_ind(
&group1.view(),
&group2.view(),
true,
Alternative::TwoSided,
"omit",
)
.expect("Operation failed");
assert!(!result.statistic.is_nan());
assert!(!result.pvalue.is_nan());
let result = ttest_ind(
&group1.view(),
&group2.view(),
true,
Alternative::TwoSided,
"raise",
);
assert!(result.is_err());
}
#[test]
fn test_ttest_rel() {
let before = array![68.5f64, 70.2, 65.3, 72.1, 69.8];
let after = array![67.2f64, 68.5, 66.1, 70.3, 68.7];
let result = ttest_rel(&before.view(), &after.view(), Alternative::TwoSided, "omit")
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.5, epsilon = 0.5);
assert!(result.pvalue < 0.5 && result.pvalue > 0.01);
let result = ttest_rel(&before.view(), &after.view(), Alternative::Greater, "omit")
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.5, epsilon = 0.5);
assert!(result.pvalue < 0.25);
let result = ttest_rel(&before.view(), &after.view(), Alternative::Less, "omit")
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.5, epsilon = 0.5);
assert!(result.pvalue > 0.5); }
#[test]
fn test_ttest_ind_from_stats() {
let mean1 = 5.48f64;
let std1 = 0.49f64;
let n1 = 5;
let mean2 = 4.94f64;
let std2 = 0.21f64;
let n2 = 5;
let result = ttest_ind_from_stats(
mean1,
std1,
n1,
mean2,
std2,
n2,
true,
Alternative::TwoSided,
)
.expect("Operation failed");
assert_relative_eq!(result.statistic, 2.3, epsilon = 0.3);
assert!(result.pvalue < 1.0);
let result = ttest_ind_from_stats(
mean1,
-1.0,
n1,
mean2,
std2,
n2,
true,
Alternative::TwoSided,
);
assert!(result.is_err());
let result =
ttest_ind_from_stats(mean1, std1, 0, mean2, std2, n2, true, Alternative::TwoSided);
assert!(result.is_err());
}
}