use ndarray::ArrayView1;
use crate::error::{RagDriftError, Result};
use crate::types::check_min_samples;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PsiBinning {
Quantile(usize),
EqualWidth(usize),
}
impl Default for PsiBinning {
fn default() -> Self {
Self::Quantile(10)
}
}
pub fn psi(
baseline: &ArrayView1<'_, f64>,
current: &ArrayView1<'_, f64>,
binning: PsiBinning,
) -> Result<f64> {
let n_bins = match binning {
PsiBinning::Quantile(n) | PsiBinning::EqualWidth(n) => n,
};
if n_bins < 2 {
return Err(RagDriftError::InvalidConfig(format!(
"PSI needs at least 2 bins, got {n_bins}"
)));
}
check_min_samples(baseline.len(), n_bins)?;
check_min_samples(current.len(), 2)?;
if baseline.iter().any(|x| !x.is_finite()) || current.iter().any(|x| !x.is_finite()) {
return Err(RagDriftError::NumericalInstability {
step: "psi".into(),
reason: "non-finite input".into(),
});
}
let edges = match binning {
PsiBinning::Quantile(n) => quantile_edges(baseline, n),
PsiBinning::EqualWidth(n) => equal_width_edges(baseline, n),
};
let edges = enforce_strictly_increasing(&edges);
if edges.len() < 3 {
return Ok(0.0);
}
let n_b = baseline.len() as f64;
let n_c = current.len() as f64;
let bin_count = edges.len() - 1;
let mut counts_b = vec![0.0; bin_count];
let mut counts_c = vec![0.0; bin_count];
for &x in baseline {
let i = bin_for(&edges, x);
counts_b[i] += 1.0;
}
for &x in current {
let i = bin_for(&edges, x);
counts_c[i] += 1.0;
}
let eps = 1e-6;
let mut total = 0.0_f64;
for i in 0..bin_count {
let p = (counts_b[i] / n_b).max(eps);
let q = (counts_c[i] / n_c).max(eps);
total += (p - q) * (p / q).ln();
}
if !total.is_finite() {
return Err(RagDriftError::NumericalInstability {
step: "psi".into(),
reason: "non-finite PSI".into(),
});
}
Ok(total.max(0.0))
}
fn quantile_edges(values: &ArrayView1<'_, f64>, n_bins: usize) -> Vec<f64> {
let mut sorted: Vec<f64> = values.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = sorted.len();
let mut out = Vec::with_capacity(n_bins + 1);
out.push(f64::NEG_INFINITY);
for i in 1..n_bins {
let q = i as f64 / n_bins as f64;
let pos = q * (len - 1) as f64;
let lo = pos.floor() as usize;
let hi = pos.ceil() as usize;
let frac = pos - lo as f64;
let val = sorted[lo] * (1.0 - frac) + sorted[hi] * frac;
out.push(val);
}
out.push(f64::INFINITY);
out
}
fn equal_width_edges(values: &ArrayView1<'_, f64>, n_bins: usize) -> Vec<f64> {
let mut min = f64::INFINITY;
let mut max = f64::NEG_INFINITY;
for &x in values {
if x < min {
min = x;
}
if x > max {
max = x;
}
}
let mut out = Vec::with_capacity(n_bins + 1);
out.push(f64::NEG_INFINITY);
let step = (max - min) / n_bins as f64;
for i in 1..n_bins {
out.push(min + step * i as f64);
}
out.push(f64::INFINITY);
out
}
fn enforce_strictly_increasing(edges: &[f64]) -> Vec<f64> {
let mut out = Vec::with_capacity(edges.len());
for &e in edges {
if let Some(&last) = out.last() {
if e <= last && e.is_finite() {
continue;
}
}
out.push(e);
}
out
}
fn bin_for(edges: &[f64], x: f64) -> usize {
let mut lo = 0usize;
let mut hi = edges.len() - 1;
while lo + 1 < hi {
let mid = (lo + hi) / 2;
if x < edges[mid] {
hi = mid;
} else {
lo = mid;
}
}
lo
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::Array1;
#[test]
fn identical_samples_zero_psi() {
let a = Array1::from((0..100).map(|x| x as f64).collect::<Vec<_>>());
let v = psi(&a.view(), &a.view(), PsiBinning::Quantile(10)).unwrap();
assert!(v < 1e-3, "expected near-zero PSI, got {v}");
}
#[test]
fn shifted_distribution_psi_above_threshold() {
let a: Vec<f64> = (0..1000).map(|x| x as f64 / 1000.0).collect();
let b: Vec<f64> = (0..1000).map(|x| (x as f64 / 1000.0) + 0.5).collect();
let v = psi(
&Array1::from(a).view(),
&Array1::from(b).view(),
PsiBinning::Quantile(10),
)
.unwrap();
assert!(
v > 0.25,
"shifted distribution should produce significant PSI, got {v}"
);
}
#[test]
fn equal_width_binning_works() {
let a: Vec<f64> = (0..100).map(|x| x as f64).collect();
let b: Vec<f64> = a.clone();
let v = psi(
&Array1::from(a).view(),
&Array1::from(b).view(),
PsiBinning::EqualWidth(10),
)
.unwrap();
assert_abs_diff_eq!(v, 0.0, epsilon = 1e-3);
}
#[test]
fn rejects_too_few_bins() {
let a = Array1::from(vec![1.0, 2.0, 3.0]);
assert!(psi(&a.view(), &a.view(), PsiBinning::Quantile(1)).is_err());
}
#[test]
fn rejects_too_few_samples() {
let a = Array1::from(vec![1.0, 2.0]);
assert!(psi(&a.view(), &a.view(), PsiBinning::Quantile(10)).is_err());
}
#[test]
fn degenerate_baseline_returns_zero() {
let a = Array1::from(vec![5.0; 100]);
let b = Array1::from(vec![5.0; 100]);
let v = psi(&a.view(), &b.view(), PsiBinning::Quantile(10)).unwrap();
assert_eq!(v, 0.0);
}
}