Skip to main content

ripvec_core/backend/
mod.rs

1//! Embedding backend abstraction layer.
2//!
3//! Defines the [`EmbedBackend`] trait that all embedding backends (CPU, CUDA,
4//! Metal, MLX) implement, plus the [`Encoding`] input type and [`BackendKind`]
5//! discriminant. Use [`load_backend`] to construct a backend by kind.
6
7pub mod arch;
8pub mod blas_info;
9// `cpu` covers both ClassicBert embedding and cross-encoder rerank;
10// both want BLAS from either openblas-src (`feature = "cpu"`) or
11// Accelerate (`feature = "cpu-accelerate"`). Widened from
12// `cfg(feature = "cpu")` so the macOS default build (which uses
13// `cpu-accelerate`) gets the reranker.
14#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
15pub mod cpu;
16#[cfg(feature = "cuda")]
17pub mod cuda;
18pub mod driver;
19pub mod generic;
20#[cfg(feature = "metal")]
21pub mod metal_kernels;
22#[cfg(feature = "mlx")]
23pub mod mlx;
24#[cfg(feature = "cuda")]
25pub mod nvrtc_cubin;
26
27/// Pre-tokenized encoding ready for inference.
28///
29/// Token IDs, attention mask, and token type IDs must all have the same length.
30/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
31/// reaching the backend.
32#[derive(Debug, Clone)]
33pub struct Encoding {
34    /// Token IDs produced by the tokenizer.
35    pub input_ids: Vec<i64>,
36    /// Attention mask (1 for real tokens, 0 for padding).
37    pub attention_mask: Vec<i64>,
38    /// Token type IDs (0 for single-sequence models).
39    pub token_type_ids: Vec<i64>,
40}
41
42/// Trait for embedding backends.
43///
44/// Implementations must be [`Send`] so they can be moved across threads (e.g.
45/// into a ring-buffer pipeline). The trait is object-safe — callers use
46/// `&dyn EmbedBackend` or `Box<dyn EmbedBackend>`.
47///
48/// # GPU vs CPU scheduling
49///
50/// - **CPU backends** (`is_gpu() == false`): cloned per rayon thread via
51///   [`clone_backend`](EmbedBackend::clone_backend).
52/// - **GPU backends** (`is_gpu() == true`): use a ring-buffer pipeline with
53///   `RING_SIZE = 4` for bounded memory.
54pub trait EmbedBackend: Send + Sync {
55    /// Embed a batch of pre-tokenized inputs, returning L2-normalized vectors.
56    ///
57    /// Each inner `Vec<f32>` is the embedding for the corresponding
58    /// [`Encoding`]. Errors **must** propagate — never silently return
59    /// defaults.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if tensor construction or the forward pass fails.
64    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
65
66    /// Whether this backend supports cheap cloning for per-thread instances.
67    ///
68    /// CPU backends return `true`; GPU backends typically return `false`.
69    fn supports_clone(&self) -> bool;
70
71    /// Create a cheap clone of this backend for per-thread use in rayon.
72    ///
73    /// # Panics
74    ///
75    /// May panic if [`supports_clone`](EmbedBackend::supports_clone) returns
76    /// `false`. Callers must check `supports_clone()` first.
77    fn clone_backend(&self) -> Box<dyn EmbedBackend>;
78
79    /// Whether this backend runs on a GPU.
80    ///
81    /// GPU backends use a ring-buffer pipelined scheduler (`RING_SIZE = 4`)
82    /// for bounded memory usage.
83    fn is_gpu(&self) -> bool;
84
85    /// Maximum token count this model supports (position embedding limit).
86    ///
87    /// `ClassicBert`: 512. `ModernBERT`: up to model config. Tokens beyond this
88    /// are truncated during tokenization.
89    fn max_tokens(&self) -> usize {
90        512 // default for classic BERT models
91    }
92
93    /// Short human-readable label for this backend (e.g. "Metal", "CUDA",
94    /// "CPU (Accelerate)", "MLX"). Used by diagnostics and the `up_to_date`
95    /// MCP tool. Defaults to "GPU" or "CPU" based on [`is_gpu`].
96    fn name(&self) -> &'static str {
97        if self.is_gpu() { "GPU" } else { "CPU" }
98    }
99}
100
101/// Trait for cross-encoder rerank backends.
102///
103/// Parallel to [`EmbedBackend`], but the forward pass terminates in a
104/// scalar relevance score per pair instead of a pooled vector. Used by
105/// the retrieve-then-rerank pipeline: a bi-encoder ([`EmbedBackend`])
106/// retrieves top-K cheaply, then [`RerankBackend`] re-scores those K
107/// candidates with the cross-encoder's higher-quality cross-attention
108/// over the concatenated `[CLS] query [SEP] doc [SEP]` sequence.
109///
110/// # Why a separate trait
111///
112/// Cross-encoders share BERT's trunk with bi-encoders, but the head and
113/// pooling differ: bi-encoder = CLS pool + L2-normalize, cross-encoder
114/// = CLS pool + linear(hidden → 1) + sigmoid. The two return shapes are
115/// incompatible (`Vec<Vec<f32>>` vs `Vec<f32>`), so unifying them under
116/// a single trait would force every caller to handle an awkward sum
117/// type. Sibling traits keep both call sites direct.
118pub trait RerankBackend: Send + Sync {
119    /// Score a batch of pre-tokenized pairs and return one score per
120    /// encoding. Scores are sigmoid-activated and lie in `[0, 1]`.
121    ///
122    /// The encoding's `token_type_ids` should mark the query side as
123    /// 0 and the doc side as 1 (standard BERT pair convention); this
124    /// is what `tokenizers::Tokenizer::encode((query, doc), ..)`
125    /// produces.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if tensor construction or the forward pass fails.
130    fn score_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<f32>>;
131
132    /// Maximum token count this model supports.
133    fn max_tokens(&self) -> usize {
134        512
135    }
136
137    /// Whether this backend runs on a GPU.
138    fn is_gpu(&self) -> bool;
139
140    /// Short human-readable label for this backend.
141    fn name(&self) -> &'static str {
142        if self.is_gpu() { "GPU" } else { "CPU" }
143    }
144}
145
146/// Available embedding backend implementations.
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
148pub enum BackendKind {
149    /// CUDA (cudarc, NVIDIA GPUs via cuBLAS + custom kernels).
150    Cuda,
151    /// MLX (Apple Silicon, macOS only).
152    Mlx,
153    /// CPU (ndarray + system BLAS).
154    #[default]
155    Cpu,
156    /// Metal (Apple Silicon, macOS only, direct Metal GPU).
157    Metal,
158}
159
160impl std::fmt::Display for BackendKind {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            Self::Cuda => write!(f, "cuda"),
164            Self::Mlx => write!(f, "mlx"),
165            Self::Cpu => write!(f, "cpu"),
166            Self::Metal => write!(f, "metal"),
167        }
168    }
169}
170
171/// Device hint passed to [`load_backend`].
172///
173/// Backends map this to their native device type. Not all backends support
174/// all devices — unsupported combinations return an error.
175#[derive(Debug, Clone, Copy, Default)]
176pub enum DeviceHint {
177    /// Automatically select the best available device.
178    #[default]
179    Auto,
180    /// Force CPU inference.
181    Cpu,
182    /// Force GPU inference (Metal on macOS, CUDA on Linux/Windows).
183    Gpu,
184}
185
186/// Inference optimization parameters passed through to model loading.
187///
188/// Currently empty — reserved for future optimizations (quantization,
189/// distillation, etc.) that need to be configured at load time.
190#[derive(Debug, Clone, Default)]
191pub struct InferenceOpts {}
192
193/// Construct an embedding backend of the given kind.
194///
195/// Downloads model weights on first use via `hf-hub`. The `device_hint`
196/// is advisory — backends that don't support GPU fall back to CPU.
197///
198/// # Errors
199///
200/// Returns an error if the requested backend was not compiled in (missing
201/// feature flag) or if model loading fails.
202pub fn load_backend(
203    kind: BackendKind,
204    #[cfg_attr(
205        not(any(
206            feature = "cuda",
207            feature = "mlx",
208            feature = "cpu",
209            feature = "cpu-accelerate",
210            feature = "metal"
211        )),
212        expect(unused_variables, reason = "used when backend features are enabled")
213    )]
214    model_repo: &str,
215    #[cfg_attr(
216        not(any(
217            feature = "cuda",
218            feature = "mlx",
219            feature = "cpu",
220            feature = "cpu-accelerate",
221            feature = "metal"
222        )),
223        expect(unused_variables, reason = "used when backend features are enabled")
224    )]
225    device_hint: DeviceHint,
226) -> crate::Result<Box<dyn EmbedBackend>> {
227    match kind {
228        #[cfg(feature = "cuda")]
229        BackendKind::Cuda => {
230            if is_modernbert_model(model_repo) {
231                return load_modernbert_cuda(model_repo);
232            }
233            let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
234            Ok(Box::new(backend))
235        }
236        #[cfg(not(feature = "cuda"))]
237        BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
238            "cuda backend requires building with: cargo build --features cuda"
239        ))),
240        #[cfg(feature = "mlx")]
241        BackendKind::Mlx => {
242            let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
243            Ok(Box::new(backend))
244        }
245        #[cfg(not(feature = "mlx"))]
246        BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
247            "mlx backend requires building with: cargo build --features mlx"
248        ))),
249        #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
250        BackendKind::Cpu => {
251            if is_modernbert_model(model_repo) {
252                return load_modernbert_cpu(model_repo);
253            }
254            #[cfg(feature = "cpu")]
255            {
256                let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
257                #[expect(
258                    clippy::needless_return,
259                    reason = "return needed before cfg(not) fallback"
260                )]
261                return Ok(Box::new(backend));
262            }
263            #[cfg(not(feature = "cpu"))]
264            Err(crate::Error::Other(anyhow::anyhow!(
265                "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
266            )))
267        }
268        #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
269        BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
270            "cpu backend requires building with: cargo build --features cpu"
271        ))),
272        #[cfg(feature = "metal")]
273        BackendKind::Metal => {
274            // All models route through the driver/arch system.
275            if is_modernbert_model(model_repo) {
276                return load_modernbert_metal(model_repo);
277            }
278            load_classic_metal(model_repo)
279        }
280        #[cfg(not(feature = "metal"))]
281        BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
282            "metal backend requires building with: cargo build --features metal"
283        ))),
284    }
285}
286
287/// Detect all available backends and load them.
288///
289/// Probes for GPU backends (CUDA, MLX) first, then falls back to CPU.
290/// Returns backends in priority order — the first entry is the primary
291/// (used for query embedding in interactive mode).
292///
293/// # Errors
294///
295/// Returns an error if no backends can be loaded (not even CPU).
296pub fn detect_backends(
297    #[cfg_attr(
298        not(any(
299            feature = "cuda",
300            feature = "mlx",
301            feature = "cpu",
302            feature = "cpu-accelerate",
303            feature = "metal"
304        )),
305        expect(unused_variables, reason = "used when backend features are enabled")
306    )]
307    model_repo: &str,
308) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
309    #[cfg_attr(
310        not(any(
311            feature = "cuda",
312            feature = "mlx",
313            feature = "cpu",
314            feature = "cpu-accelerate",
315            feature = "metal"
316        )),
317        expect(unused_mut, reason = "mut needed when backend features are enabled")
318    )]
319    let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
320
321    // Try CUDA (NVIDIA GPU)
322    #[cfg(feature = "cuda")]
323    {
324        if is_modernbert_model(model_repo) {
325            if let Ok(b) = load_modernbert_cuda(model_repo) {
326                backends.push(b);
327            }
328        } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
329            backends.push(Box::new(b));
330        }
331    }
332
333    // Try Metal (Apple Silicon GPU, preferred over MLX)
334    #[cfg(feature = "metal")]
335    {
336        // Route models through the driver/arch system by architecture.
337        if is_modernbert_model(model_repo) {
338            if let Ok(b) = load_modernbert_metal(model_repo) {
339                backends.push(b);
340            }
341        } else if let Ok(b) = load_classic_metal(model_repo) {
342            backends.push(b);
343        }
344    }
345
346    // Try MLX (Apple Silicon GPU, fallback if Metal unavailable)
347    #[cfg(feature = "mlx")]
348    if backends.is_empty()
349        && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
350    {
351        backends.push(Box::new(b));
352    }
353
354    // Add CPU as fallback only when no GPU backend was loaded.
355    // On Apple Silicon, running CPU + MLX concurrently is slower than
356    // MLX alone because they share the same physical cores and memory.
357    // On discrete GPU systems (CUDA), CPU would be a useful helper.
358    #[cfg_attr(
359        not(any(feature = "cpu", feature = "cpu-accelerate")),
360        expect(unused_variables, reason = "used when cpu feature is enabled")
361    )]
362    let has_gpu = backends.iter().any(|b| b.is_gpu());
363    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
364    if !has_gpu {
365        if is_modernbert_model(model_repo) {
366            if let Ok(b) = load_modernbert_cpu(model_repo) {
367                backends.push(b);
368            }
369        } else {
370            #[cfg(feature = "cpu")]
371            if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
372                backends.push(Box::new(b));
373            }
374        }
375    }
376
377    if backends.is_empty() {
378        return Err(crate::Error::Other(anyhow::anyhow!(
379            "no embedding backends available"
380        )));
381    }
382
383    Ok(backends)
384}
385
386// ---------------------------------------------------------------------------
387// ModernBERT loader (driver/arch system)
388// ---------------------------------------------------------------------------
389
390/// Load a `ModernBERT` model on the Metal GPU backend.
391///
392/// Downloads the model from Hugging Face Hub (cached after first download),
393/// memory-maps the safetensors weights, and builds a [`GenericBackend`]
394/// pairing a [`MetalDriver`](driver::metal::MetalDriver) with a
395/// [`ModernBertArch`](arch::modern_bert::ModernBertArch).
396///
397/// # Errors
398///
399/// Returns an error if no Metal device is available, the model cannot be
400/// downloaded, or weight loading fails.
401#[cfg(feature = "metal")]
402pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
403    use driver::metal::{MetalDriver, ModernBertConfig};
404    use generic::GenericBackend;
405    use hf_hub::api::sync::Api;
406
407    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
408    let repo = api.model(model_repo.to_string());
409
410    let config_path = repo
411        .get("config.json")
412        .map_err(|e| crate::Error::Download(e.to_string()))?;
413    let weights_path = repo
414        .get("model.safetensors")
415        .map_err(|e| crate::Error::Download(e.to_string()))?;
416
417    // Parse config.json
418    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
419        path: config_path.display().to_string(),
420        source: e,
421    })?;
422    let config_json: serde_json::Value = serde_json::from_str(&config_str)
423        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
424    let config = ModernBertConfig::from_json(&config_json)?;
425    let max_tokens = config.max_position_embeddings;
426
427    let driver = MetalDriver::new()?;
428    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
429
430    tracing::info!(
431        model_repo,
432        hidden = config.hidden_size,
433        layers = config.num_hidden_layers,
434        heads = config.num_attention_heads,
435        intermediate = config.intermediate_size,
436        max_tokens,
437        "ModernBERT loaded on Metal (driver/arch)"
438    );
439
440    Ok(Box::new(GenericBackend::new(
441        driver, arch, max_tokens, true, mmap,
442    )))
443}
444
445/// Load `ModernBERT` on CPU via the driver/arch system.
446#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
447pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
448    use driver::cpu::{CpuDriver, ModernBertConfig};
449    use generic::GenericBackend;
450    use hf_hub::api::sync::Api;
451
452    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
453    let repo = api.model(model_repo.to_string());
454
455    let config_path = repo
456        .get("config.json")
457        .map_err(|e| crate::Error::Download(e.to_string()))?;
458    let weights_path = repo
459        .get("model.safetensors")
460        .map_err(|e| crate::Error::Download(e.to_string()))?;
461
462    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
463        path: config_path.display().to_string(),
464        source: e,
465    })?;
466    let config_json: serde_json::Value = serde_json::from_str(&config_str)
467        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
468    let config = ModernBertConfig::from_json(&config_json)?;
469    let max_tokens = config.max_position_embeddings;
470
471    let driver = CpuDriver::new()?;
472    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
473
474    tracing::info!(
475        model_repo,
476        hidden = config.hidden_size,
477        layers = config.num_hidden_layers,
478        heads = config.num_attention_heads,
479        max_tokens,
480        "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
481    );
482
483    Ok(Box::new(GenericBackend::new_shared(
484        driver, arch, max_tokens, false, mmap,
485    )))
486}
487
488/// Load `ModernBERT` on CUDA via the driver/arch system.
489///
490/// Creates a [`CudaDriver`](driver::cuda::CudaDriver), loads safetensors weights
491/// onto the GPU, pre-converts GEMM weights to FP16, builds RoPE caches, and
492/// wraps the result in a [`GenericBackend`](generic::GenericBackend).
493///
494/// # Errors
495///
496/// Returns an error if no CUDA device is available, the model cannot be
497/// downloaded, or weight loading fails.
498#[cfg(feature = "cuda")]
499pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
500    use driver::cuda::{CudaDriver, ModernBertConfig};
501    use generic::GenericBackend;
502    use hf_hub::api::sync::Api;
503
504    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
505    let repo = api.model(model_repo.to_string());
506
507    let config_path = repo
508        .get("config.json")
509        .map_err(|e| crate::Error::Download(e.to_string()))?;
510    let weights_path = repo
511        .get("model.safetensors")
512        .map_err(|e| crate::Error::Download(e.to_string()))?;
513
514    // Parse config.json
515    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
516        path: config_path.display().to_string(),
517        source: e,
518    })?;
519    let config_json: serde_json::Value = serde_json::from_str(&config_str)
520        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
521    let config = ModernBertConfig::from_json(&config_json)?;
522    let max_tokens = config.max_position_embeddings;
523
524    let driver = CudaDriver::new()?;
525    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
526
527    tracing::info!(
528        model_repo,
529        hidden = config.hidden_size,
530        layers = config.num_hidden_layers,
531        heads = config.num_attention_heads,
532        intermediate = config.intermediate_size,
533        max_tokens,
534        "ModernBERT loaded on CUDA (driver/arch)"
535    );
536
537    // CUDA benefits from larger batches to saturate GPU SMs (128 SMs on RTX 4090).
538    // Metal uses 32 (AMX coprocessor-limited). CUDA can handle 128+ easily.
539    Ok(Box::new(GenericBackend::with_max_batch(
540        driver,
541        arch,
542        max_tokens,
543        true,
544        generic::MmapHolder::Owned(mmap),
545        32,
546    )))
547}
548
549/// Check whether a model repo uses the `ModernBERT` architecture.
550///
551/// Downloads and inspects `config.json` to check for `"model_type": "modernbert"`.
552/// Returns `false` on any download or parse error (fail-open for detection).
553#[cfg(any(
554    feature = "cuda",
555    feature = "metal",
556    feature = "cpu",
557    feature = "cpu-accelerate"
558))]
559fn is_modernbert_model(model_repo: &str) -> bool {
560    let Ok(api) = hf_hub::api::sync::Api::new() else {
561        return false;
562    };
563    let repo = api.model(model_repo.to_string());
564    let Ok(config_path) = repo.get("config.json") else {
565        return false;
566    };
567    let Ok(config_str) = std::fs::read_to_string(&config_path) else {
568        return false;
569    };
570    let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
571        return false;
572    };
573    json.get("model_type")
574        .and_then(serde_json::Value::as_str)
575        .is_some_and(|t| t == "modernbert")
576}
577
578// ---------------------------------------------------------------------------
579// ClassicBert loader (driver/arch system)
580// ---------------------------------------------------------------------------
581
582/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the Metal GPU backend.
583///
584/// Downloads the model from Hugging Face Hub (cached after first download),
585/// memory-maps the safetensors weights, fuses Q/K/V into a single tensor per
586/// layer, and builds a [`GenericBackend`] pairing a
587/// [`MetalDriver`](driver::metal::MetalDriver) with a
588/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
589///
590/// # Errors
591///
592/// Returns an error if no Metal device is available, the model cannot be
593/// downloaded, or weight loading fails.
594#[cfg(feature = "metal")]
595pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
596    use driver::metal::{ClassicBertConfig, MetalDriver};
597    use generic::GenericBackend;
598    use hf_hub::api::sync::Api;
599
600    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
601    let repo = api.model(model_repo.to_string());
602
603    let config_path = repo
604        .get("config.json")
605        .map_err(|e| crate::Error::Download(e.to_string()))?;
606    let weights_path = repo
607        .get("model.safetensors")
608        .map_err(|e| crate::Error::Download(e.to_string()))?;
609
610    // Parse config.json
611    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
612        path: config_path.display().to_string(),
613        source: e,
614    })?;
615    let config_json: serde_json::Value = serde_json::from_str(&config_str)
616        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
617    let config = ClassicBertConfig::from_json(&config_json)?;
618    let max_tokens = config.max_position_embeddings;
619
620    let driver = MetalDriver::new()?;
621    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
622
623    tracing::info!(
624        model_repo,
625        hidden = config.hidden_size,
626        layers = config.num_hidden_layers,
627        heads = config.num_attention_heads,
628        intermediate = config.intermediate_size,
629        max_tokens,
630        "ClassicBert loaded on Metal (driver/arch)"
631    );
632
633    Ok(Box::new(GenericBackend::new(
634        driver, arch, max_tokens, true, mmap,
635    )))
636}
637
638// ---------------------------------------------------------------------------
639// ClassicBert loader (CPU driver/arch system)
640// ---------------------------------------------------------------------------
641
642/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the CPU backend
643/// via the driver/arch system.
644///
645/// Downloads the model from Hugging Face Hub (cached after first download),
646/// reads safetensors weights into `Vec<f32>` tensors, fuses Q/K/V per layer,
647/// and builds a [`GenericBackend`] pairing a
648/// [`CpuDriver`](driver::cpu::CpuDriver) with a
649/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
650///
651/// # Errors
652///
653/// Returns an error if the model cannot be downloaded or weight loading fails.
654#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
655pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
656    use driver::cpu::{ClassicBertConfig, CpuDriver};
657    use generic::GenericBackend;
658    use hf_hub::api::sync::Api;
659
660    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
661    let repo = api.model(model_repo.to_string());
662
663    let config_path = repo
664        .get("config.json")
665        .map_err(|e| crate::Error::Download(e.to_string()))?;
666    let weights_path = repo
667        .get("model.safetensors")
668        .map_err(|e| crate::Error::Download(e.to_string()))?;
669
670    // Parse config.json
671    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
672        path: config_path.display().to_string(),
673        source: e,
674    })?;
675    let config_json: serde_json::Value = serde_json::from_str(&config_str)
676        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
677    let config = ClassicBertConfig::from_json(&config_json)?;
678    let max_tokens = config.max_position_embeddings;
679
680    let driver = CpuDriver::new()?;
681    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
682
683    tracing::info!(
684        model_repo,
685        hidden = config.hidden_size,
686        layers = config.num_hidden_layers,
687        heads = config.num_attention_heads,
688        intermediate = config.intermediate_size,
689        max_tokens,
690        "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
691    );
692
693    Ok(Box::new(GenericBackend::new_shared(
694        driver, arch, max_tokens, false, mmap,
695    )))
696}
697
698/// Load a cross-encoder rerank model for CPU inference.
699///
700/// MS-MARCO family rerankers (the default
701/// `cross-encoder/ms-marco-MiniLM-L-6-v2`) are ClassicBert-shaped, so
702/// they route through [`cpu::CpuRerankBackend`] — same trunk as
703/// [`load_classic_cpu`], plus a `Linear(hidden -> 1)` classifier head.
704///
705/// Not feature-gated like the embedding backends: the rerank path is
706/// load-bearing for the document-search use case (cacheless prose
707/// queries) and must work in the default build. The underlying
708/// `CpuRerankBackend` uses the same ndarray BLAS setup as
709/// `CpuBackend`, so it works wherever the CPU embedding backend
710/// does — `feature = "cpu"` or `feature = "cpu-accelerate"`.
711///
712/// # Errors
713///
714/// Returns an error if the model cannot be downloaded, if it lacks a
715/// classifier head (i.e., the caller pointed at a bi-encoder by
716/// mistake), or if the weights fail to parse.
717#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
718pub fn load_reranker_cpu(model_repo: &str) -> crate::Result<Box<dyn RerankBackend>> {
719    let backend = cpu::CpuRerankBackend::load(model_repo)?;
720    Ok(Box::new(backend))
721}
722
723#[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
724pub fn load_reranker_cpu(_model_repo: &str) -> crate::Result<Box<dyn RerankBackend>> {
725    Err(crate::Error::Other(anyhow::anyhow!(
726        "cross-encoder rerank requires building with --features cpu \
727         or --features cpu-accelerate"
728    )))
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734
735    /// Verify that `EmbedBackend` is object-safe by constructing a trait object type.
736    #[test]
737    fn trait_is_object_safe() {
738        // If this compiles, the trait is object-safe.
739        fn _assert_object_safe(_: &dyn EmbedBackend) {}
740    }
741
742    /// Verify that `Box<dyn EmbedBackend>` is `Send`.
743    #[test]
744    fn trait_object_is_send() {
745        fn assert_send<T: Send>() {}
746        assert_send::<Box<dyn EmbedBackend>>();
747    }
748
749    /// Verify that `Box<dyn EmbedBackend>` is `Sync` (needed for `&dyn` across threads).
750    #[test]
751    fn trait_object_is_sync() {
752        fn assert_sync<T: Sync>() {}
753        assert_sync::<Box<dyn EmbedBackend>>();
754    }
755
756    /// Verify that `Arc<dyn EmbedBackend>` is `Send` (needed for ring-buffer pipeline).
757    #[test]
758    fn arc_trait_object_is_send() {
759        fn assert_send<T: Send>() {}
760        assert_send::<std::sync::Arc<dyn EmbedBackend>>();
761    }
762
763    #[test]
764    fn encoding_construction() {
765        let enc = Encoding {
766            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
767            attention_mask: vec![1, 1, 1, 1, 1, 1],
768            token_type_ids: vec![0, 0, 0, 0, 0, 0],
769        };
770        assert_eq!(enc.input_ids.len(), 6);
771        assert_eq!(enc.attention_mask.len(), 6);
772        assert_eq!(enc.token_type_ids.len(), 6);
773    }
774
775    #[test]
776    fn encoding_clone() {
777        let enc = Encoding {
778            input_ids: vec![101, 102],
779            attention_mask: vec![1, 1],
780            token_type_ids: vec![0, 0],
781        };
782        let cloned = enc.clone();
783        assert_eq!(enc.input_ids, cloned.input_ids);
784    }
785
786    #[test]
787    fn backend_kind_default_is_cpu() {
788        assert_eq!(BackendKind::default(), BackendKind::Cpu);
789    }
790
791    #[test]
792    fn backend_kind_display() {
793        assert_eq!(BackendKind::Cuda.to_string(), "cuda");
794        assert_eq!(BackendKind::Mlx.to_string(), "mlx");
795        assert_eq!(BackendKind::Cpu.to_string(), "cpu");
796    }
797
798    #[cfg(not(feature = "mlx"))]
799    #[test]
800    fn load_backend_mlx_not_compiled() {
801        let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
802        assert!(result.is_err());
803    }
804
805    #[cfg(feature = "cpu")]
806    #[test]
807    fn detect_backends_returns_at_least_one() {
808        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
809        assert!(!backends.is_empty());
810    }
811
812    #[cfg(all(feature = "cpu", not(feature = "mlx")))]
813    #[test]
814    fn detect_backends_returns_at_least_one_backend() {
815        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
816        assert!(!backends.is_empty(), "should detect at least one backend");
817    }
818
819    /// Load `ModernBERT` on Metal and embed a short token sequence.
820    ///
821    /// Verifies that the full pipeline (weight loading, forward pass, pooling,
822    /// L2 normalization) produces a 768-dim unit vector.
823    #[cfg(feature = "metal")]
824    #[test]
825    #[ignore = "requires model download (~570MB)"]
826    #[expect(clippy::too_many_lines, reason = "end-to-end backend diagnostic test")]
827    fn modernbert_loads_and_embeds() {
828        use crate::backend::arch::ModelArch;
829        use crate::backend::driver::Driver;
830
831        let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
832        assert!(backend.is_gpu(), "Metal backend should be GPU");
833
834        let enc = Encoding {
835            input_ids: vec![1, 100, 200, 300, 2],
836            attention_mask: vec![1; 5],
837            token_type_ids: vec![0; 5],
838        };
839
840        // Stage-by-stage diagnostic using the driver directly
841        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
842        let inputs = driver.prepare_batch(std::slice::from_ref(&enc), 8).unwrap();
843
844        // Check: can we read back input_ids?
845        let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
846        eprintln!("input_ids: {:?}", &ids_host[0][..5]);
847
848        // Check: embedding lookup
849        // Need the tok_embeddings weight — load weights directly
850        let api = hf_hub::api::sync::Api::new().unwrap();
851        let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
852        let weights_path = repo.get("model.safetensors").unwrap();
853        let config_path = repo.get("config.json").unwrap();
854        let config_str = std::fs::read_to_string(&config_path).unwrap();
855        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
856        let config =
857            crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
858        let (arch, _mmap) = driver
859            .load_modern_bert_weights(&weights_path, &config)
860            .unwrap();
861
862        let hidden = driver
863            .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
864            .unwrap();
865        let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
866        let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
867        eprintln!(
868            "embedding: {nz}/{} nonzero, first 5: {:?}",
869            h[0].len(),
870            &h[0][..5]
871        );
872
873        // Stage-by-stage forward pass bisection
874        let total = 8; // padded seq
875        let hd = 768;
876        let nh = 12;
877        let head_dim = 64;
878
879        // After embedding LN
880        let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
881        let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
882        driver
883            .layer_norm(
884                &mut ln_out,
885                &emb_clone,
886                &arch.weights.emb_norm_weight,
887                &arch.weights.zero_bias,
888                total,
889                hd,
890                1e-5,
891            )
892            .unwrap();
893        let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
894        let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
895        eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
896
897        // Layer 0 QKV GEMM
898        let layer0 = &arch.weights.layers[0];
899        let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
900        driver
901            .gemm(
902                &ln_out,
903                &layer0.qkv_weight,
904                &mut qkv,
905                total,
906                3 * hd,
907                hd,
908                true,
909            )
910            .unwrap();
911        let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
912        let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
913        eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
914
915        // QKV split
916        let mut q = driver.alloc_zeros(total * hd).unwrap();
917        let mut k = driver.alloc_zeros(total * hd).unwrap();
918        let mut v = driver.alloc_zeros(total * hd).unwrap();
919        driver
920            .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
921            .unwrap();
922        let q_h = driver.to_host(&q, 1, total * hd).unwrap();
923        let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
924        eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
925
926        // Attention scores
927        let mut scores = driver.alloc_zeros(nh * 8 * 8).unwrap();
928        driver
929            .gemm_batched(
930                &q,
931                &k,
932                &mut scores,
933                8,
934                8,
935                head_dim,
936                true,
937                8 * head_dim,
938                8 * head_dim,
939                8 * 8,
940                nh,
941            )
942            .unwrap();
943        let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
944        let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
945        eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
946
947        // Full forward pass
948        let enc2 = Encoding {
949            input_ids: vec![1, 100, 200, 300, 2],
950            attention_mask: vec![1; 5],
951            token_type_ids: vec![0; 5],
952        };
953
954        let quick = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
955        let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
956        let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
957        eprintln!(
958            "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
959            &quick[0][..3]
960        );
961
962        // MRL truncation
963        eprintln!("\n=== ModernBERT MRL Truncation ===");
964        let full = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
965        let full_emb = &full[0];
966        for dims in [64, 128, 256, 384, 512, 768] {
967            let t: Vec<f32> = full_emb[..dims].to_vec();
968            let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
969            let f_norm: f32 = full_emb[..dims]
970                .iter()
971                .map(|x| x * x)
972                .sum::<f32>()
973                .sqrt()
974                .max(1e-12);
975            let cos: f32 = t
976                .iter()
977                .zip(&full_emb[..dims])
978                .map(|(a, b)| a * b)
979                .sum::<f32>()
980                / (t_norm * f_norm);
981            eprintln!("  dims={dims:>3}: cosine={cos:.6}");
982        }
983
984        // Throughput benchmark
985        eprintln!("\n=== ModernBERT Throughput ===");
986        // Build 32 encodings of varying length
987        let mut encs = Vec::new();
988        for i in 0..32 {
989            let len = 16 + (i * 4); // 16 to 140 tokens
990            let mut ids = vec![1_i64]; // CLS
991            for j in 1..len - 1 {
992                ids.push(100 + i64::from(j));
993            }
994            ids.push(2); // SEP
995            encs.push(Encoding {
996                input_ids: ids.clone(),
997                attention_mask: vec![1; ids.len()],
998                token_type_ids: vec![0; ids.len()],
999            });
1000        }
1001
1002        // Warmup
1003        let _ = arch.forward(&driver, &encs[..4]);
1004
1005        // Timed run
1006        let t0 = std::time::Instant::now();
1007        let result = arch.forward(&driver, &encs).unwrap();
1008        let elapsed = t0.elapsed();
1009        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1010        eprintln!(
1011            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1012            encs.len(),
1013            elapsed.as_secs_f64() * 1000.0,
1014            throughput
1015        );
1016        assert_eq!(result.len(), 32);
1017
1018        // Batch=1 timing (critical — CLI query path)
1019        let single = vec![encs[0].clone()];
1020        let t1 = std::time::Instant::now();
1021        let _ = arch.forward(&driver, &single).unwrap();
1022        let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
1023        eprintln!("  batch=1, time={single_ms:.1}ms");
1024    }
1025
1026    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on Metal via the driver/arch system.
1027    ///
1028    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector,
1029    /// compares against the CPU backend for numerical equivalence, and measures
1030    /// throughput (target: >=308/s matching monolithic).
1031    #[cfg(feature = "metal")]
1032    #[test]
1033    #[ignore = "requires model download (~33MB)"]
1034    fn classic_bert_driver_arch() {
1035        use crate::backend::arch::ModelArch;
1036
1037        let model_repo = "BAAI/bge-small-en-v1.5";
1038
1039        // Load via driver/arch system
1040        let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
1041        assert!(backend.is_gpu(), "Metal backend should be GPU");
1042
1043        let enc = Encoding {
1044            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1045            attention_mask: vec![1, 1, 1, 1, 1, 1],
1046            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1047        };
1048
1049        // Basic forward pass
1050        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1051        assert_eq!(result.len(), 1);
1052        assert_eq!(result[0].len(), 384);
1053
1054        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1055        eprintln!(
1056            "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
1057            &result[0][..3]
1058        );
1059        assert!(
1060            (l2 - 1.0).abs() < 0.01,
1061            "embedding should be L2-normalized, got L2={l2}"
1062        );
1063
1064        // Compare against CPU backend (reliable reference)
1065        #[cfg(feature = "cpu")]
1066        {
1067            let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
1068                .expect("CPU load failed");
1069            let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
1070            eprintln!("CPU  first 5: {:?}", &cpu_result[0][..5]);
1071            eprintln!("NEW  first 5: {:?}", &result[0][..5]);
1072            let cosine: f32 = result[0]
1073                .iter()
1074                .zip(&cpu_result[0])
1075                .map(|(a, b)| a * b)
1076                .sum();
1077            eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
1078            assert!(
1079                cosine > 0.95,
1080                "cosine similarity vs CPU should be >0.95, got {cosine}"
1081            );
1082        }
1083
1084        // Throughput benchmark
1085        eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
1086        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
1087        let config_path = {
1088            let api = hf_hub::api::sync::Api::new().unwrap();
1089            let repo = api.model(model_repo.to_string());
1090            repo.get("config.json").unwrap()
1091        };
1092        let weights_path = {
1093            let api = hf_hub::api::sync::Api::new().unwrap();
1094            let repo = api.model(model_repo.to_string());
1095            repo.get("model.safetensors").unwrap()
1096        };
1097        let config_str = std::fs::read_to_string(&config_path).unwrap();
1098        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1099        let config =
1100            crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1101        let (arch, _mmap) = driver
1102            .load_classic_bert_weights(&weights_path, &config)
1103            .unwrap();
1104
1105        // Build 32 encodings of varying length
1106        let mut encs = Vec::new();
1107        for i in 0..32 {
1108            let len = 16 + (i * 4); // 16 to 140 tokens
1109            let mut ids = vec![101_i64]; // [CLS]
1110            for j in 1..len - 1 {
1111                ids.push(100 + i64::from(j));
1112            }
1113            ids.push(102); // [SEP]
1114            encs.push(Encoding {
1115                input_ids: ids.clone(),
1116                attention_mask: vec![1; ids.len()],
1117                token_type_ids: vec![0; ids.len()],
1118            });
1119        }
1120
1121        // Warmup
1122        let _ = arch.forward(&driver, &encs[..4]);
1123
1124        // Timed run
1125        let t0 = std::time::Instant::now();
1126        let bench_result = arch.forward(&driver, &encs).unwrap();
1127        let elapsed = t0.elapsed();
1128        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1129        eprintln!(
1130            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1131            encs.len(),
1132            elapsed.as_secs_f64() * 1000.0,
1133            throughput
1134        );
1135        assert_eq!(bench_result.len(), 32);
1136    }
1137
1138    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on CPU via the driver/arch system.
1139    ///
1140    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector
1141    /// and compares against the monolithic CPU backend for numerical equivalence.
1142    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1143    #[test]
1144    #[ignore = "requires model download (~33MB)"]
1145    fn classic_bert_cpu_driver_arch() {
1146        let model_repo = "BAAI/bge-small-en-v1.5";
1147
1148        // Load via new driver/arch system
1149        let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1150        assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1151
1152        let enc = Encoding {
1153            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1154            attention_mask: vec![1, 1, 1, 1, 1, 1],
1155            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1156        };
1157
1158        // Basic forward pass
1159        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1160        assert_eq!(result.len(), 1);
1161        assert_eq!(result[0].len(), 384);
1162
1163        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1164        eprintln!(
1165            "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1166            &result[0][..5]
1167        );
1168        assert!(
1169            (l2 - 1.0).abs() < 0.01,
1170            "embedding should be L2-normalized, got L2={l2}"
1171        );
1172
1173        // Compare against monolithic CPU backend (reference)
1174        #[cfg(feature = "cpu")]
1175        {
1176            let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1177                .expect("monolithic CPU load failed");
1178            let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1179            eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1180            eprintln!("New  first 5: {:?}", &result[0][..5]);
1181            let cosine: f32 = result[0]
1182                .iter()
1183                .zip(&cpu_result[0])
1184                .map(|(a, b)| a * b)
1185                .sum();
1186            eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1187            assert!(
1188                cosine > 0.999,
1189                "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1190            );
1191        }
1192    }
1193}