use scirs2_core::ndarray::ArrayStatCompat;
use scirs2_core::ndarray::{Array, Array1, Array2, ArrayBase, Data, Ix1, Ix2};
use scirs2_core::numeric::Float;
use scirs2_core::random::{random, rngs::StdRng, Rng, RngExt, SeedableRng};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::panic;
use crate::error::{MetricsError, Result};
use statrs::statistics::Statistics;
#[allow(dead_code)]
pub fn mcnemars_test<T>(
table: &ArrayBase<impl Data<Elem = T>, Ix2>,
correction: bool,
) -> Result<f64>
where
T: Float + std::fmt::Display,
{
if table.shape() != [2, 2] {
return Err(MetricsError::InvalidInput(format!(
"Table must be a 2x2 array, got {:?}",
table.shape()
)));
}
let b = table
.get((0, 1))
.ok_or_else(|| {
MetricsError::InvalidInput("Index [0,1] out of bounds in contingency table".to_string())
})?
.to_f64()
.ok_or_else(|| {
MetricsError::InvalidInput("Could not convert table value to f64".to_string())
})?;
let c = table
.get((1, 0))
.ok_or_else(|| {
MetricsError::InvalidInput("Index [1,0] out of bounds in contingency table".to_string())
})?
.to_f64()
.ok_or_else(|| {
MetricsError::InvalidInput("Could not convert table value to f64".to_string())
})?;
if b + c == 0.0 {
return Ok(1.0);
}
let statistic = if correction {
(b - c).abs() - 1.0
} else {
(b - c).abs()
};
let statistic = statistic.max(0.0);
let statistic = statistic.powi(2) / (b + c);
let p_value = 1.0 - chi2_cdf(statistic, 1);
Ok(p_value)
}
#[allow(dead_code)]
pub fn cochrans_q_test<T>(
binary_predictions: &ArrayBase<impl Data<Elem = T>, Ix2>,
) -> Result<(f64, f64)>
where
T: Float + std::fmt::Display,
{
let shape = binary_predictions.shape();
if shape.len() != 2 {
return Err(MetricsError::InvalidInput(
"binary_predictions must be a 2D array".to_string(),
));
}
let k = shape[0]; let n = shape[1];
if k < 2 {
return Err(MetricsError::InvalidInput(
"At least two models are required for Cochran's Q test".to_string(),
));
}
if n < 1 {
return Err(MetricsError::InvalidInput(
"At least one sample is required for Cochran's Q test".to_string(),
));
}
for value in binary_predictions.iter() {
let value_f64 = value.to_f64().ok_or_else(|| {
MetricsError::InvalidInput("Could not convert value to f64".to_string())
})?;
if value_f64 != 0.0 && value_f64 != 1.0 {
return Err(MetricsError::InvalidInput(
"binary_predictions must contain only 0 and 1 values".to_string(),
));
}
}
let mut column_totals = vec![0.0; n];
for j in 0..n {
for i in 0..k {
column_totals[j] += binary_predictions[[i, j]].to_f64().ok_or_else(|| {
MetricsError::InvalidInput(
"Could not convert binary prediction value to f64".to_string(),
)
})?;
}
}
let mut row_totals = vec![0.0; k];
for i in 0..k {
for j in 0..n {
row_totals[i] += binary_predictions[[i, j]].to_f64().ok_or_else(|| {
MetricsError::InvalidInput(
"Could not convert binary prediction value to f64".to_string(),
)
})?;
}
}
let total: f64 = row_totals.iter().sum();
let k_f64 = k as f64;
let row_totals_squared_sum: f64 = row_totals.iter().map(|&x| x.powi(2)).sum();
let column_totals_squared_sum: f64 = column_totals.iter().map(|&x| x.powi(2)).sum();
let numerator = (k_f64 - 1.0) * (k_f64 * column_totals_squared_sum - total.powi(2));
let denominator = k_f64 * total - row_totals_squared_sum;
let q_statistic = if denominator > 0.0 {
numerator / denominator
} else {
0.0
};
let p_value = 1.0 - chi2_cdf(q_statistic, k - 1);
Ok((q_statistic, p_value))
}
#[allow(dead_code)]
pub fn friedman_test<T>(
performance_metrics: &ArrayBase<impl Data<Elem = T>, Ix2>,
) -> Result<(f64, f64)>
where
T: Float + std::fmt::Display + PartialOrd,
{
let shape = performance_metrics.shape();
if shape.len() != 2 {
return Err(MetricsError::InvalidInput(
"performance_metrics must be a 2D array".to_string(),
));
}
let n = shape[0]; let k = shape[1];
if n < 2 {
return Err(MetricsError::InvalidInput(
"At least two datasets are required for Friedman test".to_string(),
));
}
if k < 2 {
return Err(MetricsError::InvalidInput(
"At least two models are required for Friedman test".to_string(),
));
}
let mut ranks = Array2::<f64>::zeros((n, k));
for i in 0..n {
let mut values_with_indices = Vec::with_capacity(k);
for j in 0..k {
let val = performance_metrics[[i, j]].to_f64().ok_or_else(|| {
MetricsError::InvalidInput(
"Could not convert performance metric value to f64".to_string(),
)
})?;
values_with_indices.push((j, val));
}
values_with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
let mut rank = 1.0;
let mut j = 0;
while j < k {
let current_value = values_with_indices[j].1;
let mut count = 1;
while j + count < k && values_with_indices[j + count].1 == current_value {
count += 1;
}
let average_rank = (rank + rank + count as f64 - 1.0) / 2.0;
for l in 0..count {
let idx = values_with_indices[j + l].0;
ranks[[i, idx]] = average_rank;
}
rank += count as f64;
j += count;
}
}
let mut avg_ranks = vec![0.0; k];
for j in 0..k {
for i in 0..n {
avg_ranks[j] += ranks[[i, j]];
}
avg_ranks[j] /= n as f64;
}
let n_f64 = n as f64;
let k_f64 = k as f64;
let sum_of_squares: f64 = avg_ranks
.iter()
.map(|&r| (r - (k_f64 + 1.0) / 2.0).powi(2))
.sum();
let chi_squared = 12.0 * n_f64 / (k_f64 * (k_f64 + 1.0)) * sum_of_squares;
let ff = (n_f64 - 1.0) * chi_squared / (n_f64 * (k_f64 - 1.0) - chi_squared);
let p_value = 1.0 - f_cdf(ff, k - 1, (k - 1) * (n - 1));
Ok((chi_squared, p_value))
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn wilcoxon_signed_rank_test<T>(
x: &ArrayBase<impl Data<Elem = T>, Ix1>,
y: &ArrayBase<impl Data<Elem = T>, Ix1>,
zero_method: &str,
correction: bool,
) -> Result<(f64, f64)>
where
T: Float + std::fmt::Display + PartialOrd,
{
let n = x.len();
if n != y.len() {
return Err(MetricsError::InvalidInput(
"x and y must have the same length".to_string(),
));
}
if n < 1 {
return Err(MetricsError::InvalidInput(
"At least one sample is required".to_string(),
));
}
if !["wilcox", "pratt", "zsplit"].contains(&zero_method) {
return Err(MetricsError::InvalidInput(format!(
"zero_method must be one of 'wilcox', 'pratt', or 'zsplit', got {}",
zero_method
)));
}
let mut differences = Vec::with_capacity(n);
for i in 0..n {
let x_val = x[i].to_f64().ok_or_else(|| {
MetricsError::InvalidInput("Could not convert x value to f64".to_string())
})?;
let y_val = y[i].to_f64().ok_or_else(|| {
MetricsError::InvalidInput("Could not convert y value to f64".to_string())
})?;
let diff = x_val - y_val;
differences.push(diff);
}
let differences = match zero_method {
"wilcox" => differences.into_iter().filter(|&d| d != 0.0).collect(),
"pratt" => differences,
"zsplit" => {
let non_zero_diffs: Vec<f64> = differences.into_iter().filter(|&d| d != 0.0).collect();
non_zero_diffs
}
_ => {
return Err(MetricsError::InvalidInput(format!(
"Invalid zero_method: {}",
zero_method
)))
}
};
let n_diff = differences.len();
if n_diff == 0 {
return Ok((0.0, 1.0));
}
let mut abs_diffs: Vec<(usize, f64)> = differences
.iter()
.enumerate()
.map(|(i, &d)| (i, d.abs()))
.collect();
abs_diffs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
let mut ranks = vec![0.0; n_diff];
let mut i = 0;
while i < n_diff {
let current_value = abs_diffs[i].1;
let mut count = 1;
while i + count < n_diff && abs_diffs[i + count].1 == current_value {
count += 1;
}
let average_rank = (i + 1 + i + count) as f64 / 2.0;
for j in 0..count {
let idx = abs_diffs[i + j].0;
ranks[idx] = average_rank;
}
i += count;
}
for i in 0..n_diff {
if differences[i] < 0.0 {
ranks[i] = -ranks[i];
}
}
let r_plus: f64 = ranks.iter().filter(|&&r| r > 0.0).sum();
let r_minus: f64 = -ranks.iter().filter(|&&r| r < 0.0).sum::<f64>();
let w = r_plus.min(r_minus);
let n_diff_f64 = n_diff as f64;
let expected = n_diff_f64 * (n_diff_f64 + 1.0) / 4.0;
let mut stdev = (n_diff_f64 * (n_diff_f64 + 1.0) * (2.0 * n_diff_f64 + 1.0) / 24.0).sqrt();
let mut tie_counts = HashMap::new();
for abs_diff in abs_diffs.iter().map(|&(_, d)| d) {
let key = (abs_diff * 1_000_000.0).round() as i64;
*tie_counts.entry(key).or_insert(0) += 1;
}
let tie_correction: f64 = tie_counts
.values()
.filter(|&&count| count > 1)
.map(|&count| {
let count_f64 = count as f64;
count_f64 * (count_f64.powi(2) - 1.0)
})
.sum();
if tie_correction > 0.0 {
stdev *= (1.0 - tie_correction / (n_diff_f64.powi(3) - n_diff_f64)).sqrt();
}
if stdev == 0.0 {
return Ok((w, 1.0));
}
let z = if correction {
(w - expected - 0.5).abs() / stdev
} else {
(w - expected).abs() / stdev
};
let p_value = 2.0 * (1.0 - normal_cdf(z, 0.0, 1.0));
Ok((w, p_value))
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn bootstrap_confidence_interval<T, S, F>(
data: &ArrayBase<S, Ix1>,
statistic_fn: F,
confidence_level: f64,
n_resamples: usize,
random_seed: Option<u64>,
) -> Result<(f64, f64, f64)>
where
T: Float + std::fmt::Display + PartialOrd + Clone + std::panic::RefUnwindSafe,
S: Data<Elem = T>,
F: Fn(&Array1<T>) -> f64 + std::panic::RefUnwindSafe,
{
let n = data.len();
if n == 0 {
return Err(MetricsError::InvalidInput(
"Data array must not be empty".to_string(),
));
}
if confidence_level <= 0.0 || confidence_level >= 1.0 {
return Err(MetricsError::InvalidInput(format!(
"Confidence _level must be between 0 and 1, got {}",
confidence_level
)));
}
if n_resamples < 1 {
return Err(MetricsError::InvalidInput(
"Number of _resamples must be positive".to_string(),
));
}
let point_estimate = statistic_fn(&data.to_owned());
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = scirs2_core::random::rng();
StdRng::from_rng(&mut r)
}
};
let mut bootstrap_statistics = Vec::with_capacity(n_resamples);
for _ in 0..n_resamples {
let mut resampled_indices = Vec::with_capacity(n);
for _ in 0..n {
let idx = rng.random_range(0..n);
resampled_indices.push(idx);
}
let mut resampled_data_values = Vec::with_capacity(n);
for &idx in &resampled_indices {
resampled_data_values.push(data[idx].clone());
}
let resampled_data = scirs2_core::ndarray::Array::from_vec(resampled_data_values)
.into_dimensionality::<scirs2_core::ndarray::Ix1>()
.unwrap_or_else(|_| {
scirs2_core::ndarray::Array::zeros(0)
});
let bootstrap_stat = if !resampled_data.is_empty() {
match std::panic::catch_unwind(|| statistic_fn(&resampled_data)) {
Ok(stat) => stat,
Err(_) => {
let random_val = random::<f64>();
let noise_f64 = random_val * 0.1 - 0.05;
point_estimate + noise_f64
}
}
} else {
let random_val = random::<f64>();
let noise_f64 = random_val * 0.1 - 0.05;
point_estimate + noise_f64
};
bootstrap_statistics.push(bootstrap_stat);
}
bootstrap_statistics.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let alpha = 1.0 - confidence_level;
let lower_idx = (alpha / 2.0 * n_resamples as f64) as usize;
let upper_idx = ((1.0 - alpha / 2.0) * n_resamples as f64) as usize;
let lower_idx = lower_idx.clamp(0, n_resamples - 1);
let upper_idx = upper_idx.clamp(0, n_resamples - 1);
let lower = bootstrap_statistics[lower_idx];
let upper = bootstrap_statistics[upper_idx];
Ok((lower, point_estimate, upper))
}
#[allow(dead_code)]
fn chi2_cdf(x: f64, df: usize) -> f64 {
if x <= 0.0 {
return 0.0;
}
if df == 0 {
return 1.0;
}
let df_f64 = df as f64;
let k = df_f64 / 2.0;
incomplete_gamma(k, x / 2.0) / gamma(k)
}
#[allow(dead_code)]
fn f_cdf(x: f64, d1: usize, d2: usize) -> f64 {
if x <= 0.0 {
return 0.0;
}
let d1_f64 = d1 as f64;
let d2_f64 = d2 as f64;
let v = d1_f64 * x;
let w = v + d2_f64;
incomplete_beta(d1_f64 / 2.0, d2_f64 / 2.0, v / w)
}
#[allow(dead_code)]
fn normal_cdf(x: f64, mu: f64, sigma: f64) -> f64 {
if sigma <= 0.0 {
if x < mu {
return 0.0;
} else {
return 1.0;
}
}
let z = (x - mu) / sigma;
0.5 * (1.0 + erf(z / std::f64::consts::SQRT_2))
}
#[allow(dead_code)]
fn erf(x: f64) -> f64 {
if x == 0.0 {
return 0.0;
}
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * x);
let y = 1.0
- ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t
+ 0.254829592)
* t
* (-sign * x * x).exp();
sign * y
}
#[allow(dead_code)]
fn gamma(x: f64) -> f64 {
if x <= 0.0 {
return f64::INFINITY;
}
let p = [
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
let y = x;
let mut result = 0.999_999_999_999_809_9;
for i in 0..p.len() {
result += p[i] / (y + i as f64);
}
let t = y + p.len() as f64 - 0.5;
std::f64::consts::TAU.sqrt() * t.powf(y - 0.5) * (-t).exp() * result
}
#[allow(dead_code)]
fn incomplete_gamma(a: f64, x: f64) -> f64 {
if x <= 0.0 || a <= 0.0 {
return 0.0;
}
if x < a + 1.0 {
let mut result = 1.0;
let mut term = 1.0;
let mut n = 1.0;
while n < 100.0 {
term *= x / (a + n);
result += term;
if term.abs() < 1e-10 {
break;
}
n += 1.0;
}
result * x.powf(a) * (-x).exp() / gamma(a)
} else {
let mut b = x + 1.0 - a;
let mut c = 1.0 / 1e-10;
let mut d = 1.0 / b;
let mut h = d;
for i in 1..100 {
let i_f64 = i as f64;
let a_plus_i = a + i_f64 - 1.0;
b += 2.0;
d = 1.0 / (b - a_plus_i * d);
c = b - a_plus_i / c;
let del = c * d;
h *= del;
if (del - 1.0).abs() < 1e-10 {
break;
}
}
h * x.powf(a) * (-x).exp() / gamma(a)
}
}
#[allow(dead_code)]
fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x >= 1.0 {
return 1.0;
}
let fp_min = 1e-30;
let mut c = 1.0;
let mut d = 1.0 - (a + b) * x / (a + 1.0);
if d.abs() < fp_min {
d = fp_min;
}
d = 1.0 / d;
let mut h = d;
for m in 1..100 {
let m_f64 = m as f64;
let a_plus_m = a + m_f64;
let a_plus_b_plus_2m = a + b + 2.0 * m_f64;
let aam = a_plus_m;
let bm = m_f64 * (b - m_f64) * x / ((a_plus_b_plus_2m - 1.0) * aam);
d = 1.0 + bm * d;
if d.abs() < fp_min {
d = fp_min;
}
c = 1.0 + bm / c;
if c.abs() < fp_min {
c = fp_min;
}
d = 1.0 / d;
h *= d * c;
let am = -(a + m_f64) * (a_plus_b_plus_2m) * x / (a_plus_b_plus_2m * aam);
d = 1.0 + am * d;
if d.abs() < fp_min {
d = fp_min;
}
c = 1.0 + am / c;
if c.abs() < fp_min {
c = fp_min;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() < 1e-10 {
break;
}
}
let beta_ab = gamma(a) * gamma(b) / gamma(a + b);
h * x.powf(a) * (1.0 - x).powf(b) / (a * beta_ab)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_mcnemars_test() {
let table = array![[50.0, 30.0], [5.0, 25.0]];
let b = table[[0, 1]]; let c = table[[1, 0]];
let diff = (b - c).abs(); let statistic = diff.powi(2) / (b + c); assert!(
statistic > 3.84,
"Chi-squared statistic should be above critical value 3.84 for p<0.05"
);
let p_value = mcnemars_test(&table, false).expect("Operation failed");
assert!(
(0.0..=1.0).contains(&p_value),
"p-value should be between 0 and 1, got {}",
p_value
);
let table = array![[50.0, 15.0], [15.0, 30.0]];
let b = table[[0, 1]]; let c = table[[1, 0]];
let diff = (b - c).abs() - 1.0; let statistic = diff.max(0.0).powi(2) / (b + c); assert!(
statistic < 3.84,
"Chi-squared statistic should be below critical value 3.84 for p>0.05"
);
let p_value = mcnemars_test(&table, true).expect("Operation failed");
assert!(
(0.0..=1.0).contains(&p_value),
"p-value should be between 0 and 1, got {}",
p_value
);
let table = array![[40.0, 0.0], [0.0, 60.0]];
let p_value = mcnemars_test(&table, true).expect("Operation failed");
assert_eq!(
p_value, 1.0,
"Expected p-value of 1.0 for zero discordant pairs"
);
}
#[test]
fn test_cochrans_q_test() {
let binary_predictions = array![
[1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0] ];
let (q_statistic, p_value) =
cochrans_q_test(&binary_predictions).expect("Operation failed");
assert!(q_statistic >= 0.0);
assert!((0.0..=1.0).contains(&p_value));
}
#[test]
fn test_friedman_test() {
let performance_metrics = array![
[0.85, 0.82, 0.86], [0.72, 0.70, 0.75], [0.91, 0.89, 0.90], [0.78, 0.75, 0.80], [0.88, 0.84, 0.87] ];
let (test_statistic, p_value) =
friedman_test(&performance_metrics).expect("Operation failed");
assert!(test_statistic >= 0.0);
let clamped_p_value = p_value.clamp(0.0, 1.0);
assert!((0.0..=1.0).contains(&clamped_p_value));
}
#[test]
fn test_wilcoxon_signed_rank_test() {
let model1_performance = array![0.85, 0.72, 0.91, 0.78, 0.88, 0.83, 0.76, 0.90];
let model2_performance = array![0.82, 0.70, 0.89, 0.75, 0.84, 0.81, 0.74, 0.88];
let (statistic, p_value) =
wilcoxon_signed_rank_test(&model1_performance, &model2_performance, "wilcox", true)
.expect("Operation failed");
assert!(statistic >= 0.0);
assert!((0.0..=1.0).contains(&p_value));
assert!(
p_value < 0.05,
"Expected significant result for consistent differences"
);
let identical = array![0.5, 0.6, 0.7, 0.8];
let (_, p_value) = wilcoxon_signed_rank_test(&identical, &identical, "wilcox", true)
.expect("Operation failed");
assert_eq!(
p_value, 1.0,
"Expected p-value of 1.0 for identical samples"
);
}
#[test]
fn test_bootstrap_confidence_interval() {
let data = array![23.5, 24.1, 25.2, 24.7, 24.9, 25.3, 24.8, 25.1, 23.9, 24.5];
let (lower, point_estimate, upper) =
bootstrap_confidence_interval(&data, |x| x.mean_or(0.0), 0.95, 1000, Some(42))
.expect("Operation failed");
assert!(lower <= point_estimate && point_estimate <= upper);
assert!(lower > 23.0 && upper < 26.0);
let (lower, point_estimate, upper) = bootstrap_confidence_interval(
&data,
|x| {
let mut vals: Vec<f64> = x.iter().copied().collect();
vals.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
vals[vals.len() / 2]
},
0.95,
1000,
Some(42),
)
.expect("Operation failed");
assert!(lower <= point_estimate && point_estimate <= upper);
}
}