Skip to main content

ragdrift_core/detectors/
query.rs

1//! Query-pattern drift detector.
2//!
3//! Fits k-means on baseline query embeddings, assigns both baseline and
4//! current to the resulting clusters, and reports the symmetric KL
5//! divergence between the two assignment distributions.
6
7use ndarray::{Array1, Array2, ArrayView2, Axis};
8use rand::seq::IndexedRandom;
9use rand::{Rng, SeedableRng};
10
11use crate::error::{RagDriftError, Result};
12use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};
13
14/// Configuration for [`QueryDriftDetector`].
15#[derive(Debug, Clone, Copy)]
16pub struct QueryDriftConfig {
17    /// Threshold on the symmetric KL divergence.
18    pub threshold: f64,
19    /// Number of clusters. Default: 8.
20    pub n_clusters: usize,
21    /// Lloyd iterations. Default: 25.
22    pub max_iter: usize,
23    /// Seed for k-means++ initialization.
24    pub seed: u64,
25}
26
27impl Default for QueryDriftConfig {
28    fn default() -> Self {
29        Self {
30            threshold: 0.1,
31            n_clusters: 8,
32            max_iter: 25,
33            seed: 0,
34        }
35    }
36}
37
38/// Detects shift in query distribution via k-means cluster assignment KL.
39#[derive(Debug, Clone, Copy, Default)]
40pub struct QueryDriftDetector {
41    config: QueryDriftConfig,
42}
43
44impl QueryDriftDetector {
45    /// Construct a detector from a custom config.
46    pub fn new(config: QueryDriftConfig) -> Self {
47        Self { config }
48    }
49
50    /// Detect drift between two `(n_samples, dim)` query embedding matrices.
51    pub fn detect(
52        &self,
53        baseline: &ArrayView2<'_, f32>,
54        current: &ArrayView2<'_, f32>,
55    ) -> Result<DriftScore> {
56        check_same_cols(baseline, current)?;
57        check_min_samples(baseline.nrows(), self.config.n_clusters)?;
58        check_min_samples(current.nrows(), 1)?;
59        if self.config.n_clusters < 2 {
60            return Err(RagDriftError::InvalidConfig(
61                "n_clusters must be >= 2".into(),
62            ));
63        }
64
65        let centroids = kmeans_fit(
66            baseline,
67            self.config.n_clusters,
68            self.config.max_iter,
69            self.config.seed,
70        )?;
71
72        let p = assignment_dist(baseline, &centroids);
73        let q = assignment_dist(current, &centroids);
74        let kl = symmetric_kl(&p, &q);
75
76        Ok(DriftScore::new(
77            DriftDimension::Query,
78            kl,
79            self.config.threshold,
80            "kmeans-skl",
81        ))
82    }
83}
84
85/// Fit k-means with k-means++ initialization. Returns `(k, dim)` centroids.
86#[allow(clippy::needless_range_loop)] // index-driven assignment is clearer here
87fn kmeans_fit(
88    data: &ArrayView2<'_, f32>,
89    k: usize,
90    max_iter: usize,
91    seed: u64,
92) -> Result<Array2<f32>> {
93    let n = data.nrows();
94    let dim = data.ncols();
95    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
96
97    let mut centroids = Array2::<f32>::zeros((k, dim));
98
99    // ++ init: first centroid uniformly random.
100    let first = rng.random_range(0..n);
101    centroids.row_mut(0).assign(&data.row(first));
102
103    let mut min_d2 = vec![f32::INFINITY; n];
104    for c in 1..k {
105        // Update min distances against the centroid placed last.
106        let last = centroids.row(c - 1);
107        for i in 0..n {
108            let mut d = 0.0_f32;
109            for (a, b) in data.row(i).iter().zip(last.iter()) {
110                let diff = a - b;
111                d += diff * diff;
112            }
113            if d < min_d2[i] {
114                min_d2[i] = d;
115            }
116        }
117        // Sample next centroid index ~ D^2 weighting.
118        let total: f64 = min_d2.iter().map(|&d| d as f64).sum();
119        if total <= 0.0 {
120            // All points coincide with chosen centroids; fall back to random.
121            let idx = rng.random_range(0..n);
122            centroids.row_mut(c).assign(&data.row(idx));
123            continue;
124        }
125        let target: f64 = rng.random_range(0.0..total);
126        let mut acc = 0.0_f64;
127        let mut chosen = n - 1;
128        for (i, &d) in min_d2.iter().enumerate() {
129            acc += d as f64;
130            if acc >= target {
131                chosen = i;
132                break;
133            }
134        }
135        centroids.row_mut(c).assign(&data.row(chosen));
136    }
137
138    // Lloyd iterations.
139    let mut labels = vec![0_usize; n];
140    for _ in 0..max_iter {
141        let mut changed = false;
142        for i in 0..n {
143            let new = nearest_centroid(&data.row(i), &centroids.view());
144            if new != labels[i] {
145                changed = true;
146                labels[i] = new;
147            }
148        }
149        if !changed {
150            break;
151        }
152        let mut new_centroids = Array2::<f32>::zeros((k, dim));
153        let mut counts = vec![0_usize; k];
154        for i in 0..n {
155            let label = labels[i];
156            let mut row = new_centroids.row_mut(label);
157            for (a, b) in row.iter_mut().zip(data.row(i).iter()) {
158                *a += *b;
159            }
160            counts[label] += 1;
161        }
162        for c in 0..k {
163            if counts[c] > 0 {
164                new_centroids
165                    .row_mut(c)
166                    .mapv_inplace(|x| x / counts[c] as f32);
167            } else {
168                // Empty cluster: re-seed at a random data point.
169                let resample: Vec<usize> = (0..n).collect();
170                let idx = *resample.choose(&mut rng).unwrap();
171                new_centroids.row_mut(c).assign(&data.row(idx));
172            }
173        }
174        centroids = new_centroids;
175    }
176    Ok(centroids)
177}
178
179fn nearest_centroid(
180    point: &ndarray::ArrayView1<'_, f32>,
181    centroids: &ndarray::ArrayView2<'_, f32>,
182) -> usize {
183    let mut best = 0_usize;
184    let mut best_d = f32::INFINITY;
185    for (c, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
186        let mut d = 0.0_f32;
187        for (a, b) in point.iter().zip(centroid.iter()) {
188            let diff = a - b;
189            d += diff * diff;
190        }
191        if d < best_d {
192            best_d = d;
193            best = c;
194        }
195    }
196    best
197}
198
199fn assignment_dist(data: &ArrayView2<'_, f32>, centroids: &Array2<f32>) -> Array1<f64> {
200    let k = centroids.nrows();
201    let mut counts = Array1::<f64>::zeros(k);
202    for row in data.axis_iter(Axis(0)) {
203        let c = nearest_centroid(&row, &centroids.view());
204        counts[c] += 1.0;
205    }
206    let total = counts.sum().max(1.0);
207    counts.mapv_inplace(|x| x / total);
208    counts
209}
210
211fn symmetric_kl(p: &Array1<f64>, q: &Array1<f64>) -> f64 {
212    // Laplace smoothing so empty clusters don't blow up the log.
213    let eps = 1e-6;
214    let mut total = 0.0_f64;
215    for (pi, qi) in p.iter().zip(q.iter()) {
216        let p1 = pi.max(eps);
217        let q1 = qi.max(eps);
218        total += p1 * (p1 / q1).ln() + q1 * (q1 / p1).ln();
219    }
220    0.5 * total
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use ndarray_rand::rand_distr::StandardNormal;
227    use ndarray_rand::RandomExt;
228
229    #[test]
230    fn identical_query_embeddings_zero_drift() {
231        let a = Array2::<f32>::random((128, 8), StandardNormal);
232        let detector = QueryDriftDetector::default();
233        let s = detector.detect(&a.view(), &a.view()).unwrap();
234        assert!(s.score < 1e-3, "score={}", s.score);
235        assert!(!s.exceeded);
236    }
237
238    #[test]
239    fn shifted_query_distribution_flagged() {
240        let a = Array2::<f32>::random((128, 8), StandardNormal);
241        let mut b = Array2::<f32>::random((128, 8), StandardNormal);
242        // Push current embeddings far from baseline so they collapse into
243        // the centroid nearest the shifted region.
244        b.mapv_inplace(|v| v + 5.0);
245        let detector = QueryDriftDetector::default();
246        let s = detector.detect(&a.view(), &b.view()).unwrap();
247        assert!(s.exceeded, "expected drift, score={}", s.score);
248    }
249
250    #[test]
251    fn rejects_dim_mismatch() {
252        let a = Array2::<f32>::zeros((16, 4));
253        let b = Array2::<f32>::zeros((16, 8));
254        let detector = QueryDriftDetector::default();
255        assert!(detector.detect(&a.view(), &b.view()).is_err());
256    }
257
258    #[test]
259    fn rejects_too_few_baseline_samples() {
260        let a = Array2::<f32>::zeros((4, 4));
261        let b = Array2::<f32>::zeros((4, 4));
262        let detector = QueryDriftDetector::default();
263        assert!(detector.detect(&a.view(), &b.view()).is_err());
264    }
265}