1use std::path::Path;
29use std::sync::atomic::{AtomicUsize, Ordering};
30use std::time::Instant;
31
32use rayon::prelude::*;
33use tracing::{info_span, instrument, trace, warn};
34
35use crate::backend::{EmbedBackend, Encoding};
36use crate::chunk::{ChunkConfig, CodeChunk};
37
38pub const DEFAULT_BATCH_SIZE: usize = 32;
40
41const STREAMING_THRESHOLD: usize = 1000;
47
48const RING_SIZE: usize = 4;
53
54#[derive(Debug, Clone)]
59pub struct SearchConfig {
60 pub batch_size: usize,
63 pub max_tokens: usize,
69 pub chunk: ChunkConfig,
71 pub text_mode: bool,
75 pub cascade_dim: Option<usize>,
81 pub file_type: Option<String>,
86 pub exclude_extensions: Vec<String>,
88 pub ignore_patterns: Vec<String>,
90 pub mode: crate::hybrid::SearchMode,
92}
93
94impl SearchConfig {
95 #[must_use]
97 pub fn walk_options(&self) -> crate::walk::WalkOptions {
98 crate::walk::WalkOptions {
99 file_type: self.file_type.clone(),
100 exclude_extensions: self.exclude_extensions.clone(),
101 ignore_patterns: self.ignore_patterns.clone(),
102 }
103 }
104
105 pub fn apply_repo_config(&mut self, root: &Path) {
107 let Some((_, config)) = crate::cache::config::find_config(root) else {
108 return;
109 };
110 for pattern in config.ignore.patterns {
111 if !pattern.trim().is_empty() && !self.ignore_patterns.contains(&pattern) {
112 self.ignore_patterns.push(pattern);
113 }
114 }
115 }
116}
117
118impl Default for SearchConfig {
119 fn default() -> Self {
120 Self {
121 batch_size: DEFAULT_BATCH_SIZE,
122 max_tokens: 0,
123 chunk: ChunkConfig::default(),
124 text_mode: false,
125 cascade_dim: None,
126 file_type: None,
127 exclude_extensions: Vec::new(),
128 ignore_patterns: Vec::new(),
129 mode: crate::hybrid::SearchMode::Hybrid,
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct SearchResult {
137 pub chunk: CodeChunk,
139 pub similarity: f32,
141}
142
143#[instrument(skip_all, fields(root = %root.display(), batch_size = cfg.batch_size))]
163pub fn embed_all(
164 root: &Path,
165 backends: &[&dyn EmbedBackend],
166 tokenizer: &tokenizers::Tokenizer,
167 cfg: &SearchConfig,
168 profiler: &crate::profile::Profiler,
169) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
170 if backends.is_empty() {
171 return Err(crate::Error::Other(anyhow::anyhow!(
172 "no embedding backends provided"
173 )));
174 }
175
176 let files = {
178 let _span = info_span!("walk").entered();
179 let guard = profiler.phase("walk");
180 let walk_options = cfg.walk_options();
181 let files = crate::walk::collect_files_with_options(root, &walk_options);
182 guard.set_detail(format!("{} files", files.len()));
183 files
184 };
185
186 if files.len() >= STREAMING_THRESHOLD {
187 let total_bytes: u64 = files
189 .iter()
190 .filter_map(|p| p.metadata().ok())
191 .map(|m| m.len())
192 .sum();
193 embed_all_streaming(&files, total_bytes, backends, tokenizer, cfg, profiler)
194 } else {
195 embed_all_batch(&files, backends, tokenizer, cfg, profiler)
196 }
197}
198
199fn embed_all_batch(
204 files: &[std::path::PathBuf],
205 backends: &[&dyn EmbedBackend],
206 tokenizer: &tokenizers::Tokenizer,
207 cfg: &SearchConfig,
208 profiler: &crate::profile::Profiler,
209) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
210 let chunks: Vec<CodeChunk> = {
212 let _span = info_span!("chunk", file_count = files.len()).entered();
213 let chunk_start = Instant::now();
214 let text_mode = cfg.text_mode;
215 let result: Vec<CodeChunk> = files
216 .par_iter()
217 .flat_map(|path| {
218 let Some(source) = read_source(path) else {
219 return vec![];
220 };
221 let chunks =
222 crate::chunk::chunk_source_for_path(path, &source, text_mode, &cfg.chunk);
223 profiler.chunk_thread_report(chunks.len());
224 profiler.chunk_batch(&chunks);
225 chunks
226 })
227 .collect();
228 profiler.chunk_summary(result.len(), files.len(), chunk_start.elapsed());
229 result
230 };
231
232 let bs = cfg.batch_size.max(1);
234 let max_tokens_cfg = cfg.max_tokens;
235 let model_max = backends[0].max_tokens();
236 let _span = info_span!("embed_chunks", chunk_count = chunks.len(), batch_size = bs).entered();
237 profiler.embed_begin(chunks.len());
238
239 let all_encodings: Vec<Option<Encoding>> = chunks
240 .par_iter()
241 .map(|chunk| {
242 tokenize(
243 &chunk.enriched_content,
244 tokenizer,
245 max_tokens_cfg,
246 model_max,
247 )
248 .inspect_err(|e| {
249 warn!(file = %chunk.file_path, err = %e, "tokenization failed, skipping chunk");
250 })
251 .ok()
252 })
253 .collect();
254
255 let mut paired: Vec<(CodeChunk, Option<Encoding>)> =
259 chunks.into_iter().zip(all_encodings).collect();
260 paired.sort_by(|a, b| {
261 let len_a = a.1.as_ref().map_or(0, |e| e.input_ids.len());
262 let len_b = b.1.as_ref().map_or(0, |e| e.input_ids.len());
263 len_b.cmp(&len_a) });
265 let (chunks, sorted_encodings): (Vec<CodeChunk>, Vec<Option<Encoding>>) =
266 paired.into_iter().unzip();
267
268 let embeddings = embed_distributed(&sorted_encodings, backends, bs, profiler)?;
270 profiler.embed_done();
271
272 let (chunks, embeddings): (Vec<_>, Vec<_>) = chunks
274 .into_iter()
275 .zip(embeddings)
276 .filter(|(_, emb)| !emb.is_empty())
277 .unzip();
278
279 Ok((chunks, embeddings))
280}
281
282#[expect(
301 clippy::too_many_lines,
302 reason = "streaming pipeline has inherent complexity in thread coordination"
303)]
304fn embed_all_streaming(
305 files: &[std::path::PathBuf],
306 total_bytes: u64,
307 backends: &[&dyn EmbedBackend],
308 tokenizer: &tokenizers::Tokenizer,
309 cfg: &SearchConfig,
310 profiler: &crate::profile::Profiler,
311) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
312 use crossbeam_channel::bounded;
313
314 let bs = cfg.batch_size.max(1);
315 let max_tokens_cfg = cfg.max_tokens;
316 let model_max = backends[0].max_tokens();
317 let file_count = files.len();
318 let text_mode = cfg.text_mode;
319 let chunk_config = cfg.chunk.clone();
320
321 let (chunk_tx, chunk_rx) = bounded::<CodeChunk>(bs * 8);
325
326 let (batch_tx, batch_rx) = bounded::<Vec<(Encoding, CodeChunk)>>(RING_SIZE);
330
331 let total_chunks_produced = AtomicUsize::new(0);
333 let bytes_chunked = AtomicUsize::new(0);
334 let chunk_start = Instant::now();
335
336 std::thread::scope(|scope| {
339 scope.spawn(|| {
345 let _span = info_span!("chunk_stream", file_count).entered();
346 files.par_iter().for_each(|path| {
347 let Some(source) = read_source(path) else {
348 return;
349 };
350 let chunks =
351 crate::chunk::chunk_source_for_path(path, &source, text_mode, &chunk_config);
352 let n = chunks.len();
353 let file_bytes = source.len();
354 profiler.chunk_batch(&chunks);
355 for chunk in chunks {
356 if chunk_tx.send(chunk).is_err() {
358 return;
359 }
360 }
361 profiler.chunk_thread_report(n);
362 total_chunks_produced.fetch_add(n, Ordering::Relaxed);
363 bytes_chunked.fetch_add(file_bytes, Ordering::Relaxed);
364 });
365 drop(chunk_tx);
369 });
370
371 let tokenize_handle = scope.spawn(move || -> crate::Result<()> {
378 let _span = info_span!("tokenize_stream").entered();
379 let mut buffer: Vec<(Encoding, CodeChunk)> = Vec::with_capacity(bs);
380
381 for chunk in &chunk_rx {
382 match tokenize(
383 &chunk.enriched_content,
384 tokenizer,
385 max_tokens_cfg,
386 model_max,
387 ) {
388 Ok(encoding) => {
389 buffer.push((encoding, chunk));
390 if buffer.len() >= bs {
391 buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
393 let batch = std::mem::replace(&mut buffer, Vec::with_capacity(bs));
394 if batch_tx.send(batch).is_err() {
395 return Ok(());
397 }
398 }
399 }
400 Err(e) => {
401 warn!(
402 file = %chunk.file_path, err = %e,
403 "tokenization failed, skipping chunk"
404 );
405 }
406 }
407 }
408
409 if !buffer.is_empty() {
411 buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
412 let _ = batch_tx.send(buffer);
413 }
414 Ok(())
417 });
418
419 let _span = info_span!("embed_stream").entered();
424
425 profiler.embed_begin(0);
427
428 let mut all_chunks: Vec<CodeChunk> = Vec::new();
429 let mut all_embeddings: Vec<Vec<f32>> = Vec::new();
430 let mut embed_error: Option<crate::Error> = None;
431
432 let mut cumulative_done: usize = 0;
433 for batch in &batch_rx {
434 let batch_len = batch.len();
435 let (encodings, chunks): (Vec<Encoding>, Vec<CodeChunk>) = batch.into_iter().unzip();
436
437 let opt_encodings: Vec<Option<Encoding>> = encodings.into_iter().map(Some).collect();
439
440 let noop = crate::profile::Profiler::noop();
443 match embed_distributed(&opt_encodings, backends, bs, &noop) {
444 Ok(batch_embeddings) => {
445 profiler.embedding_batch(&batch_embeddings);
446 cumulative_done += batch_len;
447 let processed = bytes_chunked.load(Ordering::Relaxed) as u64;
450 profiler.embed_tick_bytes(cumulative_done, processed, total_bytes);
451
452 for (chunk, emb) in chunks.into_iter().zip(batch_embeddings) {
453 if !emb.is_empty() {
454 all_chunks.push(chunk);
455 all_embeddings.push(emb);
456 }
457 }
458 }
459 Err(e) => {
460 embed_error = Some(e);
461 break;
463 }
464 }
465 }
466
467 let final_total = total_chunks_produced.load(Ordering::Relaxed);
469 profiler.chunk_summary(final_total, file_count, chunk_start.elapsed());
470 profiler.embed_begin_update_total(cumulative_done);
472 profiler.embed_tick(cumulative_done);
473 profiler.embed_done();
474
475 let tokenize_result = tokenize_handle.join();
477
478 if let Some(e) = embed_error {
480 return Err(e);
481 }
482 match tokenize_result {
483 Ok(Ok(())) => {}
484 Ok(Err(e)) => return Err(e),
485 Err(_) => {
486 return Err(crate::Error::Other(anyhow::anyhow!(
487 "tokenize thread panicked"
488 )));
489 }
490 }
491
492 Ok((all_chunks, all_embeddings))
493 })
494}
495
496#[instrument(skip_all, fields(root = %root.display(), top_k, batch_size = cfg.batch_size))]
516pub fn search(
517 root: &Path,
518 query: &str,
519 backends: &[&dyn EmbedBackend],
520 tokenizer: &tokenizers::Tokenizer,
521 top_k: usize,
522 cfg: &SearchConfig,
523 profiler: &crate::profile::Profiler,
524) -> crate::Result<Vec<SearchResult>> {
525 if backends.is_empty() {
526 return Err(crate::Error::Other(anyhow::anyhow!(
527 "no embedding backends provided"
528 )));
529 }
530
531 let (chunks, embeddings) = embed_all(root, backends, tokenizer, cfg, profiler)?;
533
534 let t_query_start = std::time::Instant::now();
535
536 let hybrid = {
538 let _span = info_span!("build_hybrid_index").entered();
539 let _guard = profiler.phase("build_hybrid_index");
540 crate::hybrid::HybridIndex::new(chunks, &embeddings, cfg.cascade_dim)?
541 };
542
543 let mode = cfg.mode;
544 let effective_top_k = if top_k > 0 { top_k } else { usize::MAX };
545
546 let query_embedding = if mode == crate::hybrid::SearchMode::Keyword {
548 let dim = hybrid.semantic.hidden_dim;
550 vec![0.0f32; dim]
551 } else {
552 let _span = info_span!("embed_query").entered();
553 let _guard = profiler.phase("embed_query");
554 let t_tok = std::time::Instant::now();
555 let enc = tokenize(query, tokenizer, cfg.max_tokens, backends[0].max_tokens())?;
556 let tok_ms = t_tok.elapsed().as_secs_f64() * 1000.0;
557 let t_emb = std::time::Instant::now();
558 let mut results = backends[0].embed_batch(&[enc])?;
559 let emb_ms = t_emb.elapsed().as_secs_f64() * 1000.0;
560 eprintln!(
561 "[search] query: tokenize={tok_ms:.1}ms embed={emb_ms:.1}ms total_since_embed_all={:.1}ms",
562 t_query_start.elapsed().as_secs_f64() * 1000.0
563 );
564 results.pop().ok_or_else(|| {
565 crate::Error::Other(anyhow::anyhow!("backend returned no embedding for query"))
566 })?
567 };
568
569 let ranked = {
571 let _span = info_span!("rank", chunk_count = hybrid.chunks().len()).entered();
572 let guard = profiler.phase("rank");
573 let threshold = 0.0; let results = hybrid.search(&query_embedding, query, effective_top_k, threshold, mode);
575 guard.set_detail(format!(
576 "{mode} top {} from {}",
577 effective_top_k.min(results.len()),
578 hybrid.chunks().len()
579 ));
580 results
581 };
582
583 let results: Vec<SearchResult> = ranked
584 .into_iter()
585 .map(|(idx, score)| SearchResult {
586 chunk: hybrid.chunks()[idx].clone(),
587 similarity: score,
588 })
589 .collect();
590
591 Ok(results)
592}
593
594struct DistributedState<'a> {
596 tokenized: &'a [Option<Encoding>],
597 cursor: std::sync::atomic::AtomicUsize,
598 error_flag: std::sync::atomic::AtomicBool,
599 first_error: std::sync::Mutex<Option<crate::Error>>,
600 done_counter: std::sync::atomic::AtomicUsize,
601 batch_size: usize,
602 profiler: &'a crate::profile::Profiler,
603}
604
605impl DistributedState<'_> {
606 fn run_worker(&self, backend: &dyn EmbedBackend) -> Vec<(usize, Vec<f32>)> {
608 use std::sync::atomic::Ordering;
609
610 let n = self.tokenized.len();
611 let grab_size = if backend.is_gpu() {
615 self.batch_size * 4
616 } else {
617 self.batch_size
618 };
619 let mut results = Vec::new();
620
621 loop {
622 if self.error_flag.load(Ordering::Relaxed) {
623 break;
624 }
625
626 let start = self.cursor.fetch_add(grab_size, Ordering::Relaxed);
627 if start >= n {
628 break;
629 }
630 let end = (start + grab_size).min(n);
631 let batch = &self.tokenized[start..end];
632
633 let mut valid = Vec::with_capacity(batch.len());
635 let mut valid_indices = Vec::with_capacity(batch.len());
636 for (i, enc) in batch.iter().enumerate() {
637 if let Some(e) = enc {
638 valid.push(e.clone());
641 valid_indices.push(start + i);
642 } else {
643 results.push((start + i, vec![]));
644 }
645 }
646
647 if valid.is_empty() {
648 let done =
649 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
650 self.profiler.embed_tick(done);
651 continue;
652 }
653
654 match backend.embed_batch(&valid) {
655 Ok(batch_embeddings) => {
656 self.profiler.embedding_batch(&batch_embeddings);
657 for (idx, emb) in valid_indices.into_iter().zip(batch_embeddings) {
658 results.push((idx, emb));
659 }
660 let done =
661 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
662 self.profiler.embed_tick(done);
663 }
664 Err(e) => {
665 self.error_flag.store(true, Ordering::Relaxed);
666 if let Ok(mut guard) = self.first_error.lock()
667 && guard.is_none()
668 {
669 *guard = Some(e);
670 }
671 break;
672 }
673 }
674 }
675
676 results
677 }
678}
679
680#[expect(
694 unsafe_code,
695 reason = "BLAS thread count must be set via env vars before spawning workers"
696)]
697pub(crate) fn embed_distributed(
698 tokenized: &[Option<Encoding>],
699 backends: &[&dyn EmbedBackend],
700 batch_size: usize,
701 profiler: &crate::profile::Profiler,
702) -> crate::Result<Vec<Vec<f32>>> {
703 let n = tokenized.len();
704 let state = DistributedState {
705 tokenized,
706 cursor: std::sync::atomic::AtomicUsize::new(0),
707 error_flag: std::sync::atomic::AtomicBool::new(false),
708 first_error: std::sync::Mutex::new(None),
709 done_counter: std::sync::atomic::AtomicUsize::new(0),
710 batch_size: batch_size.max(1),
711 profiler,
712 };
713
714 let all_pairs: Vec<(usize, Vec<f32>)> =
716 if backends.len() == 1 && backends[0].supports_clone() && !backends[0].is_gpu() {
717 unsafe {
730 std::env::set_var("OPENBLAS_NUM_THREADS", "1");
731 std::env::set_var("MKL_NUM_THREADS", "1");
732 std::env::set_var("VECLIB_MAXIMUM_THREADS", "1"); #[cfg(all(not(target_os = "macos"), feature = "cpu"))]
736 {
737 unsafe extern "C" {
738 fn openblas_set_num_threads(num: std::ffi::c_int);
739 }
740 openblas_set_num_threads(1);
741 }
742 }
743
744 let num_workers = rayon::current_num_threads().max(1);
745 std::thread::scope(|s| {
746 let handles: Vec<_> = (0..num_workers)
747 .map(|_| {
748 s.spawn(|| {
749 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
752 crate::backend::driver::cpu::force_single_threaded_blas();
753 let cloned = backends[0].clone_backend();
754 state.run_worker(cloned.as_ref())
755 })
756 })
757 .collect();
758 let mut all = Vec::new();
759 for handle in handles {
760 if let Ok(pairs) = handle.join() {
761 all.extend(pairs);
762 }
763 }
764 all
765 })
766 } else if backends.len() == 1 {
767 state.run_worker(backends[0])
771 } else {
772 std::thread::scope(|s| {
774 let handles: Vec<_> = backends
775 .iter()
776 .map(|&backend| {
777 s.spawn(|| {
778 if backend.supports_clone() {
780 let cloned = backend.clone_backend();
781 state.run_worker(cloned.as_ref())
782 } else {
783 state.run_worker(backend)
784 }
785 })
786 })
787 .collect();
788
789 let mut all = Vec::new();
790 for handle in handles {
791 if let Ok(pairs) = handle.join() {
792 all.extend(pairs);
793 } else {
794 warn!("worker thread panicked");
795 state
796 .error_flag
797 .store(true, std::sync::atomic::Ordering::Relaxed);
798 }
799 }
800 all
801 })
802 };
803
804 if let Some(err) = state.first_error.into_inner().ok().flatten() {
806 return Err(err);
807 }
808
809 let mut embeddings: Vec<Vec<f32>> = vec![vec![]; n];
811 for (idx, emb) in all_pairs {
812 embeddings[idx] = emb;
813 }
814
815 Ok(embeddings)
816}
817
818pub(crate) fn read_source(path: &Path) -> Option<String> {
824 let bytes = match std::fs::read(path) {
825 Ok(b) => b,
826 Err(e) => {
827 trace!(path = %path.display(), err = %e, "skipping file: read failed");
828 return None;
829 }
830 };
831
832 if memchr::memchr(0, &bytes[..bytes.len().min(8192)]).is_some() {
834 trace!(path = %path.display(), "skipping binary file");
835 return None;
836 }
837
838 match std::str::from_utf8(&bytes) {
839 Ok(s) => Some(s.to_string()),
840 Err(e) => {
841 trace!(path = %path.display(), err = %e, "skipping file: not valid UTF-8");
842 None
843 }
844 }
845}
846
847fn tokenize(
854 text: &str,
855 tokenizer: &tokenizers::Tokenizer,
856 max_tokens: usize,
857 model_max_tokens: usize,
858) -> crate::Result<Encoding> {
859 let mut enc = crate::tokenize::tokenize_query(text, tokenizer, model_max_tokens)?;
860 if max_tokens > 0 {
861 let len = enc.input_ids.len().min(max_tokens);
862 enc.input_ids.truncate(len);
863 enc.attention_mask.truncate(len);
864 enc.token_type_ids.truncate(len);
865 }
866 Ok(enc)
867}
868
869pub fn apply_structural_boost<S: ::std::hash::BuildHasher>(
878 results: &mut [SearchResult],
879 file_ranks: &std::collections::HashMap<String, f32, S>,
880 alpha: f32,
881) {
882 if results.is_empty() || alpha == 0.0 {
883 return;
884 }
885
886 let min = results
887 .iter()
888 .map(|r| r.similarity)
889 .fold(f32::INFINITY, f32::min);
890 let max = results
891 .iter()
892 .map(|r| r.similarity)
893 .fold(f32::NEG_INFINITY, f32::max);
894 let range = (max - min).max(1e-12);
895
896 for r in results.iter_mut() {
897 let normalized = (r.similarity - min) / range;
898 let pr = file_ranks.get(&r.chunk.file_path).copied().unwrap_or(0.0);
899 r.similarity = normalized + alpha * pr;
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 #[test]
908 #[cfg(feature = "cpu")]
909 #[ignore = "loads model + embeds full source tree; run with `cargo test -- --ignored`"]
910 fn search_with_backend_trait() {
911 let backend = crate::backend::load_backend(
912 crate::backend::BackendKind::Cpu,
913 "BAAI/bge-small-en-v1.5",
914 crate::backend::DeviceHint::Cpu,
915 )
916 .unwrap();
917 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
918 let cfg = SearchConfig::default();
919 let profiler = crate::profile::Profiler::noop();
920 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
921 let results = search(
922 &dir,
923 "embedding model",
924 &[backend.as_ref()],
925 &tokenizer,
926 1,
927 &cfg,
928 &profiler,
929 );
930 assert!(results.is_ok());
931 assert!(!results.unwrap().is_empty());
932 }
933
934 #[test]
935 #[cfg(feature = "cpu")]
936 fn embed_distributed_produces_correct_count() {
937 let backend = crate::backend::load_backend(
938 crate::backend::BackendKind::Cpu,
939 "BAAI/bge-small-en-v1.5",
940 crate::backend::DeviceHint::Cpu,
941 )
942 .unwrap();
943 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
944 let profiler = crate::profile::Profiler::noop();
945
946 let texts = ["fn hello() {}", "class Foo:", "func main() {}"];
948 let encoded: Vec<Option<Encoding>> = texts
949 .iter()
950 .map(|t| super::tokenize(t, &tokenizer, 0, 512).ok())
951 .collect();
952
953 let results =
954 super::embed_distributed(&encoded, &[backend.as_ref()], 32, &profiler).unwrap();
955
956 assert_eq!(results.len(), 3);
957 for (i, emb) in results.iter().enumerate() {
959 assert_eq!(emb.len(), 384, "embedding {i} should be 384-dim");
960 }
961 }
962
963 fn truncate_and_normalize(emb: &[f32], dims: usize) -> Vec<f32> {
965 let trunc = &emb[..dims];
966 let norm: f32 = trunc.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
967 trunc.iter().map(|x| x / norm).collect()
968 }
969
970 fn rank_topk(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<usize> {
972 let mut scored: Vec<(usize, f32)> = corpus
973 .iter()
974 .enumerate()
975 .map(|(i, emb)| {
976 let dot: f32 = query.iter().zip(emb).map(|(a, b)| a * b).sum();
977 (i, dot)
978 })
979 .collect();
980 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
981 scored.into_iter().take(k).map(|(i, _)| i).collect()
982 }
983
984 #[test]
992 #[ignore = "loads model + embeds; run with --nocapture"]
993 #[expect(
994 clippy::cast_precision_loss,
995 reason = "top_k and overlap are small counts"
996 )]
997 fn mrl_retrieval_recall() {
998 let model = "BAAI/bge-small-en-v1.5";
999 let backends = crate::backend::detect_backends(model).unwrap();
1000 let tokenizer = crate::tokenize::load_tokenizer(model).unwrap();
1001 let cfg = SearchConfig::default();
1002 let profiler = crate::profile::Profiler::noop();
1003
1004 let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1006 .parent()
1007 .unwrap()
1008 .parent()
1009 .unwrap();
1010 eprintln!("Embedding {}", root.display());
1011 let backend_refs: Vec<&dyn crate::backend::EmbedBackend> =
1012 backends.iter().map(std::convert::AsRef::as_ref).collect();
1013 let (chunks, embeddings) =
1014 embed_all(root, &backend_refs, &tokenizer, &cfg, &profiler).unwrap();
1015 let full_dim = embeddings[0].len();
1016 eprintln!(
1017 "Corpus: {} chunks, {full_dim}-dim embeddings\n",
1018 chunks.len()
1019 );
1020
1021 let queries = [
1023 "error handling in the embedding pipeline",
1024 "tree-sitter chunking and AST parsing",
1025 "Metal GPU kernel dispatch",
1026 "file watcher for incremental reindex",
1027 "cosine similarity ranking",
1028 ];
1029
1030 let top_k = 10;
1031 let mrl_dims: Vec<usize> = [32, 64, 128, 192, 256, full_dim]
1032 .into_iter()
1033 .filter(|&d| d <= full_dim)
1034 .collect();
1035
1036 eprintln!("=== MRL Retrieval Recall@{top_k} (vs full {full_dim}-dim) ===\n");
1037
1038 for query in &queries {
1039 let enc = tokenize(query, &tokenizer, 0, backends[0].max_tokens()).unwrap();
1041 let query_emb = backends[0].embed_batch(&[enc]).unwrap().pop().unwrap();
1042
1043 let ref_topk = rank_topk(&query_emb, &embeddings, top_k);
1045
1046 eprintln!("Query: \"{query}\"");
1047 eprintln!(
1048 " Full-dim top-1: {} ({})",
1049 chunks[ref_topk[0]].name, chunks[ref_topk[0]].file_path
1050 );
1051
1052 for &dims in &mrl_dims {
1053 let trunc_corpus: Vec<Vec<f32>> = embeddings
1055 .iter()
1056 .map(|e| truncate_and_normalize(e, dims))
1057 .collect();
1058 let trunc_query = truncate_and_normalize(&query_emb, dims);
1059
1060 let trunc_topk = rank_topk(&trunc_query, &trunc_corpus, top_k);
1061
1062 let overlap = ref_topk.iter().filter(|i| trunc_topk.contains(i)).count();
1064 let recall = overlap as f32 / top_k as f32;
1065 let marker = if dims == full_dim {
1066 " (ref)"
1067 } else if recall >= 0.8 {
1068 " ***"
1069 } else {
1070 ""
1071 };
1072 eprintln!(
1073 " dims={dims:>3}: Recall@{top_k}={recall:.1} ({overlap}/{top_k}){marker}"
1074 );
1075 }
1076 eprintln!();
1077 }
1078 }
1079
1080 fn make_result(file_path: &str, similarity: f32) -> SearchResult {
1081 SearchResult {
1082 chunk: CodeChunk {
1083 file_path: file_path.to_string(),
1084 name: "test".to_string(),
1085 kind: "function".to_string(),
1086 start_line: 1,
1087 end_line: 10,
1088 enriched_content: String::new(),
1089 content: String::new(),
1090 },
1091 similarity,
1092 }
1093 }
1094
1095 #[test]
1096 fn structural_boost_normalizes_and_applies() {
1097 let mut results = vec![
1098 make_result("src/a.rs", 0.8),
1099 make_result("src/b.rs", 0.4),
1100 make_result("src/c.rs", 0.6),
1101 ];
1102 let mut ranks = std::collections::HashMap::new();
1103 ranks.insert("src/a.rs".to_string(), 0.5);
1104 ranks.insert("src/b.rs".to_string(), 1.0);
1105 ranks.insert("src/c.rs".to_string(), 0.0);
1106
1107 apply_structural_boost(&mut results, &ranks, 0.2);
1108
1109 assert!((results[0].similarity - 1.1).abs() < 1e-6);
1111 assert!((results[1].similarity - 0.2).abs() < 1e-6);
1113 assert!((results[2].similarity - 0.5).abs() < 1e-6);
1115 }
1116
1117 #[test]
1118 fn structural_boost_noop_on_empty() {
1119 let mut results: Vec<SearchResult> = vec![];
1120 let ranks = std::collections::HashMap::new();
1121 apply_structural_boost(&mut results, &ranks, 0.2);
1122 assert!(results.is_empty());
1123 }
1124
1125 #[test]
1126 fn structural_boost_noop_on_zero_alpha() {
1127 let mut results = vec![make_result("src/a.rs", 0.8)];
1128 let mut ranks = std::collections::HashMap::new();
1129 ranks.insert("src/a.rs".to_string(), 1.0);
1130 apply_structural_boost(&mut results, &ranks, 0.0);
1131 assert!((results[0].similarity - 0.8).abs() < 1e-6);
1133 }
1134}