use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::ArrayView1;
use scirs2_core::numeric::{Float, NumCast};
use std::cmp::Ordering;
#[allow(dead_code)]
pub fn wilcoxon<F>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
zero_method: &str,
correction: bool,
) -> StatsResult<(F, F)>
where
F: Float + std::iter::Sum<F> + std::ops::Div<Output = F> + NumCast + std::fmt::Display,
{
if x.is_empty() || y.is_empty() {
return Err(StatsError::InvalidArgument(
"Input arrays cannot be empty".to_string(),
));
}
if x.len() != y.len() {
return Err(StatsError::DimensionMismatch(
"Input arrays must have the same length for paired test".to_string(),
));
}
let mut differences = Vec::with_capacity(x.len());
for i in 0..x.len() {
differences.push((x[i] - y[i], i));
}
let differences = match zero_method {
"wilcox" => {
differences
.into_iter()
.filter(|(diff_, _)| !diff_.is_zero())
.collect::<Vec<_>>()
}
"pratt" => {
differences
}
"zsplit" => {
return Err(StatsError::InvalidArgument(
"zsplit _method not implemented yet".to_string(),
));
}
_ => {
return Err(StatsError::InvalidArgument(format!(
"Unknown zero_method: {}. Use 'wilcox', 'pratt', or 'zsplit'",
zero_method
)))
}
};
if differences.is_empty() {
return Err(StatsError::InvalidArgument(
"All differences are zero, cannot perform Wilcoxon test".to_string(),
));
}
let mut ranked_diffs = differences;
ranked_diffs.sort_by(|(a, _), (b, _)| a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal));
let mut ranks = vec![F::zero(); ranked_diffs.len()];
let mut i = 0;
while i < ranked_diffs.len() {
let current_abs_diff = ranked_diffs[i].0.abs();
let mut j = i;
while j < ranked_diffs.len() - 1 && ranked_diffs[j + 1].0.abs() == current_abs_diff {
j += 1;
}
let avg_rank = F::from(i + j).expect("Failed to convert to float")
/ F::from(2.0).expect("Failed to convert constant to float")
+ F::one();
for (idx, rank) in ranks.iter_mut().enumerate().take(j + 1).skip(i) {
let _original_idx = ranked_diffs[idx].1;
*rank = avg_rank;
}
i = j + 1;
}
let mut w_plus = F::zero();
let mut w_minus = F::zero();
for (i, (diff_, _)) in ranked_diffs.iter().enumerate() {
if diff_.is_sign_positive() {
w_plus = w_plus + ranks[i];
} else {
w_minus = w_minus + ranks[i];
}
}
let w = if w_plus < w_minus { w_plus } else { w_minus };
let n = F::from(ranked_diffs.len()).expect("Operation failed");
let w_mean = n * (n + F::one()) / F::from(4.0).expect("Failed to convert constant to float");
let w_sd = (n
* (n + F::one())
* (F::from(2.0).expect("Failed to convert constant to float") * n + F::one())
/ F::from(24.0).expect("Failed to convert constant to float"))
.sqrt();
let correction_factor = if correction {
F::from(0.5).expect("Failed to convert constant to float")
} else {
F::zero()
};
let z = (w - w_mean + correction_factor) / w_sd;
let p_value = F::from(2.0).expect("Failed to convert constant to float") * normal_cdf(-z.abs());
Ok((w, p_value))
}
#[allow(dead_code)]
pub fn mann_whitney<F>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
alternative: &str,
use_continuity: bool,
) -> StatsResult<(F, F)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::fmt::Debug
+ std::fmt::Display,
{
if x.is_empty() || y.is_empty() {
return Err(StatsError::InvalidArgument(
"Input arrays cannot be empty".to_string(),
));
}
match alternative {
"two-sided" | "less" | "greater" => {}
_ => {
return Err(StatsError::InvalidArgument(format!(
"Unknown alternative: {}. Use 'two-sided', 'less', or 'greater'",
alternative
)))
}
}
let n1 = x.len();
let n2 = y.len();
let mut all_values = Vec::with_capacity(n1 + n2);
for &value in x.iter() {
all_values.push((value, 0));
}
for &value in y.iter() {
all_values.push((value, 1));
}
all_values.sort_by(|a_, b_| a_.partial_cmp(b_).unwrap_or(Ordering::Equal));
let n = all_values.len();
let mut ranks = vec![F::zero(); n];
let mut i = 0;
while i < n {
let current_value = all_values[i].0;
let mut j = i;
while j < n - 1 && all_values[j + 1].0 == current_value {
j += 1;
}
let avg_rank = F::from(i + j).expect("Failed to convert to float")
/ F::from(2.0).expect("Failed to convert constant to float")
+ F::one();
for rank in ranks.iter_mut().take(j + 1).skip(i) {
*rank = avg_rank;
}
i = j + 1;
}
let mut rank_sum_x = F::zero();
for i in 0..n {
if all_values[i].1 == 0 {
rank_sum_x = rank_sum_x + ranks[i];
}
}
let n1_f = F::from(n1).expect("Failed to convert to float");
let n2_f = F::from(n2).expect("Failed to convert to float");
let u1 = rank_sum_x
- (n1_f * (n1_f + F::one())) / F::from(2.0).expect("Failed to convert constant to float");
let u2 = n1_f * n2_f - u1;
let u = u1.min(u2);
let mean_u = n1_f * n2_f / F::from(2.0).expect("Failed to convert constant to float");
let mut tie_correction = F::zero();
if i < n {
let mut i = 0;
while i < n {
let current_value = all_values[i].0;
let mut j = i;
while j < n - 1 && all_values[j + 1].0 == current_value {
j += 1;
}
if j > i {
let t = F::from(j - i + 1).expect("Failed to convert to float");
tie_correction = tie_correction + (t.powi(3) - t);
}
i = j + 1;
}
}
let n_f = F::from(n).expect("Failed to convert to float");
let tie_correction = tie_correction / (n_f.powi(3) - n_f);
let var_u = (n1_f * n2_f * (n_f + F::one())
/ F::from(12.0).expect("Failed to convert constant to float"))
* (F::one() - tie_correction);
let std_dev_u = var_u.sqrt();
let correction = if use_continuity {
F::from(0.5).expect("Failed to convert constant to float")
} else {
F::zero()
};
let z = match alternative {
"less" => {
if u == u1 {
(u + correction - mean_u) / std_dev_u
} else {
(u - correction - mean_u) / std_dev_u
}
}
"greater" => {
if u == u1 {
(u - correction - mean_u) / std_dev_u
} else {
(u + correction - mean_u) / std_dev_u
}
}
_ => {
(u.abs() - correction - mean_u.abs()) / std_dev_u
}
};
let p_value = match alternative {
"less" => normal_cdf(z),
"greater" => F::one() - normal_cdf(z),
_ => F::from(2.0).expect("Failed to convert constant to float") * normal_cdf(-z.abs()),
};
Ok((u, p_value))
}
#[allow(dead_code)]
pub fn kruskal_wallis<F>(samples: &[ArrayView1<F>]) -> StatsResult<(F, F)>
where
F: Float + std::iter::Sum<F> + std::ops::Div<Output = F> + NumCast + std::fmt::Display,
{
if samples.len() < 2 {
return Err(StatsError::InvalidArgument(
"At least two samples are required for Kruskal-Wallis test".to_string(),
));
}
for (i, sample) in samples.iter().enumerate() {
if sample.is_empty() {
return Err(StatsError::InvalidArgument(format!(
"Sample {} is empty",
i
)));
}
}
let mut all_values = Vec::new();
let mut groupsizes = Vec::with_capacity(samples.len());
for (group_idx, sample) in samples.iter().enumerate() {
groupsizes.push(sample.len());
for &value in sample.iter() {
all_values.push((value, group_idx));
}
}
all_values.sort_by(|a_, b_| a_.partial_cmp(b_).unwrap_or(Ordering::Equal));
let n = all_values.len();
let mut ranks = vec![F::zero(); n];
let mut i = 0;
while i < n {
let current_value = all_values[i].0;
let mut j = i;
while j < n - 1 && all_values[j + 1].0 == current_value {
j += 1;
}
let avg_rank = F::from(i + j).expect("Failed to convert to float")
/ F::from(2.0).expect("Failed to convert constant to float")
+ F::one();
for rank in ranks.iter_mut().take(j + 1).skip(i) {
*rank = avg_rank;
}
i = j + 1;
}
let mut rank_sums = vec![F::zero(); samples.len()];
for i in 0..n {
let group = all_values[i].1;
rank_sums[group] = rank_sums[group] + ranks[i];
}
let n_f = F::from(n).expect("Failed to convert to float");
let mut h = F::zero();
for (i, &rank_sum) in rank_sums.iter().enumerate() {
let n_i = F::from(groupsizes[i]).expect("Failed to convert to float");
h = h + (rank_sum * rank_sum) / n_i;
}
h = (F::from(12.0).expect("Failed to convert constant to float") / (n_f * (n_f + F::one())))
* h
- F::from(3.0).expect("Failed to convert constant to float") * (n_f + F::one());
let mut tie_correction = F::one();
let mut i = 0;
while i < n {
let current_value = all_values[i].0;
let mut j = i;
while j < n - 1 && all_values[j + 1].0 == current_value {
j += 1;
}
if j > i {
let t = F::from(j - i + 1).expect("Failed to convert to float");
tie_correction = tie_correction - (t.powi(3) - t) / (n_f.powi(3) - n_f);
}
i = j + 1;
}
if tie_correction < F::one() {
h = h / tie_correction;
}
let df = F::from(samples.len() - 1).expect("Operation failed");
let p_value = chi_square_sf(h, df);
Ok((h, p_value))
}
#[allow(dead_code)]
pub fn friedman<F>(data: &scirs2_core::ndarray::ArrayView2<F>) -> StatsResult<(F, F)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::ops::AddAssign
+ std::fmt::Display,
{
let n = data.nrows();
let k = data.ncols();
if n < 2 || k < 2 {
return Err(StatsError::InvalidArgument(
"At least 2 subjects and 2 treatments are required for Friedman test".to_string(),
));
}
let mut ranks = scirs2_core::ndarray::Array2::<F>::zeros((n, k));
for i in 0..n {
let row = data.row(i);
let mut rowdata = Vec::with_capacity(k);
for j in 0..k {
rowdata.push((row[j], j));
}
rowdata.sort_by(|a_, b_| a_.partial_cmp(b_).unwrap_or(Ordering::Equal));
let mut rank_idx = 0;
while rank_idx < k {
let current_value = rowdata[rank_idx].0;
let mut tied_idx = rank_idx;
while tied_idx < k - 1 && rowdata[tied_idx + 1].0 == current_value {
tied_idx += 1;
}
let _tie_count = tied_idx - rank_idx + 1;
let avg_rank = F::from(rank_idx + tied_idx).expect("Failed to convert to float")
/ F::from(2.0).expect("Failed to convert constant to float")
+ F::one();
for data_item in rowdata.iter().take(tied_idx + 1).skip(rank_idx) {
let col_idx = data_item.1;
ranks[[i, col_idx]] = avg_rank;
}
rank_idx = tied_idx + 1;
}
}
let mut rank_sums = vec![F::zero(); k];
for j in 0..k {
let mut col_sum = F::zero();
for i in 0..n {
col_sum += ranks[[i, j]];
}
rank_sums[j] = col_sum;
}
let n_f = F::from(n).expect("Failed to convert to float");
let k_f = F::from(k).expect("Failed to convert to float");
let mut sum_ranks_squared = F::zero();
for &rank_sum in &rank_sums {
sum_ranks_squared += rank_sum.powi(2);
}
let chi2 = (F::from(12.0).expect("Failed to convert constant to float")
/ (n_f * k_f * (k_f + F::one())))
* sum_ranks_squared
- F::from(3.0).expect("Failed to convert constant to float") * n_f * (k_f + F::one());
let df = k_f - F::one();
let p_value = chi_square_sf(chi2, df);
Ok((chi2, p_value))
}
#[allow(dead_code)]
fn normal_cdf<F: Float + NumCast>(x: F) -> F {
let x_f64 = <f64 as NumCast>::from(x).expect("Operation failed");
let cdf = if x_f64 < -8.0 {
0.0
} else if x_f64 > 8.0 {
1.0
} else {
let abs_x = x_f64.abs();
let t = 1.0 / (1.0 + 0.2316419 * abs_x);
let d = 0.3989423 * (-0.5 * x_f64 * x_f64).exp();
let p = t
* (0.319381530
+ t * (-0.356563782 + t * (1.781477937 + t * (-1.821255978 + t * 1.330274429))));
if x_f64 >= 0.0 {
1.0 - d * p
} else {
d * p
}
};
F::from(cdf).expect("Failed to convert to float")
}
#[allow(dead_code)]
fn chi_square_sf<F: Float + NumCast>(x: F, df: F) -> F {
let x_f64 = <f64 as NumCast>::from(x).expect("Operation failed");
let df_f64 = <f64 as NumCast>::from(df).expect("Operation failed");
if x_f64 <= 0.0 {
return F::one();
}
if df_f64 <= 0.0 {
return F::zero();
}
let z;
if df_f64 > 100.0 {
z = (x_f64 / df_f64).powf(1.0 / 3.0)
- (1.0 - 2.0 / (9.0 * df_f64)) / (1.0 / (3.0 * df_f64).sqrt());
} else if df_f64 > 1.0 {
z = (x_f64 / df_f64 - 1.0) * (0.5 * df_f64).sqrt();
} else {
z = (x_f64 * 0.5).sqrt();
}
let p = 1.0 - normal_cdf::<f64>(z);
F::from(p).expect("Failed to convert to float")
}