ragdrift_core/detectors/
query.rs1use ndarray::{Array1, Array2, ArrayView2, Axis};
8use rand::seq::IndexedRandom;
9use rand::{Rng, SeedableRng};
10
11use crate::error::{RagDriftError, Result};
12use crate::types::{check_min_samples, check_same_cols, DriftDimension, DriftScore};
13
14#[derive(Debug, Clone, Copy)]
16pub struct QueryDriftConfig {
17 pub threshold: f64,
19 pub n_clusters: usize,
21 pub max_iter: usize,
23 pub seed: u64,
25}
26
27impl Default for QueryDriftConfig {
28 fn default() -> Self {
29 Self {
30 threshold: 0.1,
31 n_clusters: 8,
32 max_iter: 25,
33 seed: 0,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, Default)]
40pub struct QueryDriftDetector {
41 config: QueryDriftConfig,
42}
43
44impl QueryDriftDetector {
45 pub fn new(config: QueryDriftConfig) -> Self {
47 Self { config }
48 }
49
50 pub fn detect(
52 &self,
53 baseline: &ArrayView2<'_, f32>,
54 current: &ArrayView2<'_, f32>,
55 ) -> Result<DriftScore> {
56 check_same_cols(baseline, current)?;
57 check_min_samples(baseline.nrows(), self.config.n_clusters)?;
58 check_min_samples(current.nrows(), 1)?;
59 if self.config.n_clusters < 2 {
60 return Err(RagDriftError::InvalidConfig(
61 "n_clusters must be >= 2".into(),
62 ));
63 }
64
65 let centroids = kmeans_fit(
66 baseline,
67 self.config.n_clusters,
68 self.config.max_iter,
69 self.config.seed,
70 )?;
71
72 let p = assignment_dist(baseline, ¢roids);
73 let q = assignment_dist(current, ¢roids);
74 let kl = symmetric_kl(&p, &q);
75
76 Ok(DriftScore::new(
77 DriftDimension::Query,
78 kl,
79 self.config.threshold,
80 "kmeans-skl",
81 ))
82 }
83}
84
85#[allow(clippy::needless_range_loop)] fn kmeans_fit(
88 data: &ArrayView2<'_, f32>,
89 k: usize,
90 max_iter: usize,
91 seed: u64,
92) -> Result<Array2<f32>> {
93 let n = data.nrows();
94 let dim = data.ncols();
95 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
96
97 let mut centroids = Array2::<f32>::zeros((k, dim));
98
99 let first = rng.random_range(0..n);
101 centroids.row_mut(0).assign(&data.row(first));
102
103 let mut min_d2 = vec![f32::INFINITY; n];
104 for c in 1..k {
105 let last = centroids.row(c - 1);
107 for i in 0..n {
108 let mut d = 0.0_f32;
109 for (a, b) in data.row(i).iter().zip(last.iter()) {
110 let diff = a - b;
111 d += diff * diff;
112 }
113 if d < min_d2[i] {
114 min_d2[i] = d;
115 }
116 }
117 let total: f64 = min_d2.iter().map(|&d| d as f64).sum();
119 if total <= 0.0 {
120 let idx = rng.random_range(0..n);
122 centroids.row_mut(c).assign(&data.row(idx));
123 continue;
124 }
125 let target: f64 = rng.random_range(0.0..total);
126 let mut acc = 0.0_f64;
127 let mut chosen = n - 1;
128 for (i, &d) in min_d2.iter().enumerate() {
129 acc += d as f64;
130 if acc >= target {
131 chosen = i;
132 break;
133 }
134 }
135 centroids.row_mut(c).assign(&data.row(chosen));
136 }
137
138 let mut labels = vec![0_usize; n];
140 for _ in 0..max_iter {
141 let mut changed = false;
142 for i in 0..n {
143 let new = nearest_centroid(&data.row(i), ¢roids.view());
144 if new != labels[i] {
145 changed = true;
146 labels[i] = new;
147 }
148 }
149 if !changed {
150 break;
151 }
152 let mut new_centroids = Array2::<f32>::zeros((k, dim));
153 let mut counts = vec![0_usize; k];
154 for i in 0..n {
155 let label = labels[i];
156 let mut row = new_centroids.row_mut(label);
157 for (a, b) in row.iter_mut().zip(data.row(i).iter()) {
158 *a += *b;
159 }
160 counts[label] += 1;
161 }
162 for c in 0..k {
163 if counts[c] > 0 {
164 new_centroids
165 .row_mut(c)
166 .mapv_inplace(|x| x / counts[c] as f32);
167 } else {
168 let resample: Vec<usize> = (0..n).collect();
170 let idx = *resample.choose(&mut rng).unwrap();
171 new_centroids.row_mut(c).assign(&data.row(idx));
172 }
173 }
174 centroids = new_centroids;
175 }
176 Ok(centroids)
177}
178
179fn nearest_centroid(
180 point: &ndarray::ArrayView1<'_, f32>,
181 centroids: &ndarray::ArrayView2<'_, f32>,
182) -> usize {
183 let mut best = 0_usize;
184 let mut best_d = f32::INFINITY;
185 for (c, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
186 let mut d = 0.0_f32;
187 for (a, b) in point.iter().zip(centroid.iter()) {
188 let diff = a - b;
189 d += diff * diff;
190 }
191 if d < best_d {
192 best_d = d;
193 best = c;
194 }
195 }
196 best
197}
198
199fn assignment_dist(data: &ArrayView2<'_, f32>, centroids: &Array2<f32>) -> Array1<f64> {
200 let k = centroids.nrows();
201 let mut counts = Array1::<f64>::zeros(k);
202 for row in data.axis_iter(Axis(0)) {
203 let c = nearest_centroid(&row, ¢roids.view());
204 counts[c] += 1.0;
205 }
206 let total = counts.sum().max(1.0);
207 counts.mapv_inplace(|x| x / total);
208 counts
209}
210
211fn symmetric_kl(p: &Array1<f64>, q: &Array1<f64>) -> f64 {
212 let eps = 1e-6;
214 let mut total = 0.0_f64;
215 for (pi, qi) in p.iter().zip(q.iter()) {
216 let p1 = pi.max(eps);
217 let q1 = qi.max(eps);
218 total += p1 * (p1 / q1).ln() + q1 * (q1 / p1).ln();
219 }
220 0.5 * total
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use ndarray_rand::rand_distr::StandardNormal;
227 use ndarray_rand::RandomExt;
228
229 #[test]
230 fn identical_query_embeddings_zero_drift() {
231 let a = Array2::<f32>::random((128, 8), StandardNormal);
232 let detector = QueryDriftDetector::default();
233 let s = detector.detect(&a.view(), &a.view()).unwrap();
234 assert!(s.score < 1e-3, "score={}", s.score);
235 assert!(!s.exceeded);
236 }
237
238 #[test]
239 fn shifted_query_distribution_flagged() {
240 let a = Array2::<f32>::random((128, 8), StandardNormal);
241 let mut b = Array2::<f32>::random((128, 8), StandardNormal);
242 b.mapv_inplace(|v| v + 5.0);
245 let detector = QueryDriftDetector::default();
246 let s = detector.detect(&a.view(), &b.view()).unwrap();
247 assert!(s.exceeded, "expected drift, score={}", s.score);
248 }
249
250 #[test]
251 fn rejects_dim_mismatch() {
252 let a = Array2::<f32>::zeros((16, 4));
253 let b = Array2::<f32>::zeros((16, 8));
254 let detector = QueryDriftDetector::default();
255 assert!(detector.detect(&a.view(), &b.view()).is_err());
256 }
257
258 #[test]
259 fn rejects_too_few_baseline_samples() {
260 let a = Array2::<f32>::zeros((4, 4));
261 let b = Array2::<f32>::zeros((4, 4));
262 let detector = QueryDriftDetector::default();
263 assert!(detector.detect(&a.view(), &b.view()).is_err());
264 }
265}