1use ndarray::{Array1, Array2};
12
13use crate::chunk::CodeChunk;
14use crate::turbo_quant::{CompressedCorpus, PolarCodec};
15
16pub struct SearchIndex {
26 pub chunks: Vec<CodeChunk>,
28 embeddings: Array2<f32>,
30 truncated: Option<Array2<f32>>,
33 compressed: Option<CompressedIndex>,
37 pub hidden_dim: usize,
39 truncated_dim: Option<usize>,
41}
42
43struct CompressedIndex {
47 codec: PolarCodec,
49 corpus: CompressedCorpus,
51}
52
53impl SearchIndex {
54 pub fn new(
69 chunks: Vec<CodeChunk>,
70 raw_embeddings: &[Vec<f32>],
71 cascade_dim: Option<usize>,
72 ) -> Self {
73 let hidden_dim = raw_embeddings.first().map_or(384, Vec::len);
74 let n = chunks.len();
75
76 let mut flat = Vec::with_capacity(n * hidden_dim);
78 for emb in raw_embeddings {
79 if emb.len() == hidden_dim {
80 flat.extend_from_slice(emb);
81 } else {
82 flat.extend(emb.iter().take(hidden_dim));
84 flat.resize(flat.len() + hidden_dim.saturating_sub(emb.len()), 0.0);
85 }
86 }
87
88 let embeddings =
89 Array2::from_shape_vec((n, hidden_dim), flat).expect("embedding matrix shape mismatch");
90
91 let truncated_dim = cascade_dim.map(|d| d.min(hidden_dim));
97 let truncated = truncated_dim.map(|d| {
98 let mut trunc = Array2::zeros((n, d));
99 for (i, row) in embeddings.rows().into_iter().enumerate() {
100 let full = row.as_slice().expect("embedding row not contiguous");
101
102 let len = full.len() as f32;
104 let mean: f32 = full.iter().sum::<f32>() / len;
105 let var: f32 = full.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / len;
106 let inv_std = 1.0 / (var + 1e-5).sqrt();
107
108 let norm: f32 = full[..d]
111 .iter()
112 .map(|x| {
113 let ln = (x - mean) * inv_std;
114 ln * ln
115 })
116 .sum::<f32>()
117 .sqrt()
118 .max(1e-12);
119 for (j, &v) in full[..d].iter().enumerate() {
120 trunc[[i, j]] = (v - mean) * inv_std / norm;
121 }
122 }
123 trunc
124 });
125
126 let compressed = if hidden_dim >= 64 && hidden_dim.is_multiple_of(2) {
129 let codec = PolarCodec::new(hidden_dim, 4, 42);
130 let corpus = codec.encode_batch(&embeddings);
131 Some(CompressedIndex { codec, corpus })
132 } else {
133 None
134 };
135
136 Self {
137 chunks,
138 embeddings,
139 truncated,
140 compressed,
141 hidden_dim,
142 truncated_dim,
143 }
144 }
145
146 #[must_use]
151 pub fn rank(&self, query_embedding: &[f32], threshold: f32) -> Vec<(usize, f32)> {
152 if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
153 return vec![];
154 }
155 let query = Array1::from_vec(query_embedding.to_vec());
156 let scores = crate::similarity::rank_all(&self.embeddings, &query);
157
158 let mut results: Vec<(usize, f32)> = scores
159 .into_iter()
160 .enumerate()
161 .filter(|(_, score)| *score >= threshold)
162 .collect();
163 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
164 results
165 }
166
167 #[must_use]
175 pub fn rank_turboquant(
176 &self,
177 query_embedding: &[f32],
178 top_k: usize,
179 threshold: f32,
180 ) -> Vec<(usize, f32)> {
181 let Some(ref comp) = self.compressed else {
182 return self.rank(query_embedding, threshold);
183 };
184
185 if comp.corpus.n != self.chunks.len() {
186 return self.rank(query_embedding, threshold);
187 }
188
189 let pre_filter_k = (top_k * 10).min(comp.corpus.n);
191 let query_state = comp.codec.prepare_query(query_embedding);
192 let scores = comp.codec.scan_corpus(&comp.corpus, &query_state);
193 let mut approx_scores: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
194 approx_scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
195 approx_scores.truncate(pre_filter_k);
196
197 let query = Array1::from_vec(query_embedding.to_vec());
199 let mut results: Vec<(usize, f32)> = approx_scores
200 .iter()
201 .map(|&(idx, _)| {
202 let exact = self.embeddings.row(idx).dot(&query);
203 (idx, exact)
204 })
205 .filter(|(_, score)| *score >= threshold)
206 .collect();
207 results.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
208 results.truncate(top_k);
209 results
210 }
211
212 #[must_use]
221 pub fn rank_cascade(
222 &self,
223 query_embedding: &[f32],
224 top_k: usize,
225 threshold: f32,
226 ) -> Vec<(usize, f32)> {
227 let Some(ref trunc_matrix) = self.truncated else {
228 return self.rank(query_embedding, threshold);
229 };
230 if query_embedding.len() != self.hidden_dim || self.chunks.is_empty() {
231 return vec![];
232 }
233
234 let trunc_dim = trunc_matrix.shape()[1];
235 let pre_filter_k = 100_usize.max(top_k * 3); let len = query_embedding.len() as f32;
240 let mean: f32 = query_embedding.iter().sum::<f32>() / len;
241 let var: f32 = query_embedding
242 .iter()
243 .map(|x| (x - mean).powi(2))
244 .sum::<f32>()
245 / len;
246 let inv_std = 1.0 / (var + 1e-5).sqrt();
247 let trunc_query: Vec<f32> = query_embedding[..trunc_dim]
248 .iter()
249 .map(|x| (x - mean) * inv_std)
250 .collect();
251 let norm: f32 = trunc_query
252 .iter()
253 .map(|x| x * x)
254 .sum::<f32>()
255 .sqrt()
256 .max(1e-12);
257 let trunc_query_norm: Vec<f32> = trunc_query.iter().map(|x| x / norm).collect();
258 let trunc_q = Array1::from_vec(trunc_query_norm);
259 let scores = trunc_matrix.dot(&trunc_q);
260
261 let mut candidates: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
263 candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
264 candidates.truncate(pre_filter_k);
265
266 let query_arr = Array1::from_vec(query_embedding.to_vec());
268 let mut reranked: Vec<(usize, f32)> = candidates
269 .into_iter()
270 .map(|(idx, _)| {
271 let full_score = self.embeddings.row(idx).dot(&query_arr);
272 (idx, full_score)
273 })
274 .filter(|(_, s)| *s >= threshold)
275 .collect();
276 reranked.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
277 reranked.truncate(top_k);
278 reranked
279 }
280
281 #[must_use]
285 pub fn embedding(&self, idx: usize) -> Option<Vec<f32>> {
286 if idx >= self.chunks.len() {
287 return None;
288 }
289 Some(self.embeddings.row(idx).to_vec())
290 }
291
292 #[must_use]
294 pub fn len(&self) -> usize {
295 self.chunks.len()
296 }
297
298 #[must_use]
300 pub fn is_empty(&self) -> bool {
301 self.chunks.is_empty()
302 }
303
304 #[must_use]
306 pub fn truncated_dim(&self) -> Option<usize> {
307 self.truncated_dim
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 fn dummy_chunk(name: &str) -> CodeChunk {
317 let content = format!("fn {name}() {{}}");
318 CodeChunk {
319 file_path: "test.rs".to_string(),
320 name: name.to_string(),
321 kind: "function".to_string(),
322 start_line: 1,
323 end_line: 10,
324 enriched_content: content.clone(),
325 content,
326 }
327 }
328
329 #[test]
330 fn new_builds_correct_matrix_shape() {
331 let chunks = vec![dummy_chunk("a"), dummy_chunk("b"), dummy_chunk("c")];
332 let embeddings = vec![
333 vec![1.0, 0.0, 0.0],
334 vec![0.0, 1.0, 0.0],
335 vec![0.0, 0.0, 1.0],
336 ];
337
338 let index = SearchIndex::new(chunks, &embeddings, None);
339
340 assert_eq!(index.len(), 3);
341 assert_eq!(index.hidden_dim, 3);
342 assert!(!index.is_empty());
343 }
344
345 #[test]
346 fn rank_returns_sorted_results_above_threshold() {
347 let chunks = vec![dummy_chunk("low"), dummy_chunk("high"), dummy_chunk("mid")];
348 let embeddings = vec![vec![0.2, 0.8], vec![0.9, 0.1], vec![0.5, 0.5]];
351
352 let index = SearchIndex::new(chunks, &embeddings, None);
353 let results = index.rank(&[1.0, 0.0], 0.3);
354
355 assert_eq!(results.len(), 2);
357 assert_eq!(results[0].0, 1);
359 assert_eq!(results[1].0, 2);
360 assert!(results[0].1 > results[1].1);
361 }
362
363 #[test]
364 fn rank_with_wrong_dimension_returns_empty() {
365 let chunks = vec![dummy_chunk("a")];
366 let embeddings = vec![vec![1.0, 0.0, 0.0]];
367
368 let index = SearchIndex::new(chunks, &embeddings, None);
369 let results = index.rank(&[1.0, 0.0], 0.0);
371
372 assert!(results.is_empty());
373 }
374
375 #[test]
376 fn rank_with_empty_query_returns_empty() {
377 let chunks = vec![dummy_chunk("a")];
378 let embeddings = vec![vec![1.0, 0.0, 0.0]];
379
380 let index = SearchIndex::new(chunks, &embeddings, None);
381 let results = index.rank(&[], 0.0);
382
383 assert!(results.is_empty());
384 }
385
386 #[test]
387 fn rank_handles_empty_index() {
388 let index = SearchIndex::new(vec![], &[], None);
389
390 assert!(index.is_empty());
392 assert_eq!(index.len(), 0);
393
394 let results = index.rank(&[1.0; 384], 0.0);
395 assert!(results.is_empty());
396 }
397
398 fn l2_normalize(v: &mut [f32]) {
400 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
401 for x in v.iter_mut() {
402 *x /= norm;
403 }
404 }
405
406 #[test]
407 #[expect(
408 clippy::cast_precision_loss,
409 reason = "test values are small counts and indices"
410 )]
411 fn cascade_recall_at_10_vs_full_rank() {
412 let n = 200;
415 let dim = 8;
416 let cascade_dim = 4;
417
418 let mut chunks = Vec::with_capacity(n);
419 let mut embeddings = Vec::with_capacity(n);
420 for i in 0..n {
421 chunks.push(dummy_chunk(&format!("chunk_{i}")));
422 let mut emb: Vec<f32> = (0..dim).map(|d| ((i * 7 + d * 13) as f32).sin()).collect();
424 l2_normalize(&mut emb);
425 embeddings.push(emb);
426 }
427
428 let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
430 l2_normalize(&mut query);
431
432 let index_full = SearchIndex::new(chunks.clone(), &embeddings, None);
434 let full_results = index_full.rank(&query, 0.0);
435 let full_top10: Vec<usize> = full_results.iter().take(10).map(|(idx, _)| *idx).collect();
436
437 let index_cascade = SearchIndex::new(chunks, &embeddings, Some(cascade_dim));
439 assert_eq!(index_cascade.truncated_dim(), Some(cascade_dim));
440 let cascade_results = index_cascade.rank_cascade(&query, 10, 0.0);
441 let cascade_top10: Vec<usize> = cascade_results.iter().map(|(idx, _)| *idx).collect();
442
443 let overlap = full_top10
445 .iter()
446 .filter(|i| cascade_top10.contains(i))
447 .count();
448 let recall = overlap as f32 / 10.0;
449
450 assert!(
451 recall >= 0.7,
452 "cascade Recall@10 = {recall} ({overlap}/10), expected >= 0.7"
453 );
454 }
455
456 #[test]
457 fn cascade_falls_back_without_truncated_matrix() {
458 let chunks = vec![dummy_chunk("a"), dummy_chunk("b")];
459 let embeddings = vec![vec![0.9, 0.1], vec![0.1, 0.9]];
460
461 let index = SearchIndex::new(chunks, &embeddings, None);
463 let cascade = index.rank_cascade(&[1.0, 0.0], 10, 0.0);
464 let plain = index.rank(&[1.0, 0.0], 0.0);
465
466 assert_eq!(cascade.len(), plain.len());
467 for (c, p) in cascade.iter().zip(plain.iter()) {
468 assert_eq!(c.0, p.0);
469 assert!((c.1 - p.1).abs() < 1e-6);
470 }
471 }
472
473 #[test]
474 fn cascade_respects_threshold() {
475 let chunks = vec![dummy_chunk("high"), dummy_chunk("low")];
476 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
478
479 let index = SearchIndex::new(chunks, &embeddings, Some(1));
480 let results = index.rank_cascade(&[1.0, 0.0], 10, 0.5);
481
482 assert_eq!(results.len(), 1);
484 assert_eq!(results[0].0, 0);
485 }
486
487 #[test]
488 fn turboquant_recall_vs_exact() {
489 let dim = 768;
491 let n = 200;
492 let mut embeddings: Vec<Vec<f32>> = (0..n)
493 .map(|i| {
494 let mut v: Vec<f32> = (0..dim).map(|d| ((i * 17 + d * 31) as f32).sin()).collect();
495 l2_normalize(&mut v);
496 v
497 })
498 .collect();
499
500 let chunks: Vec<CodeChunk> = (0..n).map(|i| dummy_chunk(&format!("chunk_{i}"))).collect();
501 let mut query: Vec<f32> = (0..dim).map(|d| ((42 * 7 + d * 13) as f32).sin()).collect();
502 l2_normalize(&mut query);
503
504 let index = SearchIndex::new(chunks, &embeddings, None);
505
506 let exact = index.rank(&query, 0.0);
508 let exact_top10: Vec<usize> = exact.iter().take(10).map(|(idx, _)| *idx).collect();
509
510 let tq = index.rank_turboquant(&query, 10, 0.0);
512 let tq_top10: Vec<usize> = tq.iter().take(10).map(|(idx, _)| *idx).collect();
513
514 let recall = exact_top10.iter().filter(|i| tq_top10.contains(i)).count();
516 eprintln!("TurboQuant Recall@10: {recall}/10");
517 assert!(
518 recall >= 7,
519 "TurboQuant recall should be >= 7/10, got {recall}/10"
520 );
521 }
522}