1use std::sync::{Mutex, OnceLock};
19
20use super::hnsw::{brute_force_topk, AnnIndex};
21
22pub const ANN_MIN_VECTORS: usize = 50_000;
25
26struct Cached {
27 fingerprint: u64,
28 index: AnnIndex,
29}
30
31fn cache() -> &'static Mutex<Option<Cached>> {
32 static CACHE: OnceLock<Mutex<Option<Cached>>> = OnceLock::new();
33 CACHE.get_or_init(|| Mutex::new(None))
34}
35
36#[must_use]
42pub fn topk(embeddings: &[Vec<f32>], query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
43 topk_gated(embeddings, query, top_k, ANN_MIN_VECTORS)
44}
45
46fn topk_gated(
49 embeddings: &[Vec<f32>],
50 query: &[f32],
51 top_k: usize,
52 min_vectors: usize,
53) -> Vec<(usize, f32)> {
54 if embeddings.len() < min_vectors {
55 return brute_force_topk(embeddings, query, top_k);
56 }
57
58 let fp = fingerprint(embeddings);
59 let Ok(mut guard) = cache().lock() else {
60 return brute_force_topk(embeddings, query, top_k);
61 };
62
63 let needs_build = match guard.as_ref() {
64 Some(c) => c.fingerprint != fp,
65 None => true,
66 };
67 if needs_build {
68 *guard = Some(Cached {
69 fingerprint: fp,
70 index: AnnIndex::build(embeddings.to_vec()),
71 });
72 }
73
74 match guard.as_ref() {
75 Some(c) => c.index.search(query, top_k),
76 None => brute_force_topk(embeddings, query, top_k),
77 }
78}
79
80fn fingerprint(embeddings: &[Vec<f32>]) -> u64 {
84 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
85 macro_rules! mix {
86 ($x:expr) => {{
87 h ^= $x;
88 h = h.wrapping_mul(0x0000_0100_0000_01b3);
89 }};
90 }
91 mix!(embeddings.len() as u64);
92 for (i, v) in embeddings.iter().enumerate() {
93 mix!(v.len() as u64);
94 mix!(i as u64);
95 if let Some(&f) = v.first() {
96 mix!(u64::from(f.to_bits()));
97 }
98 if let Some(&f) = v.get(v.len() / 2) {
99 mix!(u64::from(f.to_bits()));
100 }
101 if let Some(&f) = v.last() {
102 mix!(u64::from(f.to_bits()));
103 }
104 }
105 h
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::collections::HashSet;
112
113 const TEST_GATE: usize = 1000;
116
117 static TEST_LOCK: Mutex<()> = Mutex::new(());
122
123 fn serial() -> std::sync::MutexGuard<'static, ()> {
124 TEST_LOCK
125 .lock()
126 .unwrap_or_else(std::sync::PoisonError::into_inner)
127 }
128
129 fn cached_fingerprint() -> Option<u64> {
132 cache()
133 .lock()
134 .ok()
135 .and_then(|g| g.as_ref().map(|c| c.fingerprint))
136 }
137
138 fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
139 let mut v = Vec::with_capacity(dim);
140 let mut s = seed;
141 for _ in 0..dim {
142 s = s.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
143 v.push((s as f32 / u64::MAX as f32) * 2.0 - 1.0);
144 }
145 v
146 }
147
148 fn jitter(base: &[f32], seed: u64, scale: f32) -> Vec<f32> {
152 base.iter()
153 .enumerate()
154 .map(|(i, &b)| {
155 let s = seed
156 .wrapping_add(i as u64)
157 .wrapping_mul(6_364_136_223_846_793_005)
158 .wrapping_add(1);
159 b + ((s as f32 / u64::MAX as f32) * 2.0 - 1.0) * scale
160 })
161 .collect()
162 }
163
164 fn clustered(
165 n_clusters: usize,
166 per_cluster: usize,
167 dim: usize,
168 ) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
169 let centers: Vec<Vec<f32>> = (0..n_clusters)
170 .map(|c| random_vec(dim, (c as u64 + 1) * 1_000))
171 .collect();
172 let mut vectors = Vec::with_capacity(n_clusters * per_cluster);
173 for (c, center) in centers.iter().enumerate() {
174 for j in 0..per_cluster {
175 vectors.push(jitter(center, (c * per_cluster + j) as u64 + 7, 0.02));
176 }
177 }
178 (vectors, centers)
179 }
180
181 #[test]
182 fn small_corpus_matches_brute_force_exactly() {
183 let vectors: Vec<Vec<f32>> = (0..200).map(|i| random_vec(32, i)).collect();
184 let query = random_vec(32, 9_999);
185
186 let via_cache = topk(&vectors, &query, 8);
188 let exact = brute_force_topk(&vectors, &query, 8);
189
190 assert_eq!(via_cache.len(), exact.len());
191 for (a, b) in via_cache.iter().zip(exact.iter()) {
192 assert_eq!(a.0, b.0, "below threshold must be exact brute force");
193 }
194 }
195
196 #[test]
197 fn hnsw_path_recall_matches_brute_force_on_clusters() {
198 let _serial = serial();
199 let (vectors, centers) = clustered(24, 60, 32); let query = centers[5].clone();
201 let k = 20;
202
203 let ann = topk_gated(&vectors, &query, k, TEST_GATE); let exact = brute_force_topk(&vectors, &query, k);
205 assert_eq!(ann.len(), k);
206
207 let exact_set: HashSet<usize> = exact.iter().map(|(i, _)| *i).collect();
208 let overlap = ann.iter().filter(|(i, _)| exact_set.contains(i)).count();
209 assert!(
210 overlap * 100 >= k * 50,
211 "HNSW recall@{k} too low: {overlap}/{k}"
212 );
213 }
214
215 #[test]
216 fn hnsw_path_results_are_descending() {
217 let _serial = serial();
218 let (vectors, centers) = clustered(20, 60, 24); let results = topk_gated(&vectors, ¢ers[3], 10, TEST_GATE);
220 for w in results.windows(2) {
221 assert!(
222 w[0].1 >= w[1].1,
223 "results must be sorted by descending similarity"
224 );
225 }
226 }
227
228 #[test]
229 fn rebuilds_when_corpus_changes() {
230 let _serial = serial();
231 let (a, ca) = clustered(20, 55, 32); let (b, cb) = clustered(18, 60, 32); let _ = topk_gated(&a, &ca[7], 5, TEST_GATE);
239 assert_eq!(
240 cached_fingerprint(),
241 Some(fingerprint(&a)),
242 "first query caches corpus A's index"
243 );
244
245 let _ = topk_gated(&b, &cb[4], 5, TEST_GATE);
246 assert_eq!(
247 cached_fingerprint(),
248 Some(fingerprint(&b)),
249 "a different corpus must force a rebuild to B"
250 );
251
252 let _ = topk_gated(&a, &ca[7], 5, TEST_GATE);
253 assert_eq!(
254 cached_fingerprint(),
255 Some(fingerprint(&a)),
256 "re-querying A must rebuild A — never serve stale B"
257 );
258 }
259
260 #[test]
261 fn fingerprint_differs_on_content_change() {
262 let a: Vec<Vec<f32>> = (0..10).map(|i| random_vec(8, i)).collect();
263 let mut b = a.clone();
264 b[3][0] += 0.5;
265 assert_ne!(fingerprint(&a), fingerprint(&b));
266 }
267}