Skip to main content

ripvec_core/
index.rs

1//! In-memory search index for real-time re-ranking.
2//!
3//! Stores all chunk embeddings as a contiguous ndarray matrix so that
4//! re-ranking is a single BLAS matrix-vector multiply via [`crate::similarity::rank_all`].
5//!
6//! Optionally uses [`TurboQuant`](crate::turbo_quant::PolarCodec) compression for fast approximate
7//! scanning at monorepo scale (100K+ chunks). `TurboQuant` compresses 768-dim
8//! embeddings from 3072 bytes (FP32) to ~386 bytes (4-bit), giving ~5× faster
9//! scan via sequential memory access + centroid table lookup.
10
11use ndarray::{Array1, Array2};
12
13use crate::chunk::CodeChunk;
14use crate::turbo_quant::{CompressedCorpus, PolarCodec};
15
16/// Pre-computed embedding matrix for fast re-ranking.
17///
18/// Stores all chunk embeddings as a contiguous `[num_chunks, hidden_dim]`
19/// ndarray matrix. Re-ranking is a single BLAS matrix-vector multiply.
20///
21/// When constructed with a `cascade_dim`, also stores a truncated and
22/// re-normalized `[num_chunks, cascade_dim]` matrix for two-phase MRL
23/// cascade search: fast pre-filter at reduced dimension, then full-dim
24/// re-rank of the top candidates.
25pub struct SearchIndex {
26    /// All chunks with metadata.
27    pub chunks: Vec<CodeChunk>,
28    /// Embedding matrix `[num_chunks, hidden_dim]`.
29    embeddings: Array2<f32>,
30    /// Truncated + re-normalized embedding matrix for MRL cascade pre-filter.
31    /// `None` when cascade search is disabled.
32    truncated: Option<Array2<f32>>,
33    /// `TurboQuant`-compressed embeddings for fast approximate scanning.
34    /// At 4-bit: 386 bytes/vector vs 3072 bytes FP32 (8× compression).
35    /// Scan is ~5× faster than FP32 BLAS at 100K+ chunks.
36    compressed: Option<CompressedIndex>,
37    /// Hidden dimension size.
38    pub hidden_dim: usize,
39    /// Truncated dimension size, if cascade search is enabled.
40    truncated_dim: Option<usize>,
41}
42
43/// `PolarQuant`-compressed embedding index for fast approximate scanning.
44///
45/// Uses SoA flat layout ([`CompressedCorpus`]) for cache-friendly streaming scans.
46struct CompressedIndex {
47    /// The codec (holds rotation matrix + centroid tables).
48    codec: PolarCodec,
49    /// Flat SoA corpus: radii + indices packed contiguously.
50    corpus: CompressedCorpus,
51}
52
53impl SearchIndex {
54    /// Build an index from `embed_all` output.
55    ///
56    /// Flattens the per-chunk embedding vectors into a contiguous `Array2`
57    /// for BLAS-accelerated matrix-vector products at query time.
58    ///
59    /// When `cascade_dim` is `Some(d)`, also builds a truncated and
60    /// L2-re-normalized `[N, d]` matrix for two-phase MRL cascade search.
61    /// The truncated dimension is clamped to `hidden_dim`.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the flattened embedding data cannot form a valid
66    /// `[num_chunks, hidden_dim]` matrix (should never happen when
67    /// embeddings come from `embed_all`).
68    pub fn new(
69        chunks: Vec<CodeChunk>,
70        raw_embeddings: &[Vec<f32>],
71        cascade_dim: Option<usize>,
72    ) -> Self {
73        let hidden_dim = raw_embeddings.first().map_or(384, Vec::len);
74        let n = chunks.len();
75
76        // Flatten into contiguous array for BLAS
77        let mut flat = Vec::with_capacity(n * hidden_dim);
78        for emb in raw_embeddings {
79            if emb.len() == hidden_dim {
80                flat.extend_from_slice(emb);
81            } else {
82                // Pad/truncate to hidden_dim (shouldn't happen, but be safe)
83                flat.extend(emb.iter().take(hidden_dim));
84                flat.resize(flat.len() + hidden_dim.saturating_sub(emb.len()), 0.0);
85            }
86        }
87
88        let embeddings =
89            Array2::from_shape_vec((n, hidden_dim), flat).expect("embedding matrix shape mismatch");
90
91        // Build truncated + re-normalized matrix for MRL cascade pre-filter.
92        // Nomic MRL models require layer-norm before truncation:
93        //   1. Layer-norm over the FULL embedding (mean-center, scale by inv_std)
94        //   2. Truncate to first d dimensions
95        //   3. L2 renormalize the truncated slice
96        let truncated_dim = cascade_dim.map(|d| d.min(hidden_dim));
97        let truncated = truncated_dim.map(|d| {
98            let mut trunc = Array2::zeros((n, d));
99            for (i, row) in embeddings.rows().into_iter().enumerate() {
100                let full = row.as_slice().expect("embedding row not contiguous");
101
102                // Step 1: Layer-norm over FULL embedding
103                let len = full.len() as f32;
104                let mean: f32 = full.iter().sum::<f32>() / len;
105                let var: f32 = full.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len;
106                let inv_std = 1.0 / (var + 1e-5).sqrt();
107
108                // Step 2: Truncate first d dims of layer-normed embedding
109                // Step 3: L2 renormalize the truncated slice
110                let norm: f32 = full[..d]
111                    .iter()
112                    .map(|x| {
113                        let ln = (x - mean) * inv_std;
114                        ln * ln
115                    })
116                    .sum::<f32>()
117                    .sqrt()
118                    .max(1e-12);
119                for (j, &v) in full[..d].iter().enumerate() {
120                    trunc[[i, j]] = (v - mean) * inv_std / norm;
121                }
122            }
123            trunc
124        });
125
126        // Compress embeddings with PolarQuant (4-bit).
127        // At 768-dim: ~1920 bytes/vector vs 3072 FP32. 8× compression with bit-packing.
128        let compressed = if hidden_dim >= 64 && hidden_dim.is_multiple_of(2) {
129            let codec = PolarCodec::new(hidden_dim, 4, 42);
130            let corpus = codec.encode_batch(&embeddings);
131            Some(CompressedIndex { codec, corpus })
132        } else {
133            None
134        };
135
136        Self {
137            chunks,
138            embeddings,
139            truncated,
140            compressed,
141            hidden_dim,
142            truncated_dim,
143        }
144    }
145
146    /// Rank all chunks against a query embedding.
147    ///
148    /// Returns `(chunk_index, similarity_score)` pairs sorted by descending
149    /// score, filtered by `threshold`.
150    #[must_use]
151    pub fn rank(&self, query_embedding: &[f32], threshold: f32) -> Vec<(usize, f32)> {
152        if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
153            return vec![];
154        }
155        let query = Array1::from_vec(query_embedding.to_vec());
156        let scores = crate::similarity::rank_all(&self.embeddings, &query);
157
158        let mut results: Vec<(usize, f32)> = scores
159            .into_iter()
160            .enumerate()
161            .filter(|(_, score)| *score >= threshold)
162            .collect();
163        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
164        results
165    }
166
167    /// `TurboQuant`-accelerated ranking: compressed approximate scan → exact re-rank.
168    ///
169    /// 1. Estimate inner products for ALL vectors via `TurboQuant` (~5× faster than BLAS).
170    /// 2. Take top `pre_filter_k` approximate candidates.
171    /// 3. Re-rank with exact FP32 dot products on the full embedding matrix.
172    ///
173    /// Falls back to [`Self::rank`] when no compressed index is available.
174    #[must_use]
175    pub fn rank_turboquant(
176        &self,
177        query_embedding: &[f32],
178        top_k: usize,
179        threshold: f32,
180    ) -> Vec<(usize, f32)> {
181        let Some(ref comp) = self.compressed else {
182            return self.rank(query_embedding, threshold);
183        };
184
185        if comp.corpus.n != self.chunks.len() {
186            return self.rank(query_embedding, threshold);
187        }
188
189        // Phase 1: SoA corpus scan — sequential streaming, centroid table in L1.
190        let pre_filter_k = (top_k * 10).min(comp.corpus.n);
191        let query_state = comp.codec.prepare_query(query_embedding);
192        let scores = comp.codec.scan_corpus(&comp.corpus, &query_state);
193        let mut approx_scores: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
194        approx_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
195        approx_scores.truncate(pre_filter_k);
196
197        // Phase 2: exact re-rank top candidates
198        let query = Array1::from_vec(query_embedding.to_vec());
199        let mut results: Vec<(usize, f32)> = approx_scores
200            .iter()
201            .map(|&(idx, _)| {
202                let exact = self.embeddings.row(idx).dot(&query);
203                (idx, exact)
204            })
205            .filter(|(_, score)| *score >= threshold)
206            .collect();
207        results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
208        results.truncate(top_k);
209        results
210    }
211
212    /// Two-phase MRL cascade ranking: fast pre-filter then full re-rank.
213    ///
214    /// 1. Layer-norms the query over its full dimension, truncates to
215    ///    `truncated_dim`, L2-normalizes, and computes dot products against
216    ///    the truncated matrix to find the top `pre_filter_k` candidates.
217    /// 2. Re-ranks those candidates using full-dimension dot products.
218    ///
219    /// Falls back to [`Self::rank`] when no truncated matrix is available.
220    #[must_use]
221    pub fn rank_cascade(
222        &self,
223        query_embedding: &[f32],
224        top_k: usize,
225        threshold: f32,
226    ) -> Vec<(usize, f32)> {
227        let Some(ref trunc_matrix) = self.truncated else {
228            return self.rank(query_embedding, threshold);
229        };
230        if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
231            return vec![];
232        }
233
234        let trunc_dim = trunc_matrix.shape()[1];
235        let pre_filter_k = 100_usize.max(top_k * 3); // over-retrieve for re-ranking
236
237        // Phase 1: fast pre-filter at truncated dimension
238        // Apply layer-norm over full query before truncation (matches corpus processing)
239        let len = query_embedding.len() as f32;
240        let mean: f32 = query_embedding.iter().sum::<f32>() / len;
241        let var: f32 = query_embedding
242            .iter()
243            .map(|x| (x - mean).powi(2))
244            .sum::<f32>()
245            / len;
246        let inv_std = 1.0 / (var + 1e-5).sqrt();
247        let trunc_query: Vec<f32> = query_embedding[..trunc_dim]
248            .iter()
249            .map(|x| (x - mean) * inv_std)
250            .collect();
251        let norm: f32 = trunc_query
252            .iter()
253            .map(|x| x * x)
254            .sum::<f32>()
255            .sqrt()
256            .max(1e-12);
257        let trunc_query_norm: Vec<f32> = trunc_query.iter().map(|x| x / norm).collect();
258        let trunc_q = Array1::from_vec(trunc_query_norm);
259        let scores = trunc_matrix.dot(&trunc_q);
260
261        // Get top pre_filter_k indices
262        let mut candidates: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
263        candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
264        candidates.truncate(pre_filter_k);
265
266        // Phase 2: re-rank candidates with full-dimension dot products
267        let query_arr = Array1::from_vec(query_embedding.to_vec());
268        let mut reranked: Vec<(usize, f32)> = candidates
269            .into_iter()
270            .map(|(idx, _)| {
271                let full_score = self.embeddings.row(idx).dot(&query_arr);
272                (idx, full_score)
273            })
274            .filter(|(_, s)| *s >= threshold)
275            .collect();
276        reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
277        reranked.truncate(top_k);
278        reranked
279    }
280
281    /// Return a clone of the embedding vector for chunk `idx`.
282    ///
283    /// Returns `None` if `idx` is out of bounds.
284    #[must_use]
285    pub fn embedding(&self, idx: usize) -> Option<Vec<f32>> {
286        if idx >= self.chunks.len() {
287            return None;
288        }
289        Some(self.embeddings.row(idx).to_vec())
290    }
291
292    /// Find duplicate or near-duplicate chunks by pairwise cosine similarity.
293    ///
294    /// Computes `embeddings @ embeddings.T` (a single BLAS GEMM) to get all
295    /// pairwise similarities, then extracts pairs above `threshold` from the
296    /// upper triangle (avoiding self-matches and symmetric duplicates).
297    ///
298    /// Returns `(chunk_a, chunk_b, similarity)` sorted by descending similarity.
299    /// Each pair appears only once (a < b).
300    #[must_use]
301    pub fn find_duplicates(&self, threshold: f32, max_pairs: usize) -> Vec<(usize, usize, f32)> {
302        let n = self.chunks.len();
303        if n < 2 {
304            return vec![];
305        }
306
307        // Single GEMM: [n, dim] × [dim, n] = [n, n] pairwise similarity matrix
308        let sim_matrix = self.embeddings.dot(&self.embeddings.t());
309
310        // Scan upper triangle for pairs above threshold
311        let mut pairs: Vec<(usize, usize, f32)> = Vec::new();
312        for i in 0..n {
313            for j in (i + 1)..n {
314                let score = sim_matrix[[i, j]];
315                if score >= threshold {
316                    pairs.push((i, j, score));
317                }
318            }
319        }
320
321        pairs.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
322        pairs.truncate(max_pairs);
323        pairs
324    }
325
326    /// Number of chunks in the index.
327    #[must_use]
328    pub fn len(&self) -> usize {
329        self.chunks.len()
330    }
331
332    /// Whether the index is empty.
333    #[must_use]
334    pub fn is_empty(&self) -> bool {
335        self.chunks.is_empty()
336    }
337
338    /// The truncated dimension used for cascade pre-filtering, if enabled.
339    #[must_use]
340    pub fn truncated_dim(&self) -> Option<usize> {
341        self.truncated_dim
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    /// Helper to create a dummy `CodeChunk` for testing.
350    fn dummy_chunk(name: &str) -> CodeChunk {
351        let content = format!("fn {name}() {{}}");
352        CodeChunk {
353            file_path: "test.rs".to_string(),
354            name: name.to_string(),
355            kind: "function".to_string(),
356            start_line: 1,
357            end_line: 10,
358            enriched_content: content.clone(),
359            content,
360        }
361    }
362
363    #[test]
364    fn new_builds_correct_matrix_shape() {
365        let chunks = vec![dummy_chunk("a"), dummy_chunk("b"), dummy_chunk("c")];
366        let embeddings = vec![
367            vec![1.0, 0.0, 0.0],
368            vec![0.0, 1.0, 0.0],
369            vec![0.0, 0.0, 1.0],
370        ];
371
372        let index = SearchIndex::new(chunks, &embeddings, None);
373
374        assert_eq!(index.len(), 3);
375        assert_eq!(index.hidden_dim, 3);
376        assert!(!index.is_empty());
377    }
378
379    #[test]
380    fn rank_returns_sorted_results_above_threshold() {
381        let chunks = vec![dummy_chunk("low"), dummy_chunk("high"), dummy_chunk("mid")];
382        // Embeddings designed so dot product with [1, 0] gives known scores:
383        // chunk 0: 0.2, chunk 1: 0.9, chunk 2: 0.5
384        let embeddings = vec![vec![0.2, 0.8], vec![0.9, 0.1], vec![0.5, 0.5]];
385
386        let index = SearchIndex::new(chunks, &embeddings, None);
387        let results = index.rank(&[1.0, 0.0], 0.3);
388
389        // Should exclude chunk 0 (score 0.2 < threshold 0.3)
390        assert_eq!(results.len(), 2);
391        // Should be sorted descending: chunk 1 (0.9), then chunk 2 (0.5)
392        assert_eq!(results[0].0, 1);
393        assert_eq!(results[1].0, 2);
394        assert!(results[0].1 > results[1].1);
395    }
396
397    #[test]
398    fn rank_with_wrong_dimension_returns_empty() {
399        let chunks = vec![dummy_chunk("a")];
400        let embeddings = vec![vec![1.0, 0.0, 0.0]];
401
402        let index = SearchIndex::new(chunks, &embeddings, None);
403        // Query has wrong dimension (2 instead of 3)
404        let results = index.rank(&[1.0, 0.0], 0.0);
405
406        assert!(results.is_empty());
407    }
408
409    #[test]
410    fn rank_with_empty_query_returns_empty() {
411        let chunks = vec![dummy_chunk("a")];
412        let embeddings = vec![vec![1.0, 0.0, 0.0]];
413
414        let index = SearchIndex::new(chunks, &embeddings, None);
415        let results = index.rank(&[], 0.0);
416
417        assert!(results.is_empty());
418    }
419
420    #[test]
421    fn rank_handles_empty_index() {
422        let index = SearchIndex::new(vec![], &[], None);
423
424        // hidden_dim defaults to 384 for empty input
425        assert!(index.is_empty());
426        assert_eq!(index.len(), 0);
427
428        let results = index.rank(&[1.0; 384], 0.0);
429        assert!(results.is_empty());
430    }
431
432    /// L2-normalize a vector in-place.
433    fn l2_normalize(v: &mut [f32]) {
434        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
435        for x in v.iter_mut() {
436            *x /= norm;
437        }
438    }
439
440    #[test]
441    #[expect(
442        clippy::cast_precision_loss,
443        reason = "test values are small counts and indices"
444    )]
445    fn cascade_recall_at_10_vs_full_rank() {
446        // Build 200 chunks with 8-dim random-ish embeddings (L2-normalized).
447        // Use a deterministic pattern so the test is reproducible.
448        let n = 200;
449        let dim = 8;
450        let cascade_dim = 4;
451
452        let mut chunks = Vec::with_capacity(n);
453        let mut embeddings = Vec::with_capacity(n);
454        for i in 0..n {
455            chunks.push(dummy_chunk(&format!("chunk_{i}")));
456            // Deterministic pseudo-random: use sin/cos of index
457            let mut emb: Vec<f32> = (0..dim).map(|d| ((i * 7 + d * 13) as f32).sin()).collect();
458            l2_normalize(&mut emb);
459            embeddings.push(emb);
460        }
461
462        // Query: L2-normalized
463        let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
464        l2_normalize(&mut query);
465
466        // Build index without cascade (reference)
467        let index_full = SearchIndex::new(chunks.clone(), &embeddings, None);
468        let full_results = index_full.rank(&query, 0.0);
469        let full_top10: Vec<usize> = full_results.iter().take(10).map(|(idx, _)| *idx).collect();
470
471        // Build index with cascade
472        let index_cascade = SearchIndex::new(chunks, &embeddings, Some(cascade_dim));
473        assert_eq!(index_cascade.truncated_dim(), Some(cascade_dim));
474        let cascade_results = index_cascade.rank_cascade(&query, 10, 0.0);
475        let cascade_top10: Vec<usize> = cascade_results.iter().map(|(idx, _)| *idx).collect();
476
477        // Recall@10: how many of full-dim top-10 appear in cascade top-10
478        let overlap = full_top10
479            .iter()
480            .filter(|i| cascade_top10.contains(i))
481            .count();
482        let recall = overlap as f32 / 10.0;
483
484        assert!(
485            recall >= 0.7,
486            "cascade Recall@10 = {recall} ({overlap}/10), expected >= 0.7"
487        );
488    }
489
490    #[test]
491    fn cascade_falls_back_without_truncated_matrix() {
492        let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
493        let embeddings = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
494
495        // No cascade_dim → rank_cascade should behave like rank
496        let index = SearchIndex::new(chunks, &embeddings, None);
497        let cascade = index.rank_cascade(&[1.0, 0.0], 10, 0.0);
498        let plain = index.rank(&[1.0, 0.0], 0.0);
499
500        assert_eq!(cascade.len(), plain.len());
501        for (c, p) in cascade.iter().zip(plain.iter()) {
502            assert_eq!(c.0, p.0);
503            assert!((c.1 - p.1).abs() < 1e-6);
504        }
505    }
506
507    #[test]
508    fn cascade_respects_threshold() {
509        let chunks = vec![dummy_chunk("high"), dummy_chunk("low")];
510        // Embeddings: chunk 0 aligns with query, chunk 1 is orthogonal
511        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
512
513        let index = SearchIndex::new(chunks, &embeddings, Some(1));
514        let results = index.rank_cascade(&[1.0, 0.0], 10, 0.5);
515
516        // Only chunk 0 should pass the 0.5 threshold
517        assert_eq!(results.len(), 1);
518        assert_eq!(results[0].0, 0);
519    }
520
521    #[test]
522    fn turboquant_recall_vs_exact() {
523        // Generate 200 random 768-dim L2-normalized embeddings.
524        let dim = 768;
525        let n = 200;
526        let embeddings: Vec<Vec<f32>> = (0..n)
527            .map(|i| {
528                let mut v: Vec<f32> = (0..dim).map(|d| ((i * 17 + d * 31) as f32).sin()).collect();
529                l2_normalize(&mut v);
530                v
531            })
532            .collect();
533
534        let chunks: Vec<CodeChunk> = (0..n).map(|i| dummy_chunk(&format!("chunk_{i}"))).collect();
535        let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
536        l2_normalize(&mut query);
537
538        let index = SearchIndex::new(chunks, &embeddings, None);
539
540        // Exact ranking
541        let exact = index.rank(&query, 0.0);
542        let exact_top10: Vec<usize> = exact.iter().take(10).map(|(idx, _)| *idx).collect();
543
544        // TurboQuant ranking
545        let tq = index.rank_turboquant(&query, 10, 0.0);
546        let tq_top10: Vec<usize> = tq.iter().take(10).map(|(idx, _)| *idx).collect();
547
548        // Recall@10: how many of exact top-10 appear in TQ top-10
549        let recall = exact_top10.iter().filter(|i| tq_top10.contains(i)).count();
550        eprintln!("TurboQuant Recall@10: {recall}/10");
551        assert!(
552            recall >= 7,
553            "TurboQuant recall should be >= 7/10, got {recall}/10"
554        );
555    }
556}