Skip to main content

ripvec_core/
index.rs

1//! In-memory search index for real-time re-ranking.
2//!
3//! Stores all chunk embeddings as a contiguous ndarray matrix so that
4//! re-ranking is a single BLAS matrix-vector multiply via [`crate::similarity::rank_all`].
5//!
6//! Optionally uses [`TurboQuant`](crate::turbo_quant::PolarCodec) compression for fast approximate
7//! scanning at monorepo scale (100K+ chunks). `TurboQuant` compresses 768-dim
8//! embeddings from 3072 bytes (FP32) to ~386 bytes (4-bit), giving ~5× faster
9//! scan via sequential memory access + centroid table lookup.
10
11use ndarray::{Array1, Array2};
12
13use crate::chunk::CodeChunk;
14use crate::turbo_quant::{CompressedCorpus, PolarCodec};
15
16/// Pre-computed embedding matrix for fast re-ranking.
17///
18/// Stores all chunk embeddings as a contiguous `[num_chunks, hidden_dim]`
19/// ndarray matrix. Re-ranking is a single BLAS matrix-vector multiply.
20///
21/// When constructed with a `cascade_dim`, also stores a truncated and
22/// re-normalized `[num_chunks, cascade_dim]` matrix for two-phase MRL
23/// cascade search: fast pre-filter at reduced dimension, then full-dim
24/// re-rank of the top candidates.
25pub struct SearchIndex {
26    /// All chunks with metadata.
27    pub chunks: Vec<CodeChunk>,
28    /// Embedding matrix `[num_chunks, hidden_dim]`.
29    embeddings: Array2<f32>,
30    /// Truncated + re-normalized embedding matrix for MRL cascade pre-filter.
31    /// `None` when cascade search is disabled.
32    truncated: Option<Array2<f32>>,
33    /// `TurboQuant`-compressed embeddings for fast approximate scanning.
34    /// At 4-bit: 386 bytes/vector vs 3072 bytes FP32 (8× compression).
35    /// Scan is ~5× faster than FP32 BLAS at 100K+ chunks.
36    compressed: Option<CompressedIndex>,
37    /// Hidden dimension size.
38    pub hidden_dim: usize,
39    /// Truncated dimension size, if cascade search is enabled.
40    truncated_dim: Option<usize>,
41}
42
43/// `PolarQuant`-compressed embedding index for fast approximate scanning.
44///
45/// Uses SoA flat layout ([`CompressedCorpus`]) for cache-friendly streaming scans.
46struct CompressedIndex {
47    /// The codec (holds rotation matrix + centroid tables).
48    codec: PolarCodec,
49    /// Flat SoA corpus: radii + indices packed contiguously.
50    corpus: CompressedCorpus,
51}
52
53impl SearchIndex {
54    /// Build an index from `embed_all` output.
55    ///
56    /// Flattens the per-chunk embedding vectors into a contiguous `Array2`
57    /// for BLAS-accelerated matrix-vector products at query time.
58    ///
59    /// When `cascade_dim` is `Some(d)`, also builds a truncated and
60    /// L2-re-normalized `[N, d]` matrix for two-phase MRL cascade search.
61    /// The truncated dimension is clamped to `hidden_dim`.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the flattened embedding data cannot form a valid
66    /// `[num_chunks, hidden_dim]` matrix (should never happen when
67    /// embeddings come from `embed_all`).
68    pub fn new(
69        chunks: Vec<CodeChunk>,
70        raw_embeddings: &[Vec<f32>],
71        cascade_dim: Option<usize>,
72    ) -> Self {
73        let hidden_dim = raw_embeddings.first().map_or(384, Vec::len);
74        let n = chunks.len();
75
76        // Flatten into contiguous array for BLAS
77        let mut flat = Vec::with_capacity(n * hidden_dim);
78        for emb in raw_embeddings {
79            if emb.len() == hidden_dim {
80                flat.extend_from_slice(emb);
81            } else {
82                // Pad/truncate to hidden_dim (shouldn't happen, but be safe)
83                flat.extend(emb.iter().take(hidden_dim));
84                flat.resize(flat.len() + hidden_dim.saturating_sub(emb.len()), 0.0);
85            }
86        }
87
88        let embeddings =
89            Array2::from_shape_vec((n, hidden_dim), flat).expect("embedding matrix shape mismatch");
90
91        // Build truncated + re-normalized matrix for MRL cascade pre-filter.
92        // Nomic MRL models require layer-norm before truncation:
93        //   1. Layer-norm over the FULL embedding (mean-center, scale by inv_std)
94        //   2. Truncate to first d dimensions
95        //   3. L2 renormalize the truncated slice
96        let truncated_dim = cascade_dim.map(|d| d.min(hidden_dim));
97        let truncated = truncated_dim.map(|d| {
98            let mut trunc = Array2::zeros((n, d));
99            for (i, row) in embeddings.rows().into_iter().enumerate() {
100                let full = row.as_slice().expect("embedding row not contiguous");
101
102                // Step 1: Layer-norm over FULL embedding
103                let len = full.len() as f32;
104                let mean: f32 = full.iter().sum::<f32>() / len;
105                let var: f32 = full.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len;
106                let inv_std = 1.0 / (var + 1e-5).sqrt();
107
108                // Step 2: Truncate first d dims of layer-normed embedding
109                // Step 3: L2 renormalize the truncated slice
110                let norm: f32 = full[..d]
111                    .iter()
112                    .map(|x| {
113                        let ln = (x - mean) * inv_std;
114                        ln * ln
115                    })
116                    .sum::<f32>()
117                    .sqrt()
118                    .max(1e-12);
119                for (j, &v) in full[..d].iter().enumerate() {
120                    trunc[[i, j]] = (v - mean) * inv_std / norm;
121                }
122            }
123            trunc
124        });
125
126        // Compress embeddings with PolarQuant (4-bit).
127        // At 768-dim: ~1920 bytes/vector vs 3072 FP32. 8× compression with bit-packing.
128        let compressed = if hidden_dim >= 64 && hidden_dim.is_multiple_of(2) {
129            let codec = PolarCodec::new(hidden_dim, 4, 42);
130            let corpus = codec.encode_batch(&embeddings);
131            Some(CompressedIndex { codec, corpus })
132        } else {
133            None
134        };
135
136        Self {
137            chunks,
138            embeddings,
139            truncated,
140            compressed,
141            hidden_dim,
142            truncated_dim,
143        }
144    }
145
146    /// Rank all chunks against a query embedding.
147    ///
148    /// Returns `(chunk_index, similarity_score)` pairs sorted by descending
149    /// score, filtered by `threshold`.
150    #[must_use]
151    pub fn rank(&self, query_embedding: &[f32], threshold: f32) -> Vec<(usize, f32)> {
152        if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
153            return vec![];
154        }
155        let query = Array1::from_vec(query_embedding.to_vec());
156        let scores = crate::similarity::rank_all(&self.embeddings, &query);
157
158        let mut results: Vec<(usize, f32)> = scores
159            .into_iter()
160            .enumerate()
161            .filter(|(_, score)| *score >= threshold)
162            .collect();
163        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
164        results
165    }
166
167    /// `TurboQuant`-accelerated ranking: compressed approximate scan → exact re-rank.
168    ///
169    /// 1. Estimate inner products for ALL vectors via `TurboQuant` (~5× faster than BLAS).
170    /// 2. Take top `pre_filter_k` approximate candidates.
171    /// 3. Re-rank with exact FP32 dot products on the full embedding matrix.
172    ///
173    /// Falls back to [`Self::rank`] when no compressed index is available.
174    #[must_use]
175    pub fn rank_turboquant(
176        &self,
177        query_embedding: &[f32],
178        top_k: usize,
179        threshold: f32,
180    ) -> Vec<(usize, f32)> {
181        let Some(ref comp) = self.compressed else {
182            return self.rank(query_embedding, threshold);
183        };
184
185        if comp.corpus.n != self.chunks.len() {
186            return self.rank(query_embedding, threshold);
187        }
188
189        // Phase 1: SoA corpus scan — sequential streaming, centroid table in L1.
190        // `saturating_mul` guards against overflow when a caller passes a huge
191        // top_k as a "no limit" sentinel; `.min(corpus.n)` caps to the corpus
192        // size either way.
193        let pre_filter_k = top_k.saturating_mul(10).min(comp.corpus.n);
194        let query_state = comp.codec.prepare_query(query_embedding);
195        let scores = comp.codec.scan_corpus(&comp.corpus, &query_state);
196        let mut approx_scores: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
197        approx_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
198        approx_scores.truncate(pre_filter_k);
199
200        // Phase 2: exact re-rank top candidates
201        let query = Array1::from_vec(query_embedding.to_vec());
202        let mut results: Vec<(usize, f32)> = approx_scores
203            .iter()
204            .map(|&(idx, _)| {
205                let exact = self.embeddings.row(idx).dot(&query);
206                (idx, exact)
207            })
208            .filter(|(_, score)| *score >= threshold)
209            .collect();
210        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
211        results.truncate(top_k);
212        results
213    }
214
215    /// Two-phase MRL cascade ranking: fast pre-filter then full re-rank.
216    ///
217    /// 1. Layer-norms the query over its full dimension, truncates to
218    ///    `truncated_dim`, L2-normalizes, and computes dot products against
219    ///    the truncated matrix to find the top `pre_filter_k` candidates.
220    /// 2. Re-ranks those candidates using full-dimension dot products.
221    ///
222    /// Falls back to [`Self::rank`] when no truncated matrix is available.
223    #[must_use]
224    pub fn rank_cascade(
225        &self,
226        query_embedding: &[f32],
227        top_k: usize,
228        threshold: f32,
229    ) -> Vec<(usize, f32)> {
230        let Some(ref trunc_matrix) = self.truncated else {
231            return self.rank(query_embedding, threshold);
232        };
233        if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
234            return vec![];
235        }
236
237        let trunc_dim = trunc_matrix.shape()[1];
238        let pre_filter_k = 100_usize.max(top_k * 3); // over-retrieve for re-ranking
239
240        // Phase 1: fast pre-filter at truncated dimension
241        // Apply layer-norm over full query before truncation (matches corpus processing)
242        let len = query_embedding.len() as f32;
243        let mean: f32 = query_embedding.iter().sum::<f32>() / len;
244        let var: f32 = query_embedding
245            .iter()
246            .map(|x| (x - mean).powi(2))
247            .sum::<f32>()
248            / len;
249        let inv_std = 1.0 / (var + 1e-5).sqrt();
250        let trunc_query: Vec<f32> = query_embedding[..trunc_dim]
251            .iter()
252            .map(|x| (x - mean) * inv_std)
253            .collect();
254        let norm: f32 = trunc_query
255            .iter()
256            .map(|x| x * x)
257            .sum::<f32>()
258            .sqrt()
259            .max(1e-12);
260        let trunc_query_norm: Vec<f32> = trunc_query.iter().map(|x| x / norm).collect();
261        let trunc_q = Array1::from_vec(trunc_query_norm);
262        let scores = trunc_matrix.dot(&trunc_q);
263
264        // Get top pre_filter_k indices
265        let mut candidates: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
266        candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
267        candidates.truncate(pre_filter_k);
268
269        // Phase 2: re-rank candidates with full-dimension dot products
270        let query_arr = Array1::from_vec(query_embedding.to_vec());
271        let mut reranked: Vec<(usize, f32)> = candidates
272            .into_iter()
273            .map(|(idx, _)| {
274                let full_score = self.embeddings.row(idx).dot(&query_arr);
275                (idx, full_score)
276            })
277            .filter(|(_, s)| *s >= threshold)
278            .collect();
279        reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
280        reranked.truncate(top_k);
281        reranked
282    }
283
284    /// Return a clone of the embedding vector for chunk `idx`.
285    ///
286    /// Returns `None` if `idx` is out of bounds.
287    #[must_use]
288    pub fn embedding(&self, idx: usize) -> Option<Vec<f32>> {
289        if idx >= self.chunks.len() {
290            return None;
291        }
292        Some(self.embeddings.row(idx).to_vec())
293    }
294
295    /// Find duplicate or near-duplicate chunks by pairwise cosine similarity.
296    ///
297    /// Computes `embeddings @ embeddings.T` (a single BLAS GEMM) to get all
298    /// pairwise similarities, then extracts pairs above `threshold` from the
299    /// upper triangle (avoiding self-matches and symmetric duplicates).
300    ///
301    /// Returns `(chunk_a, chunk_b, similarity)` sorted by descending similarity.
302    /// Each pair appears only once (a < b).
303    #[must_use]
304    pub fn find_duplicates(&self, threshold: f32, max_pairs: usize) -> Vec<(usize, usize, f32)> {
305        let n = self.chunks.len();
306        if n < 2 {
307            return vec![];
308        }
309
310        // Single GEMM: [n, dim] × [dim, n] = [n, n] pairwise similarity matrix
311        let sim_matrix = self.embeddings.dot(&self.embeddings.t());
312
313        // Scan upper triangle for pairs above threshold
314        let mut pairs: Vec<(usize, usize, f32)> = Vec::new();
315        for i in 0..n {
316            for j in (i + 1)..n {
317                let score = sim_matrix[[i, j]];
318                if score >= threshold {
319                    pairs.push((i, j, score));
320                }
321            }
322        }
323
324        pairs.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
325        pairs.truncate(max_pairs);
326        pairs
327    }
328
329    /// Number of chunks in the index.
330    #[must_use]
331    pub fn len(&self) -> usize {
332        self.chunks.len()
333    }
334
335    /// Whether the index is empty.
336    #[must_use]
337    pub fn is_empty(&self) -> bool {
338        self.chunks.is_empty()
339    }
340
341    /// The truncated dimension used for cascade pre-filtering, if enabled.
342    #[must_use]
343    pub fn truncated_dim(&self) -> Option<usize> {
344        self.truncated_dim
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    /// Helper to create a dummy `CodeChunk` for testing.
353    fn dummy_chunk(name: &str) -> CodeChunk {
354        let content = format!("fn {name}() {{}}");
355        CodeChunk {
356            file_path: "test.rs".to_string(),
357            name: name.to_string(),
358            kind: "function".to_string(),
359            start_line: 1,
360            end_line: 10,
361            enriched_content: content.clone(),
362            content,
363        }
364    }
365
366    #[test]
367    fn new_builds_correct_matrix_shape() {
368        let chunks = vec![dummy_chunk("a"), dummy_chunk("b"), dummy_chunk("c")];
369        let embeddings = vec![
370            vec![1.0, 0.0, 0.0],
371            vec![0.0, 1.0, 0.0],
372            vec![0.0, 0.0, 1.0],
373        ];
374
375        let index = SearchIndex::new(chunks, &embeddings, None);
376
377        assert_eq!(index.len(), 3);
378        assert_eq!(index.hidden_dim, 3);
379        assert!(!index.is_empty());
380    }
381
382    #[test]
383    fn rank_returns_sorted_results_above_threshold() {
384        let chunks = vec![dummy_chunk("low"), dummy_chunk("high"), dummy_chunk("mid")];
385        // Embeddings designed so dot product with [1, 0] gives known scores:
386        // chunk 0: 0.2, chunk 1: 0.9, chunk 2: 0.5
387        let embeddings = vec![vec![0.2, 0.8], vec![0.9, 0.1], vec![0.5, 0.5]];
388
389        let index = SearchIndex::new(chunks, &embeddings, None);
390        let results = index.rank(&[1.0, 0.0], 0.3);
391
392        // Should exclude chunk 0 (score 0.2 < threshold 0.3)
393        assert_eq!(results.len(), 2);
394        // Should be sorted descending: chunk 1 (0.9), then chunk 2 (0.5)
395        assert_eq!(results[0].0, 1);
396        assert_eq!(results[1].0, 2);
397        assert!(results[0].1 > results[1].1);
398    }
399
400    #[test]
401    fn rank_with_wrong_dimension_returns_empty() {
402        let chunks = vec![dummy_chunk("a")];
403        let embeddings = vec![vec![1.0, 0.0, 0.0]];
404
405        let index = SearchIndex::new(chunks, &embeddings, None);
406        // Query has wrong dimension (2 instead of 3)
407        let results = index.rank(&[1.0, 0.0], 0.0);
408
409        assert!(results.is_empty());
410    }
411
412    #[test]
413    fn rank_with_empty_query_returns_empty() {
414        let chunks = vec![dummy_chunk("a")];
415        let embeddings = vec![vec![1.0, 0.0, 0.0]];
416
417        let index = SearchIndex::new(chunks, &embeddings, None);
418        let results = index.rank(&[], 0.0);
419
420        assert!(results.is_empty());
421    }
422
423    #[test]
424    fn rank_handles_empty_index() {
425        let index = SearchIndex::new(vec![], &[], None);
426
427        // hidden_dim defaults to 384 for empty input
428        assert!(index.is_empty());
429        assert_eq!(index.len(), 0);
430
431        let results = index.rank(&[1.0; 384], 0.0);
432        assert!(results.is_empty());
433    }
434
435    /// L2-normalize a vector in-place.
436    fn l2_normalize(v: &mut [f32]) {
437        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
438        for x in v.iter_mut() {
439            *x /= norm;
440        }
441    }
442
443    #[test]
444    #[expect(
445        clippy::cast_precision_loss,
446        reason = "test values are small counts and indices"
447    )]
448    fn cascade_recall_at_10_vs_full_rank() {
449        // Build 200 chunks with 8-dim random-ish embeddings (L2-normalized).
450        // Use a deterministic pattern so the test is reproducible.
451        let n = 200;
452        let dim = 8;
453        let cascade_dim = 4;
454
455        let mut chunks = Vec::with_capacity(n);
456        let mut embeddings = Vec::with_capacity(n);
457        for i in 0..n {
458            chunks.push(dummy_chunk(&format!("chunk_{i}")));
459            // Deterministic pseudo-random: use sin/cos of index
460            let mut emb: Vec<f32> = (0..dim).map(|d| ((i * 7 + d * 13) as f32).sin()).collect();
461            l2_normalize(&mut emb);
462            embeddings.push(emb);
463        }
464
465        // Query: L2-normalized
466        let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
467        l2_normalize(&mut query);
468
469        // Build index without cascade (reference)
470        let index_full = SearchIndex::new(chunks.clone(), &embeddings, None);
471        let full_results = index_full.rank(&query, 0.0);
472        let full_top10: Vec<usize> = full_results.iter().take(10).map(|(idx, _)| *idx).collect();
473
474        // Build index with cascade
475        let index_cascade = SearchIndex::new(chunks, &embeddings, Some(cascade_dim));
476        assert_eq!(index_cascade.truncated_dim(), Some(cascade_dim));
477        let cascade_results = index_cascade.rank_cascade(&query, 10, 0.0);
478        let cascade_top10: Vec<usize> = cascade_results.iter().map(|(idx, _)| *idx).collect();
479
480        // Recall@10: how many of full-dim top-10 appear in cascade top-10
481        let overlap = full_top10
482            .iter()
483            .filter(|i| cascade_top10.contains(i))
484            .count();
485        let recall = overlap as f32 / 10.0;
486
487        assert!(
488            recall >= 0.7,
489            "cascade Recall@10 = {recall} ({overlap}/10), expected >= 0.7"
490        );
491    }
492
493    #[test]
494    fn cascade_falls_back_without_truncated_matrix() {
495        let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
496        let embeddings = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
497
498        // No cascade_dim → rank_cascade should behave like rank
499        let index = SearchIndex::new(chunks, &embeddings, None);
500        let cascade = index.rank_cascade(&[1.0, 0.0], 10, 0.0);
501        let plain = index.rank(&[1.0, 0.0], 0.0);
502
503        assert_eq!(cascade.len(), plain.len());
504        for (c, p) in cascade.iter().zip(plain.iter()) {
505            assert_eq!(c.0, p.0);
506            assert!((c.1 - p.1).abs() < 1e-6);
507        }
508    }
509
510    #[test]
511    fn cascade_respects_threshold() {
512        let chunks = vec![dummy_chunk("high"), dummy_chunk("low")];
513        // Embeddings: chunk 0 aligns with query, chunk 1 is orthogonal
514        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
515
516        let index = SearchIndex::new(chunks, &embeddings, Some(1));
517        let results = index.rank_cascade(&[1.0, 0.0], 10, 0.5);
518
519        // Only chunk 0 should pass the 0.5 threshold
520        assert_eq!(results.len(), 1);
521        assert_eq!(results[0].0, 0);
522    }
523
524    #[test]
525    fn turboquant_recall_vs_exact() {
526        // Generate 200 random 768-dim L2-normalized embeddings.
527        let dim = 768;
528        let n = 200;
529        let embeddings: Vec<Vec<f32>> = (0..n)
530            .map(|i| {
531                let mut v: Vec<f32> = (0..dim).map(|d| ((i * 17 + d * 31) as f32).sin()).collect();
532                l2_normalize(&mut v);
533                v
534            })
535            .collect();
536
537        let chunks: Vec<CodeChunk> = (0..n).map(|i| dummy_chunk(&format!("chunk_{i}"))).collect();
538        let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
539        l2_normalize(&mut query);
540
541        let index = SearchIndex::new(chunks, &embeddings, None);
542
543        // Exact ranking
544        let exact = index.rank(&query, 0.0);
545        let exact_top10: Vec<usize> = exact.iter().take(10).map(|(idx, _)| *idx).collect();
546
547        // TurboQuant ranking
548        let tq = index.rank_turboquant(&query, 10, 0.0);
549        let tq_top10: Vec<usize> = tq.iter().take(10).map(|(idx, _)| *idx).collect();
550
551        // Recall@10: how many of exact top-10 appear in TQ top-10
552        let recall = exact_top10.iter().filter(|i| tq_top10.contains(i)).count();
553        eprintln!("TurboQuant Recall@10: {recall}/10");
554        assert!(
555            recall >= 7,
556            "TurboQuant recall should be >= 7/10, got {recall}/10"
557        );
558    }
559}