1use std::path::Path;
29use std::sync::atomic::{AtomicUsize, Ordering};
30use std::time::Instant;
31
32use rayon::prelude::*;
33use tracing::{info_span, instrument, trace, warn};
34
35use crate::backend::{EmbedBackend, Encoding};
36use crate::chunk::{ChunkConfig, CodeChunk};
37
38pub const DEFAULT_BATCH_SIZE: usize = 32;
40
41const STREAMING_THRESHOLD: usize = 1000;
47
48const RING_SIZE: usize = 4;
53
54#[derive(Debug, Clone)]
59pub struct SearchConfig {
60 pub batch_size: usize,
63 pub max_tokens: usize,
69 pub chunk: ChunkConfig,
71 pub text_mode: bool,
75 pub cascade_dim: Option<usize>,
81 pub file_type: Option<String>,
86 pub mode: crate::hybrid::SearchMode,
88}
89
90impl Default for SearchConfig {
91 fn default() -> Self {
92 Self {
93 batch_size: DEFAULT_BATCH_SIZE,
94 max_tokens: 0,
95 chunk: ChunkConfig::default(),
96 text_mode: false,
97 cascade_dim: None,
98 file_type: None,
99 mode: crate::hybrid::SearchMode::Hybrid,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct SearchResult {
107 pub chunk: CodeChunk,
109 pub similarity: f32,
111}
112
113#[instrument(skip_all, fields(root = %root.display(), batch_size = cfg.batch_size))]
133pub fn embed_all(
134 root: &Path,
135 backends: &[&dyn EmbedBackend],
136 tokenizer: &tokenizers::Tokenizer,
137 cfg: &SearchConfig,
138 profiler: &crate::profile::Profiler,
139) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
140 if backends.is_empty() {
141 return Err(crate::Error::Other(anyhow::anyhow!(
142 "no embedding backends provided"
143 )));
144 }
145
146 let files = {
148 let _span = info_span!("walk").entered();
149 let guard = profiler.phase("walk");
150 let files = crate::walk::collect_files(root, cfg.file_type.as_deref());
151 guard.set_detail(format!("{} files", files.len()));
152 files
153 };
154
155 if files.len() >= STREAMING_THRESHOLD {
156 let total_bytes: u64 = files
158 .iter()
159 .filter_map(|p| p.metadata().ok())
160 .map(|m| m.len())
161 .sum();
162 embed_all_streaming(&files, total_bytes, backends, tokenizer, cfg, profiler)
163 } else {
164 embed_all_batch(&files, backends, tokenizer, cfg, profiler)
165 }
166}
167
168fn embed_all_batch(
173 files: &[std::path::PathBuf],
174 backends: &[&dyn EmbedBackend],
175 tokenizer: &tokenizers::Tokenizer,
176 cfg: &SearchConfig,
177 profiler: &crate::profile::Profiler,
178) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
179 let chunks: Vec<CodeChunk> = {
181 let _span = info_span!("chunk", file_count = files.len()).entered();
182 let chunk_start = Instant::now();
183 let text_mode = cfg.text_mode;
184 let result: Vec<CodeChunk> = files
185 .par_iter()
186 .flat_map(|path| {
187 let Some(source) = read_source(path) else {
188 return vec![];
189 };
190 let chunks =
191 crate::chunk::chunk_source_for_path(path, &source, text_mode, &cfg.chunk);
192 profiler.chunk_thread_report(chunks.len());
193 profiler.chunk_batch(&chunks);
194 chunks
195 })
196 .collect();
197 profiler.chunk_summary(result.len(), files.len(), chunk_start.elapsed());
198 result
199 };
200
201 let bs = cfg.batch_size.max(1);
203 let max_tokens_cfg = cfg.max_tokens;
204 let model_max = backends[0].max_tokens();
205 let _span = info_span!("embed_chunks", chunk_count = chunks.len(), batch_size = bs).entered();
206 profiler.embed_begin(chunks.len());
207
208 let all_encodings: Vec<Option<Encoding>> = chunks
209 .par_iter()
210 .map(|chunk| {
211 tokenize(
212 &chunk.enriched_content,
213 tokenizer,
214 max_tokens_cfg,
215 model_max,
216 )
217 .inspect_err(|e| {
218 warn!(file = %chunk.file_path, err = %e, "tokenization failed, skipping chunk");
219 })
220 .ok()
221 })
222 .collect();
223
224 let mut paired: Vec<(CodeChunk, Option<Encoding>)> =
228 chunks.into_iter().zip(all_encodings).collect();
229 paired.sort_by(|a, b| {
230 let len_a = a.1.as_ref().map_or(0, |e| e.input_ids.len());
231 let len_b = b.1.as_ref().map_or(0, |e| e.input_ids.len());
232 len_b.cmp(&len_a) });
234 let (chunks, sorted_encodings): (Vec<CodeChunk>, Vec<Option<Encoding>>) =
235 paired.into_iter().unzip();
236
237 let embeddings = embed_distributed(&sorted_encodings, backends, bs, profiler)?;
239 profiler.embed_done();
240
241 let (chunks, embeddings): (Vec<_>, Vec<_>) = chunks
243 .into_iter()
244 .zip(embeddings)
245 .filter(|(_, emb)| !emb.is_empty())
246 .unzip();
247
248 Ok((chunks, embeddings))
249}
250
251#[expect(
270 clippy::too_many_lines,
271 reason = "streaming pipeline has inherent complexity in thread coordination"
272)]
273fn embed_all_streaming(
274 files: &[std::path::PathBuf],
275 total_bytes: u64,
276 backends: &[&dyn EmbedBackend],
277 tokenizer: &tokenizers::Tokenizer,
278 cfg: &SearchConfig,
279 profiler: &crate::profile::Profiler,
280) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
281 use crossbeam_channel::bounded;
282
283 let bs = cfg.batch_size.max(1);
284 let max_tokens_cfg = cfg.max_tokens;
285 let model_max = backends[0].max_tokens();
286 let file_count = files.len();
287 let text_mode = cfg.text_mode;
288 let chunk_config = cfg.chunk.clone();
289
290 let (chunk_tx, chunk_rx) = bounded::<CodeChunk>(bs * 8);
294
295 let (batch_tx, batch_rx) = bounded::<Vec<(Encoding, CodeChunk)>>(RING_SIZE);
299
300 let total_chunks_produced = AtomicUsize::new(0);
302 let bytes_chunked = AtomicUsize::new(0);
303 let chunk_start = Instant::now();
304
305 std::thread::scope(|scope| {
308 scope.spawn(|| {
314 let _span = info_span!("chunk_stream", file_count).entered();
315 files.par_iter().for_each(|path| {
316 let Some(source) = read_source(path) else {
317 return;
318 };
319 let chunks =
320 crate::chunk::chunk_source_for_path(path, &source, text_mode, &chunk_config);
321 let n = chunks.len();
322 let file_bytes = source.len();
323 profiler.chunk_batch(&chunks);
324 for chunk in chunks {
325 if chunk_tx.send(chunk).is_err() {
327 return;
328 }
329 }
330 profiler.chunk_thread_report(n);
331 total_chunks_produced.fetch_add(n, Ordering::Relaxed);
332 bytes_chunked.fetch_add(file_bytes, Ordering::Relaxed);
333 });
334 drop(chunk_tx);
338 });
339
340 let tokenize_handle = scope.spawn(move || -> crate::Result<()> {
347 let _span = info_span!("tokenize_stream").entered();
348 let mut buffer: Vec<(Encoding, CodeChunk)> = Vec::with_capacity(bs);
349
350 for chunk in &chunk_rx {
351 match tokenize(
352 &chunk.enriched_content,
353 tokenizer,
354 max_tokens_cfg,
355 model_max,
356 ) {
357 Ok(encoding) => {
358 buffer.push((encoding, chunk));
359 if buffer.len() >= bs {
360 buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
362 let batch = std::mem::replace(&mut buffer, Vec::with_capacity(bs));
363 if batch_tx.send(batch).is_err() {
364 return Ok(());
366 }
367 }
368 }
369 Err(e) => {
370 warn!(
371 file = %chunk.file_path, err = %e,
372 "tokenization failed, skipping chunk"
373 );
374 }
375 }
376 }
377
378 if !buffer.is_empty() {
380 buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
381 let _ = batch_tx.send(buffer);
382 }
383 Ok(())
386 });
387
388 let _span = info_span!("embed_stream").entered();
393
394 profiler.embed_begin(0);
396
397 let mut all_chunks: Vec<CodeChunk> = Vec::new();
398 let mut all_embeddings: Vec<Vec<f32>> = Vec::new();
399 let mut embed_error: Option<crate::Error> = None;
400
401 let mut cumulative_done: usize = 0;
402 for batch in &batch_rx {
403 let batch_len = batch.len();
404 let (encodings, chunks): (Vec<Encoding>, Vec<CodeChunk>) = batch.into_iter().unzip();
405
406 let opt_encodings: Vec<Option<Encoding>> = encodings.into_iter().map(Some).collect();
408
409 let noop = crate::profile::Profiler::noop();
412 match embed_distributed(&opt_encodings, backends, bs, &noop) {
413 Ok(batch_embeddings) => {
414 profiler.embedding_batch(&batch_embeddings);
415 cumulative_done += batch_len;
416 let processed = bytes_chunked.load(Ordering::Relaxed) as u64;
419 profiler.embed_tick_bytes(cumulative_done, processed, total_bytes);
420
421 for (chunk, emb) in chunks.into_iter().zip(batch_embeddings) {
422 if !emb.is_empty() {
423 all_chunks.push(chunk);
424 all_embeddings.push(emb);
425 }
426 }
427 }
428 Err(e) => {
429 embed_error = Some(e);
430 break;
432 }
433 }
434 }
435
436 let final_total = total_chunks_produced.load(Ordering::Relaxed);
438 profiler.chunk_summary(final_total, file_count, chunk_start.elapsed());
439 profiler.embed_begin_update_total(cumulative_done);
441 profiler.embed_tick(cumulative_done);
442 profiler.embed_done();
443
444 let tokenize_result = tokenize_handle.join();
446
447 if let Some(e) = embed_error {
449 return Err(e);
450 }
451 match tokenize_result {
452 Ok(Ok(())) => {}
453 Ok(Err(e)) => return Err(e),
454 Err(_) => {
455 return Err(crate::Error::Other(anyhow::anyhow!(
456 "tokenize thread panicked"
457 )));
458 }
459 }
460
461 Ok((all_chunks, all_embeddings))
462 })
463}
464
465#[instrument(skip_all, fields(root = %root.display(), top_k, batch_size = cfg.batch_size))]
485pub fn search(
486 root: &Path,
487 query: &str,
488 backends: &[&dyn EmbedBackend],
489 tokenizer: &tokenizers::Tokenizer,
490 top_k: usize,
491 cfg: &SearchConfig,
492 profiler: &crate::profile::Profiler,
493) -> crate::Result<Vec<SearchResult>> {
494 if backends.is_empty() {
495 return Err(crate::Error::Other(anyhow::anyhow!(
496 "no embedding backends provided"
497 )));
498 }
499
500 let (chunks, embeddings) = embed_all(root, backends, tokenizer, cfg, profiler)?;
502
503 let t_query_start = std::time::Instant::now();
504
505 let hybrid = {
507 let _span = info_span!("build_hybrid_index").entered();
508 let _guard = profiler.phase("build_hybrid_index");
509 crate::hybrid::HybridIndex::new(chunks, &embeddings, cfg.cascade_dim)?
510 };
511
512 let mode = cfg.mode;
513 let effective_top_k = if top_k > 0 { top_k } else { usize::MAX };
514
515 let query_embedding = if mode == crate::hybrid::SearchMode::Keyword {
517 let dim = hybrid.semantic.hidden_dim;
519 vec![0.0f32; dim]
520 } else {
521 let _span = info_span!("embed_query").entered();
522 let _guard = profiler.phase("embed_query");
523 let t_tok = std::time::Instant::now();
524 let enc = tokenize(query, tokenizer, cfg.max_tokens, backends[0].max_tokens())?;
525 let tok_ms = t_tok.elapsed().as_secs_f64() * 1000.0;
526 let t_emb = std::time::Instant::now();
527 let mut results = backends[0].embed_batch(&[enc])?;
528 let emb_ms = t_emb.elapsed().as_secs_f64() * 1000.0;
529 eprintln!(
530 "[search] query: tokenize={tok_ms:.1}ms embed={emb_ms:.1}ms total_since_embed_all={:.1}ms",
531 t_query_start.elapsed().as_secs_f64() * 1000.0
532 );
533 results.pop().ok_or_else(|| {
534 crate::Error::Other(anyhow::anyhow!("backend returned no embedding for query"))
535 })?
536 };
537
538 let ranked = {
540 let _span = info_span!("rank", chunk_count = hybrid.chunks().len()).entered();
541 let guard = profiler.phase("rank");
542 let threshold = 0.0; let results = hybrid.search(&query_embedding, query, effective_top_k, threshold, mode);
544 guard.set_detail(format!(
545 "{mode} top {} from {}",
546 effective_top_k.min(results.len()),
547 hybrid.chunks().len()
548 ));
549 results
550 };
551
552 let results: Vec<SearchResult> = ranked
553 .into_iter()
554 .map(|(idx, score)| SearchResult {
555 chunk: hybrid.chunks()[idx].clone(),
556 similarity: score,
557 })
558 .collect();
559
560 Ok(results)
561}
562
563struct DistributedState<'a> {
565 tokenized: &'a [Option<Encoding>],
566 cursor: std::sync::atomic::AtomicUsize,
567 error_flag: std::sync::atomic::AtomicBool,
568 first_error: std::sync::Mutex<Option<crate::Error>>,
569 done_counter: std::sync::atomic::AtomicUsize,
570 batch_size: usize,
571 profiler: &'a crate::profile::Profiler,
572}
573
574impl DistributedState<'_> {
575 fn run_worker(&self, backend: &dyn EmbedBackend) -> Vec<(usize, Vec<f32>)> {
577 use std::sync::atomic::Ordering;
578
579 let n = self.tokenized.len();
580 let grab_size = if backend.is_gpu() {
584 self.batch_size * 4
585 } else {
586 self.batch_size
587 };
588 let mut results = Vec::new();
589
590 loop {
591 if self.error_flag.load(Ordering::Relaxed) {
592 break;
593 }
594
595 let start = self.cursor.fetch_add(grab_size, Ordering::Relaxed);
596 if start >= n {
597 break;
598 }
599 let end = (start + grab_size).min(n);
600 let batch = &self.tokenized[start..end];
601
602 let mut valid = Vec::with_capacity(batch.len());
604 let mut valid_indices = Vec::with_capacity(batch.len());
605 for (i, enc) in batch.iter().enumerate() {
606 if let Some(e) = enc {
607 valid.push(e.clone());
610 valid_indices.push(start + i);
611 } else {
612 results.push((start + i, vec![]));
613 }
614 }
615
616 if valid.is_empty() {
617 let done =
618 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
619 self.profiler.embed_tick(done);
620 continue;
621 }
622
623 match backend.embed_batch(&valid) {
624 Ok(batch_embeddings) => {
625 self.profiler.embedding_batch(&batch_embeddings);
626 for (idx, emb) in valid_indices.into_iter().zip(batch_embeddings) {
627 results.push((idx, emb));
628 }
629 let done =
630 self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
631 self.profiler.embed_tick(done);
632 }
633 Err(e) => {
634 self.error_flag.store(true, Ordering::Relaxed);
635 if let Ok(mut guard) = self.first_error.lock()
636 && guard.is_none()
637 {
638 *guard = Some(e);
639 }
640 break;
641 }
642 }
643 }
644
645 results
646 }
647}
648
649#[expect(
663 unsafe_code,
664 reason = "BLAS thread count must be set via env vars before spawning workers"
665)]
666pub(crate) fn embed_distributed(
667 tokenized: &[Option<Encoding>],
668 backends: &[&dyn EmbedBackend],
669 batch_size: usize,
670 profiler: &crate::profile::Profiler,
671) -> crate::Result<Vec<Vec<f32>>> {
672 let n = tokenized.len();
673 let state = DistributedState {
674 tokenized,
675 cursor: std::sync::atomic::AtomicUsize::new(0),
676 error_flag: std::sync::atomic::AtomicBool::new(false),
677 first_error: std::sync::Mutex::new(None),
678 done_counter: std::sync::atomic::AtomicUsize::new(0),
679 batch_size: batch_size.max(1),
680 profiler,
681 };
682
683 let all_pairs: Vec<(usize, Vec<f32>)> =
685 if backends.len() == 1 && backends[0].supports_clone() && !backends[0].is_gpu() {
686 unsafe {
699 std::env::set_var("OPENBLAS_NUM_THREADS", "1");
700 std::env::set_var("MKL_NUM_THREADS", "1");
701 std::env::set_var("VECLIB_MAXIMUM_THREADS", "1"); #[cfg(all(not(target_os = "macos"), feature = "cpu"))]
705 {
706 unsafe extern "C" {
707 fn openblas_set_num_threads(num: std::ffi::c_int);
708 }
709 openblas_set_num_threads(1);
710 }
711 }
712
713 let num_workers = rayon::current_num_threads().max(1);
714 std::thread::scope(|s| {
715 let handles: Vec<_> = (0..num_workers)
716 .map(|_| {
717 s.spawn(|| {
718 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
721 crate::backend::driver::cpu::force_single_threaded_blas();
722 let cloned = backends[0].clone_backend();
723 state.run_worker(cloned.as_ref())
724 })
725 })
726 .collect();
727 let mut all = Vec::new();
728 for handle in handles {
729 if let Ok(pairs) = handle.join() {
730 all.extend(pairs);
731 }
732 }
733 all
734 })
735 } else if backends.len() == 1 {
736 state.run_worker(backends[0])
740 } else {
741 std::thread::scope(|s| {
743 let handles: Vec<_> = backends
744 .iter()
745 .map(|&backend| {
746 s.spawn(|| {
747 if backend.supports_clone() {
749 let cloned = backend.clone_backend();
750 state.run_worker(cloned.as_ref())
751 } else {
752 state.run_worker(backend)
753 }
754 })
755 })
756 .collect();
757
758 let mut all = Vec::new();
759 for handle in handles {
760 if let Ok(pairs) = handle.join() {
761 all.extend(pairs);
762 } else {
763 warn!("worker thread panicked");
764 state
765 .error_flag
766 .store(true, std::sync::atomic::Ordering::Relaxed);
767 }
768 }
769 all
770 })
771 };
772
773 if let Some(err) = state.first_error.into_inner().ok().flatten() {
775 return Err(err);
776 }
777
778 let mut embeddings: Vec<Vec<f32>> = vec![vec![]; n];
780 for (idx, emb) in all_pairs {
781 embeddings[idx] = emb;
782 }
783
784 Ok(embeddings)
785}
786
787pub(crate) fn read_source(path: &Path) -> Option<String> {
793 let bytes = match std::fs::read(path) {
794 Ok(b) => b,
795 Err(e) => {
796 trace!(path = %path.display(), err = %e, "skipping file: read failed");
797 return None;
798 }
799 };
800
801 if memchr::memchr(0, &bytes[..bytes.len().min(8192)]).is_some() {
803 trace!(path = %path.display(), "skipping binary file");
804 return None;
805 }
806
807 match std::str::from_utf8(&bytes) {
808 Ok(s) => Some(s.to_string()),
809 Err(e) => {
810 trace!(path = %path.display(), err = %e, "skipping file: not valid UTF-8");
811 None
812 }
813 }
814}
815
816fn tokenize(
823 text: &str,
824 tokenizer: &tokenizers::Tokenizer,
825 max_tokens: usize,
826 model_max_tokens: usize,
827) -> crate::Result<Encoding> {
828 let mut enc = crate::tokenize::tokenize_query(text, tokenizer, model_max_tokens)?;
829 if max_tokens > 0 {
830 let len = enc.input_ids.len().min(max_tokens);
831 enc.input_ids.truncate(len);
832 enc.attention_mask.truncate(len);
833 enc.token_type_ids.truncate(len);
834 }
835 Ok(enc)
836}
837
838pub fn apply_structural_boost<S: ::std::hash::BuildHasher>(
847 results: &mut [SearchResult],
848 file_ranks: &std::collections::HashMap<String, f32, S>,
849 alpha: f32,
850) {
851 if results.is_empty() || alpha == 0.0 {
852 return;
853 }
854
855 let min = results
856 .iter()
857 .map(|r| r.similarity)
858 .fold(f32::INFINITY, f32::min);
859 let max = results
860 .iter()
861 .map(|r| r.similarity)
862 .fold(f32::NEG_INFINITY, f32::max);
863 let range = (max - min).max(1e-12);
864
865 for r in results.iter_mut() {
866 let normalized = (r.similarity - min) / range;
867 let pr = file_ranks.get(&r.chunk.file_path).copied().unwrap_or(0.0);
868 r.similarity = normalized + alpha * pr;
869 }
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875
876 #[test]
877 #[cfg(feature = "cpu")]
878 #[ignore = "loads model + embeds full source tree; run with `cargo test -- --ignored`"]
879 fn search_with_backend_trait() {
880 let backend = crate::backend::load_backend(
881 crate::backend::BackendKind::Cpu,
882 "BAAI/bge-small-en-v1.5",
883 crate::backend::DeviceHint::Cpu,
884 )
885 .unwrap();
886 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
887 let cfg = SearchConfig::default();
888 let profiler = crate::profile::Profiler::noop();
889 let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
890 let results = search(
891 &dir,
892 "embedding model",
893 &[backend.as_ref()],
894 &tokenizer,
895 1,
896 &cfg,
897 &profiler,
898 );
899 assert!(results.is_ok());
900 assert!(!results.unwrap().is_empty());
901 }
902
903 #[test]
904 #[cfg(feature = "cpu")]
905 fn embed_distributed_produces_correct_count() {
906 let backend = crate::backend::load_backend(
907 crate::backend::BackendKind::Cpu,
908 "BAAI/bge-small-en-v1.5",
909 crate::backend::DeviceHint::Cpu,
910 )
911 .unwrap();
912 let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
913 let profiler = crate::profile::Profiler::noop();
914
915 let texts = ["fn hello() {}", "class Foo:", "func main() {}"];
917 let encoded: Vec<Option<Encoding>> = texts
918 .iter()
919 .map(|t| super::tokenize(t, &tokenizer, 0, 512).ok())
920 .collect();
921
922 let results =
923 super::embed_distributed(&encoded, &[backend.as_ref()], 32, &profiler).unwrap();
924
925 assert_eq!(results.len(), 3);
926 for (i, emb) in results.iter().enumerate() {
928 assert_eq!(emb.len(), 384, "embedding {i} should be 384-dim");
929 }
930 }
931
932 fn truncate_and_normalize(emb: &[f32], dims: usize) -> Vec<f32> {
934 let trunc = &emb[..dims];
935 let norm: f32 = trunc.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
936 trunc.iter().map(|x| x / norm).collect()
937 }
938
939 fn rank_topk(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<usize> {
941 let mut scored: Vec<(usize, f32)> = corpus
942 .iter()
943 .enumerate()
944 .map(|(i, emb)| {
945 let dot: f32 = query.iter().zip(emb).map(|(a, b)| a * b).sum();
946 (i, dot)
947 })
948 .collect();
949 scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
950 scored.into_iter().take(k).map(|(i, _)| i).collect()
951 }
952
953 #[test]
961 #[ignore = "loads model + embeds; run with --nocapture"]
962 #[expect(
963 clippy::cast_precision_loss,
964 reason = "top_k and overlap are small counts"
965 )]
966 fn mrl_retrieval_recall() {
967 let model = "BAAI/bge-small-en-v1.5";
968 let backends = crate::backend::detect_backends(model).unwrap();
969 let tokenizer = crate::tokenize::load_tokenizer(model).unwrap();
970 let cfg = SearchConfig::default();
971 let profiler = crate::profile::Profiler::noop();
972
973 let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
975 .parent()
976 .unwrap()
977 .parent()
978 .unwrap();
979 eprintln!("Embedding {}", root.display());
980 let backend_refs: Vec<&dyn crate::backend::EmbedBackend> =
981 backends.iter().map(std::convert::AsRef::as_ref).collect();
982 let (chunks, embeddings) =
983 embed_all(root, &backend_refs, &tokenizer, &cfg, &profiler).unwrap();
984 let full_dim = embeddings[0].len();
985 eprintln!(
986 "Corpus: {} chunks, {full_dim}-dim embeddings\n",
987 chunks.len()
988 );
989
990 let queries = [
992 "error handling in the embedding pipeline",
993 "tree-sitter chunking and AST parsing",
994 "Metal GPU kernel dispatch",
995 "file watcher for incremental reindex",
996 "cosine similarity ranking",
997 ];
998
999 let top_k = 10;
1000 let mrl_dims: Vec<usize> = [32, 64, 128, 192, 256, full_dim]
1001 .into_iter()
1002 .filter(|&d| d <= full_dim)
1003 .collect();
1004
1005 eprintln!("=== MRL Retrieval Recall@{top_k} (vs full {full_dim}-dim) ===\n");
1006
1007 for query in &queries {
1008 let enc = tokenize(query, &tokenizer, 0, backends[0].max_tokens()).unwrap();
1010 let query_emb = backends[0].embed_batch(&[enc]).unwrap().pop().unwrap();
1011
1012 let ref_topk = rank_topk(&query_emb, &embeddings, top_k);
1014
1015 eprintln!("Query: \"{query}\"");
1016 eprintln!(
1017 " Full-dim top-1: {} ({})",
1018 chunks[ref_topk[0]].name, chunks[ref_topk[0]].file_path
1019 );
1020
1021 for &dims in &mrl_dims {
1022 let trunc_corpus: Vec<Vec<f32>> = embeddings
1024 .iter()
1025 .map(|e| truncate_and_normalize(e, dims))
1026 .collect();
1027 let trunc_query = truncate_and_normalize(&query_emb, dims);
1028
1029 let trunc_topk = rank_topk(&trunc_query, &trunc_corpus, top_k);
1030
1031 let overlap = ref_topk.iter().filter(|i| trunc_topk.contains(i)).count();
1033 let recall = overlap as f32 / top_k as f32;
1034 let marker = if dims == full_dim {
1035 " (ref)"
1036 } else if recall >= 0.8 {
1037 " ***"
1038 } else {
1039 ""
1040 };
1041 eprintln!(
1042 " dims={dims:>3}: Recall@{top_k}={recall:.1} ({overlap}/{top_k}){marker}"
1043 );
1044 }
1045 eprintln!();
1046 }
1047 }
1048
1049 fn make_result(file_path: &str, similarity: f32) -> SearchResult {
1050 SearchResult {
1051 chunk: CodeChunk {
1052 file_path: file_path.to_string(),
1053 name: "test".to_string(),
1054 kind: "function".to_string(),
1055 start_line: 1,
1056 end_line: 10,
1057 enriched_content: String::new(),
1058 content: String::new(),
1059 },
1060 similarity,
1061 }
1062 }
1063
1064 #[test]
1065 fn structural_boost_normalizes_and_applies() {
1066 let mut results = vec![
1067 make_result("src/a.rs", 0.8),
1068 make_result("src/b.rs", 0.4),
1069 make_result("src/c.rs", 0.6),
1070 ];
1071 let mut ranks = std::collections::HashMap::new();
1072 ranks.insert("src/a.rs".to_string(), 0.5);
1073 ranks.insert("src/b.rs".to_string(), 1.0);
1074 ranks.insert("src/c.rs".to_string(), 0.0);
1075
1076 apply_structural_boost(&mut results, &ranks, 0.2);
1077
1078 assert!((results[0].similarity - 1.1).abs() < 1e-6);
1080 assert!((results[1].similarity - 0.2).abs() < 1e-6);
1082 assert!((results[2].similarity - 0.5).abs() < 1e-6);
1084 }
1085
1086 #[test]
1087 fn structural_boost_noop_on_empty() {
1088 let mut results: Vec<SearchResult> = vec![];
1089 let ranks = std::collections::HashMap::new();
1090 apply_structural_boost(&mut results, &ranks, 0.2);
1091 assert!(results.is_empty());
1092 }
1093
1094 #[test]
1095 fn structural_boost_noop_on_zero_alpha() {
1096 let mut results = vec![make_result("src/a.rs", 0.8)];
1097 let mut ranks = std::collections::HashMap::new();
1098 ranks.insert("src/a.rs".to_string(), 1.0);
1099 apply_structural_boost(&mut results, &ranks, 0.0);
1100 assert!((results[0].similarity - 0.8).abs() < 1e-6);
1102 }
1103}