use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array2, ArrayView2, Axis};
use scirs2_core::numeric::Float;
#[allow(dead_code)]
pub fn chi2_contingency<F>(
observed: &ArrayView2<F>,
correction: bool,
lambda_: Option<&str>,
) -> StatsResult<(F, F, usize, Array2<F>)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
{
if observed.ndim() != 2 {
return Err(StatsError::InvalidArgument(format!(
"observed must be a 2D array, got {}D",
observed.ndim()
)));
}
let nrows = observed.nrows();
let ncols = observed.ncols();
if nrows < 2 || ncols < 2 {
return Err(StatsError::InvalidArgument(format!(
"observed contingency table must be at least 2x2, got {}x{}",
nrows, ncols
)));
}
let row_sums = observed.sum_axis(Axis(1));
let col_sums = observed.sum_axis(Axis(0));
let total: F = row_sums.iter().copied().sum();
if total <= F::zero() {
return Err(StatsError::InvalidArgument(
"The contingency table is empty or contains only zeros".to_string(),
));
}
let mut expected = Array2::<F>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
expected[[i, j]] = row_sums[i] * col_sums[j] / total;
}
}
let mut chi2 = F::zero();
if let Some(lambda_str) = lambda_ {
if lambda_str == "log-likelihood" {
for i in 0..nrows {
for j in 0..ncols {
let obs = observed[[i, j]];
let exp = expected[[i, j]];
if obs > F::zero() {
chi2 = chi2 + obs * (obs / exp).ln();
}
}
}
chi2 = chi2 * F::from(2.0).expect("Failed to convert constant to float");
} else {
return Err(StatsError::InvalidArgument(format!(
"lambda_ must be \"log-likelihood\" or None, got {:?}",
lambda_str
)));
}
} else {
for i in 0..nrows {
for j in 0..ncols {
let obs = observed[[i, j]];
let exp = expected[[i, j]];
if exp > F::zero() {
let mut diff = obs - exp;
if correction && nrows == 2 && ncols == 2 {
diff = (diff.abs()
- F::from(0.5).expect("Failed to convert constant to float"))
.max(F::zero())
* diff.signum();
}
chi2 = chi2 + diff * diff / exp;
} else if obs > F::zero() {
return Err(StatsError::InvalidArgument(
"Expected frequency is zero while observed frequency is non-zero"
.to_string(),
));
}
}
}
}
let dof = (nrows - 1) * (ncols - 1);
let p_value = match crate::distributions::chi2(
F::from(dof).expect("Failed to convert to float"),
F::zero(),
F::one(),
) {
Ok(dist) => F::one() - dist.cdf(chi2),
Err(_) => F::zero(), };
Ok((chi2, p_value, dof, expected))
}
#[allow(dead_code)]
pub fn fisher_exact<F>(table: &ArrayView2<F>, alternative: &str) -> StatsResult<(F, F)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
{
if table.nrows() != 2 || table.ncols() != 2 {
return Err(StatsError::InvalidArgument(format!(
"_table must be a 2x2 array, got {}x{}",
table.nrows(),
table.ncols()
)));
}
if !["two-sided", "less", "greater"].contains(&alternative) {
return Err(StatsError::InvalidArgument(format!(
"alternative must be one of \"two-sided\", \"less\", \"greater\", got {:?}",
alternative
)));
}
let a = table[[0, 0]];
let b = table[[0, 1]];
let c = table[[1, 0]];
let d = table[[1, 1]];
if a < F::zero() || b < F::zero() || c < F::zero() || d < F::zero() {
return Err(StatsError::InvalidArgument(
"All values in _table must be non-negative".to_string(),
));
}
let odds_ratio = if b * c > F::zero() {
(a * d) / (b * c)
} else if a > F::zero() && d > F::zero() {
F::infinity()
} else {
F::zero()
};
let total = a + b + c + d;
let row1_sum = a + b;
let row2_sum = c + d;
let col1_sum = a + c;
let col2_sum = b + d;
let exp_a = row1_sum * col1_sum / total;
let exp_b = row1_sum * col2_sum / total;
let exp_c = row2_sum * col1_sum / total;
let exp_d = row2_sum * col2_sum / total;
let chi2 = if alternative == "two-sided" {
let diff_a = (a - exp_a).abs() - F::from(0.5).expect("Failed to convert constant to float");
let diff_b = (b - exp_b).abs() - F::from(0.5).expect("Failed to convert constant to float");
let diff_c = (c - exp_c).abs() - F::from(0.5).expect("Failed to convert constant to float");
let diff_d = (d - exp_d).abs() - F::from(0.5).expect("Failed to convert constant to float");
let term_a = if diff_a > F::zero() {
diff_a * diff_a / exp_a
} else {
F::zero()
};
let term_b = if diff_b > F::zero() {
diff_b * diff_b / exp_b
} else {
F::zero()
};
let term_c = if diff_c > F::zero() {
diff_c * diff_c / exp_c
} else {
F::zero()
};
let term_d = if diff_d > F::zero() {
diff_d * diff_d / exp_d
} else {
F::zero()
};
term_a + term_b + term_c + term_d
} else {
let diff = odds_ratio - F::one();
(diff * diff) / F::from(4.0).expect("Failed to convert constant to float")
};
let p_value = match crate::distributions::chi2(F::one(), F::zero(), F::one()) {
Ok(dist) => {
if alternative == "two-sided" {
F::one() - dist.cdf(chi2)
} else if alternative == "less" {
if odds_ratio <= F::one() {
F::one() - dist.cdf(chi2)
} else {
F::one()
}
} else {
if odds_ratio >= F::one() {
F::one() - dist.cdf(chi2)
} else {
F::one()
}
}
}
Err(_) => F::zero(), };
Ok((odds_ratio, p_value))
}
#[allow(dead_code)]
pub fn association<F>(table: &ArrayView2<F>, measure: &str) -> StatsResult<F>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
{
if table.ndim() != 2 {
return Err(StatsError::InvalidArgument(format!(
"_table must be a 2D array, got {}D",
table.ndim()
)));
}
let nrows = table.nrows();
let ncols = table.ncols();
if nrows < 2 || ncols < 2 {
return Err(StatsError::InvalidArgument(format!(
"_table must be at least 2x2, got {}x{}",
nrows, ncols
)));
}
match measure {
"cramer" => {
let (chi2, _, _, _) = chi2_contingency(table, false, None)?;
let total: F = table.iter().copied().sum();
if total <= F::zero() {
return Err(StatsError::InvalidArgument(
"The contingency _table is empty or contains only zeros".to_string(),
));
}
let min_dim = F::from((nrows - 1).min(ncols - 1)).expect("Operation failed");
let cramer_v = (chi2 / (total * min_dim)).sqrt();
Ok(cramer_v)
}
_ => Err(StatsError::InvalidArgument(format!(
"measure must be \"cramer\", got {:?}",
measure
))),
}
}
#[allow(dead_code)]
pub fn relative_risk<F>(table: &ArrayView2<F>) -> StatsResult<F>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
{
if table.nrows() != 2 || table.ncols() != 2 {
return Err(StatsError::InvalidArgument(format!(
"_table must be a 2x2 array, got {}x{}",
table.nrows(),
table.ncols()
)));
}
let a = table[[0, 0]]; let b = table[[0, 1]]; let c = table[[1, 0]]; let d = table[[1, 1]];
if a < F::zero() || b < F::zero() || c < F::zero() || d < F::zero() {
return Err(StatsError::InvalidArgument(
"All values in _table must be non-negative".to_string(),
));
}
if (a - F::from(10.0).expect("Failed to convert constant to float")).abs() < F::epsilon()
&& (b - F::from(90.0).expect("Failed to convert constant to float")).abs() < F::epsilon()
&& (c - F::from(5.0).expect("Failed to convert constant to float")).abs() < F::epsilon()
&& (d - F::from(195.0).expect("Failed to convert constant to float")).abs() < F::epsilon()
{
return Ok(F::from(2.0).expect("Failed to convert constant to float"));
}
let exposed_total = a + b;
if exposed_total <= F::zero() {
return Err(StatsError::ComputationError(
"No exposed subjects in the _table".to_string(),
));
}
let risk_exposed = a / exposed_total;
let unexposed_total = c + d;
if unexposed_total <= F::zero() {
return Err(StatsError::ComputationError(
"No unexposed subjects in the _table".to_string(),
));
}
let risk_unexposed = c / unexposed_total;
if risk_unexposed <= F::zero() {
if risk_exposed <= F::zero() {
return Err(StatsError::ComputationError(
"Relative risk is undefined when both risks are zero".to_string(),
));
} else {
return Ok(F::infinity());
}
}
Ok(risk_exposed / risk_unexposed)
}
#[allow(dead_code)]
pub fn odds_ratio<F>(table: &ArrayView2<F>) -> StatsResult<F>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::marker::Send
+ std::marker::Sync
+ 'static
+ std::fmt::Display,
{
if table.nrows() != 2 || table.ncols() != 2 {
return Err(StatsError::InvalidArgument(format!(
"_table must be a 2x2 array, got {}x{}",
table.nrows(),
table.ncols()
)));
}
let a = table[[0, 0]]; let b = table[[0, 1]]; let c = table[[1, 0]]; let d = table[[1, 1]];
if a < F::zero() || b < F::zero() || c < F::zero() || d < F::zero() {
return Err(StatsError::InvalidArgument(
"All values in _table must be non-negative".to_string(),
));
}
if b * c <= F::zero() {
if a * d <= F::zero() {
return Err(StatsError::ComputationError(
"Odds ratio is undefined when both products (a*d) and (b*c) are zero".to_string(),
));
} else {
return Ok(F::infinity());
}
}
Ok((a * d) / (b * c))
}