1use std::path::Path;
29use std::sync::atomic::{AtomicUsize, Ordering};
30use std::time::Instant;
31
32use rayon::prelude::*;
33use tracing::{debug, info_span, instrument, warn};
34
35use crate::backend::{EmbedBackend, Encoding};
36use crate::chunk::{ChunkConfig, CodeChunk};
37
38pub const DEFAULT_BATCH_SIZE: usize = 32;
40
41const STREAMING_THRESHOLD: usize = 1000;
47
48const RING_SIZE: usize = 4;
53
54#[derive(Debug, Clone)]
59pub struct SearchConfig {
60 pub batch_size: usize,
63 pub max_tokens: usize,
69 pub chunk: ChunkConfig,
71 pub text_mode: bool,
75 pub cascade_dim: Option<usize>,
81 pub file_type: Option<String>,
86 pub mode: crate::hybrid::SearchMode,
88}
89
90impl Default for SearchConfig {
91 fn default() -> Self {
92 Self {
93 batch_size: DEFAULT_BATCH_SIZE,
94 max_tokens: 0,
95 chunk: ChunkConfig::default(),
96 text_mode: false,
97 cascade_dim: None,
98 file_type: None,
99 mode: crate::hybrid::SearchMode::Hybrid,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct SearchResult {
107 pub chunk: CodeChunk,
109 pub similarity: f32,
111}
112
113#[instrument(skip_all, fields(root = %root.display(), batch_size = cfg.batch_size))]
133pub fn embed_all(
134 root: &Path,
135 backends: &[&dyn EmbedBackend],
136 tokenizer: &tokenizers::Tokenizer,
137 cfg: &SearchConfig,
138 profiler: &crate::profile::Profiler,
139) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
140 if backends.is_empty() {
141 return Err(crate::Error::Other(anyhow::anyhow!(
142 "no embedding backends provided"
143 )));
144 }
145
146 let files = {
148 let _span = info_span!("walk").entered();
149 let guard = profiler.phase("walk");
150 let files = crate::walk::collect_files(root, cfg.file_type.as_deref());
151 guard.set_detail(format!("{} files", files.len()));
152 files
153 };
154
155 if files.len() >= STREAMING_THRESHOLD {
156 let total_bytes: u64 = files
158 .iter()
159 .filter_map(|p| p.metadata().ok())
160 .map(|m| m.len())
161 .sum();
162 embed_all_streaming(&files, total_bytes, backends, tokenizer, cfg, profiler)
163 } else {
164 embed_all_batch(&files, backends, tokenizer, cfg, profiler)
165 }
166}
167
168fn embed_all_batch(
173 files: &[std::path::PathBuf],
174 backends: &[&dyn EmbedBackend],
175 tokenizer: &tokenizers::Tokenizer,
176 cfg: &SearchConfig,
177 profiler: &crate::profile::Profiler,
178) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
179 let chunks: Vec<CodeChunk> = {
181 let _span = info_span!("chunk", file_count = files.len()).entered();
182 let chunk_start = Instant::now();
183 let text_mode = cfg.text_mode;
184 let result: Vec<CodeChunk> = files
185 .par_iter()
186 .flat_map(|path| {
187 let Some(source) = read_source(path) else {
188 return vec![];
189 };
190 let chunks = if text_mode {
191 crate::chunk::chunk_text(path, &source, &cfg.chunk)
192 } else {
193 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
194 match crate::languages::config_for_extension(ext) {
195 Some(lang_config) => {
196 crate::chunk::chunk_file(path, &source, &lang_config, &cfg.chunk)
197 }
198 None => crate::chunk::chunk_text(path, &source, &cfg.chunk),
199 }
200 };
201 profiler.chunk_thread_report(chunks.len());
202 chunks
203 })
204 .collect();
205 profiler.chunk_summary(result.len(), files.len(), chunk_start.elapsed());
206 result
207 };
208
209 let bs = cfg.batch_size.max(1);
211 let max_tokens_cfg = cfg.max_tokens;
212 let model_max = backends[0].max_tokens();
213 let _span = info_span!("embed_chunks", chunk_count = chunks.len(), batch_size = bs).entered();
214 profiler.embed_begin(chunks.len());
215
216 let all_encodings: Vec<Option<Encoding>> = chunks
217 .par_iter()
218 .map(|chunk| {
219 tokenize(
220 &chunk.enriched_content,
221 tokenizer,
222 max_tokens_cfg,
223 model_max,
224 )
225 .inspect_err(|e| {
226 warn!(file = %chunk.file_path, err = %e, "tokenization failed, skipping chunk");
227 })
228 .ok()
229 })
230 .collect();
231
232 let mut paired: Vec<(CodeChunk, Option<Encoding>)> =
236 chunks.into_iter().zip(all_encodings).collect();
237 paired.sort_by(|a, b| {
238 let len_a = a.1.as_ref().map_or(0, |e| e.input_ids.len());
239 let len_b = b.1.as_ref().map_or(0, |e| e.input_ids.len());
240 len_b.cmp(&len_a) });
242 let (chunks, sorted_encodings): (Vec<CodeChunk>, Vec<Option<Encoding>>) =
243 paired.into_iter().unzip();
244
245 let embeddings = embed_distributed(&sorted_encodings, backends, bs, profiler)?;
247 profiler.embed_done();
248
249 let (chunks, embeddings): (Vec<_>, Vec<_>) = chunks
251 .into_iter()
252 .zip(embeddings)
253 .filter(|(_, emb)| !emb.is_empty())
254 .unzip();
255
256 Ok((chunks, embeddings))
257}
258
259#[expect(
278 clippy::too_many_lines,
279 reason = "streaming pipeline has inherent complexity in thread coordination"
280)]
281fn embed_all_streaming(
282 files: &[std::path::PathBuf],
283 total_bytes: u64,
284 backends: &[&dyn EmbedBackend],
285 tokenizer: &tokenizers::Tokenizer,
286 cfg: &SearchConfig,
287 profiler: &crate::profile::Profiler,
288) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
289 use crossbeam_channel::bounded;
290
291 let bs = cfg.batch_size.max(1);
292 let max_tokens_cfg = cfg.max_tokens;
293 let model_max = backends[0].max_tokens();
294 let file_count = files.len();
295 let text_mode = cfg.text_mode;
296 let chunk_config = cfg.chunk.clone();
297
298 let (chunk_tx, chunk_rx) = bounded::<CodeChunk>(bs * 8);
302
303 let (batch_tx, batch_rx) = bounded::<Vec<(Encoding, CodeChunk)>>(RING_SIZE);
307
308 let total_chunks_produced = AtomicUsize::new(0);
310 let bytes_chunked = AtomicUsize::new(0);
311 let chunk_start = Instant::now();
312
313 std::thread::scope(|scope| {
316 scope.spawn(|| {
322 let _span = info_span!("chunk_stream", file_count).entered();
323 files.par_iter().for_each(|path| {
324 let Some(source) = read_source(path) else {
325 return;
326 };
327 let chunks = if text_mode {
328 crate::chunk::chunk_text(path, &source, &chunk_config)
329 } else {
330 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
331 match crate::languages::config_for_extension(ext) {
332 Some(lang_config) => {
333 crate::chunk::chunk_file(path, &source, &lang_config, &chunk_config)
334 }
335 None => crate::chunk::chunk_text(path, &source, &chunk_config),
336 }
337 };
338 let n = chunks.len();
339 let file_bytes = source.len();
340 for chunk in chunks {
341 if chunk_tx.send(chunk).is_err() {
343 return;
344 }
345 }
346 profiler.chunk_thread_report(n);
347 total_chunks_produced.fetch_add(n, Ordering::Relaxed);
348 bytes_chunked.fetch_add(file_bytes, Ordering::Relaxed);
349 });
350 drop(chunk_tx);
354 });
355
356 let tokenize_handle = scope.spawn(move || -> crate::Result<()> {
363 let _span = info_span!("tokenize_stream").entered();
364 let mut buffer: Vec<(Encoding, CodeChunk)> = Vec::with_capacity(bs);
365
366 for chunk in &chunk_rx {
367 match tokenize(
368 &chunk.enriched_content,
369 tokenizer,
370 max_tokens_cfg,
371 model_max,
372 ) {
373 Ok(encoding) => {
374 buffer.push((encoding, chunk));
375 if buffer.len() >= bs {
376 buffer.sort_by(|a, b| b.0.input_ids.len().cmp(&a.0.input_ids.len()));
378 let batch = std::mem::replace(&mut buffer, Vec::with_capacity(bs));
379 if batch_tx.send(batch).is_err() {
380 return Ok(());
382 }
383 }
384 }
385 Err(e) => {
386 warn!(
387 file = %chunk.file_path, err = %e,
388 "tokenization failed, skipping chunk"
389 );
390 }
391 }
392 }
393
394 if !buffer.is_empty() {
396 buffer.sort_by(|a, b| b.0.input_ids.len().cmp(&a.0.input_ids.len()));
397 let _ = batch_tx.send(buffer);
398 }
399 Ok(())
402 });
403
404 let _span = info_span!("embed_stream").entered();
409
410 profiler.embed_begin(0);
412
413 let mut all_chunks: Vec<CodeChunk> = Vec::new();
414 let mut all_embeddings: Vec<Vec<f32>> = Vec::new();
415 let mut embed_error: Option<crate::Error> = None;
416
417 let mut cumulative_done: usize = 0;
418 for batch in &batch_rx {
419 let batch_len = batch.len();
420 let (encodings, chunks): (Vec<Encoding>, Vec<CodeChunk>) = batch.into_iter().unzip();
421
422 let opt_encodings: Vec<Option<Encoding>> = encodings.into_iter().map(Some).collect();
424
425 let noop = crate::profile::Profiler::noop();
428 match embed_distributed(&opt_encodings, backends, bs, &noop) {
429 Ok(batch_embeddings) => {
430 cumulative_done += batch_len;
431 let processed = bytes_chunked.load(Ordering::Relaxed) as u64;
434 profiler.embed_tick_bytes(cumulative_done, processed, total_bytes);
435
436 for (chunk, emb) in chunks.into_iter().zip(batch_embeddings) {
437 if !emb.is_empty() {
438 all_chunks.push(chunk);
439 all_embeddings.push(emb);
440 }
441 }
442 }
443 Err(e) => {
444 embed_error = Some(e);
445 break;
447 }
448 }
449 }
450
451 let final_total = total_chunks_produced.load(Ordering::Relaxed);
453 profiler.chunk_summary(final_total, file_count, chunk_start.elapsed());
454 profiler.embed_begin_update_total(cumulative_done);
456 profiler.embed_tick(cumulative_done);
457 profiler.embed_done();
458
459 let tokenize_result = tokenize_handle.join();
461
462 if let Some(e) = embed_error {
464 return Err(e);
465 }
466 match tokenize_result {
467 Ok(Ok(())) => {}
468 Ok(Err(e)) => return Err(e),
469 Err(_) => {
470 return Err(crate::Error::Other(anyhow::anyhow!(
471 "tokenize thread panicked"
472 )));
473 }
474 }
475
476 Ok((all_chunks, all_embeddings))
477 })
478}
479
480#[instrument(skip_all, fields(root = %root.display(), top_k, batch_size = cfg.batch_size))]
500pub fn search(
501 root: &Path,
502 query: &str,
503 backends: &[&dyn EmbedBackend],
504 tokenizer: &tokenizers::Tokenizer,
505 top_k: usize,
506 cfg: &SearchConfig,
507 profiler: &crate::profile::Profiler,
508) -> crate::Result<Vec<SearchResult>> {
509 if backends.is_empty() {
510 return Err(crate::Error::Other(anyhow::anyhow!(
511 "no embedding backends provided"
512 )));
513 }
514
515 let (chunks, embeddings) = embed_all(root, backends, tokenizer, cfg, profiler)?;
517
518 let t_query_start = std::time::Instant::now();
519
520 let hybrid = {
522 let _span = info_span!("build_hybrid_index").entered();
523 let _guard = profiler.phase("build_hybrid_index");
524 crate::hybrid::HybridIndex::new(chunks, embeddings, cfg.cascade_dim)?
525 };
526
527 let mode = cfg.mode;
528 let effective_top_k = if top_k > 0 { top_k } else { usize::MAX };
529
530 let query_embedding = if mode == crate::hybrid::SearchMode::Keyword {
532 let dim = hybrid.semantic.hidden_dim;
534 vec![0.0f32; dim]
535 } else {
536 let _span = info_span!("embed_query").entered();
537 let _guard = profiler.phase("embed_query");
538 let t_tok = std::time::Instant::now();
539 let enc = tokenize(query, tokenizer, cfg.max_tokens, backends[0].max_tokens())?;
540 let tok_ms = t_tok.elapsed().as_secs_f64() * 1000.0;
541 let t_emb = std::time::Instant::now();
542 let mut results = backends[0].embed_batch(&[enc])?;
543 let emb_ms = t_emb.elapsed().as_secs_f64() * 1000.0;
544 eprintln!(
545 "[search] query: tokenize={tok_ms:.1}ms embed={emb_ms:.1}ms total_since_embed_all={:.1}ms",
546 t_query_start.elapsed().as_secs_f64() * 1000.0
547 );
548 results.pop().ok_or_else(|| {
549 crate::Error::Other(anyhow::anyhow!("backend returned no embedding for query"))
550 })?
551 };
552
553 let ranked = {
555 let _span = info_span!("rank", chunk_count = hybrid.chunks().len()).entered();
556 let guard = profiler.phase("rank");
557 let threshold = if mode == crate::hybrid::SearchMode::Semantic {
559 0.0 } else {
561 0.0
562 };
563 let results = hybrid.search(&query_embedding, query, effective_top_k, threshold, mode);
564 guard.set_detail(format!(
565 "{mode} top {} from {}",
566 effective_top_k.min(results.len()),
567 hybrid.chunks().len()
568 ));
569 results
570 };
571
572 let results: Vec<SearchResult> = ranked
573 .into_iter()
574 .map(|(idx, score)| SearchResult {
575 chunk: hybrid.chunks()[idx].clone(),
576 similarity: score,
577 })
578 .collect();
579
580 Ok(results)
581}
582
583struct DistributedState<'a> {
585 tokenized: &'a [Option<Encoding>],
586 cursor: std::sync::atomic::AtomicUsize,
587 error_flag: std::sync::atomic::AtomicBool,
588 first_error: std::sync::Mutex<Option<crate::Error>>,
589 done_counter: std::sync::atomic::AtomicUsize,
590 batch_size: usize,
591 profiler: &'a crate::profile::Profiler,
592}
593
594impl DistributedState<'_> {
595 fn run_worker(&self, backend: &dyn EmbedBackend) -> Vec<(usize, Vec<f32>)> {
597 use std::sync::atomic::Ordering;
598
599 let n = self.tokenized.len();
600 let grab_size = if backend.is_gpu() {
604 self.batch_size * 4
605 } else {
606 self.batch_size
607 };
608 let mut results = Vec::new();
609
610 loop {
611 if self.error_flag.load(Ordering::Relaxed) {
612 break;
613 }
614
615 let start = self.cursor.fetch_add(grab_size, Ordering::Relaxed);
616 if start >= n {
617 break;
618 }
619 let end = (start + grab_size).min(n);
620 let batch = &self.tokenized[start..end];
621
622 let mut valid = Vec::with_capacity(batch.len());
624 let mut valid_indices = Vec::with_capacity(batch.len());
625 for (i, enc) in batch.iter().enumerate() {
626 if let Some(e) = enc {
627 valid.push(e.clone());
630 valid_indices.push(start + i);
631 } else {
632 results.push((start + i, vec![]));
633 }
634 }
635
636 if valid.is_empty() {
637 let done =
638 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
639 self.profiler.embed_tick(done);
640 continue;
641 }
642
643 match backend.embed_batch(&valid) {
644 Ok(batch_embeddings) => {
645 for (idx, emb) in valid_indices.into_iter().zip(batch_embeddings) {
646 results.push((idx, emb));
647 }
648 let done =
649 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
650 self.profiler.embed_tick(done);
651 }
652 Err(e) => {
653 self.error_flag.store(true, Ordering::Relaxed);
654 if let Ok(mut guard) = self.first_error.lock()
655 && guard.is_none()
656 {
657 *guard = Some(e);
658 }
659 break;
660 }
661 }
662 }
663
664 results
665 }
666}
667
668#[expect(
682 unsafe_code,
683 reason = "BLAS thread count must be set via env vars before spawning workers"
684)]
685pub(crate) fn embed_distributed(
686 tokenized: &[Option<Encoding>],
687 backends: &[&dyn EmbedBackend],
688 batch_size: usize,
689 profiler: &crate::profile::Profiler,
690) -> crate::Result<Vec<Vec<f32>>> {
691 let n = tokenized.len();
692 let state = DistributedState {
693 tokenized,
694 cursor: std::sync::atomic::AtomicUsize::new(0),
695 error_flag: std::sync::atomic::AtomicBool::new(false),
696 first_error: std::sync::Mutex::new(None),
697 done_counter: std::sync::atomic::AtomicUsize::new(0),
698 batch_size: batch_size.max(1),
699 profiler,
700 };
701
702 let all_pairs: Vec<(usize, Vec<f32>)> =
704 if backends.len() == 1 && backends[0].supports_clone() && !backends[0].is_gpu() {
705 unsafe {
718 std::env::set_var("OPENBLAS_NUM_THREADS", "1");
719 std::env::set_var("MKL_NUM_THREADS", "1");
720 std::env::set_var("VECLIB_MAXIMUM_THREADS", "1"); #[cfg(all(not(target_os = "macos"), feature = "cpu"))]
724 {
725 unsafe extern "C" {
726 fn openblas_set_num_threads(num: std::ffi::c_int);
727 }
728 openblas_set_num_threads(1);
729 }
730 }
731
732 let num_workers = rayon::current_num_threads().max(1);
733 std::thread::scope(|s| {
734 let handles: Vec<_> = (0..num_workers)
735 .map(|_| {
736 s.spawn(|| {
737 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
740 crate::backend::driver::cpu::force_single_threaded_blas();
741 let cloned = backends[0].clone_backend();
742 state.run_worker(cloned.as_ref())
743 })
744 })
745 .collect();
746 let mut all = Vec::new();
747 for handle in handles {
748 if let Ok(pairs) = handle.join() {
749 all.extend(pairs);
750 }
751 }
752 all
753 })
754 } else if backends.len() == 1 {
755 state.run_worker(backends[0])
759 } else {
760 std::thread::scope(|s| {
762 let handles: Vec<_> = backends
763 .iter()
764 .map(|&backend| {
765 s.spawn(|| {
766 if backend.supports_clone() {
768 let cloned = backend.clone_backend();
769 state.run_worker(cloned.as_ref())
770 } else {
771 state.run_worker(backend)
772 }
773 })
774 })
775 .collect();
776
777 let mut all = Vec::new();
778 for handle in handles {
779 if let Ok(pairs) = handle.join() {
780 all.extend(pairs);
781 } else {
782 warn!("worker thread panicked");
783 state
784 .error_flag
785 .store(true, std::sync::atomic::Ordering::Relaxed);
786 }
787 }
788 all
789 })
790 };
791
792 if let Some(err) = state.first_error.into_inner().ok().flatten() {
794 return Err(err);
795 }
796
797 let mut embeddings: Vec<Vec<f32>> = vec![vec![]; n];
799 for (idx, emb) in all_pairs {
800 embeddings[idx] = emb;
801 }
802
803 Ok(embeddings)
804}
805
806pub(crate) fn read_source(path: &Path) -> Option<String> {
812 let bytes = match std::fs::read(path) {
813 Ok(b) => b,
814 Err(e) => {
815 debug!(path = %path.display(), err = %e, "skipping file: read failed");
816 return None;
817 }
818 };
819
820 if memchr::memchr(0, &bytes[..bytes.len().min(8192)]).is_some() {
822 debug!(path = %path.display(), "skipping binary file");
823 return None;
824 }
825
826 match std::str::from_utf8(&bytes) {
827 Ok(s) => Some(s.to_string()),
828 Err(e) => {
829 debug!(path = %path.display(), err = %e, "skipping file: not valid UTF-8");
830 None
831 }
832 }
833}
834
835fn tokenize(
842 text: &str,
843 tokenizer: &tokenizers::Tokenizer,
844 max_tokens: usize,
845 model_max_tokens: usize,
846) -> crate::Result<Encoding> {
847 let mut enc = crate::tokenize::tokenize_query(text, tokenizer, model_max_tokens)?;
848 if max_tokens > 0 {
849 let len = enc.input_ids.len().min(max_tokens);
850 enc.input_ids.truncate(len);
851 enc.attention_mask.truncate(len);
852 enc.token_type_ids.truncate(len);
853 }
854 Ok(enc)
855}
856
857pub fn apply_structural_boost<S: ::std::hash::BuildHasher>(
866 results: &mut [SearchResult],
867 file_ranks: &std::collections::HashMap<String, f32, S>,
868 alpha: f32,
869) {
870 if results.is_empty() || alpha == 0.0 {
871 return;
872 }
873
874 let min = results
875 .iter()
876 .map(|r| r.similarity)
877 .fold(f32::INFINITY, f32::min);
878 let max = results
879 .iter()
880 .map(|r| r.similarity)
881 .fold(f32::NEG_INFINITY, f32::max);
882 let range = (max - min).max(1e-12);
883
884 for r in results.iter_mut() {
885 let normalized = (r.similarity - min) / range;
886 let pr = file_ranks.get(&r.chunk.file_path).copied().unwrap_or(0.0);
887 r.similarity = normalized + alpha * pr;
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894
895 #[test]
896 #[cfg(feature = "cpu")]
897 #[ignore = "loads model + embeds full source tree; run with `cargo test -- --ignored`"]
898 fn search_with_backend_trait() {
899 let backend = crate::backend::load_backend(
900 crate::backend::BackendKind::Cpu,
901 "BAAI/bge-small-en-v1.5",
902 crate::backend::DeviceHint::Cpu,
903 )
904 .unwrap();
905 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
906 let cfg = SearchConfig::default();
907 let profiler = crate::profile::Profiler::noop();
908 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
909 let results = search(
910 &dir,
911 "embedding model",
912 &[backend.as_ref()],
913 &tokenizer,
914 1,
915 &cfg,
916 &profiler,
917 );
918 assert!(results.is_ok());
919 assert!(!results.unwrap().is_empty());
920 }
921
922 #[test]
923 #[cfg(feature = "cpu")]
924 fn embed_distributed_produces_correct_count() {
925 let backend = crate::backend::load_backend(
926 crate::backend::BackendKind::Cpu,
927 "BAAI/bge-small-en-v1.5",
928 crate::backend::DeviceHint::Cpu,
929 )
930 .unwrap();
931 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
932 let profiler = crate::profile::Profiler::noop();
933
934 let texts = ["fn hello() {}", "class Foo:", "func main() {}"];
936 let encoded: Vec<Option<Encoding>> = texts
937 .iter()
938 .map(|t| super::tokenize(t, &tokenizer, 0, 512).ok())
939 .collect();
940
941 let results =
942 super::embed_distributed(&encoded, &[backend.as_ref()], 32, &profiler).unwrap();
943
944 assert_eq!(results.len(), 3);
945 for (i, emb) in results.iter().enumerate() {
947 assert_eq!(emb.len(), 384, "embedding {i} should be 384-dim");
948 }
949 }
950
951 fn truncate_and_normalize(emb: &[f32], dims: usize) -> Vec<f32> {
953 let trunc = &emb[..dims];
954 let norm: f32 = trunc.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
955 trunc.iter().map(|x| x / norm).collect()
956 }
957
958 fn rank_topk(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<usize> {
960 let mut scored: Vec<(usize, f32)> = corpus
961 .iter()
962 .enumerate()
963 .map(|(i, emb)| {
964 let dot: f32 = query.iter().zip(emb).map(|(a, b)| a * b).sum();
965 (i, dot)
966 })
967 .collect();
968 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
969 scored.into_iter().take(k).map(|(i, _)| i).collect()
970 }
971
972 #[test]
980 #[ignore = "loads model + embeds; run with --nocapture"]
981 #[expect(
982 clippy::cast_precision_loss,
983 reason = "top_k and overlap are small counts"
984 )]
985 fn mrl_retrieval_recall() {
986 let model = "BAAI/bge-small-en-v1.5";
987 let backends = crate::backend::detect_backends(model).unwrap();
988 let tokenizer = crate::tokenize::load_tokenizer(model).unwrap();
989 let cfg = SearchConfig::default();
990 let profiler = crate::profile::Profiler::noop();
991
992 let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
994 .parent()
995 .unwrap()
996 .parent()
997 .unwrap();
998 eprintln!("Embedding {}", root.display());
999 let backend_refs: Vec<&dyn crate::backend::EmbedBackend> =
1000 backends.iter().map(std::convert::AsRef::as_ref).collect();
1001 let (chunks, embeddings) =
1002 embed_all(root, &backend_refs, &tokenizer, &cfg, &profiler).unwrap();
1003 let full_dim = embeddings[0].len();
1004 eprintln!(
1005 "Corpus: {} chunks, {full_dim}-dim embeddings\n",
1006 chunks.len()
1007 );
1008
1009 let queries = [
1011 "error handling in the embedding pipeline",
1012 "tree-sitter chunking and AST parsing",
1013 "Metal GPU kernel dispatch",
1014 "file watcher for incremental reindex",
1015 "cosine similarity ranking",
1016 ];
1017
1018 let top_k = 10;
1019 let mrl_dims: Vec<usize> = [32, 64, 128, 192, 256, full_dim]
1020 .into_iter()
1021 .filter(|&d| d <= full_dim)
1022 .collect();
1023
1024 eprintln!("=== MRL Retrieval Recall@{top_k} (vs full {full_dim}-dim) ===\n");
1025
1026 for query in &queries {
1027 let enc = tokenize(query, &tokenizer, 0, backends[0].max_tokens()).unwrap();
1029 let query_emb = backends[0].embed_batch(&[enc]).unwrap().pop().unwrap();
1030
1031 let ref_topk = rank_topk(&query_emb, &embeddings, top_k);
1033
1034 eprintln!("Query: \"{query}\"");
1035 eprintln!(
1036 " Full-dim top-1: {} ({})",
1037 chunks[ref_topk[0]].name, chunks[ref_topk[0]].file_path
1038 );
1039
1040 for &dims in &mrl_dims {
1041 let trunc_corpus: Vec<Vec<f32>> = embeddings
1043 .iter()
1044 .map(|e| truncate_and_normalize(e, dims))
1045 .collect();
1046 let trunc_query = truncate_and_normalize(&query_emb, dims);
1047
1048 let trunc_topk = rank_topk(&trunc_query, &trunc_corpus, top_k);
1049
1050 let overlap = ref_topk.iter().filter(|i| trunc_topk.contains(i)).count();
1052 let recall = overlap as f32 / top_k as f32;
1053 let marker = if dims == full_dim {
1054 " (ref)"
1055 } else if recall >= 0.8 {
1056 " ***"
1057 } else {
1058 ""
1059 };
1060 eprintln!(
1061 " dims={dims:>3}: Recall@{top_k}={recall:.1} ({overlap}/{top_k}){marker}"
1062 );
1063 }
1064 eprintln!();
1065 }
1066 }
1067
1068 fn make_result(file_path: &str, similarity: f32) -> SearchResult {
1069 SearchResult {
1070 chunk: CodeChunk {
1071 file_path: file_path.to_string(),
1072 name: "test".to_string(),
1073 kind: "function".to_string(),
1074 start_line: 1,
1075 end_line: 10,
1076 enriched_content: String::new(),
1077 content: String::new(),
1078 },
1079 similarity,
1080 }
1081 }
1082
1083 #[test]
1084 fn structural_boost_normalizes_and_applies() {
1085 let mut results = vec![
1086 make_result("src/a.rs", 0.8),
1087 make_result("src/b.rs", 0.4),
1088 make_result("src/c.rs", 0.6),
1089 ];
1090 let mut ranks = std::collections::HashMap::new();
1091 ranks.insert("src/a.rs".to_string(), 0.5);
1092 ranks.insert("src/b.rs".to_string(), 1.0);
1093 ranks.insert("src/c.rs".to_string(), 0.0);
1094
1095 apply_structural_boost(&mut results, &ranks, 0.2);
1096
1097 assert!((results[0].similarity - 1.1).abs() < 1e-6);
1099 assert!((results[1].similarity - 0.2).abs() < 1e-6);
1101 assert!((results[2].similarity - 0.5).abs() < 1e-6);
1103 }
1104
1105 #[test]
1106 fn structural_boost_noop_on_empty() {
1107 let mut results: Vec<SearchResult> = vec![];
1108 let ranks = std::collections::HashMap::new();
1109 apply_structural_boost(&mut results, &ranks, 0.2);
1110 assert!(results.is_empty());
1111 }
1112
1113 #[test]
1114 fn structural_boost_noop_on_zero_alpha() {
1115 let mut results = vec![make_result("src/a.rs", 0.8)];
1116 let mut ranks = std::collections::HashMap::new();
1117 ranks.insert("src/a.rs".to_string(), 1.0);
1118 apply_structural_boost(&mut results, &ranks, 0.0);
1119 assert!((results[0].similarity - 0.8).abs() < 1e-6);
1121 }
1122}