1use ndarray::{Array1, Array2};
12
13use crate::chunk::CodeChunk;
14use crate::turbo_quant::{CompressedCorpus, PolarCodec};
15
16const TURBOQUANT_MIN_CHUNKS: usize = 4096;
23
24pub struct SearchIndex {
34 pub chunks: Vec<CodeChunk>,
36 embeddings: Array2<f32>,
38 truncated: Option<Array2<f32>>,
41 compressed: Option<CompressedIndex>,
45 pub hidden_dim: usize,
47 truncated_dim: Option<usize>,
49}
50
51struct CompressedIndex {
55 codec: PolarCodec,
57 corpus: CompressedCorpus,
59}
60
61impl SearchIndex {
62 pub fn new(
77 chunks: Vec<CodeChunk>,
78 raw_embeddings: &[Vec<f32>],
79 cascade_dim: Option<usize>,
80 ) -> Self {
81 let hidden_dim = raw_embeddings.first().map_or(384, Vec::len);
82 let n = chunks.len();
83
84 let mut flat = Vec::with_capacity(n * hidden_dim);
86 for emb in raw_embeddings {
87 if emb.len() == hidden_dim {
88 flat.extend_from_slice(emb);
89 } else {
90 flat.extend(emb.iter().take(hidden_dim));
92 flat.resize(flat.len() + hidden_dim.saturating_sub(emb.len()), 0.0);
93 }
94 }
95
96 let embeddings =
97 Array2::from_shape_vec((n, hidden_dim), flat).expect("embedding matrix shape mismatch");
98
99 let truncated_dim = cascade_dim.map(|d| d.min(hidden_dim));
105 let truncated = truncated_dim.map(|d| {
106 let mut trunc = Array2::zeros((n, d));
107 for (i, row) in embeddings.rows().into_iter().enumerate() {
108 let full = row.as_slice().expect("embedding row not contiguous");
109
110 let len = full.len() as f32;
112 let mean: f32 = full.iter().sum::<f32>() / len;
113 let var: f32 = full.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len;
114 let inv_std = 1.0 / (var + 1e-5).sqrt();
115
116 let norm: f32 = full[..d]
119 .iter()
120 .map(|x| {
121 let ln = (x - mean) * inv_std;
122 ln * ln
123 })
124 .sum::<f32>()
125 .sqrt()
126 .max(1e-12);
127 for (j, &v) in full[..d].iter().enumerate() {
128 trunc[[i, j]] = (v - mean) * inv_std / norm;
129 }
130 }
131 trunc
132 });
133
134 let compressed =
138 if n >= TURBOQUANT_MIN_CHUNKS && hidden_dim >= 64 && hidden_dim.is_multiple_of(2) {
139 tracing::info!(
140 chunks = n,
141 hidden_dim,
142 "building TurboQuant compressed index"
143 );
144 let codec = PolarCodec::new(hidden_dim, 4, 42);
145 let corpus = codec.encode_batch(&embeddings);
146 Some(CompressedIndex { codec, corpus })
147 } else {
148 tracing::debug!(
149 chunks = n,
150 hidden_dim,
151 min_chunks = TURBOQUANT_MIN_CHUNKS,
152 "skipping TurboQuant compression for small corpus"
153 );
154 None
155 };
156
157 Self {
158 chunks,
159 embeddings,
160 truncated,
161 compressed,
162 hidden_dim,
163 truncated_dim,
164 }
165 }
166
167 #[must_use]
172 pub fn rank(&self, query_embedding: &[f32], threshold: f32) -> Vec<(usize, f32)> {
173 if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
174 return vec![];
175 }
176 let query = Array1::from_vec(query_embedding.to_vec());
177 let scores = crate::similarity::rank_all(&self.embeddings, &query);
178
179 let mut results: Vec<(usize, f32)> = scores
180 .into_iter()
181 .enumerate()
182 .filter(|(_, score)| *score >= threshold)
183 .collect();
184 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
185 results
186 }
187
188 #[must_use]
196 pub fn rank_turboquant(
197 &self,
198 query_embedding: &[f32],
199 top_k: usize,
200 threshold: f32,
201 ) -> Vec<(usize, f32)> {
202 let Some(ref comp) = self.compressed else {
203 return self.rank(query_embedding, threshold);
204 };
205
206 if comp.corpus.n != self.chunks.len() {
207 return self.rank(query_embedding, threshold);
208 }
209
210 let pre_filter_k = top_k.saturating_mul(10).min(comp.corpus.n);
215 let query_state = comp.codec.prepare_query(query_embedding);
216 let scores = comp.codec.scan_corpus(&comp.corpus, &query_state);
217 let mut approx_scores: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
218 approx_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
219 approx_scores.truncate(pre_filter_k);
220
221 let query = Array1::from_vec(query_embedding.to_vec());
223 let mut results: Vec<(usize, f32)> = approx_scores
224 .iter()
225 .map(|&(idx, _)| {
226 let exact = self.embeddings.row(idx).dot(&query);
227 (idx, exact)
228 })
229 .filter(|(_, score)| *score >= threshold)
230 .collect();
231 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
232 results.truncate(top_k);
233 results
234 }
235
236 #[must_use]
245 pub fn rank_cascade(
246 &self,
247 query_embedding: &[f32],
248 top_k: usize,
249 threshold: f32,
250 ) -> Vec<(usize, f32)> {
251 let Some(ref trunc_matrix) = self.truncated else {
252 return self.rank(query_embedding, threshold);
253 };
254 if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
255 return vec![];
256 }
257
258 let trunc_dim = trunc_matrix.shape()[1];
259 let pre_filter_k = 100_usize.max(top_k * 3); let len = query_embedding.len() as f32;
264 let mean: f32 = query_embedding.iter().sum::<f32>() / len;
265 let var: f32 = query_embedding
266 .iter()
267 .map(|x| (x - mean).powi(2))
268 .sum::<f32>()
269 / len;
270 let inv_std = 1.0 / (var + 1e-5).sqrt();
271 let trunc_query: Vec<f32> = query_embedding[..trunc_dim]
272 .iter()
273 .map(|x| (x - mean) * inv_std)
274 .collect();
275 let norm: f32 = trunc_query
276 .iter()
277 .map(|x| x * x)
278 .sum::<f32>()
279 .sqrt()
280 .max(1e-12);
281 let trunc_query_norm: Vec<f32> = trunc_query.iter().map(|x| x / norm).collect();
282 let trunc_q = Array1::from_vec(trunc_query_norm);
283 let scores = trunc_matrix.dot(&trunc_q);
284
285 let mut candidates: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
287 candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
288 candidates.truncate(pre_filter_k);
289
290 let query_arr = Array1::from_vec(query_embedding.to_vec());
292 let mut reranked: Vec<(usize, f32)> = candidates
293 .into_iter()
294 .map(|(idx, _)| {
295 let full_score = self.embeddings.row(idx).dot(&query_arr);
296 (idx, full_score)
297 })
298 .filter(|(_, s)| *s >= threshold)
299 .collect();
300 reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
301 reranked.truncate(top_k);
302 reranked
303 }
304
305 #[must_use]
309 pub fn embedding(&self, idx: usize) -> Option<Vec<f32>> {
310 if idx >= self.chunks.len() {
311 return None;
312 }
313 Some(self.embeddings.row(idx).to_vec())
314 }
315
316 #[must_use]
325 pub fn find_duplicates(&self, threshold: f32, max_pairs: usize) -> Vec<(usize, usize, f32)> {
326 let n = self.chunks.len();
327 if n < 2 {
328 return vec![];
329 }
330
331 let sim_matrix = self.embeddings.dot(&self.embeddings.t());
333
334 let mut pairs: Vec<(usize, usize, f32)> = Vec::new();
336 for i in 0..n {
337 for j in (i + 1)..n {
338 let score = sim_matrix[[i, j]];
339 if score >= threshold {
340 pairs.push((i, j, score));
341 }
342 }
343 }
344
345 pairs.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
346 pairs.truncate(max_pairs);
347 pairs
348 }
349
350 #[must_use]
352 pub fn len(&self) -> usize {
353 self.chunks.len()
354 }
355
356 #[must_use]
358 pub fn is_empty(&self) -> bool {
359 self.chunks.is_empty()
360 }
361
362 #[must_use]
364 pub fn truncated_dim(&self) -> Option<usize> {
365 self.truncated_dim
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 fn dummy_chunk(name: &str) -> CodeChunk {
375 let content = format!("fn {name}() {{}}");
376 CodeChunk {
377 file_path: "test.rs".to_string(),
378 name: name.to_string(),
379 kind: "function".to_string(),
380 start_line: 1,
381 end_line: 10,
382 enriched_content: content.clone(),
383 content,
384 }
385 }
386
387 #[test]
388 fn new_builds_correct_matrix_shape() {
389 let chunks = vec![dummy_chunk("a"), dummy_chunk("b"), dummy_chunk("c")];
390 let embeddings = vec![
391 vec![1.0, 0.0, 0.0],
392 vec![0.0, 1.0, 0.0],
393 vec![0.0, 0.0, 1.0],
394 ];
395
396 let index = SearchIndex::new(chunks, &embeddings, None);
397
398 assert_eq!(index.len(), 3);
399 assert_eq!(index.hidden_dim, 3);
400 assert!(!index.is_empty());
401 }
402
403 #[test]
404 fn small_corpus_skips_turboquant_compression() {
405 let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
406 let embeddings = vec![vec![0.0; 768], vec![1.0; 768]];
407
408 let index = SearchIndex::new(chunks, &embeddings, None);
409
410 assert!(index.compressed.is_none());
411 }
412
413 #[test]
414 fn rank_returns_sorted_results_above_threshold() {
415 let chunks = vec![dummy_chunk("low"), dummy_chunk("high"), dummy_chunk("mid")];
416 let embeddings = vec![vec![0.2, 0.8], vec![0.9, 0.1], vec![0.5, 0.5]];
419
420 let index = SearchIndex::new(chunks, &embeddings, None);
421 let results = index.rank(&[1.0, 0.0], 0.3);
422
423 assert_eq!(results.len(), 2);
425 assert_eq!(results[0].0, 1);
427 assert_eq!(results[1].0, 2);
428 assert!(results[0].1 > results[1].1);
429 }
430
431 #[test]
432 fn rank_with_wrong_dimension_returns_empty() {
433 let chunks = vec![dummy_chunk("a")];
434 let embeddings = vec![vec![1.0, 0.0, 0.0]];
435
436 let index = SearchIndex::new(chunks, &embeddings, None);
437 let results = index.rank(&[1.0, 0.0], 0.0);
439
440 assert!(results.is_empty());
441 }
442
443 #[test]
444 fn rank_with_empty_query_returns_empty() {
445 let chunks = vec![dummy_chunk("a")];
446 let embeddings = vec![vec![1.0, 0.0, 0.0]];
447
448 let index = SearchIndex::new(chunks, &embeddings, None);
449 let results = index.rank(&[], 0.0);
450
451 assert!(results.is_empty());
452 }
453
454 #[test]
455 fn rank_handles_empty_index() {
456 let index = SearchIndex::new(vec![], &[], None);
457
458 assert!(index.is_empty());
460 assert_eq!(index.len(), 0);
461
462 let results = index.rank(&[1.0; 384], 0.0);
463 assert!(results.is_empty());
464 }
465
466 fn l2_normalize(v: &mut [f32]) {
468 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
469 for x in v.iter_mut() {
470 *x /= norm;
471 }
472 }
473
474 #[test]
475 #[expect(
476 clippy::cast_precision_loss,
477 reason = "test values are small counts and indices"
478 )]
479 fn cascade_recall_at_10_vs_full_rank() {
480 let n = 200;
483 let dim = 8;
484 let cascade_dim = 4;
485
486 let mut chunks = Vec::with_capacity(n);
487 let mut embeddings = Vec::with_capacity(n);
488 for i in 0..n {
489 chunks.push(dummy_chunk(&format!("chunk_{i}")));
490 let mut emb: Vec<f32> = (0..dim).map(|d| ((i * 7 + d * 13) as f32).sin()).collect();
492 l2_normalize(&mut emb);
493 embeddings.push(emb);
494 }
495
496 let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
498 l2_normalize(&mut query);
499
500 let index_full = SearchIndex::new(chunks.clone(), &embeddings, None);
502 let full_results = index_full.rank(&query, 0.0);
503 let full_top10: Vec<usize> = full_results.iter().take(10).map(|(idx, _)| *idx).collect();
504
505 let index_cascade = SearchIndex::new(chunks, &embeddings, Some(cascade_dim));
507 assert_eq!(index_cascade.truncated_dim(), Some(cascade_dim));
508 let cascade_results = index_cascade.rank_cascade(&query, 10, 0.0);
509 let cascade_top10: Vec<usize> = cascade_results.iter().map(|(idx, _)| *idx).collect();
510
511 let overlap = full_top10
513 .iter()
514 .filter(|i| cascade_top10.contains(i))
515 .count();
516 let recall = overlap as f32 / 10.0;
517
518 assert!(
519 recall >= 0.7,
520 "cascade Recall@10 = {recall} ({overlap}/10), expected >= 0.7"
521 );
522 }
523
524 #[test]
525 fn cascade_falls_back_without_truncated_matrix() {
526 let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
527 let embeddings = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
528
529 let index = SearchIndex::new(chunks, &embeddings, None);
531 let cascade = index.rank_cascade(&[1.0, 0.0], 10, 0.0);
532 let plain = index.rank(&[1.0, 0.0], 0.0);
533
534 assert_eq!(cascade.len(), plain.len());
535 for (c, p) in cascade.iter().zip(plain.iter()) {
536 assert_eq!(c.0, p.0);
537 assert!((c.1 - p.1).abs() < 1e-6);
538 }
539 }
540
541 #[test]
542 fn cascade_respects_threshold() {
543 let chunks = vec![dummy_chunk("high"), dummy_chunk("low")];
544 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
546
547 let index = SearchIndex::new(chunks, &embeddings, Some(1));
548 let results = index.rank_cascade(&[1.0, 0.0], 10, 0.5);
549
550 assert_eq!(results.len(), 1);
552 assert_eq!(results[0].0, 0);
553 }
554
555 #[test]
556 fn turboquant_recall_vs_exact() {
557 let dim = 768;
559 let n = 200;
560 let embeddings: Vec<Vec<f32>> = (0..n)
561 .map(|i| {
562 let mut v: Vec<f32> = (0..dim).map(|d| ((i * 17 + d * 31) as f32).sin()).collect();
563 l2_normalize(&mut v);
564 v
565 })
566 .collect();
567
568 let chunks: Vec<CodeChunk> = (0..n).map(|i| dummy_chunk(&format!("chunk_{i}"))).collect();
569 let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
570 l2_normalize(&mut query);
571
572 let index = SearchIndex::new(chunks, &embeddings, None);
573
574 let exact = index.rank(&query, 0.0);
576 let exact_top10: Vec<usize> = exact.iter().take(10).map(|(idx, _)| *idx).collect();
577
578 let tq = index.rank_turboquant(&query, 10, 0.0);
580 let tq_top10: Vec<usize> = tq.iter().take(10).map(|(idx, _)| *idx).collect();
581
582 let recall = exact_top10.iter().filter(|i| tq_top10.contains(i)).count();
584 eprintln!("TurboQuant Recall@10: {recall}/10");
585 assert!(
586 recall >= 7,
587 "TurboQuant recall should be >= 7/10, got {recall}/10"
588 );
589 }
590}