use ndarray::ArrayView1;
use crate::error::{RagDriftError, Result};
use crate::types::check_min_samples;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct KsResult {
pub statistic: f64,
pub p_value: f64,
}
pub fn ks_two_sample(
baseline: &ArrayView1<'_, f64>,
current: &ArrayView1<'_, f64>,
) -> Result<KsResult> {
check_min_samples(baseline.len(), 2)?;
check_min_samples(current.len(), 2)?;
let mut a: Vec<f64> = baseline.iter().copied().collect();
let mut b: Vec<f64> = current.iter().copied().collect();
if a.iter().any(|x| !x.is_finite()) || b.iter().any(|x| !x.is_finite()) {
return Err(RagDriftError::NumericalInstability {
step: "ks".into(),
reason: "non-finite input".into(),
});
}
a.sort_by(|x, y| x.partial_cmp(y).unwrap());
b.sort_by(|x, y| x.partial_cmp(y).unwrap());
let n = a.len() as f64;
let m = b.len() as f64;
let mut i = 0usize;
let mut j = 0usize;
let mut cdf_a = 0.0;
let mut cdf_b = 0.0;
let mut d = 0.0_f64;
while i < a.len() && j < b.len() {
let x = a[i].min(b[j]);
while i < a.len() && a[i] <= x {
i += 1;
cdf_a = i as f64 / n;
}
while j < b.len() && b[j] <= x {
j += 1;
cdf_b = j as f64 / m;
}
let diff = (cdf_a - cdf_b).abs();
if diff > d {
d = diff;
}
}
let en = (n * m / (n + m)).sqrt();
let lambda = (en + 0.12 + 0.11 / en) * d;
let p = kolmogorov_q(lambda);
Ok(KsResult {
statistic: d,
p_value: p,
})
}
fn kolmogorov_q(lambda: f64) -> f64 {
if lambda < 0.18 {
return 1.0;
}
let mut sum = 0.0_f64;
let mut prev = 0.0_f64;
for k in 1..=100 {
let term = if k % 2 == 1 { 1.0 } else { -1.0 }
* (-2.0 * (k as f64).powi(2) * lambda * lambda).exp();
sum += term;
if (sum - prev).abs() < 1e-12 {
break;
}
prev = sum;
}
(2.0 * sum).clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::Array1;
#[test]
fn identical_samples_zero_d() {
let a = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let r = ks_two_sample(&a.view(), &a.view()).unwrap();
assert_eq!(r.statistic, 0.0);
assert!(r.p_value > 0.99);
}
#[test]
fn fully_disjoint_samples_d_is_one() {
let a = Array1::from(vec![0.0, 0.1, 0.2, 0.3]);
let b = Array1::from(vec![10.0, 10.1, 10.2, 10.3]);
let r = ks_two_sample(&a.view(), &b.view()).unwrap();
assert_abs_diff_eq!(r.statistic, 1.0, epsilon = 1e-12);
assert!(r.p_value < 0.05);
}
#[test]
fn shifted_uniform() {
let a = Array1::from(vec![0.0, 1.0, 2.0, 3.0]);
let b = Array1::from(vec![0.5, 1.5, 2.5, 3.5]);
let r = ks_two_sample(&a.view(), &b.view()).unwrap();
assert_abs_diff_eq!(r.statistic, 0.25, epsilon = 1e-12);
}
#[test]
fn rejects_too_few_samples() {
let a = Array1::from(vec![1.0]);
let b = Array1::from(vec![1.0, 2.0]);
assert!(ks_two_sample(&a.view(), &b.view()).is_err());
}
#[test]
fn rejects_non_finite() {
let a = Array1::from(vec![1.0, f64::NAN, 3.0]);
let b = Array1::from(vec![1.0, 2.0, 3.0]);
assert!(ks_two_sample(&a.view(), &b.view()).is_err());
}
#[test]
fn kolmogorov_q_bounds() {
assert_eq!(kolmogorov_q(0.0), 1.0);
assert!(kolmogorov_q(10.0) < 1e-30);
assert!(kolmogorov_q(2.0) > 0.0 && kolmogorov_q(2.0) < 1.0);
}
}