ragdrift-core 0.1.4

Pure-Rust core for ragdrift: 5-dimensional drift detection for RAG systems.
Documentation
//! Query-pattern drift detector.
//!
//! Fits k-means on baseline query embeddings, assigns both baseline and
//! current to the resulting clusters, and reports the symmetric KL
//! divergence between the two assignment distributions.

use ndarray::{Array1, Array2, ArrayView2, Axis};
use rand::seq::IndexedRandom;
use rand::{Rng, SeedableRng};

use crate::error::{RagDriftError, Result};
use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};

/// Configuration for [`QueryDriftDetector`].
#[derive(Debug, Clone, Copy)]
pub struct QueryDriftConfig {
    /// Threshold on the symmetric KL divergence.
    pub threshold: f64,
    /// Number of clusters. Default: 8.
    pub n_clusters: usize,
    /// Lloyd iterations. Default: 25.
    pub max_iter: usize,
    /// Seed for k-means++ initialization.
    pub seed: u64,
}

impl Default for QueryDriftConfig {
    fn default() -> Self {
        Self {
            threshold: 0.1,
            n_clusters: 8,
            max_iter: 25,
            seed: 0,
        }
    }
}

/// Detects shift in query distribution via k-means cluster assignment KL.
#[derive(Debug, Clone, Copy, Default)]
pub struct QueryDriftDetector {
    config: QueryDriftConfig,
}

impl QueryDriftDetector {
    /// Construct a detector from a custom config.
    pub fn new(config: QueryDriftConfig) -> Self {
        Self { config }
    }

    /// Detect drift between two `(n_samples, dim)` query embedding matrices.
    pub fn detect(
        &self,
        baseline: &ArrayView2<'_, f32>,
        current: &ArrayView2<'_, f32>,
    ) -> Result<DriftScore> {
        check_same_cols(baseline, current)?;
        check_min_samples(baseline.nrows(), self.config.n_clusters)?;
        check_min_samples(current.nrows(), 1)?;
        if self.config.n_clusters < 2 {
            return Err(RagDriftError::InvalidConfig(
                "n_clusters must be >= 2".into(),
            ));
        }

        let centroids = kmeans_fit(
            baseline,
            self.config.n_clusters,
            self.config.max_iter,
            self.config.seed,
        )?;

        let p = assignment_dist(baseline, &centroids);
        let q = assignment_dist(current, &centroids);
        let kl = symmetric_kl(&p, &q);

        Ok(DriftScore::new(
            DriftDimension::Query,
            kl,
            self.config.threshold,
            "kmeans-skl",
        ))
    }
}

/// Fit k-means with k-means++ initialization. Returns `(k, dim)` centroids.
#[allow(clippy::needless_range_loop)] // index-driven assignment is clearer here
fn kmeans_fit(
    data: &ArrayView2<'_, f32>,
    k: usize,
    max_iter: usize,
    seed: u64,
) -> Result<Array2<f32>> {
    let n = data.nrows();
    let dim = data.ncols();
    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);

    let mut centroids = Array2::<f32>::zeros((k, dim));

    // ++ init: first centroid uniformly random.
    let first = rng.random_range(0..n);
    centroids.row_mut(0).assign(&data.row(first));

    let mut min_d2 = vec![f32::INFINITY; n];
    for c in 1..k {
        // Update min distances against the centroid placed last.
        let last = centroids.row(c - 1);
        for i in 0..n {
            let mut d = 0.0_f32;
            for (a, b) in data.row(i).iter().zip(last.iter()) {
                let diff = a - b;
                d += diff * diff;
            }
            if d < min_d2[i] {
                min_d2[i] = d;
            }
        }
        // Sample next centroid index ~ D^2 weighting.
        let total: f64 = min_d2.iter().map(|&d| d as f64).sum();
        if total <= 0.0 {
            // All points coincide with chosen centroids; fall back to random.
            let idx = rng.random_range(0..n);
            centroids.row_mut(c).assign(&data.row(idx));
            continue;
        }
        let target: f64 = rng.random_range(0.0..total);
        let mut acc = 0.0_f64;
        let mut chosen = n - 1;
        for (i, &d) in min_d2.iter().enumerate() {
            acc += d as f64;
            if acc >= target {
                chosen = i;
                break;
            }
        }
        centroids.row_mut(c).assign(&data.row(chosen));
    }

    // Lloyd iterations.
    let mut labels = vec![0_usize; n];
    for _ in 0..max_iter {
        let mut changed = false;
        for i in 0..n {
            let new = nearest_centroid(&data.row(i), &centroids.view());
            if new != labels[i] {
                changed = true;
                labels[i] = new;
            }
        }
        if !changed {
            break;
        }
        let mut new_centroids = Array2::<f32>::zeros((k, dim));
        let mut counts = vec![0_usize; k];
        for i in 0..n {
            let label = labels[i];
            let mut row = new_centroids.row_mut(label);
            for (a, b) in row.iter_mut().zip(data.row(i).iter()) {
                *a += *b;
            }
            counts[label] += 1;
        }
        for c in 0..k {
            if counts[c] > 0 {
                new_centroids
                    .row_mut(c)
                    .mapv_inplace(|x| x / counts[c] as f32);
            } else {
                // Empty cluster: re-seed at a random data point.
                let resample: Vec<usize> = (0..n).collect();
                let idx = *resample.choose(&mut rng).unwrap();
                new_centroids.row_mut(c).assign(&data.row(idx));
            }
        }
        centroids = new_centroids;
    }
    Ok(centroids)
}

