Skip to main content

ripvec_core/
embed.rs

1//! Parallel batch embedding pipeline with streaming backpressure.
2//!
3//! Two pipeline modes:
4//!
5//! - **Batch mode** (< `STREAMING_THRESHOLD` files): walk, chunk all, tokenize
6//!   all, sort by length, embed. Simple and optimal for small corpora.
7//!
8//! - **Streaming mode** (>= `STREAMING_THRESHOLD` files): three-stage pipeline
9//!   with bounded channels. Chunks flow through: rayon chunk workers ->
10//!   tokenize+batch collector -> GPU embed consumer. The GPU starts after the
11//!   first `batch_size` encodings are ready (~50ms), not after all chunks are
12//!   done. Backpressure prevents unbounded memory growth.
13//!
14//! # Batch inference
15//!
16//! Instead of one forward pass per chunk, chunks are grouped into batches
17//! of configurable size (default 32). Each batch is tokenized, padded to
18//! the longest sequence, and run as a single forward pass with shape
19//! `[batch_size, max_seq_len]`. This amortizes per-call overhead and enables
20//! SIMD across the batch dimension.
21//!
22//! # Parallelism
23//!
24//! On CPU, each rayon thread gets its own backend clone (cheap — most
25//! backends use `Arc`'d weights internally). On GPU, batches run sequentially
26//! from Rust while the device parallelizes internally.
27
28use std::path::Path;
29use std::sync::atomic::{AtomicUsize, Ordering};
30use std::time::Instant;
31
32use rayon::prelude::*;
33use tracing::{debug, info_span, instrument, warn};
34
35use crate::backend::{EmbedBackend, Encoding};
36use crate::chunk::{ChunkConfig, CodeChunk};
37
38/// Default batch size for embedding inference.
39pub const DEFAULT_BATCH_SIZE: usize = 32;
40
41/// File count threshold for switching from batch to streaming pipeline.
42///
43/// Below this, the batch path (chunk all -> tokenize all -> sort -> embed)
44/// is simpler and allows global sort-by-length optimization. Above this,
45/// streaming eliminates GPU idle time during chunking/tokenization.
46const STREAMING_THRESHOLD: usize = 1000;
47
48/// Number of batch-sized buffers in the embed channel for backpressure.
49///
50/// Keeps memory bounded: at most `RING_SIZE * batch_size` encodings in flight.
51/// Matches the ring-buffer depth documented on [`EmbedBackend`].
52const RING_SIZE: usize = 4;
53
54/// Runtime configuration for the search pipeline.
55///
56/// All tuning parameters that were previously compile-time constants are
57/// gathered here so they can be set from CLI arguments without recompiling.
58#[derive(Debug, Clone)]
59pub struct SearchConfig {
60    /// Chunks per inference call. Larger values amortize call overhead
61    /// but consume more memory. Default: 32.
62    pub batch_size: usize,
63    /// Maximum tokens fed to the model per chunk. `0` means no limit.
64    /// Capping tokens controls inference cost for minified or dense source.
65    /// BERT attention cost scales linearly with token count, and CLS pooling
66    /// means the first token's representation carries most semantic weight.
67    /// Default: 128 (7.7× faster than 512, with minimal quality loss).
68    pub max_tokens: usize,
69    /// Chunking parameters forwarded to the chunking phase.
70    pub chunk: ChunkConfig,
71    /// Force all files to be chunked as plain text (sliding windows only).
72    /// When `false` (default), files with recognized extensions use tree-sitter
73    /// semantic chunking, and unrecognized extensions fall back to sliding windows.
74    pub text_mode: bool,
75    /// MRL cascade pre-filter dimension.
76    ///
77    /// When set, [`SearchIndex`](crate::index::SearchIndex) stores a truncated
78    /// and L2-re-normalized copy of the embedding matrix at this dimension for
79    /// fast two-phase cascade search. `None` (default) disables cascade search.
80    pub cascade_dim: Option<usize>,
81    /// Optional file type filter (e.g. "rust", "python", "js").
82    ///
83    /// When set, only files matching this type (using ripgrep's built-in type
84    /// database) are collected during the walk phase.
85    pub file_type: Option<String>,
86    /// Search mode: hybrid (default), semantic, or keyword.
87    pub mode: crate::hybrid::SearchMode,
88}
89
90impl Default for SearchConfig {
91    fn default() -> Self {
92        Self {
93            batch_size: DEFAULT_BATCH_SIZE,
94            max_tokens: 0,
95            chunk: ChunkConfig::default(),
96            text_mode: false,
97            cascade_dim: None,
98            file_type: None,
99            mode: crate::hybrid::SearchMode::Hybrid,
100        }
101    }
102}
103
104/// A search result pairing a code chunk with its similarity score.
105#[derive(Debug, Clone)]
106pub struct SearchResult {
107    /// The matched code chunk.
108    pub chunk: CodeChunk,
109    /// Cosine similarity to the query (0.0 to 1.0).
110    pub similarity: f32,
111}
112
113/// Walk, chunk, and embed all files in a directory.
114///
115/// Returns the chunks and their corresponding embedding vectors.
116/// This is the building block for both one-shot search and interactive mode.
117/// The caller handles query embedding and ranking.
118///
119/// Accepts multiple backends for hybrid scheduling — chunks are distributed
120/// across all backends via work-stealing (see [`embed_distributed`]).
121///
122/// Automatically selects between two pipeline modes:
123/// - **Batch** (< `STREAMING_THRESHOLD` files): chunk all, tokenize all, sort
124///   by length, embed. Optimal for small corpora.
125/// - **Streaming** (>= `STREAMING_THRESHOLD` files): three-stage pipeline with
126///   bounded channels. GPU starts after the first batch is ready, not after all
127///   chunks are done. Eliminates GPU idle time during chunking/tokenization.
128///
129/// # Errors
130///
131/// Returns an error if file walking, chunking, or embedding fails.
132#[instrument(skip_all, fields(root = %root.display(), batch_size = cfg.batch_size))]
133pub fn embed_all(
134    root: &Path,
135    backends: &[&dyn EmbedBackend],
136    tokenizer: &tokenizers::Tokenizer,
137    cfg: &SearchConfig,
138    profiler: &crate::profile::Profiler,
139) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
140    if backends.is_empty() {
141        return Err(crate::Error::Other(anyhow::anyhow!(
142            "no embedding backends provided"
143        )));
144    }
145
146    // Phase 1: Collect files (respects .gitignore, filters by extension)
147    let files = {
148        let _span = info_span!("walk").entered();
149        let guard = profiler.phase("walk");
150        let files = crate::walk::collect_files(root, cfg.file_type.as_deref());
151        guard.set_detail(format!("{} files", files.len()));
152        files
153    };
154
155    if files.len() >= STREAMING_THRESHOLD {
156        // Compute total source bytes for byte-based progress (known after walk).
157        let total_bytes: u64 = files
158            .iter()
159            .filter_map(|p| p.metadata().ok())
160            .map(|m| m.len())
161            .sum();
162        embed_all_streaming(&files, total_bytes, backends, tokenizer, cfg, profiler)
163    } else {
164        embed_all_batch(&files, backends, tokenizer, cfg, profiler)
165    }
166}
167
168/// Batch pipeline: chunk all -> tokenize all -> sort by length -> embed.
169///
170/// Optimal for small corpora where the global sort-by-length optimization
171/// matters more than eliminating GPU idle time.
172fn embed_all_batch(
173    files: &[std::path::PathBuf],
174    backends: &[&dyn EmbedBackend],
175    tokenizer: &tokenizers::Tokenizer,
176    cfg: &SearchConfig,
177    profiler: &crate::profile::Profiler,
178) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
179    // Phase 2: Chunk all files in parallel.
180    let chunks: Vec<CodeChunk> = {
181        let _span = info_span!("chunk", file_count = files.len()).entered();
182        let chunk_start = Instant::now();
183        let text_mode = cfg.text_mode;
184        let result: Vec<CodeChunk> = files
185            .par_iter()
186            .flat_map(|path| {
187                let Some(source) = read_source(path) else {
188                    return vec![];
189                };
190                let chunks = if text_mode {
191                    crate::chunk::chunk_text(path, &source, &cfg.chunk)
192                } else {
193                    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
194                    match crate::languages::config_for_extension(ext) {
195                        Some(lang_config) => {
196                            crate::chunk::chunk_file(path, &source, &lang_config, &cfg.chunk)
197                        }
198                        None => crate::chunk::chunk_text(path, &source, &cfg.chunk),
199                    }
200                };
201                profiler.chunk_thread_report(chunks.len());
202                chunks
203            })
204            .collect();
205        profiler.chunk_summary(result.len(), files.len(), chunk_start.elapsed());
206        result
207    };
208
209    // Phase 3: Pre-tokenize all chunks in parallel (CPU-bound, all rayon threads)
210    let bs = cfg.batch_size.max(1);
211    let max_tokens_cfg = cfg.max_tokens;
212    let model_max = backends[0].max_tokens();
213    let _span = info_span!("embed_chunks", chunk_count = chunks.len(), batch_size = bs).entered();
214    profiler.embed_begin(chunks.len());
215
216    let all_encodings: Vec<Option<Encoding>> = chunks
217        .par_iter()
218        .map(|chunk| {
219            tokenize(
220                &chunk.enriched_content,
221                tokenizer,
222                max_tokens_cfg,
223                model_max,
224            )
225            .inspect_err(|e| {
226                warn!(file = %chunk.file_path, err = %e, "tokenization failed, skipping chunk");
227            })
228            .ok()
229        })
230        .collect();
231
232    // Sort chunks and their encodings together by descending token count.
233    // This groups similar-length sequences into the same batch, minimizing
234    // padding waste (short chunks no longer get padded to a long neighbour).
235    let mut paired: Vec<(CodeChunk, Option<Encoding>)> =
236        chunks.into_iter().zip(all_encodings).collect();
237    paired.sort_by(|a, b| {
238        let len_a = a.1.as_ref().map_or(0, |e| e.input_ids.len());
239        let len_b = b.1.as_ref().map_or(0, |e| e.input_ids.len());
240        len_b.cmp(&len_a) // descending — longest first
241    });
242    let (chunks, sorted_encodings): (Vec<CodeChunk>, Vec<Option<Encoding>>) =
243        paired.into_iter().unzip();
244
245    // Phase 4: Distribute pre-tokenized batches across all backends
246    let embeddings = embed_distributed(&sorted_encodings, backends, bs, profiler)?;
247    profiler.embed_done();
248
249    // Filter out chunks whose tokenization failed (empty embedding vectors).
250    let (chunks, embeddings): (Vec<_>, Vec<_>) = chunks
251        .into_iter()
252        .zip(embeddings)
253        .filter(|(_, emb)| !emb.is_empty())
254        .unzip();
255
256    Ok((chunks, embeddings))
257}
258
259/// Streaming pipeline: chunk -> tokenize -> batch -> embed with backpressure.
260///
261/// Three concurrent stages connected by bounded channels:
262///
263/// 1. **Chunk producers** (rayon pool, in a scoped thread): read + parse files,
264///    send chunks to channel.
265/// 2. **Tokenize + batch collector** (scoped thread): tokenize chunks, sort
266///    within batch windows, send full batches to the embed channel.
267/// 3. **Embed consumer** (main thread): calls `embed_distributed` on each
268///    batch, collects results.
269///
270/// The bounded channels provide natural backpressure: if the GPU falls behind,
271/// the tokenize stage blocks, which blocks chunk producers via the chunk channel.
272/// If chunking is fast and the GPU is slow, at most
273/// `8 * batch_size + RING_SIZE * batch_size` items are in memory.
274///
275/// Uses `std::thread::scope` so all threads can borrow the caller's stack
276/// (`tokenizer`, `backends`, `profiler`) without `'static` bounds.
277#[expect(
278    clippy::too_many_lines,
279    reason = "streaming pipeline has inherent complexity in thread coordination"
280)]
281fn embed_all_streaming(
282    files: &[std::path::PathBuf],
283    total_bytes: u64,
284    backends: &[&dyn EmbedBackend],
285    tokenizer: &tokenizers::Tokenizer,
286    cfg: &SearchConfig,
287    profiler: &crate::profile::Profiler,
288) -> crate::Result<(Vec<CodeChunk>, Vec<Vec<f32>>)> {
289    use crossbeam_channel::bounded;
290
291    let bs = cfg.batch_size.max(1);
292    let max_tokens_cfg = cfg.max_tokens;
293    let model_max = backends[0].max_tokens();
294    let file_count = files.len();
295    let text_mode = cfg.text_mode;
296    let chunk_config = cfg.chunk.clone();
297
298    // Bounded channel from chunk producers -> tokenize+batch stage.
299    // Factor of 8 gives enough buffering for rayon parallelism without
300    // unbounded growth (at most ~8 batches worth of chunks in flight).
301    let (chunk_tx, chunk_rx) = bounded::<CodeChunk>(bs * 8);
302
303    // Bounded channel from tokenize+batch stage -> embed consumer.
304    // RING_SIZE batches in flight provides enough pipeline depth for GPU
305    // to stay busy while the next batch is being tokenized.
306    let (batch_tx, batch_rx) = bounded::<Vec<(Encoding, CodeChunk)>>(RING_SIZE);
307
308    // Shared counters for profiling across the streaming pipeline.
309    let total_chunks_produced = AtomicUsize::new(0);
310    let bytes_chunked = AtomicUsize::new(0);
311    let chunk_start = Instant::now();
312
313    // All stages run inside std::thread::scope so they can borrow from the
314    // caller's stack (tokenizer, backends, profiler, files, etc.).
315    std::thread::scope(|scope| {
316        // --- Stage 1: Chunk producers (rayon inside a scoped thread) ---
317        //
318        // Spawns a scoped thread that drives rayon's par_iter. Each file is
319        // chunked independently and chunks are sent into the bounded channel.
320        // If the channel is full, rayon workers block, providing backpressure.
321        scope.spawn(|| {
322            let _span = info_span!("chunk_stream", file_count).entered();
323            files.par_iter().for_each(|path| {
324                let Some(source) = read_source(path) else {
325                    return;
326                };
327                let chunks = if text_mode {
328                    crate::chunk::chunk_text(path, &source, &chunk_config)
329                } else {
330                    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
331                    match crate::languages::config_for_extension(ext) {
332                        Some(lang_config) => {
333                            crate::chunk::chunk_file(path, &source, &lang_config, &chunk_config)
334                        }
335                        None => crate::chunk::chunk_text(path, &source, &chunk_config),
336                    }
337                };
338                let n = chunks.len();
339                let file_bytes = source.len();
340                for chunk in chunks {
341                    // Channel disconnected means downstream errored; stop.
342                    if chunk_tx.send(chunk).is_err() {
343                        return;
344                    }
345                }
346                profiler.chunk_thread_report(n);
347                total_chunks_produced.fetch_add(n, Ordering::Relaxed);
348                bytes_chunked.fetch_add(file_bytes, Ordering::Relaxed);
349            });
350            // chunk_tx is dropped here, closing the channel — but the borrow
351            // of chunk_tx lives until the scoped thread ends. We need to
352            // explicitly drop it so the tokenize stage sees the channel close.
353            drop(chunk_tx);
354        });
355
356        // --- Stage 2: Tokenize + batch collector (scoped thread) ---
357        //
358        // Receives individual chunks, tokenizes each (HuggingFace tokenizer
359        // is Send + Sync), and accumulates into batch-sized buffers. Within
360        // each buffer, entries are sorted by descending token count — the same
361        // padding-reduction optimization as the batch path, applied locally.
362        let tokenize_handle = scope.spawn(move || -> crate::Result<()> {
363            let _span = info_span!("tokenize_stream").entered();
364            let mut buffer: Vec<(Encoding, CodeChunk)> = Vec::with_capacity(bs);
365
366            for chunk in &chunk_rx {
367                match tokenize(
368                    &chunk.enriched_content,
369                    tokenizer,
370                    max_tokens_cfg,
371                    model_max,
372                ) {
373                    Ok(encoding) => {
374                        buffer.push((encoding, chunk));
375                        if buffer.len() >= bs {
376                            // Sort within batch by descending token count.
377                            buffer.sort_by(|a, b| b.0.input_ids.len().cmp(&a.0.input_ids.len()));
378                            let batch = std::mem::replace(&mut buffer, Vec::with_capacity(bs));
379                            if batch_tx.send(batch).is_err() {
380                                // Embed consumer dropped; stop tokenizing.
381                                return Ok(());
382                            }
383                        }
384                    }
385                    Err(e) => {
386                        warn!(
387                            file = %chunk.file_path, err = %e,
388                            "tokenization failed, skipping chunk"
389                        );
390                    }
391                }
392            }
393
394            // Flush remaining partial batch.
395            if !buffer.is_empty() {
396                buffer.sort_by(|a, b| b.0.input_ids.len().cmp(&a.0.input_ids.len()));
397                let _ = batch_tx.send(buffer);
398            }
399            // batch_tx drops here, closing the embed channel.
400
401            Ok(())
402        });
403
404        // --- Stage 3: Embed consumer (main thread within scope) ---
405        //
406        // Receives sorted batches, embeds via the backend(s), collects results.
407        // Profiler is driven from here since this thread owns the reference.
408        let _span = info_span!("embed_stream").entered();
409
410        // Total isn't known upfront in streaming mode; start at 0 and update.
411        profiler.embed_begin(0);
412
413        let mut all_chunks: Vec<CodeChunk> = Vec::new();
414        let mut all_embeddings: Vec<Vec<f32>> = Vec::new();
415        let mut embed_error: Option<crate::Error> = None;
416
417        let mut cumulative_done: usize = 0;
418        for batch in &batch_rx {
419            let batch_len = batch.len();
420            let (encodings, chunks): (Vec<Encoding>, Vec<CodeChunk>) = batch.into_iter().unzip();
421
422            // Wrap as Option<Encoding> for embed_distributed compatibility.
423            let opt_encodings: Vec<Option<Encoding>> = encodings.into_iter().map(Some).collect();
424
425            // Pass noop profiler to embed_distributed — its internal done counter
426            // resets per call (0→batch_size), which corrupts our global progress.
427            let noop = crate::profile::Profiler::noop();
428            match embed_distributed(&opt_encodings, backends, bs, &noop) {
429                Ok(batch_embeddings) => {
430                    cumulative_done += batch_len;
431                    // Byte-based progress: total_bytes known from walk, bytes_chunked
432                    // tracks how much source data has been processed through the pipeline.
433                    let processed = bytes_chunked.load(Ordering::Relaxed) as u64;
434                    profiler.embed_tick_bytes(cumulative_done, processed, total_bytes);
435
436                    for (chunk, emb) in chunks.into_iter().zip(batch_embeddings) {
437                        if !emb.is_empty() {
438                            all_chunks.push(chunk);
439                            all_embeddings.push(emb);
440                        }
441                    }
442                }
443                Err(e) => {
444                    embed_error = Some(e);
445                    // break exits the for loop; batch_rx drops naturally after.
446                    break;
447                }
448            }
449        }
450
451        // Report chunk summary now that all stages have completed (or errored).
452        let final_total = total_chunks_produced.load(Ordering::Relaxed);
453        profiler.chunk_summary(final_total, file_count, chunk_start.elapsed());
454        // Set the final total so embed_done shows the correct summary.
455        profiler.embed_begin_update_total(cumulative_done);
456        profiler.embed_tick(cumulative_done);
457        profiler.embed_done();
458
459        // Wait for tokenize thread and check for errors.
460        let tokenize_result = tokenize_handle.join();
461
462        // Error priority: embed > tokenize > thread panic.
463        if let Some(e) = embed_error {
464            return Err(e);
465        }
466        match tokenize_result {
467            Ok(Ok(())) => {}
468            Ok(Err(e)) => return Err(e),
469            Err(_) => {
470                return Err(crate::Error::Other(anyhow::anyhow!(
471                    "tokenize thread panicked"
472                )));
473            }
474        }
475
476        Ok((all_chunks, all_embeddings))
477    })
478}
479
480/// Search a directory for code chunks semantically similar to a query.
481///
482/// Walks the directory, chunks all supported files, embeds everything
483/// in parallel batches, and returns the top-k results ranked by similarity.
484///
485/// Accepts multiple backends for hybrid scheduling — the first backend
486/// (`backends[0]`) is used for query embedding.
487///
488/// All tuning parameters (batch size, token limit, chunk sizing) are
489/// controlled via [`SearchConfig`].
490///
491/// # Errors
492///
493/// Returns an error if the query cannot be tokenized or embedded.
494///
495/// # Panics
496///
497/// Panics if a per-thread backend clone fails during parallel embedding
498/// (should not happen if the backend loaded successfully).
499#[instrument(skip_all, fields(root = %root.display(), top_k, batch_size = cfg.batch_size))]
500pub fn search(
501    root: &Path,
502    query: &str,
503    backends: &[&dyn EmbedBackend],
504    tokenizer: &tokenizers::Tokenizer,
505    top_k: usize,
506    cfg: &SearchConfig,
507    profiler: &crate::profile::Profiler,
508) -> crate::Result<Vec<SearchResult>> {
509    if backends.is_empty() {
510        return Err(crate::Error::Other(anyhow::anyhow!(
511            "no embedding backends provided"
512        )));
513    }
514
515    // Phases 1, 2, 3, 4: walk, chunk, pre-tokenize, embed all files
516    let (chunks, embeddings) = embed_all(root, backends, tokenizer, cfg, profiler)?;
517
518    let t_query_start = std::time::Instant::now();
519
520    // Phase 5: Build hybrid index (semantic + BM25)
521    let hybrid = {
522        let _span = info_span!("build_hybrid_index").entered();
523        let _guard = profiler.phase("build_hybrid_index");
524        crate::hybrid::HybridIndex::new(chunks, embeddings, cfg.cascade_dim)?
525    };
526
527    let mode = cfg.mode;
528    let effective_top_k = if top_k > 0 { top_k } else { usize::MAX };
529
530    // Phase 6: Embed query (skip for keyword-only mode)
531    let query_embedding = if mode == crate::hybrid::SearchMode::Keyword {
532        // Keyword mode: no embedding needed, use zero vector
533        let dim = hybrid.semantic.hidden_dim;
534        vec![0.0f32; dim]
535    } else {
536        let _span = info_span!("embed_query").entered();
537        let _guard = profiler.phase("embed_query");
538        let t_tok = std::time::Instant::now();
539        let enc = tokenize(query, tokenizer, cfg.max_tokens, backends[0].max_tokens())?;
540        let tok_ms = t_tok.elapsed().as_secs_f64() * 1000.0;
541        let t_emb = std::time::Instant::now();
542        let mut results = backends[0].embed_batch(&[enc])?;
543        let emb_ms = t_emb.elapsed().as_secs_f64() * 1000.0;
544        eprintln!(
545            "[search] query: tokenize={tok_ms:.1}ms embed={emb_ms:.1}ms total_since_embed_all={:.1}ms",
546            t_query_start.elapsed().as_secs_f64() * 1000.0
547        );
548        results.pop().ok_or_else(|| {
549            crate::Error::Other(anyhow::anyhow!("backend returned no embedding for query"))
550        })?
551    };
552
553    // Phase 7: Hybrid/semantic/keyword ranking
554    let ranked = {
555        let _span = info_span!("rank", chunk_count = hybrid.chunks().len()).entered();
556        let guard = profiler.phase("rank");
557        // Threshold only applies to semantic modes; keyword/hybrid use RRF scores
558        let threshold = if mode == crate::hybrid::SearchMode::Semantic {
559            0.0 // SearchIndex::rank applies its own threshold
560        } else {
561            0.0
562        };
563        let results = hybrid.search(&query_embedding, query, effective_top_k, threshold, mode);
564        guard.set_detail(format!(
565            "{mode} top {} from {}",
566            effective_top_k.min(results.len()),
567            hybrid.chunks().len()
568        ));
569        results
570    };
571
572    let results: Vec<SearchResult> = ranked
573        .into_iter()
574        .map(|(idx, score)| SearchResult {
575            chunk: hybrid.chunks()[idx].clone(),
576            similarity: score,
577        })
578        .collect();
579
580    Ok(results)
581}
582
583/// Shared state for [`embed_distributed`] workers.
584struct DistributedState<'a> {
585    tokenized: &'a [Option<Encoding>],
586    cursor: std::sync::atomic::AtomicUsize,
587    error_flag: std::sync::atomic::AtomicBool,
588    first_error: std::sync::Mutex<Option<crate::Error>>,
589    done_counter: std::sync::atomic::AtomicUsize,
590    batch_size: usize,
591    profiler: &'a crate::profile::Profiler,
592}
593
594impl DistributedState<'_> {
595    /// Worker loop: claim batches from the shared cursor, embed, collect results.
596    fn run_worker(&self, backend: &dyn EmbedBackend) -> Vec<(usize, Vec<f32>)> {
597        use std::sync::atomic::Ordering;
598
599        let n = self.tokenized.len();
600        // GPU backends grab larger batches to amortize per-call overhead.
601        // MLX's lazy eval graph optimizer benefits from large matrices.
602        // Metal sub-batches internally via MAX_BATCH to limit padding waste.
603        let grab_size = if backend.is_gpu() {
604            self.batch_size * 4
605        } else {
606            self.batch_size
607        };
608        let mut results = Vec::new();
609
610        loop {
611            if self.error_flag.load(Ordering::Relaxed) {
612                break;
613            }
614
615            let start = self.cursor.fetch_add(grab_size, Ordering::Relaxed);
616            if start >= n {
617                break;
618            }
619            let end = (start + grab_size).min(n);
620            let batch = &self.tokenized[start..end];
621
622            // Separate valid encodings from Nones, tracking which indices succeeded
623            let mut valid = Vec::with_capacity(batch.len());
624            let mut valid_indices = Vec::with_capacity(batch.len());
625            for (i, enc) in batch.iter().enumerate() {
626                if let Some(e) = enc {
627                    // TODO(perf): cloning 3 Vecs per chunk; consider making
628                    // `EmbedBackend::embed_batch` accept `&[&Encoding]` to avoid this.
629                    valid.push(e.clone());
630                    valid_indices.push(start + i);
631                } else {
632                    results.push((start + i, vec![]));
633                }
634            }
635
636            if valid.is_empty() {
637                let done =
638                    self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
639                self.profiler.embed_tick(done);
640                continue;
641            }
642
643            match backend.embed_batch(&valid) {
644                Ok(batch_embeddings) => {
645                    for (idx, emb) in valid_indices.into_iter().zip(batch_embeddings) {
646                        results.push((idx, emb));
647                    }
648                    let done =
649                        self.done_counter.fetch_add(batch.len(), Ordering::Relaxed) + batch.len();
650                    self.profiler.embed_tick(done);
651                }
652                Err(e) => {
653                    self.error_flag.store(true, Ordering::Relaxed);
654                    if let Ok(mut guard) = self.first_error.lock()
655                        && guard.is_none()
656                    {
657                        *guard = Some(e);
658                    }
659                    break;
660                }
661            }
662        }
663
664        results
665    }
666}
667
668/// Distribute pre-tokenized chunks across multiple backends using work-stealing.
669///
670/// Each backend gets a dedicated worker thread. Workers compete on a shared
671/// `AtomicUsize` cursor to claim batches of chunks. GPU backends grab larger
672/// batches (`batch_size * 4`), CPU backends grab smaller ones (`batch_size`).
673/// Results are written by original chunk index — no merge step needed.
674///
675/// When `backends` has a single entry, no extra threads are spawned.
676///
677/// # Errors
678///
679/// Returns the first error from any backend. Other workers exit early
680/// when an error is detected.
681#[expect(
682    unsafe_code,
683    reason = "BLAS thread count must be set via env vars before spawning workers"
684)]
685pub(crate) fn embed_distributed(
686    tokenized: &[Option<Encoding>],
687    backends: &[&dyn EmbedBackend],
688    batch_size: usize,
689    profiler: &crate::profile::Profiler,
690) -> crate::Result<Vec<Vec<f32>>> {
691    let n = tokenized.len();
692    let state = DistributedState {
693        tokenized,
694        cursor: std::sync::atomic::AtomicUsize::new(0),
695        error_flag: std::sync::atomic::AtomicBool::new(false),
696        first_error: std::sync::Mutex::new(None),
697        done_counter: std::sync::atomic::AtomicUsize::new(0),
698        batch_size: batch_size.max(1),
699        profiler,
700    };
701
702    // Collect (index, embedding) pairs from all workers
703    let all_pairs: Vec<(usize, Vec<f32>)> =
704        if backends.len() == 1 && backends[0].supports_clone() && !backends[0].is_gpu() {
705            // Single cloneable CPU backend: spawn N workers with single-threaded BLAS.
706            //
707            // BLAS libraries (OpenBLAS, MKL) internally spawn threads for each matmul.
708            // For small matrices ([1,384]×[384,384]), this thread overhead dominates —
709            // profiling shows 80% of time in sched_yield (thread contention).
710            //
711            // Instead: force BLAS to single-thread per worker, parallelize across
712            // independent BERT inferences. Each worker gets its own cloned backend.
713            // Force BLAS libraries to single-threaded mode.
714            // We parallelize across independent BERT inferences instead.
715            // env vars don't always work (OpenBLAS may ignore after init),
716            // so also call the runtime API directly.
717            unsafe {
718                std::env::set_var("OPENBLAS_NUM_THREADS", "1");
719                std::env::set_var("MKL_NUM_THREADS", "1");
720                std::env::set_var("VECLIB_MAXIMUM_THREADS", "1"); // macOS Accelerate
721
722                // Direct FFI to set BLAS thread count — works even after init
723                #[cfg(all(not(target_os = "macos"), feature = "cpu"))]
724                {
725                    unsafe extern "C" {
726                        fn openblas_set_num_threads(num: std::ffi::c_int);
727                    }
728                    openblas_set_num_threads(1);
729                }
730            }
731
732            let num_workers = rayon::current_num_threads().max(1);
733            std::thread::scope(|s| {
734                let handles: Vec<_> = (0..num_workers)
735                    .map(|_| {
736                        s.spawn(|| {
737                            // Per-thread: force single-threaded BLAS (thread-local setting).
738                            // On macOS 15+ this calls BLASSetThreading; on Linux openblas_set_num_threads.
739                            #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
740                            crate::backend::driver::cpu::force_single_threaded_blas();
741                            let cloned = backends[0].clone_backend();
742                            state.run_worker(cloned.as_ref())
743                        })
744                    })
745                    .collect();
746                let mut all = Vec::new();
747                for handle in handles {
748                    if let Ok(pairs) = handle.join() {
749                        all.extend(pairs);
750                    }
751                }
752                all
753            })
754        } else if backends.len() == 1 {
755            // Single non-cloneable backend (GPU or CPU ModernBERT): run on the calling thread.
756            // GPU backends handle parallelism internally; CPU uses BLAS internal
757            // multi-threading (Accelerate/OpenBLAS) for intra-GEMM parallelism.
758            state.run_worker(backends[0])
759        } else {
760            // Multiple backends: one thread per backend via std::thread::scope
761            std::thread::scope(|s| {
762                let handles: Vec<_> = backends
763                    .iter()
764                    .map(|&backend| {
765                        s.spawn(|| {
766                            // CPU backends that support cloning get a thread-local copy
767                            if backend.supports_clone() {
768                                let cloned = backend.clone_backend();
769                                state.run_worker(cloned.as_ref())
770                            } else {
771                                state.run_worker(backend)
772                            }
773                        })
774                    })
775                    .collect();
776
777                let mut all = Vec::new();
778                for handle in handles {
779                    if let Ok(pairs) = handle.join() {
780                        all.extend(pairs);
781                    } else {
782                        warn!("worker thread panicked");
783                        state
784                            .error_flag
785                            .store(true, std::sync::atomic::Ordering::Relaxed);
786                    }
787                }
788                all
789            })
790        };
791
792    // Check for errors before assembling results
793    if let Some(err) = state.first_error.into_inner().ok().flatten() {
794        return Err(err);
795    }
796
797    // Scatter results into output vec by original index
798    let mut embeddings: Vec<Vec<f32>> = vec![vec![]; n];
799    for (idx, emb) in all_pairs {
800        embeddings[idx] = emb;
801    }
802
803    Ok(embeddings)
804}
805
806/// Read a source file into a `String`, skipping binary files.
807///
808/// Reads the file as raw bytes first, checks for NUL bytes in the first 8 KB
809/// to detect binary files, then converts to UTF-8. Returns `None` (with a
810/// debug log) when the file cannot be read, is binary, or is not valid UTF-8.
811pub(crate) fn read_source(path: &Path) -> Option<String> {
812    let bytes = match std::fs::read(path) {
813        Ok(b) => b,
814        Err(e) => {
815            debug!(path = %path.display(), err = %e, "skipping file: read failed");
816            return None;
817        }
818    };
819
820    // Skip binary files: NUL byte anywhere in the first 8 KB is a reliable signal.
821    if memchr::memchr(0, &bytes[..bytes.len().min(8192)]).is_some() {
822        debug!(path = %path.display(), "skipping binary file");
823        return None;
824    }
825
826    match std::str::from_utf8(&bytes) {
827        Ok(s) => Some(s.to_string()),
828        Err(e) => {
829            debug!(path = %path.display(), err = %e, "skipping file: not valid UTF-8");
830            None
831        }
832    }
833}
834
835/// Tokenize text into an [`Encoding`] ready for model inference.
836///
837/// Delegates to [`crate::tokenize::tokenize_query`] for the core encoding,
838/// then applies an additional `max_tokens` truncation when non-zero.
839/// CLS pooling means the first token's representation carries most semantic
840/// weight, so truncation has minimal quality impact.
841fn tokenize(
842    text: &str,
843    tokenizer: &tokenizers::Tokenizer,
844    max_tokens: usize,
845    model_max_tokens: usize,
846) -> crate::Result<Encoding> {
847    let mut enc = crate::tokenize::tokenize_query(text, tokenizer, model_max_tokens)?;
848    if max_tokens > 0 {
849        let len = enc.input_ids.len().min(max_tokens);
850        enc.input_ids.truncate(len);
851        enc.attention_mask.truncate(len);
852        enc.token_type_ids.truncate(len);
853    }
854    Ok(enc)
855}
856
857/// Normalize similarity scores to `[0,1]` and apply a `PageRank` structural boost.
858///
859/// Each result's similarity is min-max normalized, then a weighted `PageRank`
860/// score is added: `final = normalized + alpha * pagerank`. This promotes
861/// architecturally important files (many dependents) in search results.
862///
863/// Called from the MCP search handler which has access to the `RepoGraph`,
864/// rather than from [`search`] directly.
865pub fn apply_structural_boost<S: ::std::hash::BuildHasher>(
866    results: &mut [SearchResult],
867    file_ranks: &std::collections::HashMap<String, f32, S>,
868    alpha: f32,
869) {
870    if results.is_empty() || alpha == 0.0 {
871        return;
872    }
873
874    let min = results
875        .iter()
876        .map(|r| r.similarity)
877        .fold(f32::INFINITY, f32::min);
878    let max = results
879        .iter()
880        .map(|r| r.similarity)
881        .fold(f32::NEG_INFINITY, f32::max);
882    let range = (max - min).max(1e-12);
883
884    for r in results.iter_mut() {
885        let normalized = (r.similarity - min) / range;
886        let pr = file_ranks.get(&r.chunk.file_path).copied().unwrap_or(0.0);
887        r.similarity = normalized + alpha * pr;
888    }
889}
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894
895    #[test]
896    #[cfg(feature = "cpu")]
897    #[ignore = "loads model + embeds full source tree; run with `cargo test -- --ignored`"]
898    fn search_with_backend_trait() {
899        let backend = crate::backend::load_backend(
900            crate::backend::BackendKind::Cpu,
901            "BAAI/bge-small-en-v1.5",
902            crate::backend::DeviceHint::Cpu,
903        )
904        .unwrap();
905        let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
906        let cfg = SearchConfig::default();
907        let profiler = crate::profile::Profiler::noop();
908        let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src");
909        let results = search(
910            &dir,
911            "embedding model",
912            &[backend.as_ref()],
913            &tokenizer,
914            1,
915            &cfg,
916            &profiler,
917        );
918        assert!(results.is_ok());
919        assert!(!results.unwrap().is_empty());
920    }
921
922    #[test]
923    #[cfg(feature = "cpu")]
924    fn embed_distributed_produces_correct_count() {
925        let backend = crate::backend::load_backend(
926            crate::backend::BackendKind::Cpu,
927            "BAAI/bge-small-en-v1.5",
928            crate::backend::DeviceHint::Cpu,
929        )
930        .unwrap();
931        let tokenizer = crate::tokenize::load_tokenizer("BAAI/bge-small-en-v1.5").unwrap();
932        let profiler = crate::profile::Profiler::noop();
933
934        // Tokenize a few strings
935        let texts = ["fn hello() {}", "class Foo:", "func main() {}"];
936        let encoded: Vec<Option<Encoding>> = texts
937            .iter()
938            .map(|t| super::tokenize(t, &tokenizer, 0, 512).ok())
939            .collect();
940
941        let results =
942            super::embed_distributed(&encoded, &[backend.as_ref()], 32, &profiler).unwrap();
943
944        assert_eq!(results.len(), 3);
945        // All should be 384-dim (bge-small hidden size)
946        for (i, emb) in results.iter().enumerate() {
947            assert_eq!(emb.len(), 384, "embedding {i} should be 384-dim");
948        }
949    }
950
951    /// Truncate an embedding to `dims` dimensions and L2-normalize.
952    fn truncate_and_normalize(emb: &[f32], dims: usize) -> Vec<f32> {
953        let trunc = &emb[..dims];
954        let norm: f32 = trunc.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
955        trunc.iter().map(|x| x / norm).collect()
956    }
957
958    /// Rank corpus embeddings against a query, return top-K chunk indices.
959    fn rank_topk(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<usize> {
960        let mut scored: Vec<(usize, f32)> = corpus
961            .iter()
962            .enumerate()
963            .map(|(i, emb)| {
964                let dot: f32 = query.iter().zip(emb).map(|(a, b)| a * b).sum();
965                (i, dot)
966            })
967            .collect();
968        scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
969        scored.into_iter().take(k).map(|(i, _)| i).collect()
970    }
971
972    /// MRL retrieval recall test: does truncated search retrieve the same results?
973    ///
974    /// Embeds the ripvec codebase at full dimension, then tests whether
975    /// truncating to fewer dimensions retrieves the same top-10 results.
976    /// This is the real MRL quality test — per-vector cosine is trivially 1.0
977    /// but retrieval recall can degrade if the first N dims don't preserve
978    /// relative ordering between different vectors.
979    #[test]
980    #[ignore = "loads model + embeds; run with --nocapture"]
981    #[expect(
982        clippy::cast_precision_loss,
983        reason = "top_k and overlap are small counts"
984    )]
985    fn mrl_retrieval_recall() {
986        let model = "BAAI/bge-small-en-v1.5";
987        let backends = crate::backend::detect_backends(model).unwrap();
988        let tokenizer = crate::tokenize::load_tokenizer(model).unwrap();
989        let cfg = SearchConfig::default();
990        let profiler = crate::profile::Profiler::noop();
991
992        // Embed the ripvec source tree
993        let root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
994            .parent()
995            .unwrap()
996            .parent()
997            .unwrap();
998        eprintln!("Embedding {}", root.display());
999        let backend_refs: Vec<&dyn crate::backend::EmbedBackend> =
1000            backends.iter().map(std::convert::AsRef::as_ref).collect();
1001        let (chunks, embeddings) =
1002            embed_all(root, &backend_refs, &tokenizer, &cfg, &profiler).unwrap();
1003        let full_dim = embeddings[0].len();
1004        eprintln!(
1005            "Corpus: {} chunks, {full_dim}-dim embeddings\n",
1006            chunks.len()
1007        );
1008
1009        // Test queries spanning different semantic intents
1010        let queries = [
1011            "error handling in the embedding pipeline",
1012            "tree-sitter chunking and AST parsing",
1013            "Metal GPU kernel dispatch",
1014            "file watcher for incremental reindex",
1015            "cosine similarity ranking",
1016        ];
1017
1018        let top_k = 10;
1019        let mrl_dims: Vec<usize> = [32, 64, 128, 192, 256, full_dim]
1020            .into_iter()
1021            .filter(|&d| d <= full_dim)
1022            .collect();
1023
1024        eprintln!("=== MRL Retrieval Recall@{top_k} (vs full {full_dim}-dim) ===\n");
1025
1026        for query in &queries {
1027            // Embed query at full dim
1028            let enc = tokenize(query, &tokenizer, 0, backends[0].max_tokens()).unwrap();
1029            let query_emb = backends[0].embed_batch(&[enc]).unwrap().pop().unwrap();
1030
1031            // Full-dim reference ranking
1032            let ref_topk = rank_topk(&query_emb, &embeddings, top_k);
1033
1034            eprintln!("Query: \"{query}\"");
1035            eprintln!(
1036                "  Full-dim top-1: {} ({})",
1037                chunks[ref_topk[0]].name, chunks[ref_topk[0]].file_path
1038            );
1039
1040            for &dims in &mrl_dims {
1041                // Truncate corpus and query
1042                let trunc_corpus: Vec<Vec<f32>> = embeddings
1043                    .iter()
1044                    .map(|e| truncate_and_normalize(e, dims))
1045                    .collect();
1046                let trunc_query = truncate_and_normalize(&query_emb, dims);
1047
1048                let trunc_topk = rank_topk(&trunc_query, &trunc_corpus, top_k);
1049
1050                // Recall@K: how many of the full-dim top-K appear in truncated top-K
1051                let overlap = ref_topk.iter().filter(|i| trunc_topk.contains(i)).count();
1052                let recall = overlap as f32 / top_k as f32;
1053                let marker = if dims == full_dim {
1054                    " (ref)"
1055                } else if recall >= 0.8 {
1056                    " ***"
1057                } else {
1058                    ""
1059                };
1060                eprintln!(
1061                    "  dims={dims:>3}: Recall@{top_k}={recall:.1} ({overlap}/{top_k}){marker}"
1062                );
1063            }
1064            eprintln!();
1065        }
1066    }
1067
1068    fn make_result(file_path: &str, similarity: f32) -> SearchResult {
1069        SearchResult {
1070            chunk: CodeChunk {
1071                file_path: file_path.to_string(),
1072                name: "test".to_string(),
1073                kind: "function".to_string(),
1074                start_line: 1,
1075                end_line: 10,
1076                enriched_content: String::new(),
1077                content: String::new(),
1078            },
1079            similarity,
1080        }
1081    }
1082
1083    #[test]
1084    fn structural_boost_normalizes_and_applies() {
1085        let mut results = vec![
1086            make_result("src/a.rs", 0.8),
1087            make_result("src/b.rs", 0.4),
1088            make_result("src/c.rs", 0.6),
1089        ];
1090        let mut ranks = std::collections::HashMap::new();
1091        ranks.insert("src/a.rs".to_string(), 0.5);
1092        ranks.insert("src/b.rs".to_string(), 1.0);
1093        ranks.insert("src/c.rs".to_string(), 0.0);
1094
1095        apply_structural_boost(&mut results, &ranks, 0.2);
1096
1097        // a: normalized=(0.8-0.4)/0.4=1.0, boost=0.2*0.5=0.1 => 1.1
1098        assert!((results[0].similarity - 1.1).abs() < 1e-6);
1099        // b: normalized=(0.4-0.4)/0.4=0.0, boost=0.2*1.0=0.2 => 0.2
1100        assert!((results[1].similarity - 0.2).abs() < 1e-6);
1101        // c: normalized=(0.6-0.4)/0.4=0.5, boost=0.2*0.0=0.0 => 0.5
1102        assert!((results[2].similarity - 0.5).abs() < 1e-6);
1103    }
1104
1105    #[test]
1106    fn structural_boost_noop_on_empty() {
1107        let mut results: Vec<SearchResult> = vec![];
1108        let ranks = std::collections::HashMap::new();
1109        apply_structural_boost(&mut results, &ranks, 0.2);
1110        assert!(results.is_empty());
1111    }
1112
1113    #[test]
1114    fn structural_boost_noop_on_zero_alpha() {
1115        let mut results = vec![make_result("src/a.rs", 0.8)];
1116        let mut ranks = std::collections::HashMap::new();
1117        ranks.insert("src/a.rs".to_string(), 1.0);
1118        apply_structural_boost(&mut results, &ranks, 0.0);
1119        // Should be unchanged
1120        assert!((results[0].similarity - 0.8).abs() < 1e-6);
1121    }
1122}