1use ndarray::{Array1, Array2};
12
13use crate::chunk::CodeChunk;
14use crate::turbo_quant::{CompressedCorpus, PolarCodec};
15
16pub struct SearchIndex {
26 pub chunks: Vec<CodeChunk>,
28 embeddings: Array2<f32>,
30 truncated: Option<Array2<f32>>,
33 compressed: Option<CompressedIndex>,
37 pub hidden_dim: usize,
39 truncated_dim: Option<usize>,
41}
42
43struct CompressedIndex {
47 codec: PolarCodec,
49 corpus: CompressedCorpus,
51}
52
53impl SearchIndex {
54 pub fn new(
69 chunks: Vec<CodeChunk>,
70 raw_embeddings: &[Vec<f32>],
71 cascade_dim: Option<usize>,
72 ) -> Self {
73 let hidden_dim = raw_embeddings.first().map_or(384, Vec::len);
74 let n = chunks.len();
75
76 let mut flat = Vec::with_capacity(n * hidden_dim);
78 for emb in raw_embeddings {
79 if emb.len() == hidden_dim {
80 flat.extend_from_slice(emb);
81 } else {
82 flat.extend(emb.iter().take(hidden_dim));
84 flat.resize(flat.len() + hidden_dim.saturating_sub(emb.len()), 0.0);
85 }
86 }
87
88 let embeddings =
89 Array2::from_shape_vec((n, hidden_dim), flat).expect("embedding matrix shape mismatch");
90
91 let truncated_dim = cascade_dim.map(|d| d.min(hidden_dim));
97 let truncated = truncated_dim.map(|d| {
98 let mut trunc = Array2::zeros((n, d));
99 for (i, row) in embeddings.rows().into_iter().enumerate() {
100 let full = row.as_slice().expect("embedding row not contiguous");
101
102 let len = full.len() as f32;
104 let mean: f32 = full.iter().sum::<f32>() / len;
105 let var: f32 = full.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len;
106 let inv_std = 1.0 / (var + 1e-5).sqrt();
107
108 let norm: f32 = full[..d]
111 .iter()
112 .map(|x| {
113 let ln = (x - mean) * inv_std;
114 ln * ln
115 })
116 .sum::<f32>()
117 .sqrt()
118 .max(1e-12);
119 for (j, &v) in full[..d].iter().enumerate() {
120 trunc[[i, j]] = (v - mean) * inv_std / norm;
121 }
122 }
123 trunc
124 });
125
126 let compressed = if hidden_dim >= 64 && hidden_dim.is_multiple_of(2) {
129 let codec = PolarCodec::new(hidden_dim, 4, 42);
130 let corpus = codec.encode_batch(&embeddings);
131 Some(CompressedIndex { codec, corpus })
132 } else {
133 None
134 };
135
136 Self {
137 chunks,
138 embeddings,
139 truncated,
140 compressed,
141 hidden_dim,
142 truncated_dim,
143 }
144 }
145
146 #[must_use]
151 pub fn rank(&self, query_embedding: &[f32], threshold: f32) -> Vec<(usize, f32)> {
152 if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
153 return vec![];
154 }
155 let query = Array1::from_vec(query_embedding.to_vec());
156 let scores = crate::similarity::rank_all(&self.embeddings, &query);
157
158 let mut results: Vec<(usize, f32)> = scores
159 .into_iter()
160 .enumerate()
161 .filter(|(_, score)| *score >= threshold)
162 .collect();
163 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
164 results
165 }
166
167 #[must_use]
175 pub fn rank_turboquant(
176 &self,
177 query_embedding: &[f32],
178 top_k: usize,
179 threshold: f32,
180 ) -> Vec<(usize, f32)> {
181 let Some(ref comp) = self.compressed else {
182 return self.rank(query_embedding, threshold);
183 };
184
185 if comp.corpus.n != self.chunks.len() {
186 return self.rank(query_embedding, threshold);
187 }
188
189 let pre_filter_k = top_k.saturating_mul(10).min(comp.corpus.n);
194 let query_state = comp.codec.prepare_query(query_embedding);
195 let scores = comp.codec.scan_corpus(&comp.corpus, &query_state);
196 let mut approx_scores: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
197 approx_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
198 approx_scores.truncate(pre_filter_k);
199
200 let query = Array1::from_vec(query_embedding.to_vec());
202 let mut results: Vec<(usize, f32)> = approx_scores
203 .iter()
204 .map(|&(idx, _)| {
205 let exact = self.embeddings.row(idx).dot(&query);
206 (idx, exact)
207 })
208 .filter(|(_, score)| *score >= threshold)
209 .collect();
210 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
211 results.truncate(top_k);
212 results
213 }
214
215 #[must_use]
224 pub fn rank_cascade(
225 &self,
226 query_embedding: &[f32],
227 top_k: usize,
228 threshold: f32,
229 ) -> Vec<(usize, f32)> {
230 let Some(ref trunc_matrix) = self.truncated else {
231 return self.rank(query_embedding, threshold);
232 };
233 if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
234 return vec![];
235 }
236
237 let trunc_dim = trunc_matrix.shape()[1];
238 let pre_filter_k = 100_usize.max(top_k * 3); let len = query_embedding.len() as f32;
243 let mean: f32 = query_embedding.iter().sum::<f32>() / len;
244 let var: f32 = query_embedding
245 .iter()
246 .map(|x| (x - mean).powi(2))
247 .sum::<f32>()
248 / len;
249 let inv_std = 1.0 / (var + 1e-5).sqrt();
250 let trunc_query: Vec<f32> = query_embedding[..trunc_dim]
251 .iter()
252 .map(|x| (x - mean) * inv_std)
253 .collect();
254 let norm: f32 = trunc_query
255 .iter()
256 .map(|x| x * x)
257 .sum::<f32>()
258 .sqrt()
259 .max(1e-12);
260 let trunc_query_norm: Vec<f32> = trunc_query.iter().map(|x| x / norm).collect();
261 let trunc_q = Array1::from_vec(trunc_query_norm);
262 let scores = trunc_matrix.dot(&trunc_q);
263
264 let mut candidates: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
266 candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
267 candidates.truncate(pre_filter_k);
268
269 let query_arr = Array1::from_vec(query_embedding.to_vec());
271 let mut reranked: Vec<(usize, f32)> = candidates
272 .into_iter()
273 .map(|(idx, _)| {
274 let full_score = self.embeddings.row(idx).dot(&query_arr);
275 (idx, full_score)
276 })
277 .filter(|(_, s)| *s >= threshold)
278 .collect();
279 reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
280 reranked.truncate(top_k);
281 reranked
282 }
283
284 #[must_use]
288 pub fn embedding(&self, idx: usize) -> Option<Vec<f32>> {
289 if idx >= self.chunks.len() {
290 return None;
291 }
292 Some(self.embeddings.row(idx).to_vec())
293 }
294
295 #[must_use]
304 pub fn find_duplicates(&self, threshold: f32, max_pairs: usize) -> Vec<(usize, usize, f32)> {
305 let n = self.chunks.len();
306 if n < 2 {
307 return vec![];
308 }
309
310 let sim_matrix = self.embeddings.dot(&self.embeddings.t());
312
313 let mut pairs: Vec<(usize, usize, f32)> = Vec::new();
315 for i in 0..n {
316 for j in (i + 1)..n {
317 let score = sim_matrix[[i, j]];
318 if score >= threshold {
319 pairs.push((i, j, score));
320 }
321 }
322 }
323
324 pairs.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
325 pairs.truncate(max_pairs);
326 pairs
327 }
328
329 #[must_use]
331 pub fn len(&self) -> usize {
332 self.chunks.len()
333 }
334
335 #[must_use]
337 pub fn is_empty(&self) -> bool {
338 self.chunks.is_empty()
339 }
340
341 #[must_use]
343 pub fn truncated_dim(&self) -> Option<usize> {
344 self.truncated_dim
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 fn dummy_chunk(name: &str) -> CodeChunk {
354 let content = format!("fn {name}() {{}}");
355 CodeChunk {
356 file_path: "test.rs".to_string(),
357 name: name.to_string(),
358 kind: "function".to_string(),
359 start_line: 1,
360 end_line: 10,
361 enriched_content: content.clone(),
362 content,
363 }
364 }
365
366 #[test]
367 fn new_builds_correct_matrix_shape() {
368 let chunks = vec![dummy_chunk("a"), dummy_chunk("b"), dummy_chunk("c")];
369 let embeddings = vec![
370 vec![1.0, 0.0, 0.0],
371 vec![0.0, 1.0, 0.0],
372 vec![0.0, 0.0, 1.0],
373 ];
374
375 let index = SearchIndex::new(chunks, &embeddings, None);
376
377 assert_eq!(index.len(), 3);
378 assert_eq!(index.hidden_dim, 3);
379 assert!(!index.is_empty());
380 }
381
382 #[test]
383 fn rank_returns_sorted_results_above_threshold() {
384 let chunks = vec![dummy_chunk("low"), dummy_chunk("high"), dummy_chunk("mid")];
385 let embeddings = vec![vec![0.2, 0.8], vec![0.9, 0.1], vec![0.5, 0.5]];
388
389 let index = SearchIndex::new(chunks, &embeddings, None);
390 let results = index.rank(&[1.0, 0.0], 0.3);
391
392 assert_eq!(results.len(), 2);
394 assert_eq!(results[0].0, 1);
396 assert_eq!(results[1].0, 2);
397 assert!(results[0].1 > results[1].1);
398 }
399
400 #[test]
401 fn rank_with_wrong_dimension_returns_empty() {
402 let chunks = vec![dummy_chunk("a")];
403 let embeddings = vec![vec![1.0, 0.0, 0.0]];
404
405 let index = SearchIndex::new(chunks, &embeddings, None);
406 let results = index.rank(&[1.0, 0.0], 0.0);
408
409 assert!(results.is_empty());
410 }
411
412 #[test]
413 fn rank_with_empty_query_returns_empty() {
414 let chunks = vec![dummy_chunk("a")];
415 let embeddings = vec![vec![1.0, 0.0, 0.0]];
416
417 let index = SearchIndex::new(chunks, &embeddings, None);
418 let results = index.rank(&[], 0.0);
419
420 assert!(results.is_empty());
421 }
422
423 #[test]
424 fn rank_handles_empty_index() {
425 let index = SearchIndex::new(vec![], &[], None);
426
427 assert!(index.is_empty());
429 assert_eq!(index.len(), 0);
430
431 let results = index.rank(&[1.0; 384], 0.0);
432 assert!(results.is_empty());
433 }
434
435 fn l2_normalize(v: &mut [f32]) {
437 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
438 for x in v.iter_mut() {
439 *x /= norm;
440 }
441 }
442
443 #[test]
444 #[expect(
445 clippy::cast_precision_loss,
446 reason = "test values are small counts and indices"
447 )]
448 fn cascade_recall_at_10_vs_full_rank() {
449 let n = 200;
452 let dim = 8;
453 let cascade_dim = 4;
454
455 let mut chunks = Vec::with_capacity(n);
456 let mut embeddings = Vec::with_capacity(n);
457 for i in 0..n {
458 chunks.push(dummy_chunk(&format!("chunk_{i}")));
459 let mut emb: Vec<f32> = (0..dim).map(|d| ((i * 7 + d * 13) as f32).sin()).collect();
461 l2_normalize(&mut emb);
462 embeddings.push(emb);
463 }
464
465 let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
467 l2_normalize(&mut query);
468
469 let index_full = SearchIndex::new(chunks.clone(), &embeddings, None);
471 let full_results = index_full.rank(&query, 0.0);
472 let full_top10: Vec<usize> = full_results.iter().take(10).map(|(idx, _)| *idx).collect();
473
474 let index_cascade = SearchIndex::new(chunks, &embeddings, Some(cascade_dim));
476 assert_eq!(index_cascade.truncated_dim(), Some(cascade_dim));
477 let cascade_results = index_cascade.rank_cascade(&query, 10, 0.0);
478 let cascade_top10: Vec<usize> = cascade_results.iter().map(|(idx, _)| *idx).collect();
479
480 let overlap = full_top10
482 .iter()
483 .filter(|i| cascade_top10.contains(i))
484 .count();
485 let recall = overlap as f32 / 10.0;
486
487 assert!(
488 recall >= 0.7,
489 "cascade Recall@10 = {recall} ({overlap}/10), expected >= 0.7"
490 );
491 }
492
493 #[test]
494 fn cascade_falls_back_without_truncated_matrix() {
495 let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
496 let embeddings = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
497
498 let index = SearchIndex::new(chunks, &embeddings, None);
500 let cascade = index.rank_cascade(&[1.0, 0.0], 10, 0.0);
501 let plain = index.rank(&[1.0, 0.0], 0.0);
502
503 assert_eq!(cascade.len(), plain.len());
504 for (c, p) in cascade.iter().zip(plain.iter()) {
505 assert_eq!(c.0, p.0);
506 assert!((c.1 - p.1).abs() < 1e-6);
507 }
508 }
509
510 #[test]
511 fn cascade_respects_threshold() {
512 let chunks = vec![dummy_chunk("high"), dummy_chunk("low")];
513 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
515
516 let index = SearchIndex::new(chunks, &embeddings, Some(1));
517 let results = index.rank_cascade(&[1.0, 0.0], 10, 0.5);
518
519 assert_eq!(results.len(), 1);
521 assert_eq!(results[0].0, 0);
522 }
523
524 #[test]
525 fn turboquant_recall_vs_exact() {
526 let dim = 768;
528 let n = 200;
529 let embeddings: Vec<Vec<f32>> = (0..n)
530 .map(|i| {
531 let mut v: Vec<f32> = (0..dim).map(|d| ((i * 17 + d * 31) as f32).sin()).collect();
532 l2_normalize(&mut v);
533 v
534 })
535 .collect();
536
537 let chunks: Vec<CodeChunk> = (0..n).map(|i| dummy_chunk(&format!("chunk_{i}"))).collect();
538 let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
539 l2_normalize(&mut query);
540
541 let index = SearchIndex::new(chunks, &embeddings, None);
542
543 let exact = index.rank(&query, 0.0);
545 let exact_top10: Vec<usize> = exact.iter().take(10).map(|(idx, _)| *idx).collect();
546
547 let tq = index.rank_turboquant(&query, 10, 0.0);
549 let tq_top10: Vec<usize> = tq.iter().take(10).map(|(idx, _)| *idx).collect();
550
551 let recall = exact_top10.iter().filter(|i| tq_top10.contains(i)).count();
553 eprintln!("TurboQuant Recall@10: {recall}/10");
554 assert!(
555 recall >= 7,
556 "TurboQuant recall should be >= 7/10, got {recall}/10"
557 );
558 }
559}