use num_traits::{Float, FromPrimitive, ToPrimitive};
use rand::thread_rng;
use statrs::distribution::{Normal, ContinuousCDF};
use crate::{CDF, Statistic, TestResult};
#[derive(Debug, Clone, Copy)]
pub struct KSTest;
impl<D, F> Statistic<D, TestResult<F>> for KSTest
where
D: AsRef<[F]> + Clone,
F: Float + FromPrimitive + ToPrimitive + Copy,
{
fn compute(&self, data: &D) -> TestResult<F> {
let data_slice = data.as_ref();
let n = data_slice.len();
if n == 0 {
return TestResult {
observed_statistic: F::zero(),
p_value: F::from(1.0).expect("1.0 is valid float"),
};
}
let n_f = F::from(n).expect("sample size fits in float");
let normal = Normal::new(0.0, 1.0).expect("Valid standard normal distribution");
let mut d_plus_max = F::zero();
let mut d_minus_max = F::zero();
let ecdf = CDF.compute(data);
for (i, &x) in ecdf.points().iter().enumerate() {
let i_f = F::from(i).expect("index fits in float");
let i1_f = F::from(i + 1).expect("index+1 fits in float");
let x_f64 = x.to_f64().expect("value convertible to f64");
let f_x = F::from(normal.cdf(x_f64)).expect("CDF value convertible to float");
let d_plus = i1_f / n_f - f_x;
let d_minus = f_x - i_f / n_f;
if d_plus > d_plus_max {
d_plus_max = d_plus;
}
if d_minus > d_minus_max {
d_minus_max = d_minus;
}
}
let d_max = if d_plus_max > d_minus_max {
d_plus_max
} else {
d_minus_max
};
let d_f64 = d_max.to_f64().expect("statistic convertible to f64");
let p_value_f64 = if d_f64 <= 0.0 {
1.0 } else if d_f64 >= 1.0 {
0.0 } else {
let n_f64 = n as f64;
let mut p = 0.0;
let mut k = 1;
let mut prev_term = f64::INFINITY;
while k <= 100 {
let exponent = -2.0 * (k as f64).powi(2) * d_f64 * d_f64 * n_f64;
if exponent < -700.0 {
break;
}
let term = (-1.0f64).powi(k - 1) * exponent.exp();
if term.abs() < 1e-15 || term.abs() < prev_term * 1e-12 {
p += term;
break;
}
p += term;
prev_term = term.abs();
k += 1;
}
let p_val = 2.0 * p;
p_val.max(0.0).min(1.0)
};
let p_value = F::from(p_value_f64).expect("p-value convertible to float");
TestResult {
observed_statistic: d_max,
p_value,
}
}
}