ragdrift_core/detectors/
query.rs1use crate::error::RagDriftError;
4use crate::stats::kmeans::{assign, kmeans};
5use crate::types::{DriftDimension, DriftScore};
6use crate::Result;
7use ndarray::ArrayView2;
8
9pub struct QueryDriftDetector {
15 threshold: f64,
16 k: usize,
17 max_iters: usize,
18 tol: f32,
19 seed: u64,
20 smoothing: f64,
21}
22
23impl QueryDriftDetector {
24 pub fn new(threshold: f64, k: usize) -> Self {
27 Self {
28 threshold,
29 k,
30 max_iters: 50,
31 tol: 1e-4,
32 seed: 0,
33 smoothing: 1e-6,
34 }
35 }
36
37 pub fn with_seed(mut self, seed: u64) -> Self {
39 self.seed = seed;
40 self
41 }
42
43 pub fn detect(
45 &self,
46 baseline: ArrayView2<f32>,
47 current: ArrayView2<f32>,
48 ) -> Result<DriftScore> {
49 if baseline.ncols() != current.ncols() {
50 return Err(RagDriftError::DimensionMismatch {
51 expected: baseline.ncols(),
52 actual: current.ncols(),
53 context: "QueryDriftDetector::detect",
54 });
55 }
56 let res = kmeans(baseline, self.k, self.max_iters, self.tol, self.seed)?;
57 let baseline_freqs = freqs(&res.labels, self.k);
58
59 let curr_labels = assign(current, res.centroids.view());
60 let current_freqs = freqs(&curr_labels, self.k);
61
62 let kl = symmetric_kl(&baseline_freqs, ¤t_freqs, self.smoothing);
63 Ok(DriftScore::new(
64 DriftDimension::Query,
65 kl,
66 self.threshold,
67 "kmeans+sym_kl",
68 ))
69 }
70}
71
72fn freqs(labels: &[usize], k: usize) -> Vec<f64> {
73 let mut c = vec![0_u64; k];
74 for &l in labels {
75 if l < k {
76 c[l] += 1;
77 }
78 }
79 let n = labels.len() as f64;
80 if n == 0.0 {
81 return vec![0.0; k];
82 }
83 c.into_iter().map(|x| x as f64 / n).collect()
84}
85
86fn symmetric_kl(p: &[f64], q: &[f64], eps: f64) -> f64 {
87 debug_assert_eq!(p.len(), q.len());
88 let mut kl_pq = 0.0_f64;
89 let mut kl_qp = 0.0_f64;
90 for (pi, qi) in p.iter().zip(q.iter()) {
91 let ps = pi + eps;
92 let qs = qi + eps;
93 kl_pq += ps * (ps / qs).ln();
94 kl_qp += qs * (qs / ps).ln();
95 }
96 0.5 * (kl_pq + kl_qp)
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use ndarray::Array2;
103
104 #[test]
105 fn same_query_mix_scores_low() {
106 let mut points = Array2::<f32>::zeros((40, 2));
108 for i in 0..20 {
109 points[[i, 0]] = (i as f32) * 0.01;
110 points[[i, 1]] = (i as f32) * 0.01;
111 points[[i + 20, 0]] = 10.0 + (i as f32) * 0.01;
112 points[[i + 20, 1]] = 10.0 + (i as f32) * 0.01;
113 }
114 let det = QueryDriftDetector::new(0.1, 2).with_seed(7);
115 let s = det.detect(points.view(), points.view()).unwrap();
116 assert!(s.score < 1e-3, "score was {}", s.score);
117 }
118
119 #[test]
120 fn shifted_query_mix_flagged() {
121 let mut baseline = Array2::<f32>::zeros((40, 2));
123 for i in 0..20 {
124 baseline[[i, 0]] = (i as f32) * 0.01;
125 baseline[[i + 20, 0]] = 10.0 + (i as f32) * 0.01;
126 baseline[[i + 20, 1]] = 10.0;
127 }
128 let mut current = Array2::<f32>::zeros((40, 2));
129 for i in 0..40 {
130 current[[i, 0]] = (i as f32) * 0.01; }
132 let det = QueryDriftDetector::new(0.1, 2).with_seed(7);
133 let s = det.detect(baseline.view(), current.view()).unwrap();
134 assert!(s.exceeded, "score was {}", s.score);
135 }
136}