use crate::distributions::chi2;
use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, NumCast, PrimInt};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct ChiSquareResult<F> {
pub statistic: F,
pub p_value: F,
pub df: usize,
pub expected: Array2<F>,
}
#[allow(dead_code)]
pub fn chi2_gof<F, I>(
observed: &ArrayView1<I>,
expected: Option<ArrayView1<F>>,
) -> StatsResult<ChiSquareResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
I: PrimInt + NumCast + std::fmt::Display,
{
if observed.is_empty() {
return Err(StatsError::InvalidArgument(
"Observed frequencies cannot be empty".to_string(),
));
}
let mut obs_float = Array1::<F>::zeros(observed.len());
for (i, &val) in observed.iter().enumerate() {
obs_float[i] = F::from(val).expect("Failed to convert to float");
}
let exp_float = match expected {
Some(exp) => {
if exp.len() != observed.len() {
return Err(StatsError::DimensionMismatch(
"Observed and expected frequencies must have the same dimensions".to_string(),
));
}
exp.to_owned()
}
None => {
let total_obs = obs_float.sum();
let uniform_exp = total_obs / F::from(observed.len()).expect("Test: operation failed");
Array1::<F>::from_elem(observed.len(), uniform_exp)
}
};
for &val in exp_float.iter() {
if val <= F::zero() {
return Err(StatsError::InvalidArgument(
"Expected frequencies must be positive".to_string(),
));
}
}
let mut chi2_stat = F::zero();
for (obs, exp) in obs_float.iter().zip(exp_float.iter()) {
chi2_stat = chi2_stat + (*obs - *exp).powi(2) / *exp;
}
let df = observed.len() - 1;
let chi2_dist = chi2(
F::from(df).expect("Failed to convert to float"),
F::zero(),
F::one(),
)?;
let p_value = F::one() - chi2_dist.cdf(chi2_stat);
Ok(ChiSquareResult {
statistic: chi2_stat,
p_value,
df,
expected: Array2::from_shape_vec((exp_float.len(), 1), exp_float.to_vec())
.map_err(|_| StatsError::ComputationError("Failed to reshape array".to_string()))?,
})
}
#[allow(dead_code)]
pub fn chi2_independence<F, I>(observed: &ArrayView2<I>) -> StatsResult<ChiSquareResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
I: PrimInt + NumCast + std::fmt::Display,
{
if observed.is_empty() {
return Err(StatsError::InvalidArgument(
"Observed frequencies cannot be empty".to_string(),
));
}
let rows = observed.shape()[0];
let cols = observed.shape()[1];
if rows < 2 || cols < 2 {
return Err(StatsError::InvalidArgument(
"Contingency table must have at least 2 rows and 2 columns".to_string(),
));
}
let mut obs_float = Array2::<F>::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
obs_float[(i, j)] = F::from(observed[(i, j)]).expect("Test: operation failed");
}
}
let row_sums = obs_float.sum_axis(Axis(1));
let col_sums = obs_float.sum_axis(Axis(0));
let total = obs_float.sum();
let mut expected = Array2::<F>::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
expected[(i, j)] = row_sums[i] * col_sums[j] / total;
}
}
for val in expected.iter() {
if *val <= F::zero() {
return Err(StatsError::InvalidArgument(
"Expected frequencies must be positive".to_string(),
));
}
}
let mut chi2_stat = F::zero();
for i in 0..rows {
for j in 0..cols {
let obs = obs_float[(i, j)];
let exp = expected[(i, j)];
chi2_stat = chi2_stat + (obs - exp).powi(2) / exp;
}
}
let df = (rows - 1) * (cols - 1);
let chi2_dist = chi2(
F::from(df).expect("Failed to convert to float"),
F::zero(),
F::one(),
)?;
let p_value = F::one() - chi2_dist.cdf(chi2_stat);
Ok(ChiSquareResult {
statistic: chi2_stat,
p_value,
df,
expected,
})
}
#[allow(dead_code)]
pub fn chi2_yates<F, I>(observed: &ArrayView2<I>) -> StatsResult<ChiSquareResult<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
I: PrimInt + NumCast + std::fmt::Display,
{
let rows = observed.shape()[0];
let cols = observed.shape()[1];
if rows != 2 || cols != 2 {
return Err(StatsError::InvalidArgument(
"Yates' correction requires a 2x2 contingency table".to_string(),
));
}
let mut obs_float = Array2::<F>::zeros((2, 2));
for i in 0..2 {
for j in 0..2 {
obs_float[(i, j)] = F::from(observed[(i, j)]).expect("Test: operation failed");
}
}
let row_sums = obs_float.sum_axis(Axis(1));
let col_sums = obs_float.sum_axis(Axis(0));
let total = obs_float.sum();
let mut expected = Array2::<F>::zeros((2, 2));
for i in 0..2 {
for j in 0..2 {
expected[(i, j)] = row_sums[i] * col_sums[j] / total;
}
}
for val in expected.iter() {
if *val <= F::zero() {
return Err(StatsError::InvalidArgument(
"Expected frequencies must be positive".to_string(),
));
}
}
let mut chi2_stat = F::zero();
for i in 0..2 {
for j in 0..2 {
let obs = obs_float[(i, j)];
let exp = expected[(i, j)];
let diff =
(obs - exp).abs() - F::from(0.5).expect("Failed to convert constant to float");
let diff_squared = if diff > F::zero() {
diff.powi(2)
} else {
F::zero()
};
chi2_stat = chi2_stat + diff_squared / exp;
}
}
let df = 1;
let chi2_dist = chi2(
F::from(df).expect("Failed to convert to float"),
F::zero(),
F::one(),
)?;
let p_value = F::one() - chi2_dist.cdf(chi2_stat);
Ok(ChiSquareResult {
statistic: chi2_stat,
p_value,
df,
expected,
})
}