Skip to main content

ragdrift_core/detectors/
query.rs

1//! Query-pattern drift: cluster baseline queries, measure assignment shift via KL.
2
3use crate::error::RagDriftError;
4use crate::stats::kmeans::{assign, kmeans};
5use crate::types::{DriftDimension, DriftScore};
6use crate::Result;
7use ndarray::ArrayView2;
8
9/// Detects drift in the *intent* mix of incoming queries.
10///
11/// Clusters the baseline embeddings with k-means, then re-assigns the current
12/// embeddings to those baseline centroids. The score is the symmetric
13/// Kullback–Leibler divergence between the two assignment distributions.
14pub struct QueryDriftDetector {
15    threshold: f64,
16    k: usize,
17    max_iters: usize,
18    tol: f32,
19    seed: u64,
20    smoothing: f64,
21}
22
23impl QueryDriftDetector {
24    /// Create a detector with `k` clusters. Defaults: 50 Lloyd iters, tol 1e-4,
25    /// 1e-6 add-epsilon for KL smoothing.
26    pub fn new(threshold: f64, k: usize) -> Self {
27        Self {
28            threshold,
29            k,
30            max_iters: 50,
31            tol: 1e-4,
32            seed: 0,
33            smoothing: 1e-6,
34        }
35    }
36
37    /// Override the RNG seed.
38    pub fn with_seed(mut self, seed: u64) -> Self {
39        self.seed = seed;
40        self
41    }
42
43    /// Run the detector against baseline and current query embeddings.
44    pub fn detect(
45        &self,
46        baseline: ArrayView2<f32>,
47        current: ArrayView2<f32>,
48    ) -> Result<DriftScore> {
49        if baseline.ncols() != current.ncols() {
50            return Err(RagDriftError::DimensionMismatch {
51                expected: baseline.ncols(),
52                actual: current.ncols(),
53                context: "QueryDriftDetector::detect",
54            });
55        }
56        let res = kmeans(baseline, self.k, self.max_iters, self.tol, self.seed)?;
57        let baseline_freqs = freqs(&res.labels, self.k);
58
59        let curr_labels = assign(current, res.centroids.view());
60        let current_freqs = freqs(&curr_labels, self.k);
61
62        let kl = symmetric_kl(&baseline_freqs, &current_freqs, self.smoothing);
63        Ok(DriftScore::new(
64            DriftDimension::Query,
65            kl,
66            self.threshold,
67            "kmeans+sym_kl",
68        ))
69    }
70}
71
72fn freqs(labels: &[usize], k: usize) -> Vec<f64> {
73    let mut c = vec![0_u64; k];
74    for &l in labels {
75        if l < k {
76            c[l] += 1;
77        }
78    }
79    let n = labels.len() as f64;
80    if n == 0.0 {
81        return vec![0.0; k];
82    }
83    c.into_iter().map(|x| x as f64 / n).collect()
84}
85
86fn symmetric_kl(p: &[f64], q: &[f64], eps: f64) -> f64 {
87    debug_assert_eq!(p.len(), q.len());
88    let mut kl_pq = 0.0_f64;
89    let mut kl_qp = 0.0_f64;
90    for (pi, qi) in p.iter().zip(q.iter()) {
91        let ps = pi + eps;
92        let qs = qi + eps;
93        kl_pq += ps * (ps / qs).ln();
94        kl_qp += qs * (qs / ps).ln();
95    }
96    0.5 * (kl_pq + kl_qp)
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use ndarray::Array2;
103
104    #[test]
105    fn same_query_mix_scores_low() {
106        // two distinct clusters at (0,0) and (10,10), 20 points each
107        let mut points = Array2::<f32>::zeros((40, 2));
108        for i in 0..20 {
109            points[[i, 0]] = (i as f32) * 0.01;
110            points[[i, 1]] = (i as f32) * 0.01;
111            points[[i + 20, 0]] = 10.0 + (i as f32) * 0.01;
112            points[[i + 20, 1]] = 10.0 + (i as f32) * 0.01;
113        }
114        let det = QueryDriftDetector::new(0.1, 2).with_seed(7);
115        let s = det.detect(points.view(), points.view()).unwrap();
116        assert!(s.score < 1e-3, "score was {}", s.score);
117    }
118
119    #[test]
120    fn shifted_query_mix_flagged() {
121        // baseline is 50/50 between two clusters; current is 100% in one cluster
122        let mut baseline = Array2::<f32>::zeros((40, 2));
123        for i in 0..20 {
124            baseline[[i, 0]] = (i as f32) * 0.01;
125            baseline[[i + 20, 0]] = 10.0 + (i as f32) * 0.01;
126            baseline[[i + 20, 1]] = 10.0;
127        }
128        let mut current = Array2::<f32>::zeros((40, 2));
129        for i in 0..40 {
130            current[[i, 0]] = (i as f32) * 0.01; // all in cluster A
131        }
132        let det = QueryDriftDetector::new(0.1, 2).with_seed(7);
133        let s = det.detect(baseline.view(), current.view()).unwrap();
134        assert!(s.exceeded, "score was {}", s.score);
135    }
136}