use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::ArrayView1;
use scirs2_core::numeric::{Float, NumCast};
use std::cmp::Ordering;
#[allow(dead_code)]
pub fn levene<F>(
samples: &[ArrayView1<F>],
center: &str,
proportion_to_cut: F,
) -> StatsResult<(F, F)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::fmt::Debug
+ std::fmt::Display,
{
if center != "mean" && center != "median" && center != "trimmed" {
return Err(StatsError::InvalidArgument(format!(
"Invalid center parameter: {}. Use 'mean', 'median', or 'trimmed'",
center
)));
}
let k = samples.len();
if k < 2 {
return Err(StatsError::InvalidArgument(
"At least two samples are required for Levene's test".to_string(),
));
}
for (i, sample) in samples.iter().enumerate() {
if sample.is_empty() {
return Err(StatsError::InvalidArgument(format!(
"Sample {} is empty",
i
)));
}
}
let mut n_i = Vec::with_capacity(k);
let mut y_ci = Vec::with_capacity(k);
let mut samples_processed = Vec::with_capacity(k);
for sample in samples {
if center == "trimmed" {
let mut sorted_sample = sample.to_vec();
sorted_sample.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let trimmed = trim_both(&sorted_sample, proportion_to_cut);
samples_processed.push(trimmed);
} else {
samples_processed.push(sample.to_vec());
}
}
for sample in samples_processed.iter() {
let size = sample.len();
n_i.push(F::from(size).expect("Failed to convert to float"));
let central_value = match center {
"mean" => calculate_mean(sample),
"median" => calculate_median(sample),
"trimmed" => calculate_mean(sample), _ => unreachable!(),
};
y_ci.push(central_value);
}
let n_tot = n_i.iter().cloned().sum::<F>();
let mut z_ij = Vec::with_capacity(k);
for (i, sample) in samples_processed.iter().enumerate() {
let center_i = y_ci[i];
let deviations: Vec<F> = sample.iter().map(|&x| (x - center_i).abs()).collect();
z_ij.push(deviations);
}
let mut z_i = Vec::with_capacity(k);
for deviations in &z_ij {
let mean_dev = calculate_mean(deviations);
z_i.push(mean_dev);
}
let mut z_bar = F::zero();
for i in 0..k {
z_bar = z_bar + z_i[i] * n_i[i];
}
z_bar = z_bar / n_tot;
let mut numerator = F::zero();
for i in 0..k {
numerator = numerator + n_i[i] * (z_i[i] - z_bar).powi(2);
}
numerator = numerator * (n_tot - F::from(k).expect("Failed to convert to float"));
let mut denominator = F::zero();
for i in 0..k {
for j in 0..z_ij[i].len() {
denominator = denominator + (z_ij[i][j] - z_i[i]).powi(2);
}
}
denominator = denominator * F::from(k - 1).expect("Failed to convert to float");
let w = numerator / denominator;
let df1 = F::from(k - 1).expect("Failed to convert to float");
let df2 = n_tot - F::from(k).expect("Failed to convert to float");
let p_value = f_distribution_sf(w, df1, df2);
Ok((w, p_value))
}
#[allow(dead_code)]
fn calculate_mean<F>(data: &[F]) -> F
where
F: Float + std::iter::Sum<F> + std::fmt::Display,
{
let sum = data.iter().cloned().sum::<F>();
sum / F::from(data.len()).expect("Operation failed")
}
#[allow(dead_code)]
fn calculate_median<F>(data: &[F]) -> F
where
F: Float + Copy + std::fmt::Display,
{
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let n = sorted.len();
if n.is_multiple_of(2) {
let mid_right = n / 2;
let mid_left = mid_right - 1;
(sorted[mid_left] + sorted[mid_right])
/ F::from(2.0).expect("Failed to convert constant to float")
} else {
sorted[n / 2]
}
}
#[allow(dead_code)]
fn trim_both<F>(sorteddata: &[F], proportion: F) -> Vec<F>
where
F: Float + Copy + std::fmt::Display,
{
if proportion <= F::zero()
|| proportion >= F::from(0.5).expect("Failed to convert constant to float")
{
return sorteddata.to_vec();
}
let n = sorteddata.len();
let k = (F::from(n).expect("Failed to convert to float") * proportion).floor();
let k_int = k.to_usize().expect("Operation failed");
if k_int == 0 {
return sorteddata.to_vec();
}
sorteddata[k_int..n - k_int].to_vec()
}
#[allow(dead_code)]
fn f_distribution_sf<F: Float + NumCast>(f: F, df1: F, df2: F) -> F {
let f_f64 = <f64 as NumCast>::from(f).expect("Operation failed");
let df1_f64 = <f64 as NumCast>::from(df1).expect("Operation failed");
let df2_f64 = <f64 as NumCast>::from(df2).expect("Operation failed");
let x = df2_f64 / (df2_f64 + df1_f64 * f_f64);
let p = beta_cdf(x, df2_f64 / 2.0, df1_f64 / 2.0);
F::from(p).expect("Failed to convert to float")
}
#[allow(dead_code)]
fn beta_cdf(x: f64, a: f64, b: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x >= 1.0 {
return 1.0;
}
let max_iter = 100;
let eps = 1e-10;
if x <= (a / (a + b)) {
let bt = beta_continued_fraction(x, a, b, max_iter, eps);
bt / beta_function(a, b)
} else {
let bt = beta_continued_fraction(1.0 - x, b, a, max_iter, eps);
1.0 - bt / beta_function(b, a)
}
}
#[allow(dead_code)]
fn beta_continued_fraction(x: f64, a: f64, b: f64, maxiter: usize, eps: f64) -> f64 {
let qab = a + b;
let qap = a + 1.0;
let qam = a - 1.0;
let mut c = 1.0;
let mut d = 1.0 - qab * x / qap;
if d.abs() < eps {
d = eps;
}
d = 1.0 / d;
let mut h = d;
for m in 1..maxiter {
let m2 = 2 * m;
let aa = m as f64 * (b - m as f64) * x / ((qam + m2 as f64) * (a + m2 as f64));
d = 1.0 + aa * d;
if d.abs() < eps {
d = eps;
}
c = 1.0 + aa / c;
if c.abs() < eps {
c = eps;
}
d = 1.0 / d;
h *= d * c;
let aa = -(a + m as f64) * (qab + m as f64) * x / ((a + m2 as f64) * (qap + m2 as f64));
d = 1.0 + aa * d;
if d.abs() < eps {
d = eps;
}
c = 1.0 + aa / c;
if c.abs() < eps {
c = eps;
}
d = 1.0 / d;
h *= d * c;
if (d * c - 1.0).abs() < eps {
break;
}
}
x.powf(a) * (1.0 - x).powf(b) * h / a
}
#[allow(dead_code)]
fn beta_function(a: f64, b: f64) -> f64 {
gamma_function(a) * gamma_function(b) / gamma_function(a + b)
}
#[allow(dead_code)]
fn gamma_function(x: f64) -> f64 {
if x <= 0.0 {
panic!("Gamma function not defined for non-positive values");
}
if x < 0.5 {
return std::f64::consts::PI / ((std::f64::consts::PI * x).sin() * gamma_function(1.0 - x));
}
let p = [
676.5203681218851,
-1259.1392167224028,
771.323428777653,
-176.61502916214,
12.507343278687,
-0.1385710952657,
9.984369578019e-6,
1.50563273515e-7,
];
let z = x - 1.0;
let mut result = 0.9999999999998;
for (i, &value) in p.iter().enumerate() {
result += value / (z + (i + 1) as f64);
}
let t = z + p.len() as f64 - 0.5;
2.506628274631 * t.powf(z + 0.5) * (-t).exp() * result
}
#[allow(dead_code)]
pub fn bartlett<F>(samples: &[ArrayView1<F>]) -> StatsResult<(F, F)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::fmt::Debug
+ std::fmt::Display,
{
let k = samples.len();
if k < 2 {
return Err(StatsError::InvalidArgument(
"At least two samples are required for Bartlett's test".to_string(),
));
}
for (i, sample) in samples.iter().enumerate() {
if sample.is_empty() {
return Err(StatsError::InvalidArgument(format!(
"Sample {} is empty",
i
)));
}
}
let mut n_i = Vec::with_capacity(k);
let mut v_i = Vec::with_capacity(k); let mut df_i = Vec::with_capacity(k);
for sample in samples {
let n = sample.len();
if n < 2 {
return Err(StatsError::InvalidArgument(
"Each sample must have at least 2 observations".to_string(),
));
}
let n_f = F::from(n).expect("Failed to convert to float");
let df = n_f - F::one();
let mean = sample.iter().cloned().sum::<F>() / n_f;
let variance = sample.iter().map(|&x| (x - mean).powi(2)).sum::<F>() / df;
n_i.push(n_f);
v_i.push(variance);
df_i.push(df);
}
let n_tot = n_i.iter().cloned().sum::<F>();
let df_tot = n_tot - F::from(k).expect("Failed to convert to float");
let mut numerator = F::zero();
for i in 0..k {
numerator = numerator + df_i[i] * v_i[i];
}
let pooled_var = numerator / df_tot;
let mut ln_term_sum = F::zero();
for i in 0..k {
ln_term_sum = ln_term_sum + df_i[i] * (v_i[i] / pooled_var).ln();
}
let correction_factor = F::one()
+ (F::one()
/ (F::from(3).expect("Failed to convert constant to float")
* F::from(k - 1).expect("Failed to convert to float")))
* (df_i.iter().map(|&df| F::one() / df).sum::<F>() - F::one() / df_tot);
let test_statistic = (df_tot * pooled_var.ln()
- df_i
.iter()
.zip(v_i.iter())
.map(|(&df, &v)| df * v.ln())
.sum::<F>())
/ correction_factor;
let df_chi2 = F::from(k - 1).expect("Failed to convert to float");
let p_value = chi_square_sf(test_statistic, df_chi2);
Ok((test_statistic, p_value))
}
#[allow(dead_code)]
fn chi_square_sf<F: Float + NumCast>(x: F, df: F) -> F {
let x_f64 = <f64 as NumCast>::from(x).expect("Operation failed");
let df_f64 = <f64 as NumCast>::from(df).expect("Operation failed");
if x_f64 <= 0.0 {
return F::one();
}
let p_value = 1.0 - chi_square_cdf(x_f64, df_f64);
F::from(p_value).expect("Failed to convert to float")
}
#[allow(dead_code)]
fn chi_square_cdf(x: f64, df: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let a = df / 2.0; let x_half = x / 2.0;
gamma_p(a, x_half)
}
#[allow(dead_code)]
fn gamma_p(a: f64, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
if x > 200.0 * a {
return 1.0; }
if a < 1.0 {
let series_sum = gamma_series(a, x);
let gamma_a = gamma_function(a);
series_sum / gamma_a
} else {
if x < a + 1.0 {
let series_sum = gamma_series(a, x);
let gamma_a = gamma_function(a);
series_sum / gamma_a
} else {
let cf = gamma_continued_fraction(a, x);
let gamma_a = gamma_function(a);
1.0 - cf / gamma_a
}
}
}
#[allow(dead_code)]
fn gamma_series(a: f64, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let max_iter = 100;
let epsilon = 1e-10;
let mut term = 1.0 / a;
let mut sum = term;
for n in 1..max_iter {
term *= x / (a + n as f64);
sum += term;
if term < epsilon * sum {
break;
}
}
sum * (-x).exp() * x.powf(a)
}
#[allow(dead_code)]
fn gamma_continued_fraction(a: f64, x: f64) -> f64 {
if x <= 0.0 {
return gamma_function(a);
}
let max_iter = 100;
let epsilon = 1e-10;
let mut b = x + 1.0 - a;
let mut c = 1.0 / 1e-30; let mut d = 1.0 / b;
let mut h = d;
for i in 1..max_iter {
let i_f64 = i as f64;
let a_i = -i_f64 * (i_f64 - a);
b += 2.0;
d = 1.0 / (b + a_i * d);
c = b + a_i / c;
let del = c * d;
h *= del;
if (del - 1.0).abs() < epsilon {
break;
}
}
h * (-x).exp() * x.powf(a)
}
#[allow(dead_code)]
pub fn brown_forsythe<F>(samples: &[ArrayView1<F>]) -> StatsResult<(F, F)>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ NumCast
+ std::fmt::Debug
+ std::fmt::Display,
{
levene(
samples,
"median",
F::from(0.05).expect("Failed to convert constant to float"),
)
}