Skip to main content

lean_ctx/core/
ann_cache.rs

1//! Process-wide cache for the HNSW [`AnnIndex`] used by dense semantic search.
2//!
3//! Building an HNSW graph is O(n log n) with a wide construction beam, so doing
4//! it per query would be slower than brute force. This cache keeps one built
5//! index keyed by a content fingerprint of the embedding set: repeated queries
6//! over the same corpus reuse the graph and get sub-linear search, while a
7//! changed corpus (different fingerprint) transparently triggers a rebuild.
8//!
9//! It is threshold-gated — corpora below [`ANN_MIN_VECTORS`] skip the cache and
10//! use exact SIMD brute-force top-k, which is both faster (no graph overhead)
11//! and *exact*. The threshold is deliberately high: at lean-ctx's typical scale
12//! (a few thousand chunks) exact brute force over int8/SIMD dot products is only
13//! ~1-2 ms, so HNSW's approximate recall is not worth trading. HNSW activates
14//! only for genuinely large corpora where exact scan would dominate latency.
15//! On any lock failure it falls back to brute force, so correctness never
16//! depends on the cache being available.
17
18use std::sync::{Mutex, OnceLock};
19
20use super::hnsw::{brute_force_topk, AnnIndex};
21
22/// Minimum corpus size before an HNSW graph is worth building and caching.
23/// Below this, exact SIMD brute force is faster *and* exact (no recall loss).
24pub const ANN_MIN_VECTORS: usize = 50_000;
25
26struct Cached {
27    fingerprint: u64,
28    index: AnnIndex,
29}
30
31fn cache() -> &'static Mutex<Option<Cached>> {
32    static CACHE: OnceLock<Mutex<Option<Cached>>> = OnceLock::new();
33    CACHE.get_or_init(|| Mutex::new(None))
34}
35
36/// Returns the top-k `(index, similarity)` pairs for `query` over `embeddings`,
37/// sorted by descending similarity.
38///
39/// Small corpora use exact brute force. Large corpora build (once) and reuse a
40/// cached HNSW index. Falls back to brute force on lock failure.
41#[must_use]
42pub fn topk(embeddings: &[Vec<f32>], query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
43    topk_gated(embeddings, query, top_k, ANN_MIN_VECTORS)
44}
45
46/// Core implementation with an injectable gate so tests can exercise the HNSW
47/// path without materializing a 50k-vector corpus.
48fn topk_gated(
49    embeddings: &[Vec<f32>],
50    query: &[f32],
51    top_k: usize,
52    min_vectors: usize,
53) -> Vec<(usize, f32)> {
54    if embeddings.len() < min_vectors {
55        return brute_force_topk(embeddings, query, top_k);
56    }
57
58    let fp = fingerprint(embeddings);
59    let Ok(mut guard) = cache().lock() else {
60        return brute_force_topk(embeddings, query, top_k);
61    };
62
63    let needs_build = match guard.as_ref() {
64        Some(c) => c.fingerprint != fp,
65        None => true,
66    };
67    if needs_build {
68        *guard = Some(Cached {
69            fingerprint: fp,
70            index: AnnIndex::build(embeddings.to_vec()),
71        });
72    }
73
74    match guard.as_ref() {
75        Some(c) => c.index.search(query, top_k),
76        None => brute_force_topk(embeddings, query, top_k),
77    }
78}
79
80/// Cheap, content-sensitive fingerprint (FNV-1a over lengths + sampled values).
81/// Strong enough that a changed corpus reliably triggers a rebuild; a collision
82/// would only mildly degrade already-approximate recall, never break results.
83fn fingerprint(embeddings: &[Vec<f32>]) -> u64 {
84    let mut h: u64 = 0xcbf2_9ce4_8422_2325;
85    macro_rules! mix {
86        ($x:expr) => {{
87            h ^= $x;
88            h = h.wrapping_mul(0x0000_0100_0000_01b3);
89        }};
90    }
91    mix!(embeddings.len() as u64);
92    for (i, v) in embeddings.iter().enumerate() {
93        mix!(v.len() as u64);
94        mix!(i as u64);
95        if let Some(&f) = v.first() {
96            mix!(u64::from(f.to_bits()));
97        }
98        if let Some(&f) = v.get(v.len() / 2) {
99            mix!(u64::from(f.to_bits()));
100        }
101        if let Some(&f) = v.last() {
102            mix!(u64::from(f.to_bits()));
103        }
104    }
105    h
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::collections::HashSet;
112
113    // Test gate that forces the HNSW path on modest corpora (AnnIndex itself
114    // switches to HNSW at 1000 vectors, so 1000 here exercises the real graph).
115    const TEST_GATE: usize = 1000;
116
117    // The cache is a single process-wide slot, so tests that drive the HNSW path
118    // must not interleave or they would clobber each other's cached index. This
119    // lock serializes them; poison is recovered since a panic in one test must
120    // not cascade into the others.
121    static TEST_LOCK: Mutex<()> = Mutex::new(());
122
123    fn serial() -> std::sync::MutexGuard<'static, ()> {
124        TEST_LOCK
125            .lock()
126            .unwrap_or_else(std::sync::PoisonError::into_inner)
127    }
128
129    /// Reads the fingerprint of the currently cached index (test-only
130    /// introspection; `tests` is a child module so it may touch private state).
131    fn cached_fingerprint() -> Option<u64> {
132        cache()
133            .lock()
134            .ok()
135            .and_then(|g| g.as_ref().map(|c| c.fingerprint))
136    }
137
138    fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
139        let mut v = Vec::with_capacity(dim);
140        let mut s = seed;
141        for _ in 0..dim {
142            s = s.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
143            v.push((s as f32 / u64::MAX as f32) * 2.0 - 1.0);
144        }
145        v
146    }
147
148    /// A vector near `base` with small per-dimension noise — produces dense,
149    /// well-connected clusters where HNSW recall is high and stable (unlike a
150    /// single needle in random noise, which approximate search can miss).
151    fn jitter(base: &[f32], seed: u64, scale: f32) -> Vec<f32> {
152        base.iter()
153            .enumerate()
154            .map(|(i, &b)| {
155                let s = seed
156                    .wrapping_add(i as u64)
157                    .wrapping_mul(6_364_136_223_846_793_005)
158                    .wrapping_add(1);
159                b + ((s as f32 / u64::MAX as f32) * 2.0 - 1.0) * scale
160            })
161            .collect()
162    }
163
164    fn clustered(
165        n_clusters: usize,
166        per_cluster: usize,
167        dim: usize,
168    ) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {
169        let centers: Vec<Vec<f32>> = (0..n_clusters)
170            .map(|c| random_vec(dim, (c as u64 + 1) * 1_000))
171            .collect();
172        let mut vectors = Vec::with_capacity(n_clusters * per_cluster);
173        for (c, center) in centers.iter().enumerate() {
174            for j in 0..per_cluster {
175                vectors.push(jitter(center, (c * per_cluster + j) as u64 + 7, 0.02));
176            }
177        }
178        (vectors, centers)
179    }
180
181    #[test]
182    fn small_corpus_matches_brute_force_exactly() {
183        let vectors: Vec<Vec<f32>> = (0..200).map(|i| random_vec(32, i)).collect();
184        let query = random_vec(32, 9_999);
185
186        // Production gate (50k) → small corpus is exact brute force.
187        let via_cache = topk(&vectors, &query, 8);
188        let exact = brute_force_topk(&vectors, &query, 8);
189
190        assert_eq!(via_cache.len(), exact.len());
191        for (a, b) in via_cache.iter().zip(exact.iter()) {
192            assert_eq!(a.0, b.0, "below threshold must be exact brute force");
193        }
194    }
195
196    #[test]
197    fn hnsw_path_recall_matches_brute_force_on_clusters() {
198        let _serial = serial();
199        let (vectors, centers) = clustered(24, 60, 32); // 1440 vectors
200        let query = centers[5].clone();
201        let k = 20;
202
203        let ann = topk_gated(&vectors, &query, k, TEST_GATE); // forces HNSW
204        let exact = brute_force_topk(&vectors, &query, k);
205        assert_eq!(ann.len(), k);
206
207        let exact_set: HashSet<usize> = exact.iter().map(|(i, _)| *i).collect();
208        let overlap = ann.iter().filter(|(i, _)| exact_set.contains(i)).count();
209        assert!(
210            overlap * 100 >= k * 50,
211            "HNSW recall@{k} too low: {overlap}/{k}"
212        );
213    }
214
215    #[test]
216    fn hnsw_path_results_are_descending() {
217        let _serial = serial();
218        let (vectors, centers) = clustered(20, 60, 24); // 1200 vectors
219        let results = topk_gated(&vectors, &centers[3], 10, TEST_GATE);
220        for w in results.windows(2) {
221            assert!(
222                w[0].1 >= w[1].1,
223                "results must be sorted by descending similarity"
224            );
225        }
226    }
227
228    #[test]
229    fn rebuilds_when_corpus_changes() {
230        let _serial = serial();
231        // Two distinct corpora share the global cache slot; the fingerprint must
232        // force a rebuild so each query reflects its own corpus (no staleness).
233        // Asserting on the cached fingerprint tests the rebuild mechanism
234        // directly — deterministic, unlike HNSW's approximate top-1 recall.
235        let (a, ca) = clustered(20, 55, 32); // 1100 vectors
236        let (b, cb) = clustered(18, 60, 32); // 1080 vectors
237
238        let _ = topk_gated(&a, &ca[7], 5, TEST_GATE);
239        assert_eq!(
240            cached_fingerprint(),
241            Some(fingerprint(&a)),
242            "first query caches corpus A's index"
243        );
244
245        let _ = topk_gated(&b, &cb[4], 5, TEST_GATE);
246        assert_eq!(
247            cached_fingerprint(),
248            Some(fingerprint(&b)),
249            "a different corpus must force a rebuild to B"
250        );
251
252        let _ = topk_gated(&a, &ca[7], 5, TEST_GATE);
253        assert_eq!(
254            cached_fingerprint(),
255            Some(fingerprint(&a)),
256            "re-querying A must rebuild A — never serve stale B"
257        );
258    }
259
260    #[test]
261    fn fingerprint_differs_on_content_change() {
262        let a: Vec<Vec<f32>> = (0..10).map(|i| random_vec(8, i)).collect();
263        let mut b = a.clone();
264        b[3][0] += 0.5;
265        assert_ne!(fingerprint(&a), fingerprint(&b));
266    }
267}