Skip to main content

ripvec_core/
index.rs

1//! In-memory search index for real-time re-ranking.
2//!
3//! Stores all chunk embeddings as a contiguous ndarray matrix so that
4//! re-ranking is a single BLAS matrix-vector multiply via [`crate::similarity::rank_all`].
5//!
6//! Optionally uses [`TurboQuant`](crate::turbo_quant::PolarCodec) compression for fast approximate
7//! scanning at monorepo scale (100K+ chunks). `TurboQuant` compresses 768-dim
8//! embeddings from 3072 bytes (FP32) to ~386 bytes (4-bit), giving ~5× faster
9//! scan via sequential memory access + centroid table lookup.
10
11use ndarray::{Array1, Array2};
12
13use crate::chunk::CodeChunk;
14use crate::turbo_quant::{CompressedCorpus, PolarCodec};
15
16/// Minimum corpus size before building TurboQuant compression.
17///
18/// TurboQuant pays a fixed setup cost to generate a 768×768 rotation matrix and
19/// encode the corpus. That cost is only worthwhile at larger corpus sizes; for
20/// small projects exact BLAS ranking is faster and avoids long index-build
21/// pauses.
22const TURBOQUANT_MIN_CHUNKS: usize = 4096;
23
24/// Pre-computed embedding matrix for fast re-ranking.
25///
26/// Stores all chunk embeddings as a contiguous `[num_chunks, hidden_dim]`
27/// ndarray matrix. Re-ranking is a single BLAS matrix-vector multiply.
28///
29/// When constructed with a `cascade_dim`, also stores a truncated and
30/// re-normalized `[num_chunks, cascade_dim]` matrix for two-phase MRL
31/// cascade search: fast pre-filter at reduced dimension, then full-dim
32/// re-rank of the top candidates.
33pub struct SearchIndex {
34    /// All chunks with metadata.
35    pub chunks: Vec<CodeChunk>,
36    /// Embedding matrix `[num_chunks, hidden_dim]`.
37    embeddings: Array2<f32>,
38    /// Truncated + re-normalized embedding matrix for MRL cascade pre-filter.
39    /// `None` when cascade search is disabled.
40    truncated: Option<Array2<f32>>,
41    /// `TurboQuant`-compressed embeddings for fast approximate scanning.
42    /// At 4-bit: 386 bytes/vector vs 3072 bytes FP32 (8× compression).
43    /// Scan is ~5× faster than FP32 BLAS at 100K+ chunks.
44    compressed: Option<CompressedIndex>,
45    /// Hidden dimension size.
46    pub hidden_dim: usize,
47    /// Truncated dimension size, if cascade search is enabled.
48    truncated_dim: Option<usize>,
49}
50
51/// `PolarQuant`-compressed embedding index for fast approximate scanning.
52///
53/// Uses SoA flat layout ([`CompressedCorpus`]) for cache-friendly streaming scans.
54struct CompressedIndex {
55    /// The codec (holds rotation matrix + centroid tables).
56    codec: PolarCodec,
57    /// Flat SoA corpus: radii + indices packed contiguously.
58    corpus: CompressedCorpus,
59}
60
61impl SearchIndex {
62    /// Build an index from `embed_all` output.
63    ///
64    /// Flattens the per-chunk embedding vectors into a contiguous `Array2`
65    /// for BLAS-accelerated matrix-vector products at query time.
66    ///
67    /// When `cascade_dim` is `Some(d)`, also builds a truncated and
68    /// L2-re-normalized `[N, d]` matrix for two-phase MRL cascade search.
69    /// The truncated dimension is clamped to `hidden_dim`.
70    ///
71    /// # Panics
72    ///
73    /// Panics if the flattened embedding data cannot form a valid
74    /// `[num_chunks, hidden_dim]` matrix (should never happen when
75    /// embeddings come from `embed_all`).
76    pub fn new(
77        chunks: Vec<CodeChunk>,
78        raw_embeddings: &[Vec<f32>],
79        cascade_dim: Option<usize>,
80    ) -> Self {
81        let hidden_dim = raw_embeddings.first().map_or(384, Vec::len);
82        let n = chunks.len();
83
84        // Flatten into contiguous array for BLAS
85        let mut flat = Vec::with_capacity(n * hidden_dim);
86        for emb in raw_embeddings {
87            if emb.len() == hidden_dim {
88                flat.extend_from_slice(emb);
89            } else {
90                // Pad/truncate to hidden_dim (shouldn't happen, but be safe)
91                flat.extend(emb.iter().take(hidden_dim));
92                flat.resize(flat.len() + hidden_dim.saturating_sub(emb.len()), 0.0);
93            }
94        }
95
96        let embeddings =
97            Array2::from_shape_vec((n, hidden_dim), flat).expect("embedding matrix shape mismatch");
98
99        // Build truncated + re-normalized matrix for MRL cascade pre-filter.
100        // Nomic MRL models require layer-norm before truncation:
101        //   1. Layer-norm over the FULL embedding (mean-center, scale by inv_std)
102        //   2. Truncate to first d dimensions
103        //   3. L2 renormalize the truncated slice
104        let truncated_dim = cascade_dim.map(|d| d.min(hidden_dim));
105        let truncated = truncated_dim.map(|d| {
106            let mut trunc = Array2::zeros((n, d));
107            for (i, row) in embeddings.rows().into_iter().enumerate() {
108                let full = row.as_slice().expect("embedding row not contiguous");
109
110                // Step 1: Layer-norm over FULL embedding
111                let len = full.len() as f32;
112                let mean: f32 = full.iter().sum::<f32>() / len;
113                let var: f32 = full.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len;
114                let inv_std = 1.0 / (var + 1e-5).sqrt();
115
116                // Step 2: Truncate first d dims of layer-normed embedding
117                // Step 3: L2 renormalize the truncated slice
118                let norm: f32 = full[..d]
119                    .iter()
120                    .map(|x| {
121                        let ln = (x - mean) * inv_std;
122                        ln * ln
123                    })
124                    .sum::<f32>()
125                    .sqrt()
126                    .max(1e-12);
127                for (j, &v) in full[..d].iter().enumerate() {
128                    trunc[[i, j]] = (v - mean) * inv_std / norm;
129                }
130            }
131            trunc
132        });
133
134        // Compress embeddings with PolarQuant (4-bit) only when the corpus is
135        // large enough to amortize the fixed rotation/encoding setup cost.
136        // At 768-dim: ~1920 bytes/vector vs 3072 FP32. 8× compression with bit-packing.
137        let compressed =
138            if n >= TURBOQUANT_MIN_CHUNKS && hidden_dim >= 64 && hidden_dim.is_multiple_of(2) {
139                tracing::info!(
140                    chunks = n,
141                    hidden_dim,
142                    "building TurboQuant compressed index"
143                );
144                let codec = PolarCodec::new(hidden_dim, 4, 42);
145                let corpus = codec.encode_batch(&embeddings);
146                Some(CompressedIndex { codec, corpus })
147            } else {
148                tracing::debug!(
149                    chunks = n,
150                    hidden_dim,
151                    min_chunks = TURBOQUANT_MIN_CHUNKS,
152                    "skipping TurboQuant compression for small corpus"
153                );
154                None
155            };
156
157        Self {
158            chunks,
159            embeddings,
160            truncated,
161            compressed,
162            hidden_dim,
163            truncated_dim,
164        }
165    }
166
167    /// Rank all chunks against a query embedding.
168    ///
169    /// Returns `(chunk_index, similarity_score)` pairs sorted by descending
170    /// score, filtered by `threshold`.
171    #[must_use]
172    pub fn rank(&self, query_embedding: &[f32], threshold: f32) -> Vec<(usize, f32)> {
173        if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
174            return vec![];
175        }
176        let query = Array1::from_vec(query_embedding.to_vec());
177        let scores = crate::similarity::rank_all(&self.embeddings, &query);
178
179        let mut results: Vec<(usize, f32)> = scores
180            .into_iter()
181            .enumerate()
182            .filter(|(_, score)| *score >= threshold)
183            .collect();
184        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
185        results
186    }
187
188    /// `TurboQuant`-accelerated ranking: compressed approximate scan → exact re-rank.
189    ///
190    /// 1. Estimate inner products for ALL vectors via `TurboQuant` (~5× faster than BLAS).
191    /// 2. Take top `pre_filter_k` approximate candidates.
192    /// 3. Re-rank with exact FP32 dot products on the full embedding matrix.
193    ///
194    /// Falls back to [`Self::rank`] when no compressed index is available.
195    #[must_use]
196    pub fn rank_turboquant(
197        &self,
198        query_embedding: &[f32],
199        top_k: usize,
200        threshold: f32,
201    ) -> Vec<(usize, f32)> {
202        let Some(ref comp) = self.compressed else {
203            return self.rank(query_embedding, threshold);
204        };
205
206        if comp.corpus.n != self.chunks.len() {
207            return self.rank(query_embedding, threshold);
208        }
209
210        // Phase 1: SoA corpus scan — sequential streaming, centroid table in L1.
211        // `saturating_mul` guards against overflow when a caller passes a huge
212        // top_k as a "no limit" sentinel; `.min(corpus.n)` caps to the corpus
213        // size either way.
214        let pre_filter_k = top_k.saturating_mul(10).min(comp.corpus.n);
215        let query_state = comp.codec.prepare_query(query_embedding);
216        let scores = comp.codec.scan_corpus(&comp.corpus, &query_state);
217        let mut approx_scores: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
218        approx_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
219        approx_scores.truncate(pre_filter_k);
220
221        // Phase 2: exact re-rank top candidates
222        let query = Array1::from_vec(query_embedding.to_vec());
223        let mut results: Vec<(usize, f32)> = approx_scores
224            .iter()
225            .map(|&(idx, _)| {
226                let exact = self.embeddings.row(idx).dot(&query);
227                (idx, exact)
228            })
229            .filter(|(_, score)| *score >= threshold)
230            .collect();
231        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
232        results.truncate(top_k);
233        results
234    }
235
236    /// Two-phase MRL cascade ranking: fast pre-filter then full re-rank.
237    ///
238    /// 1. Layer-norms the query over its full dimension, truncates to
239    ///    `truncated_dim`, L2-normalizes, and computes dot products against
240    ///    the truncated matrix to find the top `pre_filter_k` candidates.
241    /// 2. Re-ranks those candidates using full-dimension dot products.
242    ///
243    /// Falls back to [`Self::rank`] when no truncated matrix is available.
244    #[must_use]
245    pub fn rank_cascade(
246        &self,
247        query_embedding: &[f32],
248        top_k: usize,
249        threshold: f32,
250    ) -> Vec<(usize, f32)> {
251        let Some(ref trunc_matrix) = self.truncated else {
252            return self.rank(query_embedding, threshold);
253        };
254        if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
255            return vec![];
256        }
257
258        let trunc_dim = trunc_matrix.shape()[1];
259        let pre_filter_k = 100_usize.max(top_k * 3); // over-retrieve for re-ranking
260
261        // Phase 1: fast pre-filter at truncated dimension
262        // Apply layer-norm over full query before truncation (matches corpus processing)
263        let len = query_embedding.len() as f32;
264        let mean: f32 = query_embedding.iter().sum::<f32>() / len;
265        let var: f32 = query_embedding
266            .iter()
267            .map(|x| (x - mean).powi(2))
268            .sum::<f32>()
269            / len;
270        let inv_std = 1.0 / (var + 1e-5).sqrt();
271        let trunc_query: Vec<f32> = query_embedding[..trunc_dim]
272            .iter()
273            .map(|x| (x - mean) * inv_std)
274            .collect();
275        let norm: f32 = trunc_query
276            .iter()
277            .map(|x| x * x)
278            .sum::<f32>()
279            .sqrt()
280            .max(1e-12);
281        let trunc_query_norm: Vec<f32> = trunc_query.iter().map(|x| x / norm).collect();
282        let trunc_q = Array1::from_vec(trunc_query_norm);
283        let scores = trunc_matrix.dot(&trunc_q);
284
285        // Get top pre_filter_k indices
286        let mut candidates: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
287        candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
288        candidates.truncate(pre_filter_k);
289
290        // Phase 2: re-rank candidates with full-dimension dot products
291        let query_arr = Array1::from_vec(query_embedding.to_vec());
292        let mut reranked: Vec<(usize, f32)> = candidates
293            .into_iter()
294            .map(|(idx, _)| {
295                let full_score = self.embeddings.row(idx).dot(&query_arr);
296                (idx, full_score)
297            })
298            .filter(|(_, s)| *s >= threshold)
299            .collect();
300        reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
301        reranked.truncate(top_k);
302        reranked
303    }
304
305    /// Return a clone of the embedding vector for chunk `idx`.
306    ///
307    /// Returns `None` if `idx` is out of bounds.
308    #[must_use]
309    pub fn embedding(&self, idx: usize) -> Option<Vec<f32>> {
310        if idx >= self.chunks.len() {
311            return None;
312        }
313        Some(self.embeddings.row(idx).to_vec())
314    }
315
316    /// Find duplicate or near-duplicate chunks by pairwise cosine similarity.
317    ///
318    /// Computes `embeddings @ embeddings.T` (a single BLAS GEMM) to get all
319    /// pairwise similarities, then extracts pairs above `threshold` from the
320    /// upper triangle (avoiding self-matches and symmetric duplicates).
321    ///
322    /// Returns `(chunk_a, chunk_b, similarity)` sorted by descending similarity.
323    /// Each pair appears only once (a < b).
324    #[must_use]
325    pub fn find_duplicates(&self, threshold: f32, max_pairs: usize) -> Vec<(usize, usize, f32)> {
326        let n = self.chunks.len();
327        if n < 2 {
328            return vec![];
329        }
330
331        // Single GEMM: [n, dim] × [dim, n] = [n, n] pairwise similarity matrix
332        let sim_matrix = self.embeddings.dot(&self.embeddings.t());
333
334        // Scan upper triangle for pairs above threshold
335        let mut pairs: Vec<(usize, usize, f32)> = Vec::new();
336        for i in 0..n {
337            for j in (i + 1)..n {
338                let score = sim_matrix[[i, j]];
339                if score >= threshold {
340                    pairs.push((i, j, score));
341                }
342            }
343        }
344
345        pairs.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
346        pairs.truncate(max_pairs);
347        pairs
348    }
349
350    /// Number of chunks in the index.
351    #[must_use]
352    pub fn len(&self) -> usize {
353        self.chunks.len()
354    }
355
356    /// Whether the index is empty.
357    #[must_use]
358    pub fn is_empty(&self) -> bool {
359        self.chunks.is_empty()
360    }
361
362    /// The truncated dimension used for cascade pre-filtering, if enabled.
363    #[must_use]
364    pub fn truncated_dim(&self) -> Option<usize> {
365        self.truncated_dim
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    /// Helper to create a dummy `CodeChunk` for testing.
374    fn dummy_chunk(name: &str) -> CodeChunk {
375        let content = format!("fn {name}() {{}}");
376        CodeChunk {
377            file_path: "test.rs".to_string(),
378            name: name.to_string(),
379            kind: "function".to_string(),
380            start_line: 1,
381            end_line: 10,
382            enriched_content: content.clone(),
383            content,
384        }
385    }
386
387    #[test]
388    fn new_builds_correct_matrix_shape() {
389        let chunks = vec![dummy_chunk("a"), dummy_chunk("b"), dummy_chunk("c")];
390        let embeddings = vec![
391            vec![1.0, 0.0, 0.0],
392            vec![0.0, 1.0, 0.0],
393            vec![0.0, 0.0, 1.0],
394        ];
395
396        let index = SearchIndex::new(chunks, &embeddings, None);
397
398        assert_eq!(index.len(), 3);
399        assert_eq!(index.hidden_dim, 3);
400        assert!(!index.is_empty());
401    }
402
403    #[test]
404    fn small_corpus_skips_turboquant_compression() {
405        let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
406        let embeddings = vec![vec![0.0; 768], vec![1.0; 768]];
407
408        let index = SearchIndex::new(chunks, &embeddings, None);
409
410        assert!(index.compressed.is_none());
411    }
412
413    #[test]
414    fn rank_returns_sorted_results_above_threshold() {
415        let chunks = vec![dummy_chunk("low"), dummy_chunk("high"), dummy_chunk("mid")];
416        // Embeddings designed so dot product with [1, 0] gives known scores:
417        // chunk 0: 0.2, chunk 1: 0.9, chunk 2: 0.5
418        let embeddings = vec![vec![0.2, 0.8], vec![0.9, 0.1], vec![0.5, 0.5]];
419
420        let index = SearchIndex::new(chunks, &embeddings, None);
421        let results = index.rank(&[1.0, 0.0], 0.3);
422
423        // Should exclude chunk 0 (score 0.2 < threshold 0.3)
424        assert_eq!(results.len(), 2);
425        // Should be sorted descending: chunk 1 (0.9), then chunk 2 (0.5)
426        assert_eq!(results[0].0, 1);
427        assert_eq!(results[1].0, 2);
428        assert!(results[0].1 > results[1].1);
429    }
430
431    #[test]
432    fn rank_with_wrong_dimension_returns_empty() {
433        let chunks = vec![dummy_chunk("a")];
434        let embeddings = vec![vec![1.0, 0.0, 0.0]];
435
436        let index = SearchIndex::new(chunks, &embeddings, None);
437        // Query has wrong dimension (2 instead of 3)
438        let results = index.rank(&[1.0, 0.0], 0.0);
439
440        assert!(results.is_empty());
441    }
442
443    #[test]
444    fn rank_with_empty_query_returns_empty() {
445        let chunks = vec![dummy_chunk("a")];
446        let embeddings = vec![vec![1.0, 0.0, 0.0]];
447
448        let index = SearchIndex::new(chunks, &embeddings, None);
449        let results = index.rank(&[], 0.0);
450
451        assert!(results.is_empty());
452    }
453
454    #[test]
455    fn rank_handles_empty_index() {
456        let index = SearchIndex::new(vec![], &[], None);
457
458        // hidden_dim defaults to 384 for empty input
459        assert!(index.is_empty());
460        assert_eq!(index.len(), 0);
461
462        let results = index.rank(&[1.0; 384], 0.0);
463        assert!(results.is_empty());
464    }
465
466    /// L2-normalize a vector in-place.
467    fn l2_normalize(v: &mut [f32]) {
468        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
469        for x in v.iter_mut() {
470            *x /= norm;
471        }
472    }
473
474    #[test]
475    #[expect(
476        clippy::cast_precision_loss,
477        reason = "test values are small counts and indices"
478    )]
479    fn cascade_recall_at_10_vs_full_rank() {
480        // Build 200 chunks with 8-dim random-ish embeddings (L2-normalized).
481        // Use a deterministic pattern so the test is reproducible.
482        let n = 200;
483        let dim = 8;
484        let cascade_dim = 4;
485
486        let mut chunks = Vec::with_capacity(n);
487        let mut embeddings = Vec::with_capacity(n);
488        for i in 0..n {
489            chunks.push(dummy_chunk(&format!("chunk_{i}")));
490            // Deterministic pseudo-random: use sin/cos of index
491            let mut emb: Vec<f32> = (0..dim).map(|d| ((i * 7 + d * 13) as f32).sin()).collect();
492            l2_normalize(&mut emb);
493            embeddings.push(emb);
494        }
495
496        // Query: L2-normalized
497        let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
498        l2_normalize(&mut query);
499
500        // Build index without cascade (reference)
501        let index_full = SearchIndex::new(chunks.clone(), &embeddings, None);
502        let full_results = index_full.rank(&query, 0.0);
503        let full_top10: Vec<usize> = full_results.iter().take(10).map(|(idx, _)| *idx).collect();
504
505        // Build index with cascade
506        let index_cascade = SearchIndex::new(chunks, &embeddings, Some(cascade_dim));
507        assert_eq!(index_cascade.truncated_dim(), Some(cascade_dim));
508        let cascade_results = index_cascade.rank_cascade(&query, 10, 0.0);
509        let cascade_top10: Vec<usize> = cascade_results.iter().map(|(idx, _)| *idx).collect();
510
511        // Recall@10: how many of full-dim top-10 appear in cascade top-10
512        let overlap = full_top10
513            .iter()
514            .filter(|i| cascade_top10.contains(i))
515            .count();
516        let recall = overlap as f32 / 10.0;
517
518        assert!(
519            recall >= 0.7,
520            "cascade Recall@10 = {recall} ({overlap}/10), expected >= 0.7"
521        );
522    }
523
524    #[test]
525    fn cascade_falls_back_without_truncated_matrix() {
526        let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
527        let embeddings = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
528
529        // No cascade_dim → rank_cascade should behave like rank
530        let index = SearchIndex::new(chunks, &embeddings, None);
531        let cascade = index.rank_cascade(&[1.0, 0.0], 10, 0.0);
532        let plain = index.rank(&[1.0, 0.0], 0.0);
533
534        assert_eq!(cascade.len(), plain.len());
535        for (c, p) in cascade.iter().zip(plain.iter()) {
536            assert_eq!(c.0, p.0);
537            assert!((c.1 - p.1).abs() < 1e-6);
538        }
539    }
540
541    #[test]
542    fn cascade_respects_threshold() {
543        let chunks = vec![dummy_chunk("high"), dummy_chunk("low")];
544        // Embeddings: chunk 0 aligns with query, chunk 1 is orthogonal
545        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
546
547        let index = SearchIndex::new(chunks, &embeddings, Some(1));
548        let results = index.rank_cascade(&[1.0, 0.0], 10, 0.5);
549
550        // Only chunk 0 should pass the 0.5 threshold
551        assert_eq!(results.len(), 1);
552        assert_eq!(results[0].0, 0);
553    }
554
555    #[test]
556    fn turboquant_recall_vs_exact() {
557        // Generate 200 random 768-dim L2-normalized embeddings.
558        let dim = 768;
559        let n = 200;
560        let embeddings: Vec<Vec<f32>> = (0..n)
561            .map(|i| {
562                let mut v: Vec<f32> = (0..dim).map(|d| ((i * 17 + d * 31) as f32).sin()).collect();
563                l2_normalize(&mut v);
564                v
565            })
566            .collect();
567
568        let chunks: Vec<CodeChunk> = (0..n).map(|i| dummy_chunk(&format!("chunk_{i}"))).collect();
569        let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
570        l2_normalize(&mut query);
571
572        let index = SearchIndex::new(chunks, &embeddings, None);
573
574        // Exact ranking
575        let exact = index.rank(&query, 0.0);
576        let exact_top10: Vec<usize> = exact.iter().take(10).map(|(idx, _)| *idx).collect();
577
578        // TurboQuant ranking
579        let tq = index.rank_turboquant(&query, 10, 0.0);
580        let tq_top10: Vec<usize> = tq.iter().take(10).map(|(idx, _)| *idx).collect();
581
582        // Recall@10: how many of exact top-10 appear in TQ top-10
583        let recall = exact_top10.iter().filter(|i| tq_top10.contains(i)).count();
584        eprintln!("TurboQuant Recall@10: {recall}/10");
585        assert!(
586            recall >= 7,
587            "TurboQuant recall should be >= 7/10, got {recall}/10"
588        );
589    }
590}