1use std::path::Path;
29use std::sync::atomic::{AtomicUsize, Ordering};
30use std::time::Instant;
31
32use rayon::prelude::*;
33use tracing::{info_span, instrument, trace, warn};
34
35use crate::backend::{EmbedBackend, Encoding};
36use crate::chunk::{ChunkConfig, CodeChunk};
37
38pub const DEFAULT_BATCH_SIZE: usize = 32;
40
41const STREAMING_THRESHOLD: usize = 1000;
47
48const RING_SIZE: usize = 4;
53
54#[derive(Debug, Clone)]
59pub struct SearchConfig {
60 pub batch_size: usize,
63 pub max_tokens: usize,
69 pub chunk: ChunkConfig,
71 pub text_mode: bool,
75 pub cascade_dim: Option<usize>,
81 pub file_type: Option<String>,
86 pub exclude_extensions: Vec<String>,
88 pub include_extensions: Vec<String>,
93 pub ignore_patterns: Vec<String>,
95 pub scope: Scope,
100 pub mode: crate::hybrid::SearchMode,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
111#[serde(rename_all = "lowercase")]
112pub enum Scope {
113 Code,
116 Docs,
120 #[default]
124 All,
125}
126
127pub const PROSE_EXTENSIONS: &[&str] = &[
130 "md", "markdown", "mdx", "rst", "txt", "text", "adoc", "asciidoc", "org",
131];
132
133impl SearchConfig {
134 #[must_use]
145 pub fn walk_options(&self) -> crate::walk::WalkOptions {
146 let mut include = self.include_extensions.clone();
147 let mut exclude = self.exclude_extensions.clone();
148 if include.is_empty() {
149 match self.scope {
150 Scope::Docs => {
151 include.extend(PROSE_EXTENSIONS.iter().map(|s| (*s).to_string()));
152 }
153 Scope::Code => {
154 for ext in PROSE_EXTENSIONS {
155 if !exclude.iter().any(|e| e.eq_ignore_ascii_case(ext)) {
156 exclude.push((*ext).to_string());
157 }
158 }
159 }
160 Scope::All => {}
161 }
162 }
163 crate::walk::WalkOptions {
164 file_type: self.file_type.clone(),
165 include_extensions: include,
166 exclude_extensions: exclude,
167 ignore_patterns: self.ignore_patterns.clone(),
168 }
169 }
170
171 pub fn apply_repo_config(&mut self, root: &Path) {
173 let Some((_, config)) = crate::cache::config::find_config(root) else {
174 return;
175 };
176 for pattern in config.ignore.patterns {
177 if !pattern.trim().is_empty() && !self.ignore_patterns.contains(&pattern) {
178 self.ignore_patterns.push(pattern);
179 }
180 }
181 }
182}
183
184impl Default for SearchConfig {
185 fn default() -> Self {
186 Self {
187 batch_size: DEFAULT_BATCH_SIZE,
188 max_tokens: 0,
189 chunk: ChunkConfig::default(),
190 text_mode: false,
191 cascade_dim: None,
192 file_type: None,
193 exclude_extensions: Vec::new(),
194 include_extensions: Vec::new(),
195 ignore_patterns: Vec::new(),
196 scope: Scope::All,
197 mode: crate::hybrid::SearchMode::Hybrid,
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct SearchResult {
205 pub chunk: CodeChunk,
207 pub similarity: f32,
209}
210
211#[instrument(skip_all, fields(root = %root.display(), batch_size = cfg.batch_size))]
231pub fn embed_all(
232 root: &Path,
233 backends: &[&dyn EmbedBackend],
234 tokenizer: &tokenizers::Tokenizer,
235 cfg: &SearchConfig,
236 profiler: &crate::profile::Profiler,
237) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
238 if backends.is_empty() {
239 return Err(crate::Error::Other(anyhow::anyhow!(
240 "no embedding backends provided"
241 )));
242 }
243
244 let files = {
246 let _span = info_span!("walk").entered();
247 let guard = profiler.phase("walk");
248 let walk_options = cfg.walk_options();
249 let files = crate::walk::collect_files_with_options(root, &walk_options);
250 guard.set_detail(format!("{} files", files.len()));
251 files
252 };
253
254 if files.len() >= STREAMING_THRESHOLD {
255 let total_bytes: u64 = files
257 .iter()
258 .filter_map(|p| p.metadata().ok())
259 .map(|m| m.len())
260 .sum();
261 embed_all_streaming(&files, total_bytes, backends, tokenizer, cfg, profiler)
262 } else {
263 embed_all_batch(&files, backends, tokenizer, cfg, profiler)
264 }
265}
266
267fn embed_all_batch(
272 files: &[std::path::PathBuf],
273 backends: &[&dyn EmbedBackend],
274 tokenizer: &tokenizers::Tokenizer,
275 cfg: &SearchConfig,
276 profiler: &crate::profile::Profiler,
277) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
278 let chunks: Vec<CodeChunk> = {
280 let _span = info_span!("chunk", file_count = files.len()).entered();
281 let chunk_start = Instant::now();
282 let text_mode = cfg.text_mode;
283 let result: Vec<CodeChunk> = files
284 .par_iter()
285 .flat_map(|path| {
286 let Some(source) = read_source(path) else {
287 return vec![];
288 };
289 let chunks =
290 crate::chunk::chunk_source_for_path(path, &source, text_mode, &cfg.chunk);
291 profiler.chunk_thread_report(chunks.len());
292 profiler.chunk_batch(&chunks);
293 chunks
294 })
295 .collect();
296 profiler.chunk_summary(result.len(), files.len(), chunk_start.elapsed());
297 result
298 };
299
300 let bs = cfg.batch_size.max(1);
302 let max_tokens_cfg = cfg.max_tokens;
303 let model_max = backends[0].max_tokens();
304 let _span = info_span!("embed_chunks", chunk_count = chunks.len(), batch_size = bs).entered();
305 profiler.embed_begin(chunks.len());
306
307 let all_encodings: Vec<Option<Encoding>> = chunks
308 .par_iter()
309 .map(|chunk| {
310 tokenize(
311 &chunk.enriched_content,
312 tokenizer,
313 max_tokens_cfg,
314 model_max,
315 )
316 .inspect_err(|e| {
317 warn!(file = %chunk.file_path, err = %e, "tokenization failed, skipping chunk");
318 })
319 .ok()
320 })
321 .collect();
322
323 let mut paired: Vec<(CodeChunk, Option<Encoding>)> =
327 chunks.into_iter().zip(all_encodings).collect();
328 paired.sort_by(|a, b| {
329 let len_a = a.1.as_ref().map_or(0, |e| e.input_ids.len());
330 let len_b = b.1.as_ref().map_or(0, |e| e.input_ids.len());
331 len_b.cmp(&len_a) });
333 let (chunks, sorted_encodings): (Vec<CodeChunk>, Vec<Option<Encoding>>) =
334 paired.into_iter().unzip();
335
336 let embeddings = embed_distributed(&sorted_encodings, backends, bs, profiler)?;
338 profiler.embed_done();
339
340 let (chunks, embeddings): (Vec<_>, Vec<_>) = chunks
342 .into_iter()
343 .zip(embeddings)
344 .filter(|(_, emb)| !emb.is_empty())
345 .unzip();
346
347 Ok((chunks, embeddings))
348}
349
350#[expect(
369 clippy::too_many_lines,
370 reason = "streaming pipeline has inherent complexity in thread coordination"
371)]
372fn embed_all_streaming(
373 files: &[std::path::PathBuf],
374 total_bytes: u64,
375 backends: &[&dyn EmbedBackend],
376 tokenizer: &tokenizers::Tokenizer,
377 cfg: &SearchConfig,
378 profiler: &crate::profile::Profiler,
379) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
380 use crossbeam_channel::bounded;
381
382 let bs = cfg.batch_size.max(1);
383 let max_tokens_cfg = cfg.max_tokens;
384 let model_max = backends[0].max_tokens();
385 let file_count = files.len();
386 let text_mode = cfg.text_mode;
387 let chunk_config = cfg.chunk.clone();
388
389 let (chunk_tx, chunk_rx) = bounded::<CodeChunk>(bs * 8);
393
394 let (batch_tx, batch_rx) = bounded::<Vec<(Encoding, CodeChunk)>>(RING_SIZE);
398
399 let total_chunks_produced = AtomicUsize::new(0);
401 let bytes_chunked = AtomicUsize::new(0);
402 let chunk_start = Instant::now();
403
404 std::thread::scope(|scope| {
407 scope.spawn(|| {
413 let _span = info_span!("chunk_stream", file_count).entered();
414 files.par_iter().for_each(|path| {
415 let Some(source) = read_source(path) else {
416 return;
417 };
418 let chunks =
419 crate::chunk::chunk_source_for_path(path, &source, text_mode, &chunk_config);
420 let n = chunks.len();
421 let file_bytes = source.len();
422 profiler.chunk_batch(&chunks);
423 for chunk in chunks {
424 if chunk_tx.send(chunk).is_err() {
426 return;
427 }
428 }
429 profiler.chunk_thread_report(n);
430 total_chunks_produced.fetch_add(n, Ordering::Relaxed);
431 bytes_chunked.fetch_add(file_bytes, Ordering::Relaxed);
432 });
433 drop(chunk_tx);
437 });
438
439 let tokenize_handle = scope.spawn(move || -> crate::Result<()> {
446 let _span = info_span!("tokenize_stream").entered();
447 let mut buffer: Vec<(Encoding, CodeChunk)> = Vec::with_capacity(bs);
448
449 for chunk in &chunk_rx {
450 match tokenize(
451 &chunk.enriched_content,
452 tokenizer,
453 max_tokens_cfg,
454 model_max,
455 ) {
456 Ok(encoding) => {
457 buffer.push((encoding, chunk));
458 if buffer.len() >= bs {
459 buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
461 let batch = std::mem::replace(&mut buffer, Vec::with_capacity(bs));
462 if batch_tx.send(batch).is_err() {
463 return Ok(());
465 }
466 }
467 }
468 Err(e) => {
469 warn!(
470 file = %chunk.file_path, err = %e,
471 "tokenization failed, skipping chunk"
472 );
473 }
474 }
475 }
476
477 if !buffer.is_empty() {
479 buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
480 let _ = batch_tx.send(buffer);
481 }
482 Ok(())
485 });
486
487 let _span = info_span!("embed_stream").entered();
492
493 profiler.embed_begin(0);
495
496 let mut all_chunks: Vec<CodeChunk> = Vec::new();
497 let mut all_embeddings: Vec<Vec<f32>> = Vec::new();
498 let mut embed_error: Option<crate::Error> = None;
499
500 let mut cumulative_done: usize = 0;
501 for batch in &batch_rx {
502 let batch_len = batch.len();
503 let (encodings, chunks): (Vec<Encoding>, Vec<CodeChunk>) = batch.into_iter().unzip();
504
505 let opt_encodings: Vec<Option<Encoding>> = encodings.into_iter().map(Some).collect();
507
508 let noop = crate::profile::Profiler::noop();
511 match embed_distributed(&opt_encodings, backends, bs, &noop) {
512 Ok(batch_embeddings) => {
513 profiler.embedding_batch(&batch_embeddings);
514 cumulative_done += batch_len;
515 let processed = bytes_chunked.load(Ordering::Relaxed) as u64;
518 profiler.embed_tick_bytes(cumulative_done, processed, total_bytes);
519
520 for (chunk, emb) in chunks.into_iter().zip(batch_embeddings) {
521 if !emb.is_empty() {
522 all_chunks.push(chunk);
523 all_embeddings.push(emb);
524 }
525 }
526 }
527 Err(e) => {
528 embed_error = Some(e);
529 break;
531 }
532 }
533 }
534
535 let final_total = total_chunks_produced.load(Ordering::Relaxed);
537 profiler.chunk_summary(final_total, file_count, chunk_start.elapsed());
538 profiler.embed_begin_update_total(cumulative_done);
540 profiler.embed_tick(cumulative_done);
541 profiler.embed_done();
542
543 let tokenize_result = tokenize_handle.join();
545
546 if let Some(e) = embed_error {
548 return Err(e);
549 }
550 match tokenize_result {
551 Ok(Ok(())) => {}
552 Ok(Err(e)) => return Err(e),
553 Err(_) => {
554 return Err(crate::Error::Other(anyhow::anyhow!(
555 "tokenize thread panicked"
556 )));
557 }
558 }
559
560 Ok((all_chunks, all_embeddings))
561 })
562}
563
564#[instrument(skip_all, fields(root = %root.display(), top_k, batch_size = cfg.batch_size))]
584pub fn search(
585 root: &Path,
586 query: &str,
587 backends: &[&dyn EmbedBackend],
588 tokenizer: &tokenizers::Tokenizer,
589 top_k: usize,
590 cfg: &SearchConfig,
591 profiler: &crate::profile::Profiler,
592) -> crate::Result<Vec<SearchResult>> {
593 if backends.is_empty() {
594 return Err(crate::Error::Other(anyhow::anyhow!(
595 "no embedding backends provided"
596 )));
597 }
598
599 let (chunks, embeddings) = embed_all(root, backends, tokenizer, cfg, profiler)?;
601
602 let t_query_start = std::time::Instant::now();
603
604 let hybrid = {
606 let _span = info_span!("build_hybrid_index").entered();
607 let _guard = profiler.phase("build_hybrid_index");
608 crate::hybrid::HybridIndex::new(chunks, &embeddings, cfg.cascade_dim)?
609 };
610
611 let mode = cfg.mode;
612 let effective_top_k = if top_k > 0 { top_k } else { usize::MAX };
613
614 let query_embedding = if mode == crate::hybrid::SearchMode::Keyword {
616 let dim = hybrid.semantic.hidden_dim;
618 vec![0.0f32; dim]
619 } else {
620 let _span = info_span!("embed_query").entered();
621 let _guard = profiler.phase("embed_query");
622 let t_tok = std::time::Instant::now();
623 let enc = tokenize(query, tokenizer, cfg.max_tokens, backends[0].max_tokens())?;
624 let tok_ms = t_tok.elapsed().as_secs_f64() * 1000.0;
625 let t_emb = std::time::Instant::now();
626 let mut results = backends[0].embed_batch(&[enc])?;
627 let emb_ms = t_emb.elapsed().as_secs_f64() * 1000.0;
628 eprintln!(
629 "[search] query: tokenize={tok_ms:.1}ms embed={emb_ms:.1}ms total_since_embed_all={:.1}ms",
630 t_query_start.elapsed().as_secs_f64() * 1000.0
631 );
632 results.pop().ok_or_else(|| {
633 crate::Error::Other(anyhow::anyhow!("backend returned no embedding for query"))
634 })?
635 };
636
637 let ranked = {
639 let _span = info_span!("rank", chunk_count = hybrid.chunks().len()).entered();
640 let guard = profiler.phase("rank");
641 let threshold = 0.0; let results = hybrid.search(&query_embedding, query, effective_top_k, threshold, mode);
643 guard.set_detail(format!(
644 "{mode} top {} from {}",
645 effective_top_k.min(results.len()),
646 hybrid.chunks().len()
647 ));
648 results
649 };
650
651 let results: Vec<SearchResult> = ranked
652 .into_iter()
653 .map(|(idx, score)| SearchResult {
654 chunk: hybrid.chunks()[idx].clone(),
655 similarity: score,
656 })
657 .collect();
658
659 Ok(results)
660}
661
662struct DistributedState<'a> {
664 tokenized: &'a [Option<Encoding>],
665 cursor: std::sync::atomic::AtomicUsize,
666 error_flag: std::sync::atomic::AtomicBool,
667 first_error: std::sync::Mutex<Option<crate::Error>>,
668 done_counter: std::sync::atomic::AtomicUsize,
669 batch_size: usize,
670 profiler: &'a crate::profile::Profiler,
671}
672
673impl DistributedState<'_> {
674 fn run_worker(&self, backend: &dyn EmbedBackend) -> Vec<(usize, Vec<f32>)> {
676 use std::sync::atomic::Ordering;
677
678 let n = self.tokenized.len();
679 let grab_size = if backend.is_gpu() {
683 self.batch_size * 4
684 } else {
685 self.batch_size
686 };
687 let mut results = Vec::new();
688
689 loop {
690 if self.error_flag.load(Ordering::Relaxed) {
691 break;
692 }
693
694 let start = self.cursor.fetch_add(grab_size, Ordering::Relaxed);
695 if start >= n {
696 break;
697 }
698 let end = (start + grab_size).min(n);
699 let batch = &self.tokenized[start..end];
700
701 let mut valid = Vec::with_capacity(batch.len());
703 let mut valid_indices = Vec::with_capacity(batch.len());
704 for (i, enc) in batch.iter().enumerate() {
705 if let Some(e) = enc {
706 valid.push(e.clone());
709 valid_indices.push(start + i);
710 } else {
711 results.push((start + i, vec![]));
712 }
713 }
714
715 if valid.is_empty() {
716 let done =
717 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
718 self.profiler.embed_tick(done);
719 continue;
720 }
721
722 match backend.embed_batch(&valid) {
723 Ok(batch_embeddings) => {
724 self.profiler.embedding_batch(&batch_embeddings);
725 for (idx, emb) in valid_indices.into_iter().zip(batch_embeddings) {
726 results.push((idx, emb));
727 }
728 let done =
729 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
730 self.profiler.embed_tick(done);
731 }
732 Err(e) => {
733 self.error_flag.store(true, Ordering::Relaxed);
734 if let Ok(mut guard) = self.first_error.lock()
735 && guard.is_none()
736 {
737 *guard = Some(e);
738 }
739 break;
740 }
741 }
742 }
743
744 results
745 }
746}
747
748#[expect(
762 unsafe_code,
763 reason = "BLAS thread count must be set via env vars before spawning workers"
764)]
765pub(crate) fn embed_distributed(
766 tokenized: &[Option<Encoding>],
767 backends: &[&dyn EmbedBackend],
768 batch_size: usize,
769 profiler: &crate::profile::Profiler,
770) -> crate::Result<Vec<Vec<f32>>> {
771 let n = tokenized.len();
772 let state = DistributedState {
773 tokenized,
774 cursor: std::sync::atomic::AtomicUsize::new(0),
775 error_flag: std::sync::atomic::AtomicBool::new(false),
776 first_error: std::sync::Mutex::new(None),
777 done_counter: std::sync::atomic::AtomicUsize::new(0),
778 batch_size: batch_size.max(1),
779 profiler,
780 };
781
782 let all_pairs: Vec<(usize, Vec<f32>)> =
784 if backends.len() == 1 && backends[0].supports_clone() && !backends[0].is_gpu() {
785 unsafe {
798 std::env::set_var("OPENBLAS_NUM_THREADS", "1");
799 std::env::set_var("MKL_NUM_THREADS", "1");
800 std::env::set_var("VECLIB_MAXIMUM_THREADS", "1"); #[cfg(all(not(target_os = "macos"), feature = "cpu"))]
804 {
805 unsafe extern "C" {
806 fn openblas_set_num_threads(num: std::ffi::c_int);
807 }
808 openblas_set_num_threads(1);
809 }
810 }
811
812 let num_workers = rayon::current_num_threads().max(1);
813 std::thread::scope(|s| {
814 let handles: Vec<_> = (0..num_workers)
815 .map(|_| {
816 s.spawn(|| {
817 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
820 crate::backend::driver::cpu::force_single_threaded_blas();
821 let cloned = backends[0].clone_backend();
822 state.run_worker(cloned.as_ref())
823 })
824 })
825 .collect();
826 let mut all = Vec::new();
827 for handle in handles {
828 if let Ok(pairs) = handle.join() {
829 all.extend(pairs);
830 }
831 }
832 all
833 })
834 } else if backends.len() == 1 {
835 state.run_worker(backends[0])
839 } else {
840 std::thread::scope(|s| {
842 let handles: Vec<_> = backends
843 .iter()
844 .map(|&backend| {
845 s.spawn(|| {
846 if backend.supports_clone() {
848 let cloned = backend.clone_backend();
849 state.run_worker(cloned.as_ref())
850 } else {
851 state.run_worker(backend)
852 }
853 })
854 })
855 .collect();
856
857 let mut all = Vec::new();
858 for handle in handles {
859 if let Ok(pairs) = handle.join() {
860 all.extend(pairs);
861 } else {
862 warn!("worker thread panicked");
863 state
864 .error_flag
865 .store(true, std::sync::atomic::Ordering::Relaxed);
866 }
867 }
868 all
869 })
870 };
871
872 if let Some(err) = state.first_error.into_inner().ok().flatten() {
874 return Err(err);
875 }
876
877 let mut embeddings: Vec<Vec<f32>> = vec![vec![]; n];
879 for (idx, emb) in all_pairs {
880 embeddings[idx] = emb;
881 }
882
883 Ok(embeddings)
884}
885
886pub(crate) fn read_source(path: &Path) -> Option<String> {
892 let bytes = match std::fs::read(path) {
893 Ok(b) => b,
894 Err(e) => {
895 trace!(path = %path.display(), err = %e, "skipping file: read failed");
896 return None;
897 }
898 };
899
900 if memchr::memchr(0, &bytes[..bytes.len().min(8192)]).is_some() {
902 trace!(path = %path.display(), "skipping binary file");
903 return None;
904 }
905
906 match std::str::from_utf8(&bytes) {
907 Ok(s) => Some(s.to_string()),
908 Err(e) => {
909 trace!(path = %path.display(), err = %e, "skipping file: not valid UTF-8");
910 None
911 }
912 }
913}
914
915fn tokenize(
922 text: &str,
923 tokenizer: &tokenizers::Tokenizer,
924 max_tokens: usize,
925 model_max_tokens: usize,
926) -> crate::Result<Encoding> {
927 let mut enc = crate::tokenize::tokenize_query(text, tokenizer, model_max_tokens)?;
928 if max_tokens > 0 {
929 let len = enc.input_ids.len().min(max_tokens);
930 enc.input_ids.truncate(len);
931 enc.attention_mask.truncate(len);
932 enc.token_type_ids.truncate(len);
933 }
934 Ok(enc)
935}
936
937pub fn apply_structural_boost<S: ::std::hash::BuildHasher>(
946 results: &mut [SearchResult],
947 file_ranks: &std::collections::HashMap<String, f32, S>,
948 alpha: f32,
949) {
950 if results.is_empty() || alpha == 0.0 {
951 return;
952 }
953
954 let min = results
955 .iter()
956 .map(|r| r.similarity)
957 .fold(f32::INFINITY, f32::min);
958 let max = results
959 .iter()
960 .map(|r| r.similarity)
961 .fold(f32::NEG_INFINITY, f32::max);
962 let range = (max - min).max(1e-12);
963
964 for r in results.iter_mut() {
965 let normalized = (r.similarity - min) / range;
966 let pr = file_ranks.get(&r.chunk.file_path).copied().unwrap_or(0.0);
967 r.similarity = normalized + alpha * pr;
968 }
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974
975 #[test]
976 #[cfg(feature = "cpu")]
977 #[ignore = "loads model + embeds full source tree; run with `cargo test -- --ignored`"]
978 fn search_with_backend_trait() {
979 let backend = crate::backend::load_backend(
980 crate::backend::BackendKind::Cpu,
981 "BAAI/bge-small-en-v1.5",
982 crate::backend::DeviceHint::Cpu,
983 )
984 .unwrap();
985 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
986 let cfg = SearchConfig::default();
987 let profiler = crate::profile::Profiler::noop();
988 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
989 let results = search(
990 &dir,
991 "embedding model",
992 &[backend.as_ref()],
993 &tokenizer,
994 1,
995 &cfg,
996 &profiler,
997 );
998 assert!(results.is_ok());
999 assert!(!results.unwrap().is_empty());
1000 }
1001
1002 #[test]
1003 #[cfg(feature = "cpu")]
1004 fn embed_distributed_produces_correct_count() {
1005 let backend = crate::backend::load_backend(
1006 crate::backend::BackendKind::Cpu,
1007 "BAAI/bge-small-en-v1.5",
1008 crate::backend::DeviceHint::Cpu,
1009 )
1010 .unwrap();
1011 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
1012 let profiler = crate::profile::Profiler::noop();
1013
1014 let texts = ["fn hello() {}", "class Foo:", "func main() {}"];
1016 let encoded: Vec<Option<Encoding>> = texts
1017 .iter()
1018 .map(|t| super::tokenize(t, &tokenizer, 0, 512).ok())
1019 .collect();
1020
1021 let results =
1022 super::embed_distributed(&encoded, &[backend.as_ref()], 32, &profiler).unwrap();
1023
1024 assert_eq!(results.len(), 3);
1025 for (i, emb) in results.iter().enumerate() {
1027 assert_eq!(emb.len(), 384, "embedding {i} should be 384-dim");
1028 }
1029 }
1030
1031 fn truncate_and_normalize(emb: &[f32], dims: usize) -> Vec<f32> {
1033 let trunc = &emb[..dims];
1034 let norm: f32 = trunc.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
1035 trunc.iter().map(|x| x / norm).collect()
1036 }
1037
1038 fn rank_topk(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<usize> {
1040 let mut scored: Vec<(usize, f32)> = corpus
1041 .iter()
1042 .enumerate()
1043 .map(|(i, emb)| {
1044 let dot: f32 = query.iter().zip(emb).map(|(a, b)| a * b).sum();
1045 (i, dot)
1046 })
1047 .collect();
1048 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
1049 scored.into_iter().take(k).map(|(i, _)| i).collect()
1050 }
1051
1052 #[test]
1060 #[ignore = "loads model + embeds; run with --nocapture"]
1061 #[expect(
1062 clippy::cast_precision_loss,
1063 reason = "top_k and overlap are small counts"
1064 )]
1065 fn mrl_retrieval_recall() {
1066 let model = "BAAI/bge-small-en-v1.5";
1067 let backends = crate::backend::detect_backends(model).unwrap();
1068 let tokenizer = crate::tokenize::load_tokenizer(model).unwrap();
1069 let cfg = SearchConfig::default();
1070 let profiler = crate::profile::Profiler::noop();
1071
1072 let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1074 .parent()
1075 .unwrap()
1076 .parent()
1077 .unwrap();
1078 eprintln!("Embedding {}", root.display());
1079 let backend_refs: Vec<&dyn crate::backend::EmbedBackend> =
1080 backends.iter().map(std::convert::AsRef::as_ref).collect();
1081 let (chunks, embeddings) =
1082 embed_all(root, &backend_refs, &tokenizer, &cfg, &profiler).unwrap();
1083 let full_dim = embeddings[0].len();
1084 eprintln!(
1085 "Corpus: {} chunks, {full_dim}-dim embeddings\n",
1086 chunks.len()
1087 );
1088
1089 let queries = [
1091 "error handling in the embedding pipeline",
1092 "tree-sitter chunking and AST parsing",
1093 "Metal GPU kernel dispatch",
1094 "file watcher for incremental reindex",
1095 "cosine similarity ranking",
1096 ];
1097
1098 let top_k = 10;
1099 let mrl_dims: Vec<usize> = [32, 64, 128, 192, 256, full_dim]
1100 .into_iter()
1101 .filter(|&d| d <= full_dim)
1102 .collect();
1103
1104 eprintln!("=== MRL Retrieval Recall@{top_k} (vs full {full_dim}-dim) ===\n");
1105
1106 for query in &queries {
1107 let enc = tokenize(query, &tokenizer, 0, backends[0].max_tokens()).unwrap();
1109 let query_emb = backends[0].embed_batch(&[enc]).unwrap().pop().unwrap();
1110
1111 let ref_topk = rank_topk(&query_emb, &embeddings, top_k);
1113
1114 eprintln!("Query: \"{query}\"");
1115 eprintln!(
1116 " Full-dim top-1: {} ({})",
1117 chunks[ref_topk[0]].name, chunks[ref_topk[0]].file_path
1118 );
1119
1120 for &dims in &mrl_dims {
1121 let trunc_corpus: Vec<Vec<f32>> = embeddings
1123 .iter()
1124 .map(|e| truncate_and_normalize(e, dims))
1125 .collect();
1126 let trunc_query = truncate_and_normalize(&query_emb, dims);
1127
1128 let trunc_topk = rank_topk(&trunc_query, &trunc_corpus, top_k);
1129
1130 let overlap = ref_topk.iter().filter(|i| trunc_topk.contains(i)).count();
1132 let recall = overlap as f32 / top_k as f32;
1133 let marker = if dims == full_dim {
1134 " (ref)"
1135 } else if recall >= 0.8 {
1136 " ***"
1137 } else {
1138 ""
1139 };
1140 eprintln!(
1141 " dims={dims:>3}: Recall@{top_k}={recall:.1} ({overlap}/{top_k}){marker}"
1142 );
1143 }
1144 eprintln!();
1145 }
1146 }
1147
1148 fn make_result(file_path: &str, similarity: f32) -> SearchResult {
1149 SearchResult {
1150 chunk: CodeChunk {
1151 file_path: file_path.to_string(),
1152 name: "test".to_string(),
1153 kind: "function".to_string(),
1154 start_line: 1,
1155 end_line: 10,
1156 enriched_content: String::new(),
1157 content: String::new(),
1158 },
1159 similarity,
1160 }
1161 }
1162
1163 #[test]
1164 fn structural_boost_normalizes_and_applies() {
1165 let mut results = vec![
1166 make_result("src/a.rs", 0.8),
1167 make_result("src/b.rs", 0.4),
1168 make_result("src/c.rs", 0.6),
1169 ];
1170 let mut ranks = std::collections::HashMap::new();
1171 ranks.insert("src/a.rs".to_string(), 0.5);
1172 ranks.insert("src/b.rs".to_string(), 1.0);
1173 ranks.insert("src/c.rs".to_string(), 0.0);
1174
1175 apply_structural_boost(&mut results, &ranks, 0.2);
1176
1177 assert!((results[0].similarity - 1.1).abs() < 1e-6);
1179 assert!((results[1].similarity - 0.2).abs() < 1e-6);
1181 assert!((results[2].similarity - 0.5).abs() < 1e-6);
1183 }
1184
1185 #[test]
1186 fn structural_boost_noop_on_empty() {
1187 let mut results: Vec<SearchResult> = vec![];
1188 let ranks = std::collections::HashMap::new();
1189 apply_structural_boost(&mut results, &ranks, 0.2);
1190 assert!(results.is_empty());
1191 }
1192
1193 #[test]
1194 fn structural_boost_noop_on_zero_alpha() {
1195 let mut results = vec![make_result("src/a.rs", 0.8)];
1196 let mut ranks = std::collections::HashMap::new();
1197 ranks.insert("src/a.rs".to_string(), 1.0);
1198 apply_structural_boost(&mut results, &ranks, 0.0);
1199 assert!((results[0].similarity - 0.8).abs() < 1e-6);
1201 }
1202}