fn nearest_centroid(
    point: &ndarray::ArrayView1<'_, f32>,
    centroids: &ndarray::ArrayView2<'_, f32>,
) -> usize {
    let mut best = 0_usize;
    let mut best_d = f32::INFINITY;
    for (c, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
        let mut d = 0.0_f32;
        for (a, b) in point.iter().zip(centroid.iter()) {
            let diff = a - b;
            d += diff * diff;
        }
        if d < best_d {
            best_d = d;
            best = c;
        }
    }
    best
}

fn assignment_dist(data: &ArrayView2<'_, f32>, centroids: &Array2<f32>) -> Array1<f64> {
    let k = centroids.nrows();
    let mut counts = Array1::<f64>::zeros(k);
    for row in data.axis_iter(Axis(0)) {
        let c = nearest_centroid(&row, &centroids.view());
        counts[c] += 1.0;
    }
    let total = counts.sum().max(1.0);
    counts.mapv_inplace(|x| x / total);
    counts
}

fn symmetric_kl(p: &Array1<f64>, q: &Array1<f64>) -> f64 {
    // Laplace smoothing so empty clusters don't blow up the log.
    let eps = 1e-6;
    let mut total = 0.0_f64;
    for (pi, qi) in p.iter().zip(q.iter()) {
        let p1 = pi.max(eps);
        let q1 = qi.max(eps);
        total += p1 * (p1 / q1).ln() + q1 * (q1 / p1).ln();
    }
    0.5 * total
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray_rand::rand_distr::StandardNormal;
    use ndarray_rand::RandomExt;

    #[test]
    fn identical_query_embeddings_zero_drift() {
        let a = Array2::<f32>::random((128, 8), StandardNormal);
        let detector = QueryDriftDetector::default();
        let s = detector.detect(&a.view(), &a.view()).unwrap();
        assert!(s.score < 1e-3, "score={}", s.score);
        assert!(!s.exceeded);
    }

    #[test]
    fn shifted_query_distribution_flagged() {
        let a = Array2::<f32>::random((128, 8), StandardNormal);
        let mut b = Array2::<f32>::random((128, 8), StandardNormal);
        // Push current embeddings far from baseline so they collapse into
        // the centroid nearest the shifted region.
        b.mapv_inplace(|v| v + 5.0);
        let detector = QueryDriftDetector::default();
        let s = detector.detect(&a.view(), &b.view()).unwrap();
        assert!(s.exceeded, "expected drift, score={}", s.score);
    }

    #[test]
    fn rejects_dim_mismatch() {
        let a = Array2::<f32>::zeros((16, 4));
        let b = Array2::<f32>::zeros((16, 8));
        let detector = QueryDriftDetector::default();
        assert!(detector.detect(&a.view(), &b.view()).is_err());
    }

    #[test]
    fn rejects_too_few_baseline_samples() {
        let a = Array2::<f32>::zeros((4, 4));
        let b = Array2::<f32>::zeros((4, 4));
        let detector = QueryDriftDetector::default();
        assert!(detector.detect(&a.view(), &b.view()).is_err());
    }
}