use ndarray::{Array1, Array2, ArrayView2, Axis};
use rand::seq::IndexedRandom;
use rand::{Rng, SeedableRng};
use crate::error::{RagDriftError, Result};
use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};
#[derive(Debug, Clone, Copy)]
pub struct QueryDriftConfig {
pub threshold: f64,
pub n_clusters: usize,
pub max_iter: usize,
pub seed: u64,
}
impl Default for QueryDriftConfig {
fn default() -> Self {
Self {
threshold: 0.1,
n_clusters: 8,
max_iter: 25,
seed: 0,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct QueryDriftDetector {
config: QueryDriftConfig,
}
impl QueryDriftDetector {
pub fn new(config: QueryDriftConfig) -> Self {
Self { config }
}
pub fn detect(
&self,
baseline: &ArrayView2<'_, f32>,
current: &ArrayView2<'_, f32>,
) -> Result<DriftScore> {
check_same_cols(baseline, current)?;
check_min_samples(baseline.nrows(), self.config.n_clusters)?;
check_min_samples(current.nrows(), 1)?;
if self.config.n_clusters < 2 {
return Err(RagDriftError::InvalidConfig(
"n_clusters must be >= 2".into(),
));
}
let centroids = kmeans_fit(
baseline,
self.config.n_clusters,
self.config.max_iter,
self.config.seed,
)?;
let p = assignment_dist(baseline, ¢roids);
let q = assignment_dist(current, ¢roids);
let kl = symmetric_kl(&p, &q);
Ok(DriftScore::new(
DriftDimension::Query,
kl,
self.config.threshold,
"kmeans-skl",
))
}
}
#[allow(clippy::needless_range_loop)] fn kmeans_fit(
data: &ArrayView2<'_, f32>,
k: usize,
max_iter: usize,
seed: u64,
) -> Result<Array2<f32>> {
let n = data.nrows();
let dim = data.ncols();
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut centroids = Array2::<f32>::zeros((k, dim));
let first = rng.random_range(0..n);
centroids.row_mut(0).assign(&data.row(first));
let mut min_d2 = vec![f32::INFINITY; n];
for c in 1..k {
let last = centroids.row(c - 1);
for i in 0..n {
let mut d = 0.0_f32;
for (a, b) in data.row(i).iter().zip(last.iter()) {
let diff = a - b;
d += diff * diff;
}
if d < min_d2[i] {
min_d2[i] = d;
}
}
let total: f64 = min_d2.iter().map(|&d| d as f64).sum();
if total <= 0.0 {
let idx = rng.random_range(0..n);
centroids.row_mut(c).assign(&data.row(idx));
continue;
}
let target: f64 = rng.random_range(0.0..total);
let mut acc = 0.0_f64;
let mut chosen = n - 1;
for (i, &d) in min_d2.iter().enumerate() {
acc += d as f64;
if acc >= target {
chosen = i;
break;
}
}
centroids.row_mut(c).assign(&data.row(chosen));
}
let mut labels = vec![0_usize; n];
for _ in 0..max_iter {
let mut changed = false;
for i in 0..n {
let new = nearest_centroid(&data.row(i), ¢roids.view());
if new != labels[i] {
changed = true;
labels[i] = new;
}
}
if !changed {
break;
}
let mut new_centroids = Array2::<f32>::zeros((k, dim));
let mut counts = vec![0_usize; k];
for i in 0..n {
let label = labels[i];
let mut row = new_centroids.row_mut(label);
for (a, b) in row.iter_mut().zip(data.row(i).iter()) {
*a += *b;
}
counts[label] += 1;
}
for c in 0..k {
if counts[c] > 0 {
new_centroids
.row_mut(c)
.mapv_inplace(|x| x / counts[c] as f32);
} else {
let resample: Vec<usize> = (0..n).collect();
let idx = *resample.choose(&mut rng).unwrap();
new_centroids.row_mut(c).assign(&data.row(idx));
}
}
centroids = new_centroids;
}
Ok(centroids)
}
fn nearest_centroid(
point: &ndarray::ArrayView1<'_, f32>,
centroids: &ndarray::ArrayView2<'_, f32>,
) -> usize {
let mut best = 0_usize;
let mut best_d = f32::INFINITY;
for (c, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
let mut d = 0.0_f32;
for (a, b) in point.iter().zip(centroid.iter()) {
let diff = a - b;
d += diff * diff;
}
if d < best_d {
best_d = d;
best = c;
}
}
best
}
fn assignment_dist(data: &ArrayView2<'_, f32>, centroids: &Array2<f32>) -> Array1<f64> {
let k = centroids.nrows();
let mut counts = Array1::<f64>::zeros(k);
for row in data.axis_iter(Axis(0)) {
let c = nearest_centroid(&row, ¢roids.view());
counts[c] += 1.0;
}
let total = counts.sum().max(1.0);
counts.mapv_inplace(|x| x / total);
counts
}
fn symmetric_kl(p: &Array1<f64>, q: &Array1<f64>) -> f64 {
let eps = 1e-6;
let mut total = 0.0_f64;
for (pi, qi) in p.iter().zip(q.iter()) {
let p1 = pi.max(eps);
let q1 = qi.max(eps);
total += p1 * (p1 / q1).ln() + q1 * (q1 / p1).ln();
}
0.5 * total
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray_rand::rand_distr::StandardNormal;
use ndarray_rand::RandomExt;
#[test]
fn identical_query_embeddings_zero_drift() {
let a = Array2::<f32>::random((128, 8), StandardNormal);
let detector = QueryDriftDetector::default();
let s = detector.detect(&a.view(), &a.view()).unwrap();
assert!(s.score < 1e-3, "score={}", s.score);
assert!(!s.exceeded);
}
#[test]
fn shifted_query_distribution_flagged() {
let a = Array2::<f32>::random((128, 8), StandardNormal);
let mut b = Array2::<f32>::random((128, 8), StandardNormal);
b.mapv_inplace(|v| v + 5.0);
let detector = QueryDriftDetector::default();
let s = detector.detect(&a.view(), &b.view()).unwrap();
assert!(s.exceeded, "expected drift, score={}", s.score);
}
#[test]
fn rejects_dim_mismatch() {
let a = Array2::<f32>::zeros((16, 4));
let b = Array2::<f32>::zeros((16, 8));
let detector = QueryDriftDetector::default();
assert!(detector.detect(&a.view(), &b.view()).is_err());
}
#[test]
fn rejects_too_few_baseline_samples() {
let a = Array2::<f32>::zeros((4, 4));
let b = Array2::<f32>::zeros((4, 4));
let detector = QueryDriftDetector::default();
assert!(detector.detect(&a.view(), &b.view()).is_err());
}
}