Skip to main content

ripvec_core/backend/
mod.rs

1//! Embedding backend abstraction layer.
2//!
3//! Defines the [`EmbedBackend`] trait that all embedding backends (CPU, CUDA,
4//! Metal, MLX) implement, plus the [`Encoding`] input type and [`BackendKind`]
5//! discriminant. Use [`load_backend`] to construct a backend by kind.
6
7pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19#[cfg(feature = "cuda")]
20pub mod nvrtc_cubin;
21
22/// Pre-tokenized encoding ready for inference.
23///
24/// Token IDs, attention mask, and token type IDs must all have the same length.
25/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
26/// reaching the backend.
27#[derive(Debug, Clone)]
28pub struct Encoding {
29    /// Token IDs produced by the tokenizer.
30    pub input_ids: Vec<i64>,
31    /// Attention mask (1 for real tokens, 0 for padding).
32    pub attention_mask: Vec<i64>,
33    /// Token type IDs (0 for single-sequence models).
34    pub token_type_ids: Vec<i64>,
35}
36
37/// Trait for embedding backends.
38///
39/// Implementations must be [`Send`] so they can be moved across threads (e.g.
40/// into a ring-buffer pipeline). The trait is object-safe — callers use
41/// `&dyn EmbedBackend` or `Box<dyn EmbedBackend>`.
42///
43/// # GPU vs CPU scheduling
44///
45/// - **CPU backends** (`is_gpu() == false`): cloned per rayon thread via
46///   [`clone_backend`](EmbedBackend::clone_backend).
47/// - **GPU backends** (`is_gpu() == true`): use a ring-buffer pipeline with
48///   `RING_SIZE = 4` for bounded memory.
49pub trait EmbedBackend: Send + Sync {
50    /// Embed a batch of pre-tokenized inputs, returning L2-normalized vectors.
51    ///
52    /// Each inner `Vec<f32>` is the embedding for the corresponding
53    /// [`Encoding`]. Errors **must** propagate — never silently return
54    /// defaults.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if tensor construction or the forward pass fails.
59    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
60
61    /// Whether this backend supports cheap cloning for per-thread instances.
62    ///
63    /// CPU backends return `true`; GPU backends typically return `false`.
64    fn supports_clone(&self) -> bool;
65
66    /// Create a cheap clone of this backend for per-thread use in rayon.
67    ///
68    /// # Panics
69    ///
70    /// May panic if [`supports_clone`](EmbedBackend::supports_clone) returns
71    /// `false`. Callers must check `supports_clone()` first.
72    fn clone_backend(&self) -> Box<dyn EmbedBackend>;
73
74    /// Whether this backend runs on a GPU.
75    ///
76    /// GPU backends use a ring-buffer pipelined scheduler (`RING_SIZE = 4`)
77    /// for bounded memory usage.
78    fn is_gpu(&self) -> bool;
79
80    /// Maximum token count this model supports (position embedding limit).
81    ///
82    /// `ClassicBert`: 512. `ModernBERT`: up to model config. Tokens beyond this
83    /// are truncated during tokenization.
84    fn max_tokens(&self) -> usize {
85        512 // default for classic BERT models
86    }
87
88    /// Short human-readable label for this backend (e.g. "Metal", "CUDA",
89    /// "CPU (Accelerate)", "MLX"). Used by diagnostics and the `up_to_date`
90    /// MCP tool. Defaults to "GPU" or "CPU" based on [`is_gpu`].
91    fn name(&self) -> &'static str {
92        if self.is_gpu() { "GPU" } else { "CPU" }
93    }
94}
95
96/// Available embedding backend implementations.
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
98pub enum BackendKind {
99    /// CUDA (cudarc, NVIDIA GPUs via cuBLAS + custom kernels).
100    Cuda,
101    /// MLX (Apple Silicon, macOS only).
102    Mlx,
103    /// CPU (ndarray + system BLAS).
104    #[default]
105    Cpu,
106    /// Metal (Apple Silicon, macOS only, direct Metal GPU).
107    Metal,
108}
109
110impl std::fmt::Display for BackendKind {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        match self {
113            Self::Cuda => write!(f, "cuda"),
114            Self::Mlx => write!(f, "mlx"),
115            Self::Cpu => write!(f, "cpu"),
116            Self::Metal => write!(f, "metal"),
117        }
118    }
119}
120
121/// Device hint passed to [`load_backend`].
122///
123/// Backends map this to their native device type. Not all backends support
124/// all devices — unsupported combinations return an error.
125#[derive(Debug, Clone, Copy, Default)]
126pub enum DeviceHint {
127    /// Automatically select the best available device.
128    #[default]
129    Auto,
130    /// Force CPU inference.
131    Cpu,
132    /// Force GPU inference (Metal on macOS, CUDA on Linux/Windows).
133    Gpu,
134}
135
136/// Inference optimization parameters passed through to model loading.
137///
138/// Currently empty — reserved for future optimizations (quantization,
139/// distillation, etc.) that need to be configured at load time.
140#[derive(Debug, Clone, Default)]
141pub struct InferenceOpts {}
142
143/// Construct an embedding backend of the given kind.
144///
145/// Downloads model weights on first use via `hf-hub`. The `device_hint`
146/// is advisory — backends that don't support GPU fall back to CPU.
147///
148/// # Errors
149///
150/// Returns an error if the requested backend was not compiled in (missing
151/// feature flag) or if model loading fails.
152pub fn load_backend(
153    kind: BackendKind,
154    #[cfg_attr(
155        not(any(
156            feature = "cuda",
157            feature = "mlx",
158            feature = "cpu",
159            feature = "cpu-accelerate",
160            feature = "metal"
161        )),
162        expect(unused_variables, reason = "used when backend features are enabled")
163    )]
164    model_repo: &str,
165    #[cfg_attr(
166        not(any(
167            feature = "cuda",
168            feature = "mlx",
169            feature = "cpu",
170            feature = "cpu-accelerate",
171            feature = "metal"
172        )),
173        expect(unused_variables, reason = "used when backend features are enabled")
174    )]
175    device_hint: DeviceHint,
176) -> crate::Result<Box<dyn EmbedBackend>> {
177    match kind {
178        #[cfg(feature = "cuda")]
179        BackendKind::Cuda => {
180            if is_modernbert_model(model_repo) {
181                return load_modernbert_cuda(model_repo);
182            }
183            let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
184            Ok(Box::new(backend))
185        }
186        #[cfg(not(feature = "cuda"))]
187        BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
188            "cuda backend requires building with: cargo build --features cuda"
189        ))),
190        #[cfg(feature = "mlx")]
191        BackendKind::Mlx => {
192            let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
193            Ok(Box::new(backend))
194        }
195        #[cfg(not(feature = "mlx"))]
196        BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
197            "mlx backend requires building with: cargo build --features mlx"
198        ))),
199        #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
200        BackendKind::Cpu => {
201            if is_modernbert_model(model_repo) {
202                return load_modernbert_cpu(model_repo);
203            }
204            #[cfg(feature = "cpu")]
205            {
206                let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
207                #[expect(
208                    clippy::needless_return,
209                    reason = "return needed before cfg(not) fallback"
210                )]
211                return Ok(Box::new(backend));
212            }
213            #[cfg(not(feature = "cpu"))]
214            Err(crate::Error::Other(anyhow::anyhow!(
215                "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
216            )))
217        }
218        #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
219        BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
220            "cpu backend requires building with: cargo build --features cpu"
221        ))),
222        #[cfg(feature = "metal")]
223        BackendKind::Metal => {
224            // All models route through the driver/arch system.
225            if is_modernbert_model(model_repo) {
226                return load_modernbert_metal(model_repo);
227            }
228            load_classic_metal(model_repo)
229        }
230        #[cfg(not(feature = "metal"))]
231        BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
232            "metal backend requires building with: cargo build --features metal"
233        ))),
234    }
235}
236
237/// Detect all available backends and load them.
238///
239/// Probes for GPU backends (CUDA, MLX) first, then falls back to CPU.
240/// Returns backends in priority order — the first entry is the primary
241/// (used for query embedding in interactive mode).
242///
243/// # Errors
244///
245/// Returns an error if no backends can be loaded (not even CPU).
246pub fn detect_backends(
247    #[cfg_attr(
248        not(any(
249            feature = "cuda",
250            feature = "mlx",
251            feature = "cpu",
252            feature = "cpu-accelerate",
253            feature = "metal"
254        )),
255        expect(unused_variables, reason = "used when backend features are enabled")
256    )]
257    model_repo: &str,
258) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
259    #[cfg_attr(
260        not(any(
261            feature = "cuda",
262            feature = "mlx",
263            feature = "cpu",
264            feature = "cpu-accelerate",
265            feature = "metal"
266        )),
267        expect(unused_mut, reason = "mut needed when backend features are enabled")
268    )]
269    let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
270
271    // Try CUDA (NVIDIA GPU)
272    #[cfg(feature = "cuda")]
273    {
274        if is_modernbert_model(model_repo) {
275            if let Ok(b) = load_modernbert_cuda(model_repo) {
276                backends.push(b);
277            }
278        } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
279            backends.push(Box::new(b));
280        }
281    }
282
283    // Try Metal (Apple Silicon GPU, preferred over MLX)
284    #[cfg(feature = "metal")]
285    {
286        // Route models through the driver/arch system by architecture.
287        if is_modernbert_model(model_repo) {
288            if let Ok(b) = load_modernbert_metal(model_repo) {
289                backends.push(b);
290            }
291        } else if let Ok(b) = load_classic_metal(model_repo) {
292            backends.push(b);
293        }
294    }
295
296    // Try MLX (Apple Silicon GPU, fallback if Metal unavailable)
297    #[cfg(feature = "mlx")]
298    if backends.is_empty()
299        && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
300    {
301        backends.push(Box::new(b));
302    }
303
304    // Add CPU as fallback only when no GPU backend was loaded.
305    // On Apple Silicon, running CPU + MLX concurrently is slower than
306    // MLX alone because they share the same physical cores and memory.
307    // On discrete GPU systems (CUDA), CPU would be a useful helper.
308    #[cfg_attr(
309        not(any(feature = "cpu", feature = "cpu-accelerate")),
310        expect(unused_variables, reason = "used when cpu feature is enabled")
311    )]
312    let has_gpu = backends.iter().any(|b| b.is_gpu());
313    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
314    if !has_gpu {
315        if is_modernbert_model(model_repo) {
316            if let Ok(b) = load_modernbert_cpu(model_repo) {
317                backends.push(b);
318            }
319        } else {
320            #[cfg(feature = "cpu")]
321            if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
322                backends.push(Box::new(b));
323            }
324        }
325    }
326
327    if backends.is_empty() {
328        return Err(crate::Error::Other(anyhow::anyhow!(
329            "no embedding backends available"
330        )));
331    }
332
333    Ok(backends)
334}
335
336// ---------------------------------------------------------------------------
337// ModernBERT loader (driver/arch system)
338// ---------------------------------------------------------------------------
339
340/// Load a `ModernBERT` model on the Metal GPU backend.
341///
342/// Downloads the model from Hugging Face Hub (cached after first download),
343/// memory-maps the safetensors weights, and builds a [`GenericBackend`]
344/// pairing a [`MetalDriver`](driver::metal::MetalDriver) with a
345/// [`ModernBertArch`](arch::modern_bert::ModernBertArch).
346///
347/// # Errors
348///
349/// Returns an error if no Metal device is available, the model cannot be
350/// downloaded, or weight loading fails.
351#[cfg(feature = "metal")]
352pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
353    use driver::metal::{MetalDriver, ModernBertConfig};
354    use generic::GenericBackend;
355    use hf_hub::api::sync::Api;
356
357    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
358    let repo = api.model(model_repo.to_string());
359
360    let config_path = repo
361        .get("config.json")
362        .map_err(|e| crate::Error::Download(e.to_string()))?;
363    let weights_path = repo
364        .get("model.safetensors")
365        .map_err(|e| crate::Error::Download(e.to_string()))?;
366
367    // Parse config.json
368    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
369        path: config_path.display().to_string(),
370        source: e,
371    })?;
372    let config_json: serde_json::Value = serde_json::from_str(&config_str)
373        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
374    let config = ModernBertConfig::from_json(&config_json)?;
375    let max_tokens = config.max_position_embeddings;
376
377    let driver = MetalDriver::new()?;
378    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
379
380    tracing::info!(
381        model_repo,
382        hidden = config.hidden_size,
383        layers = config.num_hidden_layers,
384        heads = config.num_attention_heads,
385        intermediate = config.intermediate_size,
386        max_tokens,
387        "ModernBERT loaded on Metal (driver/arch)"
388    );
389
390    Ok(Box::new(GenericBackend::new(
391        driver, arch, max_tokens, true, mmap,
392    )))
393}
394
395/// Load `ModernBERT` on CPU via the driver/arch system.
396#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
397pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
398    use driver::cpu::{CpuDriver, ModernBertConfig};
399    use generic::GenericBackend;
400    use hf_hub::api::sync::Api;
401
402    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
403    let repo = api.model(model_repo.to_string());
404
405    let config_path = repo
406        .get("config.json")
407        .map_err(|e| crate::Error::Download(e.to_string()))?;
408    let weights_path = repo
409        .get("model.safetensors")
410        .map_err(|e| crate::Error::Download(e.to_string()))?;
411
412    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
413        path: config_path.display().to_string(),
414        source: e,
415    })?;
416    let config_json: serde_json::Value = serde_json::from_str(&config_str)
417        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
418    let config = ModernBertConfig::from_json(&config_json)?;
419    let max_tokens = config.max_position_embeddings;
420
421    let driver = CpuDriver::new()?;
422    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
423
424    tracing::info!(
425        model_repo,
426        hidden = config.hidden_size,
427        layers = config.num_hidden_layers,
428        heads = config.num_attention_heads,
429        max_tokens,
430        "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
431    );
432
433    Ok(Box::new(GenericBackend::new_shared(
434        driver, arch, max_tokens, false, mmap,
435    )))
436}
437
438/// Load `ModernBERT` on CUDA via the driver/arch system.
439///
440/// Creates a [`CudaDriver`](driver::cuda::CudaDriver), loads safetensors weights
441/// onto the GPU, pre-converts GEMM weights to FP16, builds RoPE caches, and
442/// wraps the result in a [`GenericBackend`](generic::GenericBackend).
443///
444/// # Errors
445///
446/// Returns an error if no CUDA device is available, the model cannot be
447/// downloaded, or weight loading fails.
448#[cfg(feature = "cuda")]
449pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
450    use driver::cuda::{CudaDriver, ModernBertConfig};
451    use generic::GenericBackend;
452    use hf_hub::api::sync::Api;
453
454    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
455    let repo = api.model(model_repo.to_string());
456
457    let config_path = repo
458        .get("config.json")
459        .map_err(|e| crate::Error::Download(e.to_string()))?;
460    let weights_path = repo
461        .get("model.safetensors")
462        .map_err(|e| crate::Error::Download(e.to_string()))?;
463
464    // Parse config.json
465    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
466        path: config_path.display().to_string(),
467        source: e,
468    })?;
469    let config_json: serde_json::Value = serde_json::from_str(&config_str)
470        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
471    let config = ModernBertConfig::from_json(&config_json)?;
472    let max_tokens = config.max_position_embeddings;
473
474    let driver = CudaDriver::new()?;
475    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
476
477    tracing::info!(
478        model_repo,
479        hidden = config.hidden_size,
480        layers = config.num_hidden_layers,
481        heads = config.num_attention_heads,
482        intermediate = config.intermediate_size,
483        max_tokens,
484        "ModernBERT loaded on CUDA (driver/arch)"
485    );
486
487    // CUDA benefits from larger batches to saturate GPU SMs (128 SMs on RTX 4090).
488    // Metal uses 32 (AMX coprocessor-limited). CUDA can handle 128+ easily.
489    Ok(Box::new(GenericBackend::with_max_batch(
490        driver,
491        arch,
492        max_tokens,
493        true,
494        generic::MmapHolder::Owned(mmap),
495        32,
496    )))
497}
498
499/// Check whether a model repo uses the `ModernBERT` architecture.
500///
501/// Downloads and inspects `config.json` to check for `"model_type": "modernbert"`.
502/// Returns `false` on any download or parse error (fail-open for detection).
503#[cfg(any(
504    feature = "cuda",
505    feature = "metal",
506    feature = "cpu",
507    feature = "cpu-accelerate"
508))]
509fn is_modernbert_model(model_repo: &str) -> bool {
510    let Ok(api) = hf_hub::api::sync::Api::new() else {
511        return false;
512    };
513    let repo = api.model(model_repo.to_string());
514    let Ok(config_path) = repo.get("config.json") else {
515        return false;
516    };
517    let Ok(config_str) = std::fs::read_to_string(&config_path) else {
518        return false;
519    };
520    let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
521        return false;
522    };
523    json.get("model_type")
524        .and_then(serde_json::Value::as_str)
525        .is_some_and(|t| t == "modernbert")
526}
527
528// ---------------------------------------------------------------------------
529// ClassicBert loader (driver/arch system)
530// ---------------------------------------------------------------------------
531
532/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the Metal GPU backend.
533///
534/// Downloads the model from Hugging Face Hub (cached after first download),
535/// memory-maps the safetensors weights, fuses Q/K/V into a single tensor per
536/// layer, and builds a [`GenericBackend`] pairing a
537/// [`MetalDriver`](driver::metal::MetalDriver) with a
538/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
539///
540/// # Errors
541///
542/// Returns an error if no Metal device is available, the model cannot be
543/// downloaded, or weight loading fails.
544#[cfg(feature = "metal")]
545pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
546    use driver::metal::{ClassicBertConfig, MetalDriver};
547    use generic::GenericBackend;
548    use hf_hub::api::sync::Api;
549
550    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
551    let repo = api.model(model_repo.to_string());
552
553    let config_path = repo
554        .get("config.json")
555        .map_err(|e| crate::Error::Download(e.to_string()))?;
556    let weights_path = repo
557        .get("model.safetensors")
558        .map_err(|e| crate::Error::Download(e.to_string()))?;
559
560    // Parse config.json
561    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
562        path: config_path.display().to_string(),
563        source: e,
564    })?;
565    let config_json: serde_json::Value = serde_json::from_str(&config_str)
566        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
567    let config = ClassicBertConfig::from_json(&config_json)?;
568    let max_tokens = config.max_position_embeddings;
569
570    let driver = MetalDriver::new()?;
571    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
572
573    tracing::info!(
574        model_repo,
575        hidden = config.hidden_size,
576        layers = config.num_hidden_layers,
577        heads = config.num_attention_heads,
578        intermediate = config.intermediate_size,
579        max_tokens,
580        "ClassicBert loaded on Metal (driver/arch)"
581    );
582
583    Ok(Box::new(GenericBackend::new(
584        driver, arch, max_tokens, true, mmap,
585    )))
586}
587
588// ---------------------------------------------------------------------------
589// ClassicBert loader (CPU driver/arch system)
590// ---------------------------------------------------------------------------
591
592/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the CPU backend
593/// via the driver/arch system.
594///
595/// Downloads the model from Hugging Face Hub (cached after first download),
596/// reads safetensors weights into `Vec<f32>` tensors, fuses Q/K/V per layer,
597/// and builds a [`GenericBackend`] pairing a
598/// [`CpuDriver`](driver::cpu::CpuDriver) with a
599/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
600///
601/// # Errors
602///
603/// Returns an error if the model cannot be downloaded or weight loading fails.
604#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
605pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
606    use driver::cpu::{ClassicBertConfig, CpuDriver};
607    use generic::GenericBackend;
608    use hf_hub::api::sync::Api;
609
610    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
611    let repo = api.model(model_repo.to_string());
612
613    let config_path = repo
614        .get("config.json")
615        .map_err(|e| crate::Error::Download(e.to_string()))?;
616    let weights_path = repo
617        .get("model.safetensors")
618        .map_err(|e| crate::Error::Download(e.to_string()))?;
619
620    // Parse config.json
621    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
622        path: config_path.display().to_string(),
623        source: e,
624    })?;
625    let config_json: serde_json::Value = serde_json::from_str(&config_str)
626        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
627    let config = ClassicBertConfig::from_json(&config_json)?;
628    let max_tokens = config.max_position_embeddings;
629
630    let driver = CpuDriver::new()?;
631    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
632
633    tracing::info!(
634        model_repo,
635        hidden = config.hidden_size,
636        layers = config.num_hidden_layers,
637        heads = config.num_attention_heads,
638        intermediate = config.intermediate_size,
639        max_tokens,
640        "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
641    );
642
643    Ok(Box::new(GenericBackend::new_shared(
644        driver, arch, max_tokens, false, mmap,
645    )))
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    /// Verify that `EmbedBackend` is object-safe by constructing a trait object type.
653    #[test]
654    fn trait_is_object_safe() {
655        // If this compiles, the trait is object-safe.
656        fn _assert_object_safe(_: &dyn EmbedBackend) {}
657    }
658
659    /// Verify that `Box<dyn EmbedBackend>` is `Send`.
660    #[test]
661    fn trait_object_is_send() {
662        fn assert_send<T: Send>() {}
663        assert_send::<Box<dyn EmbedBackend>>();
664    }
665
666    /// Verify that `Box<dyn EmbedBackend>` is `Sync` (needed for `&dyn` across threads).
667    #[test]
668    fn trait_object_is_sync() {
669        fn assert_sync<T: Sync>() {}
670        assert_sync::<Box<dyn EmbedBackend>>();
671    }
672
673    /// Verify that `Arc<dyn EmbedBackend>` is `Send` (needed for ring-buffer pipeline).
674    #[test]
675    fn arc_trait_object_is_send() {
676        fn assert_send<T: Send>() {}
677        assert_send::<std::sync::Arc<dyn EmbedBackend>>();
678    }
679
680    #[test]
681    fn encoding_construction() {
682        let enc = Encoding {
683            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
684            attention_mask: vec![1, 1, 1, 1, 1, 1],
685            token_type_ids: vec![0, 0, 0, 0, 0, 0],
686        };
687        assert_eq!(enc.input_ids.len(), 6);
688        assert_eq!(enc.attention_mask.len(), 6);
689        assert_eq!(enc.token_type_ids.len(), 6);
690    }
691
692    #[test]
693    fn encoding_clone() {
694        let enc = Encoding {
695            input_ids: vec![101, 102],
696            attention_mask: vec![1, 1],
697            token_type_ids: vec![0, 0],
698        };
699        let cloned = enc.clone();
700        assert_eq!(enc.input_ids, cloned.input_ids);
701    }
702
703    #[test]
704    fn backend_kind_default_is_cpu() {
705        assert_eq!(BackendKind::default(), BackendKind::Cpu);
706    }
707
708    #[test]
709    fn backend_kind_display() {
710        assert_eq!(BackendKind::Cuda.to_string(), "cuda");
711        assert_eq!(BackendKind::Mlx.to_string(), "mlx");
712        assert_eq!(BackendKind::Cpu.to_string(), "cpu");
713    }
714
715    #[cfg(not(feature = "mlx"))]
716    #[test]
717    fn load_backend_mlx_not_compiled() {
718        let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
719        assert!(result.is_err());
720    }
721
722    #[cfg(feature = "cpu")]
723    #[test]
724    fn detect_backends_returns_at_least_one() {
725        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
726        assert!(!backends.is_empty());
727    }
728
729    #[cfg(all(feature = "cpu", not(feature = "mlx")))]
730    #[test]
731    fn detect_backends_returns_at_least_one_backend() {
732        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
733        assert!(!backends.is_empty(), "should detect at least one backend");
734    }
735
736    /// Load `ModernBERT` on Metal and embed a short token sequence.
737    ///
738    /// Verifies that the full pipeline (weight loading, forward pass, pooling,
739    /// L2 normalization) produces a 768-dim unit vector.
740    #[cfg(feature = "metal")]
741    #[test]
742    #[ignore = "requires model download (~570MB)"]
743    #[expect(clippy::too_many_lines, reason = "end-to-end backend diagnostic test")]
744    fn modernbert_loads_and_embeds() {
745        use crate::backend::arch::ModelArch;
746        use crate::backend::driver::Driver;
747
748        let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
749        assert!(backend.is_gpu(), "Metal backend should be GPU");
750
751        let enc = Encoding {
752            input_ids: vec![1, 100, 200, 300, 2],
753            attention_mask: vec![1; 5],
754            token_type_ids: vec![0; 5],
755        };
756
757        // Stage-by-stage diagnostic using the driver directly
758        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
759        let inputs = driver.prepare_batch(std::slice::from_ref(&enc), 8).unwrap();
760
761        // Check: can we read back input_ids?
762        let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
763        eprintln!("input_ids: {:?}", &ids_host[0][..5]);
764
765        // Check: embedding lookup
766        // Need the tok_embeddings weight — load weights directly
767        let api = hf_hub::api::sync::Api::new().unwrap();
768        let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
769        let weights_path = repo.get("model.safetensors").unwrap();
770        let config_path = repo.get("config.json").unwrap();
771        let config_str = std::fs::read_to_string(&config_path).unwrap();
772        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
773        let config =
774            crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
775        let (arch, _mmap) = driver
776            .load_modern_bert_weights(&weights_path, &config)
777            .unwrap();
778
779        let hidden = driver
780            .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
781            .unwrap();
782        let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
783        let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
784        eprintln!(
785            "embedding: {nz}/{} nonzero, first 5: {:?}",
786            h[0].len(),
787            &h[0][..5]
788        );
789
790        // Stage-by-stage forward pass bisection
791        let total = 8; // padded seq
792        let hd = 768;
793        let nh = 12;
794        let head_dim = 64;
795
796        // After embedding LN
797        let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
798        let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
799        driver
800            .layer_norm(
801                &mut ln_out,
802                &emb_clone,
803                &arch.weights.emb_norm_weight,
804                &arch.weights.zero_bias,
805                total,
806                hd,
807                1e-5,
808            )
809            .unwrap();
810        let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
811        let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
812        eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
813
814        // Layer 0 QKV GEMM
815        let layer0 = &arch.weights.layers[0];
816        let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
817        driver
818            .gemm(
819                &ln_out,
820                &layer0.qkv_weight,
821                &mut qkv,
822                total,
823                3 * hd,
824                hd,
825                true,
826            )
827            .unwrap();
828        let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
829        let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
830        eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
831
832        // QKV split
833        let mut q = driver.alloc_zeros(total * hd).unwrap();
834        let mut k = driver.alloc_zeros(total * hd).unwrap();
835        let mut v = driver.alloc_zeros(total * hd).unwrap();
836        driver
837            .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
838            .unwrap();
839        let q_h = driver.to_host(&q, 1, total * hd).unwrap();
840        let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
841        eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
842
843        // Attention scores
844        let mut scores = driver.alloc_zeros(nh * 8 * 8).unwrap();
845        driver
846            .gemm_batched(
847                &q,
848                &k,
849                &mut scores,
850                8,
851                8,
852                head_dim,
853                true,
854                8 * head_dim,
855                8 * head_dim,
856                8 * 8,
857                nh,
858            )
859            .unwrap();
860        let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
861        let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
862        eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
863
864        // Full forward pass
865        let enc2 = Encoding {
866            input_ids: vec![1, 100, 200, 300, 2],
867            attention_mask: vec![1; 5],
868            token_type_ids: vec![0; 5],
869        };
870
871        let quick = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
872        let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
873        let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
874        eprintln!(
875            "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
876            &quick[0][..3]
877        );
878
879        // MRL truncation
880        eprintln!("\n=== ModernBERT MRL Truncation ===");
881        let full = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
882        let full_emb = &full[0];
883        for dims in [64, 128, 256, 384, 512, 768] {
884            let t: Vec<f32> = full_emb[..dims].to_vec();
885            let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
886            let f_norm: f32 = full_emb[..dims]
887                .iter()
888                .map(|x| x * x)
889                .sum::<f32>()
890                .sqrt()
891                .max(1e-12);
892            let cos: f32 = t
893                .iter()
894                .zip(&full_emb[..dims])
895                .map(|(a, b)| a * b)
896                .sum::<f32>()
897                / (t_norm * f_norm);
898            eprintln!("  dims={dims:>3}: cosine={cos:.6}");
899        }
900
901        // Throughput benchmark
902        eprintln!("\n=== ModernBERT Throughput ===");
903        // Build 32 encodings of varying length
904        let mut encs = Vec::new();
905        for i in 0..32 {
906            let len = 16 + (i * 4); // 16 to 140 tokens
907            let mut ids = vec![1_i64]; // CLS
908            for j in 1..len - 1 {
909                ids.push(100 + i64::from(j));
910            }
911            ids.push(2); // SEP
912            encs.push(Encoding {
913                input_ids: ids.clone(),
914                attention_mask: vec![1; ids.len()],
915                token_type_ids: vec![0; ids.len()],
916            });
917        }
918
919        // Warmup
920        let _ = arch.forward(&driver, &encs[..4]);
921
922        // Timed run
923        let t0 = std::time::Instant::now();
924        let result = arch.forward(&driver, &encs).unwrap();
925        let elapsed = t0.elapsed();
926        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
927        eprintln!(
928            "  batch={}, time={:.1}ms, throughput={:.1}/s",
929            encs.len(),
930            elapsed.as_secs_f64() * 1000.0,
931            throughput
932        );
933        assert_eq!(result.len(), 32);
934
935        // Batch=1 timing (critical — CLI query path)
936        let single = vec![encs[0].clone()];
937        let t1 = std::time::Instant::now();
938        let _ = arch.forward(&driver, &single).unwrap();
939        let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
940        eprintln!("  batch=1, time={single_ms:.1}ms");
941    }
942
943    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on Metal via the driver/arch system.
944    ///
945    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector,
946    /// compares against the CPU backend for numerical equivalence, and measures
947    /// throughput (target: >=308/s matching monolithic).
948    #[cfg(feature = "metal")]
949    #[test]
950    #[ignore = "requires model download (~33MB)"]
951    fn classic_bert_driver_arch() {
952        use crate::backend::arch::ModelArch;
953
954        let model_repo = "BAAI/bge-small-en-v1.5";
955
956        // Load via driver/arch system
957        let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
958        assert!(backend.is_gpu(), "Metal backend should be GPU");
959
960        let enc = Encoding {
961            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
962            attention_mask: vec![1, 1, 1, 1, 1, 1],
963            token_type_ids: vec![0, 0, 0, 0, 0, 0],
964        };
965
966        // Basic forward pass
967        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
968        assert_eq!(result.len(), 1);
969        assert_eq!(result[0].len(), 384);
970
971        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
972        eprintln!(
973            "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
974            &result[0][..3]
975        );
976        assert!(
977            (l2 - 1.0).abs() < 0.01,
978            "embedding should be L2-normalized, got L2={l2}"
979        );
980
981        // Compare against CPU backend (reliable reference)
982        #[cfg(feature = "cpu")]
983        {
984            let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
985                .expect("CPU load failed");
986            let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
987            eprintln!("CPU  first 5: {:?}", &cpu_result[0][..5]);
988            eprintln!("NEW  first 5: {:?}", &result[0][..5]);
989            let cosine: f32 = result[0]
990                .iter()
991                .zip(&cpu_result[0])
992                .map(|(a, b)| a * b)
993                .sum();
994            eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
995            assert!(
996                cosine > 0.95,
997                "cosine similarity vs CPU should be >0.95, got {cosine}"
998            );
999        }
1000
1001        // Throughput benchmark
1002        eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
1003        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
1004        let config_path = {
1005            let api = hf_hub::api::sync::Api::new().unwrap();
1006            let repo = api.model(model_repo.to_string());
1007            repo.get("config.json").unwrap()
1008        };
1009        let weights_path = {
1010            let api = hf_hub::api::sync::Api::new().unwrap();
1011            let repo = api.model(model_repo.to_string());
1012            repo.get("model.safetensors").unwrap()
1013        };
1014        let config_str = std::fs::read_to_string(&config_path).unwrap();
1015        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1016        let config =
1017            crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1018        let (arch, _mmap) = driver
1019            .load_classic_bert_weights(&weights_path, &config)
1020            .unwrap();
1021
1022        // Build 32 encodings of varying length
1023        let mut encs = Vec::new();
1024        for i in 0..32 {
1025            let len = 16 + (i * 4); // 16 to 140 tokens
1026            let mut ids = vec![101_i64]; // [CLS]
1027            for j in 1..len - 1 {
1028                ids.push(100 + i64::from(j));
1029            }
1030            ids.push(102); // [SEP]
1031            encs.push(Encoding {
1032                input_ids: ids.clone(),
1033                attention_mask: vec![1; ids.len()],
1034                token_type_ids: vec![0; ids.len()],
1035            });
1036        }
1037
1038        // Warmup
1039        let _ = arch.forward(&driver, &encs[..4]);
1040
1041        // Timed run
1042        let t0 = std::time::Instant::now();
1043        let bench_result = arch.forward(&driver, &encs).unwrap();
1044        let elapsed = t0.elapsed();
1045        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1046        eprintln!(
1047            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1048            encs.len(),
1049            elapsed.as_secs_f64() * 1000.0,
1050            throughput
1051        );
1052        assert_eq!(bench_result.len(), 32);
1053    }
1054
1055    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on CPU via the driver/arch system.
1056    ///
1057    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector
1058    /// and compares against the monolithic CPU backend for numerical equivalence.
1059    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1060    #[test]
1061    #[ignore = "requires model download (~33MB)"]
1062    fn classic_bert_cpu_driver_arch() {
1063        let model_repo = "BAAI/bge-small-en-v1.5";
1064
1065        // Load via new driver/arch system
1066        let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1067        assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1068
1069        let enc = Encoding {
1070            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1071            attention_mask: vec![1, 1, 1, 1, 1, 1],
1072            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1073        };
1074
1075        // Basic forward pass
1076        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1077        assert_eq!(result.len(), 1);
1078        assert_eq!(result[0].len(), 384);
1079
1080        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1081        eprintln!(
1082            "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1083            &result[0][..5]
1084        );
1085        assert!(
1086            (l2 - 1.0).abs() < 0.01,
1087            "embedding should be L2-normalized, got L2={l2}"
1088        );
1089
1090        // Compare against monolithic CPU backend (reference)
1091        #[cfg(feature = "cpu")]
1092        {
1093            let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1094                .expect("monolithic CPU load failed");
1095            let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1096            eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1097            eprintln!("New  first 5: {:?}", &result[0][..5]);
1098            let cosine: f32 = result[0]
1099                .iter()
1100                .zip(&cpu_result[0])
1101                .map(|(a, b)| a * b)
1102                .sum();
1103            eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1104            assert!(
1105                cosine > 0.999,
1106                "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1107            );
1108        }
1109    }
1110}