Skip to main content

ripvec_core/
embed.rs

1//! Parallel batch embedding pipeline with streaming backpressure.
2//!
3//! Two pipeline modes:
4//!
5//! - **Batch mode** (< `STREAMING_THRESHOLD` files): walk, chunk all, tokenize
6//!   all, sort by length, embed. Simple and optimal for small corpora.
7//!
8//! - **Streaming mode** (>= `STREAMING_THRESHOLD` files): three-stage pipeline
9//!   with bounded channels. Chunks flow through: rayon chunk workers ->
10//!   tokenize+batch collector -> GPU embed consumer. The GPU starts after the
11//!   first `batch_size` encodings are ready (~50ms), not after all chunks are
12//!   done. Backpressure prevents unbounded memory growth.
13//!
14//! # Batch inference
15//!
16//! Instead of one forward pass per chunk, chunks are grouped into batches
17//! of configurable size (default 32). Each batch is tokenized, padded to
18//! the longest sequence, and run as a single forward pass with shape
19//! `[batch_size, max_seq_len]`. This amortizes per-call overhead and enables
20//! SIMD across the batch dimension.
21//!
22//! # Parallelism
23//!
24//! On CPU, each rayon thread gets its own backend clone (cheap — most
25//! backends use `Arc`'d weights internally). On GPU, batches run sequentially
26//! from Rust while the device parallelizes internally.
27
28use std::path::Path;
29use std::sync::atomic::{AtomicUsize, Ordering};
30use std::time::Instant;
31
32use rayon::prelude::*;
33use tracing::{info_span, instrument, trace, warn};
34
35use crate::backend::{EmbedBackend, Encoding};
36use crate::chunk::{ChunkConfig, CodeChunk};
37
38/// Default batch size for embedding inference.
39pub const DEFAULT_BATCH_SIZE: usize = 32;
40
41/// File count threshold for switching from batch to streaming pipeline.
42///
43/// Below this, the batch path (chunk all -> tokenize all -> sort -> embed)
44/// is simpler and allows global sort-by-length optimization. Above this,
45/// streaming eliminates GPU idle time during chunking/tokenization.
46const STREAMING_THRESHOLD: usize = 1000;
47
48/// Number of batch-sized buffers in the embed channel for backpressure.
49///
50/// Keeps memory bounded: at most `RING_SIZE * batch_size` encodings in flight.
51/// Matches the ring-buffer depth documented on [`EmbedBackend`].
52const RING_SIZE: usize = 4;
53
54/// Runtime configuration for the search pipeline.
55///
56/// All tuning parameters that were previously compile-time constants are
57/// gathered here so they can be set from CLI arguments without recompiling.
58#[derive(Debug, Clone)]
59pub struct SearchConfig {
60    /// Chunks per inference call. Larger values amortize call overhead
61    /// but consume more memory. Default: 32.
62    pub batch_size: usize,
63    /// Maximum tokens fed to the model per chunk. `0` means no limit.
64    /// Capping tokens controls inference cost for minified or dense source.
65    /// BERT attention cost scales linearly with token count, and CLS pooling
66    /// means the first token's representation carries most semantic weight.
67    /// Default: 128 (7.7× faster than 512, with minimal quality loss).
68    pub max_tokens: usize,
69    /// Chunking parameters forwarded to the chunking phase.
70    pub chunk: ChunkConfig,
71    /// Force all files to be chunked as plain text (sliding windows only).
72    /// When `false` (default), files with recognized extensions use tree-sitter
73    /// semantic chunking, and unrecognized extensions fall back to sliding windows.
74    pub text_mode: bool,
75    /// MRL cascade pre-filter dimension.
76    ///
77    /// When set, [`SearchIndex`](crate::index::SearchIndex) stores a truncated
78    /// and L2-re-normalized copy of the embedding matrix at this dimension for
79    /// fast two-phase cascade search. `None` (default) disables cascade search.
80    pub cascade_dim: Option<usize>,
81    /// Optional file type filter (e.g. "rust", "python", "js").
82    ///
83    /// When set, only files matching this type (using ripgrep's built-in type
84    /// database) are collected during the walk phase.
85    pub file_type: Option<String>,
86    /// File extensions to exclude during the walk phase.
87    pub exclude_extensions: Vec<String>,
88    /// File extensions to include during the walk phase. Empty means
89    /// "no extension whitelist" (other filters still apply). Non-empty
90    /// restricts walking to files whose extension matches one of these
91    /// (normalized lowercase, with or without leading dot).
92    pub include_extensions: Vec<String>,
93    /// Additional `.gitignore`-style patterns to exclude during the walk phase.
94    pub ignore_patterns: Vec<String>,
95    /// Intent-shaped scope: code, docs, or all. Drives the default
96    /// extension whitelist when `include_extensions` is empty and the
97    /// rerank gate in the MCP layer (`docs` and `all`-on-mixed-corpus
98    /// fire rerank; `code` skips). See [`Scope`].
99    pub scope: Scope,
100    /// Search mode: hybrid (default), semantic, or keyword.
101    pub mode: crate::hybrid::SearchMode,
102}
103
104/// Intent-shaped scope for a search invocation.
105///
106/// Used as the user-facing axis for picking what kind of files
107/// participate in a search and whether the prose-tuned cross-encoder
108/// reranker fires. Maps internally to extension allow-lists and to
109/// the rerank gate's policy table.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
111#[serde(rename_all = "lowercase")]
112pub enum Scope {
113    /// Only code-language files. Cross-encoder rerank is skipped — the
114    /// ms-marco model is out-of-domain for code chunks.
115    Code,
116    /// Only prose / documentation files (`md`, `rst`, `txt`, `adoc`,
117    /// `mdx`, `org`). Cross-encoder rerank fires by default on NL
118    /// queries.
119    Docs,
120    /// No extension whitelist; the rerank gate decides based on the
121    /// indexed corpus's prose fraction (see
122    /// `RipvecIndex::corpus_class`). Default.
123    #[default]
124    All,
125}
126
127/// Canonical prose file extensions for `Scope::Docs`. Kept in sync with
128/// [`crate::encoder::ripvec::ranking::is_prose_path`].
129pub const PROSE_EXTENSIONS: &[&str] = &[
130    "md", "markdown", "mdx", "rst", "txt", "text", "adoc", "asciidoc", "org",
131];
132
133impl SearchConfig {
134    /// Convert search configuration into shared walk filters.
135    ///
136    /// Resolves the scope-implied extension whitelist:
137    ///
138    /// - Explicit `include_extensions` always wins.
139    /// - Otherwise `Scope::Docs` injects the canonical prose set
140    ///   ([`PROSE_EXTENSIONS`]).
141    /// - `Scope::Code` injects the canonical prose set as
142    ///   *exclusions* (so prose files are skipped during walk).
143    /// - `Scope::All` leaves the include set empty (no whitelist).
144    #[must_use]
145    pub fn walk_options(&self) -> crate::walk::WalkOptions {
146        let mut include = self.include_extensions.clone();
147        let mut exclude = self.exclude_extensions.clone();
148        if include.is_empty() {
149            match self.scope {
150                Scope::Docs => {
151                    include.extend(PROSE_EXTENSIONS.iter().map(|s| (*s).to_string()));
152                }
153                Scope::Code => {
154                    for ext in PROSE_EXTENSIONS {
155                        if !exclude.iter().any(|e| e.eq_ignore_ascii_case(ext)) {
156                            exclude.push((*ext).to_string());
157                        }
158                    }
159                }
160                Scope::All => {}
161            }
162        }
163        crate::walk::WalkOptions {
164            file_type: self.file_type.clone(),
165            include_extensions: include,
166            exclude_extensions: exclude,
167            ignore_patterns: self.ignore_patterns.clone(),
168        }
169    }
170
171    /// Merge ignore patterns from `.ripvec/config.toml`, if present.
172    pub fn apply_repo_config(&mut self, root: &Path) {
173        let Some((_, config)) = crate::cache::config::find_config(root) else {
174            return;
175        };
176        for pattern in config.ignore.patterns {
177            if !pattern.trim().is_empty() && !self.ignore_patterns.contains(&pattern) {
178                self.ignore_patterns.push(pattern);
179            }
180        }
181    }
182}
183
184impl Default for SearchConfig {
185    fn default() -> Self {
186        Self {
187            batch_size: DEFAULT_BATCH_SIZE,
188            max_tokens: 0,
189            chunk: ChunkConfig::default(),
190            text_mode: false,
191            cascade_dim: None,
192            file_type: None,
193            exclude_extensions: Vec::new(),
194            include_extensions: Vec::new(),
195            ignore_patterns: Vec::new(),
196            scope: Scope::All,
197            mode: crate::hybrid::SearchMode::Hybrid,
198        }
199    }
200}
201
202/// A search result pairing a code chunk with its similarity score.
203#[derive(Debug, Clone)]
204pub struct SearchResult {
205    /// The matched code chunk.
206    pub chunk: CodeChunk,
207    /// Cosine similarity to the query (0.0 to 1.0).
208    pub similarity: f32,
209}
210
211/// Walk, chunk, and embed all files in a directory.
212///
213/// Returns the chunks and their corresponding embedding vectors.
214/// This is the building block for both one-shot search and interactive mode.
215/// The caller handles query embedding and ranking.
216///
217/// Accepts multiple backends for hybrid scheduling — chunks are distributed
218/// across all backends via work-stealing (see `embed_distributed`).
219///
220/// Automatically selects between two pipeline modes:
221/// - **Batch** (< `STREAMING_THRESHOLD` files): chunk all, tokenize all, sort
222///   by length, embed. Optimal for small corpora.
223/// - **Streaming** (>= `STREAMING_THRESHOLD` files): three-stage pipeline with
224///   bounded channels. GPU starts after the first batch is ready, not after all
225///   chunks are done. Eliminates GPU idle time during chunking/tokenization.
226///
227/// # Errors
228///
229/// Returns an error if file walking, chunking, or embedding fails.
230#[instrument(skip_all, fields(root = %root.display(), batch_size = cfg.batch_size))]
231pub fn embed_all(
232    root: &Path,
233    backends: &[&dyn EmbedBackend],
234    tokenizer: &tokenizers::Tokenizer,
235    cfg: &SearchConfig,
236    profiler: &crate::profile::Profiler,
237) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
238    if backends.is_empty() {
239        return Err(crate::Error::Other(anyhow::anyhow!(
240            "no embedding backends provided"
241        )));
242    }
243
244    // Phase 1: Collect files (respects .gitignore, filters by extension)
245    let files = {
246        let _span = info_span!("walk").entered();
247        let guard = profiler.phase("walk");
248        let walk_options = cfg.walk_options();
249        let files = crate::walk::collect_files_with_options(root, &walk_options);
250        guard.set_detail(format!("{} files", files.len()));
251        files
252    };
253
254    if files.len() >= STREAMING_THRESHOLD {
255        // Compute total source bytes for byte-based progress (known after walk).
256        let total_bytes: u64 = files
257            .iter()
258            .filter_map(|p| p.metadata().ok())
259            .map(|m| m.len())
260            .sum();
261        embed_all_streaming(&files, total_bytes, backends, tokenizer, cfg, profiler)
262    } else {
263        embed_all_batch(&files, backends, tokenizer, cfg, profiler)
264    }
265}
266
267/// Batch pipeline: chunk all -> tokenize all -> sort by length -> embed.
268///
269/// Optimal for small corpora where the global sort-by-length optimization
270/// matters more than eliminating GPU idle time.
271fn embed_all_batch(
272    files: &[std::path::PathBuf],
273    backends: &[&dyn EmbedBackend],
274    tokenizer: &tokenizers::Tokenizer,
275    cfg: &SearchConfig,
276    profiler: &crate::profile::Profiler,
277) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
278    // Phase 2: Chunk all files in parallel.
279    let chunks: Vec<CodeChunk> = {
280        let _span = info_span!("chunk", file_count = files.len()).entered();
281        let chunk_start = Instant::now();
282        let text_mode = cfg.text_mode;
283        let result: Vec<CodeChunk> = files
284            .par_iter()
285            .flat_map(|path| {
286                let Some(source) = read_source(path) else {
287                    return vec![];
288                };
289                let chunks =
290                    crate::chunk::chunk_source_for_path(path, &source, text_mode, &cfg.chunk);
291                profiler.chunk_thread_report(chunks.len());
292                profiler.chunk_batch(&chunks);
293                chunks
294            })
295            .collect();
296        profiler.chunk_summary(result.len(), files.len(), chunk_start.elapsed());
297        result
298    };
299
300    // Phase 3: Pre-tokenize all chunks in parallel (CPU-bound, all rayon threads)
301    let bs = cfg.batch_size.max(1);
302    let max_tokens_cfg = cfg.max_tokens;
303    let model_max = backends[0].max_tokens();
304    let _span = info_span!("embed_chunks", chunk_count = chunks.len(), batch_size = bs).entered();
305    profiler.embed_begin(chunks.len());
306
307    let all_encodings: Vec<Option<Encoding>> = chunks
308        .par_iter()
309        .map(|chunk| {
310            tokenize(
311                &chunk.enriched_content,
312                tokenizer,
313                max_tokens_cfg,
314                model_max,
315            )
316            .inspect_err(|e| {
317                warn!(file = %chunk.file_path, err = %e, "tokenization failed, skipping chunk");
318            })
319            .ok()
320        })
321        .collect();
322
323    // Sort chunks and their encodings together by descending token count.
324    // This groups similar-length sequences into the same batch, minimizing
325    // padding waste (short chunks no longer get padded to a long neighbour).
326    let mut paired: Vec<(CodeChunk, Option<Encoding>)> =
327        chunks.into_iter().zip(all_encodings).collect();
328    paired.sort_by(|a, b| {
329        let len_a = a.1.as_ref().map_or(0, |e| e.input_ids.len());
330        let len_b = b.1.as_ref().map_or(0, |e| e.input_ids.len());
331        len_b.cmp(&len_a) // descending — longest first
332    });
333    let (chunks, sorted_encodings): (Vec<CodeChunk>, Vec<Option<Encoding>>) =
334        paired.into_iter().unzip();
335
336    // Phase 4: Distribute pre-tokenized batches across all backends
337    let embeddings = embed_distributed(&sorted_encodings, backends, bs, profiler)?;
338    profiler.embed_done();
339
340    // Filter out chunks whose tokenization failed (empty embedding vectors).
341    let (chunks, embeddings): (Vec<_>, Vec<_>) = chunks
342        .into_iter()
343        .zip(embeddings)
344        .filter(|(_, emb)| !emb.is_empty())
345        .unzip();
346
347    Ok((chunks, embeddings))
348}
349
350/// Streaming pipeline: chunk -> tokenize -> batch -> embed with backpressure.
351///
352/// Three concurrent stages connected by bounded channels:
353///
354/// 1. **Chunk producers** (rayon pool, in a scoped thread): read + parse files,
355///    send chunks to channel.
356/// 2. **Tokenize + batch collector** (scoped thread): tokenize chunks, sort
357///    within batch windows, send full batches to the embed channel.
358/// 3. **Embed consumer** (main thread): calls `embed_distributed` on each
359///    batch, collects results.
360///
361/// The bounded channels provide natural backpressure: if the GPU falls behind,
362/// the tokenize stage blocks, which blocks chunk producers via the chunk channel.
363/// If chunking is fast and the GPU is slow, at most
364/// `8 * batch_size + RING_SIZE * batch_size` items are in memory.
365///
366/// Uses `std::thread::scope` so all threads can borrow the caller's stack
367/// (`tokenizer`, `backends`, `profiler`) without `'static` bounds.
368#[expect(
369    clippy::too_many_lines,
370    reason = "streaming pipeline has inherent complexity in thread coordination"
371)]
372fn embed_all_streaming(
373    files: &[std::path::PathBuf],
374    total_bytes: u64,
375    backends: &[&dyn EmbedBackend],
376    tokenizer: &tokenizers::Tokenizer,
377    cfg: &SearchConfig,
378    profiler: &crate::profile::Profiler,
379) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
380    use crossbeam_channel::bounded;
381
382    let bs = cfg.batch_size.max(1);
383    let max_tokens_cfg = cfg.max_tokens;
384    let model_max = backends[0].max_tokens();
385    let file_count = files.len();
386    let text_mode = cfg.text_mode;
387    let chunk_config = cfg.chunk.clone();
388
389    // Bounded channel from chunk producers -> tokenize+batch stage.
390    // Factor of 8 gives enough buffering for rayon parallelism without
391    // unbounded growth (at most ~8 batches worth of chunks in flight).
392    let (chunk_tx, chunk_rx) = bounded::<CodeChunk>(bs * 8);
393
394    // Bounded channel from tokenize+batch stage -> embed consumer.
395    // RING_SIZE batches in flight provides enough pipeline depth for GPU
396    // to stay busy while the next batch is being tokenized.
397    let (batch_tx, batch_rx) = bounded::<Vec<(Encoding, CodeChunk)>>(RING_SIZE);
398
399    // Shared counters for profiling across the streaming pipeline.
400    let total_chunks_produced = AtomicUsize::new(0);
401    let bytes_chunked = AtomicUsize::new(0);
402    let chunk_start = Instant::now();
403
404    // All stages run inside std::thread::scope so they can borrow from the
405    // caller's stack (tokenizer, backends, profiler, files, etc.).
406    std::thread::scope(|scope| {
407        // --- Stage 1: Chunk producers (rayon inside a scoped thread) ---
408        //
409        // Spawns a scoped thread that drives rayon's par_iter. Each file is
410        // chunked independently and chunks are sent into the bounded channel.
411        // If the channel is full, rayon workers block, providing backpressure.
412        scope.spawn(|| {
413            let _span = info_span!("chunk_stream", file_count).entered();
414            files.par_iter().for_each(|path| {
415                let Some(source) = read_source(path) else {
416                    return;
417                };
418                let chunks =
419                    crate::chunk::chunk_source_for_path(path, &source, text_mode, &chunk_config);
420                let n = chunks.len();
421                let file_bytes = source.len();
422                profiler.chunk_batch(&chunks);
423                for chunk in chunks {
424                    // Channel disconnected means downstream errored; stop.
425                    if chunk_tx.send(chunk).is_err() {
426                        return;
427                    }
428                }
429                profiler.chunk_thread_report(n);
430                total_chunks_produced.fetch_add(n, Ordering::Relaxed);
431                bytes_chunked.fetch_add(file_bytes, Ordering::Relaxed);
432            });
433            // chunk_tx is dropped here, closing the channel — but the borrow
434            // of chunk_tx lives until the scoped thread ends. We need to
435            // explicitly drop it so the tokenize stage sees the channel close.
436            drop(chunk_tx);
437        });
438
439        // --- Stage 2: Tokenize + batch collector (scoped thread) ---
440        //
441        // Receives individual chunks, tokenizes each (HuggingFace tokenizer
442        // is Send + Sync), and accumulates into batch-sized buffers. Within
443        // each buffer, entries are sorted by descending token count — the same
444        // padding-reduction optimization as the batch path, applied locally.
445        let tokenize_handle = scope.spawn(move || -> crate::Result<()> {
446            let _span = info_span!("tokenize_stream").entered();
447            let mut buffer: Vec<(Encoding, CodeChunk)> = Vec::with_capacity(bs);
448
449            for chunk in &chunk_rx {
450                match tokenize(
451                    &chunk.enriched_content,
452                    tokenizer,
453                    max_tokens_cfg,
454                    model_max,
455                ) {
456                    Ok(encoding) => {
457                        buffer.push((encoding, chunk));
458                        if buffer.len() >= bs {
459                            // Sort within batch by descending token count.
460                            buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
461                            let batch = std::mem::replace(&mut buffer, Vec::with_capacity(bs));
462                            if batch_tx.send(batch).is_err() {
463                                // Embed consumer dropped; stop tokenizing.
464                                return Ok(());
465                            }
466                        }
467                    }
468                    Err(e) => {
469                        warn!(
470                            file = %chunk.file_path, err = %e,
471                            "tokenization failed, skipping chunk"
472                        );
473                    }
474                }
475            }
476
477            // Flush remaining partial batch.
478            if !buffer.is_empty() {
479                buffer.sort_by_key(|b| std::cmp::Reverse(b.0.input_ids.len()));
480                let _ = batch_tx.send(buffer);
481            }
482            // batch_tx drops here, closing the embed channel.
483
484            Ok(())
485        });
486
487        // --- Stage 3: Embed consumer (main thread within scope) ---
488        //
489        // Receives sorted batches, embeds via the backend(s), collects results.
490        // Profiler is driven from here since this thread owns the reference.
491        let _span = info_span!("embed_stream").entered();
492
493        // Total isn't known upfront in streaming mode; start at 0 and update.
494        profiler.embed_begin(0);
495
496        let mut all_chunks: Vec<CodeChunk> = Vec::new();
497        let mut all_embeddings: Vec<Vec<f32>> = Vec::new();
498        let mut embed_error: Option<crate::Error> = None;
499
500        let mut cumulative_done: usize = 0;
501        for batch in &batch_rx {
502            let batch_len = batch.len();
503            let (encodings, chunks): (Vec<Encoding>, Vec<CodeChunk>) = batch.into_iter().unzip();
504
505            // Wrap as Option<Encoding> for embed_distributed compatibility.
506            let opt_encodings: Vec<Option<Encoding>> = encodings.into_iter().map(Some).collect();
507
508            // Pass noop profiler to embed_distributed — its internal done counter
509            // resets per call (0→batch_size), which corrupts our global progress.
510            let noop = crate::profile::Profiler::noop();
511            match embed_distributed(&opt_encodings, backends, bs, &noop) {
512                Ok(batch_embeddings) => {
513                    profiler.embedding_batch(&batch_embeddings);
514                    cumulative_done += batch_len;
515                    // Byte-based progress: total_bytes known from walk, bytes_chunked
516                    // tracks how much source data has been processed through the pipeline.
517                    let processed = bytes_chunked.load(Ordering::Relaxed) as u64;
518                    profiler.embed_tick_bytes(cumulative_done, processed, total_bytes);
519
520                    for (chunk, emb) in chunks.into_iter().zip(batch_embeddings) {
521                        if !emb.is_empty() {
522                            all_chunks.push(chunk);
523                            all_embeddings.push(emb);
524                        }
525                    }
526                }
527                Err(e) => {
528                    embed_error = Some(e);
529                    // break exits the for loop; batch_rx drops naturally after.
530                    break;
531                }
532            }
533        }
534
535        // Report chunk summary now that all stages have completed (or errored).
536        let final_total = total_chunks_produced.load(Ordering::Relaxed);
537        profiler.chunk_summary(final_total, file_count, chunk_start.elapsed());
538        // Set the final total so embed_done shows the correct summary.
539        profiler.embed_begin_update_total(cumulative_done);
540        profiler.embed_tick(cumulative_done);
541        profiler.embed_done();
542
543        // Wait for tokenize thread and check for errors.
544        let tokenize_result = tokenize_handle.join();
545
546        // Error priority: embed > tokenize > thread panic.
547        if let Some(e) = embed_error {
548            return Err(e);
549        }
550        match tokenize_result {
551            Ok(Ok(())) => {}
552            Ok(Err(e)) => return Err(e),
553            Err(_) => {
554                return Err(crate::Error::Other(anyhow::anyhow!(
555                    "tokenize thread panicked"
556                )));
557            }
558        }
559
560        Ok((all_chunks, all_embeddings))
561    })
562}
563
564/// Search a directory for code chunks semantically similar to a query.
565///
566/// Walks the directory, chunks all supported files, embeds everything
567/// in parallel batches, and returns the top-k results ranked by similarity.
568///
569/// Accepts multiple backends for hybrid scheduling — the first backend
570/// (`backends[0]`) is used for query embedding.
571///
572/// All tuning parameters (batch size, token limit, chunk sizing) are
573/// controlled via [`SearchConfig`].
574///
575/// # Errors
576///
577/// Returns an error if the query cannot be tokenized or embedded.
578///
579/// # Panics
580///
581/// Panics if a per-thread backend clone fails during parallel embedding
582/// (should not happen if the backend loaded successfully).
583#[instrument(skip_all, fields(root = %root.display(), top_k, batch_size = cfg.batch_size))]
584pub fn search(
585    root: &Path,
586    query: &str,
587    backends: &[&dyn EmbedBackend],
588    tokenizer: &tokenizers::Tokenizer,
589    top_k: usize,
590    cfg: &SearchConfig,
591    profiler: &crate::profile::Profiler,
592) -> crate::Result<Vec<SearchResult>> {
593    if backends.is_empty() {
594        return Err(crate::Error::Other(anyhow::anyhow!(
595            "no embedding backends provided"
596        )));
597    }
598
599    // Phases 1, 2, 3, 4: walk, chunk, pre-tokenize, embed all files
600    let (chunks, embeddings) = embed_all(root, backends, tokenizer, cfg, profiler)?;
601
602    let t_query_start = std::time::Instant::now();
603
604    // Phase 5: Build hybrid index (semantic + BM25)
605    let hybrid = {
606        let _span = info_span!("build_hybrid_index").entered();
607        let _guard = profiler.phase("build_hybrid_index");
608        crate::hybrid::HybridIndex::new(chunks, &embeddings, cfg.cascade_dim)?
609    };
610
611    let mode = cfg.mode;
612    let effective_top_k = if top_k > 0 { top_k } else { usize::MAX };
613
614    // Phase 6: Embed query (skip for keyword-only mode)
615    let query_embedding = if mode == crate::hybrid::SearchMode::Keyword {
616        // Keyword mode: no embedding needed, use zero vector
617        let dim = hybrid.semantic.hidden_dim;
618        vec![0.0f32; dim]
619    } else {
620        let _span = info_span!("embed_query").entered();
621        let _guard = profiler.phase("embed_query");
622        let t_tok = std::time::Instant::now();
623        let enc = tokenize(query, tokenizer, cfg.max_tokens, backends[0].max_tokens())?;
624        let tok_ms = t_tok.elapsed().as_secs_f64() * 1000.0;
625        let t_emb = std::time::Instant::now();
626        let mut results = backends[0].embed_batch(&[enc])?;
627        let emb_ms = t_emb.elapsed().as_secs_f64() * 1000.0;
628        eprintln!(
629            "[search] query: tokenize={tok_ms:.1}ms embed={emb_ms:.1}ms total_since_embed_all={:.1}ms",
630            t_query_start.elapsed().as_secs_f64() * 1000.0
631        );
632        results.pop().ok_or_else(|| {
633            crate::Error::Other(anyhow::anyhow!("backend returned no embedding for query"))
634        })?
635    };
636
637    // Phase 7: Hybrid/semantic/keyword ranking
638    let ranked = {
639        let _span = info_span!("rank", chunk_count = hybrid.chunks().len()).entered();
640        let guard = profiler.phase("rank");
641        let threshold = 0.0; // all modes use 0.0; SearchIndex::rank applies its own
642        let results = hybrid.search(&query_embedding, query, effective_top_k, threshold, mode);
643        guard.set_detail(format!(
644            "{mode} top {} from {}",
645            effective_top_k.min(results.len()),
646            hybrid.chunks().len()
647        ));
648        results
649    };
650
651    let results: Vec<SearchResult> = ranked
652        .into_iter()
653        .map(|(idx, score)| SearchResult {
654            chunk: hybrid.chunks()[idx].clone(),
655            similarity: score,
656        })
657        .collect();
658
659    Ok(results)
660}
661
662/// Shared state for [`embed_distributed`] workers.
663struct DistributedState<'a> {
664    tokenized: &'a [Option<Encoding>],
665    cursor: std::sync::atomic::AtomicUsize,
666    error_flag: std::sync::atomic::AtomicBool,
667    first_error: std::sync::Mutex<Option<crate::Error>>,
668    done_counter: std::sync::atomic::AtomicUsize,
669    batch_size: usize,
670    profiler: &'a crate::profile::Profiler,
671}
672
673impl DistributedState<'_> {
674    /// Worker loop: claim batches from the shared cursor, embed, collect results.
675    fn run_worker(&self, backend: &dyn EmbedBackend) -> Vec<(usize, Vec<f32>)> {
676        use std::sync::atomic::Ordering;
677
678        let n = self.tokenized.len();
679        // GPU backends grab larger batches to amortize per-call overhead.
680        // MLX's lazy eval graph optimizer benefits from large matrices.
681        // Metal sub-batches internally via MAX_BATCH to limit padding waste.
682        let grab_size = if backend.is_gpu() {
683            self.batch_size * 4
684        } else {
685            self.batch_size
686        };
687        let mut results = Vec::new();
688
689        loop {
690            if self.error_flag.load(Ordering::Relaxed) {
691                break;
692            }
693
694            let start = self.cursor.fetch_add(grab_size, Ordering::Relaxed);
695            if start >= n {
696                break;
697            }
698            let end = (start + grab_size).min(n);
699            let batch = &self.tokenized[start..end];
700
701            // Separate valid encodings from Nones, tracking which indices succeeded
702            let mut valid = Vec::with_capacity(batch.len());
703            let mut valid_indices = Vec::with_capacity(batch.len());
704            for (i, enc) in batch.iter().enumerate() {
705                if let Some(e) = enc {
706                    // TODO(perf): cloning 3 Vecs per chunk; consider making
707                    // `EmbedBackend::embed_batch` accept `&[&Encoding]` to avoid this.
708                    valid.push(e.clone());
709                    valid_indices.push(start + i);
710                } else {
711                    results.push((start + i, vec![]));
712                }
713            }
714
715            if valid.is_empty() {
716                let done =
717                    self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
718                self.profiler.embed_tick(done);
719                continue;
720            }
721
722            match backend.embed_batch(&valid) {
723                Ok(batch_embeddings) => {
724                    self.profiler.embedding_batch(&batch_embeddings);
725                    for (idx, emb) in valid_indices.into_iter().zip(batch_embeddings) {
726                        results.push((idx, emb));
727                    }
728                    let done =
729                        self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
730                    self.profiler.embed_tick(done);
731                }
732                Err(e) => {
733                    self.error_flag.store(true, Ordering::Relaxed);
734                    if let Ok(mut guard) = self.first_error.lock()
735                        && guard.is_none()
736                    {
737                        *guard = Some(e);
738                    }
739                    break;
740                }
741            }
742        }
743
744        results
745    }
746}
747
748/// Distribute pre-tokenized chunks across multiple backends using work-stealing.
749///
750/// Each backend gets a dedicated worker thread. Workers compete on a shared
751/// `AtomicUsize` cursor to claim batches of chunks. GPU backends grab larger
752/// batches (`batch_size * 4`), CPU backends grab smaller ones (`batch_size`).
753/// Results are written by original chunk index — no merge step needed.
754///
755/// When `backends` has a single entry, no extra threads are spawned.
756///
757/// # Errors
758///
759/// Returns the first error from any backend. Other workers exit early
760/// when an error is detected.
761#[expect(
762    unsafe_code,
763    reason = "BLAS thread count must be set via env vars before spawning workers"
764)]
765pub(crate) fn embed_distributed(
766    tokenized: &[Option<Encoding>],
767    backends: &[&dyn EmbedBackend],
768    batch_size: usize,
769    profiler: &crate::profile::Profiler,
770) -> crate::Result<Vec<Vec<f32>>> {
771    let n = tokenized.len();
772    let state = DistributedState {
773        tokenized,
774        cursor: std::sync::atomic::AtomicUsize::new(0),
775        error_flag: std::sync::atomic::AtomicBool::new(false),
776        first_error: std::sync::Mutex::new(None),
777        done_counter: std::sync::atomic::AtomicUsize::new(0),
778        batch_size: batch_size.max(1),
779        profiler,
780    };
781
782    // Collect (index, embedding) pairs from all workers
783    let all_pairs: Vec<(usize, Vec<f32>)> =
784        if backends.len() == 1 && backends[0].supports_clone() && !backends[0].is_gpu() {
785            // Single cloneable CPU backend: spawn N workers with single-threaded BLAS.
786            //
787            // BLAS libraries (OpenBLAS, MKL) internally spawn threads for each matmul.
788            // For small matrices ([1,384]×[384,384]), this thread overhead dominates —
789            // profiling shows 80% of time in sched_yield (thread contention).
790            //
791            // Instead: force BLAS to single-thread per worker, parallelize across
792            // independent BERT inferences. Each worker gets its own cloned backend.
793            // Force BLAS libraries to single-threaded mode.
794            // We parallelize across independent BERT inferences instead.
795            // env vars don't always work (OpenBLAS may ignore after init),
796            // so also call the runtime API directly.
797            unsafe {
798                std::env::set_var("OPENBLAS_NUM_THREADS", "1");
799                std::env::set_var("MKL_NUM_THREADS", "1");
800                std::env::set_var("VECLIB_MAXIMUM_THREADS", "1"); // macOS Accelerate
801
802                // Direct FFI to set BLAS thread count — works even after init
803                #[cfg(all(not(target_os = "macos"), feature = "cpu"))]
804                {
805                    unsafe extern "C" {
806                        fn openblas_set_num_threads(num: std::ffi::c_int);
807                    }
808                    openblas_set_num_threads(1);
809                }
810            }
811
812            let num_workers = rayon::current_num_threads().max(1);
813            std::thread::scope(|s| {
814                let handles: Vec<_> = (0..num_workers)
815                    .map(|_| {
816                        s.spawn(|| {
817                            // Per-thread: force single-threaded BLAS (thread-local setting).
818                            // On macOS 15+ this calls BLASSetThreading; on Linux openblas_set_num_threads.
819                            #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
820                            crate::backend::driver::cpu::force_single_threaded_blas();
821                            let cloned = backends[0].clone_backend();
822                            state.run_worker(cloned.as_ref())
823                        })
824                    })
825                    .collect();
826                let mut all = Vec::new();
827                for handle in handles {
828                    if let Ok(pairs) = handle.join() {
829                        all.extend(pairs);
830                    }
831                }
832                all
833            })
834        } else if backends.len() == 1 {
835            // Single non-cloneable backend (GPU or CPU ModernBERT): run on the calling thread.
836            // GPU backends handle parallelism internally; CPU uses BLAS internal
837            // multi-threading (Accelerate/OpenBLAS) for intra-GEMM parallelism.
838            state.run_worker(backends[0])
839        } else {
840            // Multiple backends: one thread per backend via std::thread::scope
841            std::thread::scope(|s| {
842                let handles: Vec<_> = backends
843                    .iter()
844                    .map(|&backend| {
845                        s.spawn(|| {
846                            // CPU backends that support cloning get a thread-local copy
847                            if backend.supports_clone() {
848                                let cloned = backend.clone_backend();
849                                state.run_worker(cloned.as_ref())
850                            } else {
851                                state.run_worker(backend)
852                            }
853                        })
854                    })
855                    .collect();
856
857                let mut all = Vec::new();
858                for handle in handles {
859                    if let Ok(pairs) = handle.join() {
860                        all.extend(pairs);
861                    } else {
862                        warn!("worker thread panicked");
863                        state
864                            .error_flag
865                            .store(true, std::sync::atomic::Ordering::Relaxed);
866                    }
867                }
868                all
869            })
870        };
871
872    // Check for errors before assembling results
873    if let Some(err) = state.first_error.into_inner().ok().flatten() {
874        return Err(err);
875    }
876
877    // Scatter results into output vec by original index
878    let mut embeddings: Vec<Vec<f32>> = vec![vec![]; n];
879    for (idx, emb) in all_pairs {
880        embeddings[idx] = emb;
881    }
882
883    Ok(embeddings)
884}
885
886/// Read a source file into a `String`, skipping binary files.
887///
888/// Reads the file as raw bytes first, checks for NUL bytes in the first 8 KB
889/// to detect binary files, then converts to UTF-8. Returns `None` (with a
890/// trace log) when the file cannot be read, is binary, or is not valid UTF-8.
891pub(crate) fn read_source(path: &Path) -> Option<String> {
892    let bytes = match std::fs::read(path) {
893        Ok(b) => b,
894        Err(e) => {
895            trace!(path = %path.display(), err = %e, "skipping file: read failed");
896            return None;
897        }
898    };
899
900    // Skip binary files: NUL byte anywhere in the first 8 KB is a reliable signal.
901    if memchr::memchr(0, &bytes[..bytes.len().min(8192)]).is_some() {
902        trace!(path = %path.display(), "skipping binary file");
903        return None;
904    }
905
906    match std::str::from_utf8(&bytes) {
907        Ok(s) => Some(s.to_string()),
908        Err(e) => {
909            trace!(path = %path.display(), err = %e, "skipping file: not valid UTF-8");
910            None
911        }
912    }
913}
914
915/// Tokenize text into an [`Encoding`] ready for model inference.
916///
917/// Delegates to [`crate::tokenize::tokenize_query`] for the core encoding,
918/// then applies an additional `max_tokens` truncation when non-zero.
919/// CLS pooling means the first token's representation carries most semantic
920/// weight, so truncation has minimal quality impact.
921fn tokenize(
922    text: &str,
923    tokenizer: &tokenizers::Tokenizer,
924    max_tokens: usize,
925    model_max_tokens: usize,
926) -> crate::Result<Encoding> {
927    let mut enc = crate::tokenize::tokenize_query(text, tokenizer, model_max_tokens)?;
928    if max_tokens > 0 {
929        let len = enc.input_ids.len().min(max_tokens);
930        enc.input_ids.truncate(len);
931        enc.attention_mask.truncate(len);
932        enc.token_type_ids.truncate(len);
933    }
934    Ok(enc)
935}
936
937/// Normalize similarity scores to `[0,1]` and apply a `PageRank` structural boost.
938///
939/// Each result's similarity is min-max normalized, then a weighted `PageRank`
940/// score is added: `final = normalized + alpha * pagerank`. This promotes
941/// architecturally important files (many dependents) in search results.
942///
943/// Called from the MCP search handler which has access to the `RepoGraph`,
944/// rather than from [`search`] directly.
945pub fn apply_structural_boost<S: ::std::hash::BuildHasher>(
946    results: &mut [SearchResult],
947    file_ranks: &std::collections::HashMap<String, f32, S>,
948    alpha: f32,
949) {
950    if results.is_empty() || alpha == 0.0 {
951        return;
952    }
953
954    let min = results
955        .iter()
956        .map(|r| r.similarity)
957        .fold(f32::INFINITY, f32::min);
958    let max = results
959        .iter()
960        .map(|r| r.similarity)
961        .fold(f32::NEG_INFINITY, f32::max);
962    let range = (max - min).max(1e-12);
963
964    for r in results.iter_mut() {
965        let normalized = (r.similarity - min) / range;
966        let pr = file_ranks.get(&r.chunk.file_path).copied().unwrap_or(0.0);
967        r.similarity = normalized + alpha * pr;
968    }
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974
975    #[test]
976    #[cfg(feature = "cpu")]
977    #[ignore = "loads model + embeds full source tree; run with `cargo test -- --ignored`"]
978    fn search_with_backend_trait() {
979        let backend = crate::backend::load_backend(
980            crate::backend::BackendKind::Cpu,
981            "BAAI/bge-small-en-v1.5",
982            crate::backend::DeviceHint::Cpu,
983        )
984        .unwrap();
985        let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
986        let cfg = SearchConfig::default();
987        let profiler = crate::profile::Profiler::noop();
988        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
989        let results = search(
990            &dir,
991            "embedding model",
992            &[backend.as_ref()],
993            &tokenizer,
994            1,
995            &cfg,
996            &profiler,
997        );
998        assert!(results.is_ok());
999        assert!(!results.unwrap().is_empty());
1000    }
1001
1002    #[test]
1003    #[cfg(feature = "cpu")]
1004    fn embed_distributed_produces_correct_count() {
1005        let backend = crate::backend::load_backend(
1006            crate::backend::BackendKind::Cpu,
1007            "BAAI/bge-small-en-v1.5",
1008            crate::backend::DeviceHint::Cpu,
1009        )
1010        .unwrap();
1011        let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
1012        let profiler = crate::profile::Profiler::noop();
1013
1014        // Tokenize a few strings
1015        let texts = ["fn hello() {}", "class Foo:", "func main() {}"];
1016        let encoded: Vec<Option<Encoding>> = texts
1017            .iter()
1018            .map(|t| super::tokenize(t, &tokenizer, 0, 512).ok())
1019            .collect();
1020
1021        let results =
1022            super::embed_distributed(&encoded, &[backend.as_ref()], 32, &profiler).unwrap();
1023
1024        assert_eq!(results.len(), 3);
1025        // All should be 384-dim (bge-small hidden size)
1026        for (i, emb) in results.iter().enumerate() {
1027            assert_eq!(emb.len(), 384, "embedding {i} should be 384-dim");
1028        }
1029    }
1030
1031    /// Truncate an embedding to `dims` dimensions and L2-normalize.
1032    fn truncate_and_normalize(emb: &[f32], dims: usize) -> Vec<f32> {
1033        let trunc = &emb[..dims];
1034        let norm: f32 = trunc.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
1035        trunc.iter().map(|x| x / norm).collect()
1036    }
1037
1038    /// Rank corpus embeddings against a query, return top-K chunk indices.
1039    fn rank_topk(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<usize> {
1040        let mut scored: Vec<(usize, f32)> = corpus
1041            .iter()
1042            .enumerate()
1043            .map(|(i, emb)| {
1044                let dot: f32 = query.iter().zip(emb).map(|(a, b)| a * b).sum();
1045                (i, dot)
1046            })
1047            .collect();
1048        scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
1049        scored.into_iter().take(k).map(|(i, _)| i).collect()
1050    }
1051
1052    /// MRL retrieval recall test: does truncated search retrieve the same results?
1053    ///
1054    /// Embeds the ripvec codebase at full dimension, then tests whether
1055    /// truncating to fewer dimensions retrieves the same top-10 results.
1056    /// This is the real MRL quality test — per-vector cosine is trivially 1.0
1057    /// but retrieval recall can degrade if the first N dims don't preserve
1058    /// relative ordering between different vectors.
1059    #[test]
1060    #[ignore = "loads model + embeds; run with --nocapture"]
1061    #[expect(
1062        clippy::cast_precision_loss,
1063        reason = "top_k and overlap are small counts"
1064    )]
1065    fn mrl_retrieval_recall() {
1066        let model = "BAAI/bge-small-en-v1.5";
1067        let backends = crate::backend::detect_backends(model).unwrap();
1068        let tokenizer = crate::tokenize::load_tokenizer(model).unwrap();
1069        let cfg = SearchConfig::default();
1070        let profiler = crate::profile::Profiler::noop();
1071
1072        // Embed the ripvec source tree
1073        let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1074            .parent()
1075            .unwrap()
1076            .parent()
1077            .unwrap();
1078        eprintln!("Embedding {}", root.display());
1079        let backend_refs: Vec<&dyn crate::backend::EmbedBackend> =
1080            backends.iter().map(std::convert::AsRef::as_ref).collect();
1081        let (chunks, embeddings) =
1082            embed_all(root, &backend_refs, &tokenizer, &cfg, &profiler).unwrap();
1083        let full_dim = embeddings[0].len();
1084        eprintln!(
1085            "Corpus: {} chunks, {full_dim}-dim embeddings\n",
1086            chunks.len()
1087        );
1088
1089        // Test queries spanning different semantic intents
1090        let queries = [
1091            "error handling in the embedding pipeline",
1092            "tree-sitter chunking and AST parsing",
1093            "Metal GPU kernel dispatch",
1094            "file watcher for incremental reindex",
1095            "cosine similarity ranking",
1096        ];
1097
1098        let top_k = 10;
1099        let mrl_dims: Vec<usize> = [32, 64, 128, 192, 256, full_dim]
1100            .into_iter()
1101            .filter(|&d| d <= full_dim)
1102            .collect();
1103
1104        eprintln!("=== MRL Retrieval Recall@{top_k} (vs full {full_dim}-dim) ===\n");
1105
1106        for query in &queries {
1107            // Embed query at full dim
1108            let enc = tokenize(query, &tokenizer, 0, backends[0].max_tokens()).unwrap();
1109            let query_emb = backends[0].embed_batch(&[enc]).unwrap().pop().unwrap();
1110
1111            // Full-dim reference ranking
1112            let ref_topk = rank_topk(&query_emb, &embeddings, top_k);
1113
1114            eprintln!("Query: \"{query}\"");
1115            eprintln!(
1116                "  Full-dim top-1: {} ({})",
1117                chunks[ref_topk[0]].name, chunks[ref_topk[0]].file_path
1118            );
1119
1120            for &dims in &mrl_dims {
1121                // Truncate corpus and query
1122                let trunc_corpus: Vec<Vec<f32>> = embeddings
1123                    .iter()
1124                    .map(|e| truncate_and_normalize(e, dims))
1125                    .collect();
1126                let trunc_query = truncate_and_normalize(&query_emb, dims);
1127
1128                let trunc_topk = rank_topk(&trunc_query, &trunc_corpus, top_k);
1129
1130                // Recall@K: how many of the full-dim top-K appear in truncated top-K
1131                let overlap = ref_topk.iter().filter(|i| trunc_topk.contains(i)).count();
1132                let recall = overlap as f32 / top_k as f32;
1133                let marker = if dims == full_dim {
1134                    " (ref)"
1135                } else if recall >= 0.8 {
1136                    " ***"
1137                } else {
1138                    ""
1139                };
1140                eprintln!(
1141                    "  dims={dims:>3}: Recall@{top_k}={recall:.1} ({overlap}/{top_k}){marker}"
1142                );
1143            }
1144            eprintln!();
1145        }
1146    }
1147
1148    fn make_result(file_path: &str, similarity: f32) -> SearchResult {
1149        SearchResult {
1150            chunk: CodeChunk {
1151                file_path: file_path.to_string(),
1152                name: "test".to_string(),
1153                kind: "function".to_string(),
1154                start_line: 1,
1155                end_line: 10,
1156                enriched_content: String::new(),
1157                content: String::new(),
1158            },
1159            similarity,
1160        }
1161    }
1162
1163    #[test]
1164    fn structural_boost_normalizes_and_applies() {
1165        let mut results = vec![
1166            make_result("src/a.rs", 0.8),
1167            make_result("src/b.rs", 0.4),
1168            make_result("src/c.rs", 0.6),
1169        ];
1170        let mut ranks = std::collections::HashMap::new();
1171        ranks.insert("src/a.rs".to_string(), 0.5);
1172        ranks.insert("src/b.rs".to_string(), 1.0);
1173        ranks.insert("src/c.rs".to_string(), 0.0);
1174
1175        apply_structural_boost(&mut results, &ranks, 0.2);
1176
1177        // a: normalized=(0.8-0.4)/0.4=1.0, boost=0.2*0.5=0.1 => 1.1
1178        assert!((results[0].similarity - 1.1).abs() < 1e-6);
1179        // b: normalized=(0.4-0.4)/0.4=0.0, boost=0.2*1.0=0.2 => 0.2
1180        assert!((results[1].similarity - 0.2).abs() < 1e-6);
1181        // c: normalized=(0.6-0.4)/0.4=0.5, boost=0.2*0.0=0.0 => 0.5
1182        assert!((results[2].similarity - 0.5).abs() < 1e-6);
1183    }
1184
1185    #[test]
1186    fn structural_boost_noop_on_empty() {
1187        let mut results: Vec<SearchResult> = vec![];
1188        let ranks = std::collections::HashMap::new();
1189        apply_structural_boost(&mut results, &ranks, 0.2);
1190        assert!(results.is_empty());
1191    }
1192
1193    #[test]
1194    fn structural_boost_noop_on_zero_alpha() {
1195        let mut results = vec![make_result("src/a.rs", 0.8)];
1196        let mut ranks = std::collections::HashMap::new();
1197        ranks.insert("src/a.rs".to_string(), 1.0);
1198        apply_structural_boost(&mut results, &ranks, 0.0);
1199        // Should be unchanged
1200        assert!((results[0].similarity - 0.8).abs() < 1e-6);
1201    }
1202}