Skip to main content

ripvec_core/backend/
mod.rs

1//! Embedding backend abstraction layer.
2//!
3//! Defines the [`EmbedBackend`] trait that all embedding backends (CPU, CUDA,
4//! Metal, MLX) implement, plus the [`Encoding`] input type and [`BackendKind`]
5//! discriminant. Use [`load_backend`] to construct a backend by kind.
6
7pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19#[cfg(feature = "cuda")]
20pub mod nvrtc_cubin;
21
22/// Pre-tokenized encoding ready for inference.
23///
24/// Token IDs, attention mask, and token type IDs must all have the same length.
25/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
26/// reaching the backend.
27#[derive(Debug, Clone)]
28pub struct Encoding {
29    /// Token IDs produced by the tokenizer.
30    pub input_ids: Vec<i64>,
31    /// Attention mask (1 for real tokens, 0 for padding).
32    pub attention_mask: Vec<i64>,
33    /// Token type IDs (0 for single-sequence models).
34    pub token_type_ids: Vec<i64>,
35}
36
37/// Trait for embedding backends.
38///
39/// Implementations must be [`Send`] so they can be moved across threads (e.g.
40/// into a ring-buffer pipeline). The trait is object-safe — callers use
41/// `&dyn EmbedBackend` or `Box<dyn EmbedBackend>`.
42///
43/// # GPU vs CPU scheduling
44///
45/// - **CPU backends** (`is_gpu() == false`): cloned per rayon thread via
46///   [`clone_backend`](EmbedBackend::clone_backend).
47/// - **GPU backends** (`is_gpu() == true`): use a ring-buffer pipeline with
48///   `RING_SIZE = 4` for bounded memory.
49pub trait EmbedBackend: Send + Sync {
50    /// Embed a batch of pre-tokenized inputs, returning L2-normalized vectors.
51    ///
52    /// Each inner `Vec<f32>` is the embedding for the corresponding
53    /// [`Encoding`]. Errors **must** propagate — never silently return
54    /// defaults.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if tensor construction or the forward pass fails.
59    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
60
61    /// Whether this backend supports cheap cloning for per-thread instances.
62    ///
63    /// CPU backends return `true`; GPU backends typically return `false`.
64    fn supports_clone(&self) -> bool;
65
66    /// Create a cheap clone of this backend for per-thread use in rayon.
67    ///
68    /// # Panics
69    ///
70    /// May panic if [`supports_clone`](EmbedBackend::supports_clone) returns
71    /// `false`. Callers must check `supports_clone()` first.
72    fn clone_backend(&self) -> Box<dyn EmbedBackend>;
73
74    /// Whether this backend runs on a GPU.
75    ///
76    /// GPU backends use a ring-buffer pipelined scheduler (`RING_SIZE = 4`)
77    /// for bounded memory usage.
78    fn is_gpu(&self) -> bool;
79
80    /// Maximum token count this model supports (position embedding limit).
81    ///
82    /// `ClassicBert`: 512. `ModernBERT`: up to model config. Tokens beyond this
83    /// are truncated during tokenization.
84    fn max_tokens(&self) -> usize {
85        512 // default for classic BERT models
86    }
87}
88
89/// Available embedding backend implementations.
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
91pub enum BackendKind {
92    /// CUDA (cudarc, NVIDIA GPUs via cuBLAS + custom kernels).
93    Cuda,
94    /// MLX (Apple Silicon, macOS only).
95    Mlx,
96    /// CPU (ndarray + system BLAS).
97    #[default]
98    Cpu,
99    /// Metal (Apple Silicon, macOS only, direct Metal GPU).
100    Metal,
101}
102
103impl std::fmt::Display for BackendKind {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            Self::Cuda => write!(f, "cuda"),
107            Self::Mlx => write!(f, "mlx"),
108            Self::Cpu => write!(f, "cpu"),
109            Self::Metal => write!(f, "metal"),
110        }
111    }
112}
113
114/// Device hint passed to [`load_backend`].
115///
116/// Backends map this to their native device type. Not all backends support
117/// all devices — unsupported combinations return an error.
118#[derive(Debug, Clone, Copy, Default)]
119pub enum DeviceHint {
120    /// Automatically select the best available device.
121    #[default]
122    Auto,
123    /// Force CPU inference.
124    Cpu,
125    /// Force GPU inference (Metal on macOS, CUDA on Linux/Windows).
126    Gpu,
127}
128
129/// Inference optimization parameters passed through to model loading.
130///
131/// Currently empty — reserved for future optimizations (quantization,
132/// distillation, etc.) that need to be configured at load time.
133#[derive(Debug, Clone, Default)]
134pub struct InferenceOpts {}
135
136/// Construct an embedding backend of the given kind.
137///
138/// Downloads model weights on first use via `hf-hub`. The `device_hint`
139/// is advisory — backends that don't support GPU fall back to CPU.
140///
141/// # Errors
142///
143/// Returns an error if the requested backend was not compiled in (missing
144/// feature flag) or if model loading fails.
145pub fn load_backend(
146    kind: BackendKind,
147    #[cfg_attr(
148        not(any(
149            feature = "cuda",
150            feature = "mlx",
151            feature = "cpu",
152            feature = "cpu-accelerate",
153            feature = "metal"
154        )),
155        expect(unused_variables, reason = "used when backend features are enabled")
156    )]
157    model_repo: &str,
158    #[cfg_attr(
159        not(any(
160            feature = "cuda",
161            feature = "mlx",
162            feature = "cpu",
163            feature = "cpu-accelerate",
164            feature = "metal"
165        )),
166        expect(unused_variables, reason = "used when backend features are enabled")
167    )]
168    device_hint: DeviceHint,
169) -> crate::Result<Box<dyn EmbedBackend>> {
170    match kind {
171        #[cfg(feature = "cuda")]
172        BackendKind::Cuda => {
173            if is_modernbert_model(model_repo) {
174                return load_modernbert_cuda(model_repo);
175            }
176            let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
177            Ok(Box::new(backend))
178        }
179        #[cfg(not(feature = "cuda"))]
180        BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
181            "cuda backend requires building with: cargo build --features cuda"
182        ))),
183        #[cfg(feature = "mlx")]
184        BackendKind::Mlx => {
185            let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
186            Ok(Box::new(backend))
187        }
188        #[cfg(not(feature = "mlx"))]
189        BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
190            "mlx backend requires building with: cargo build --features mlx"
191        ))),
192        #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
193        BackendKind::Cpu => {
194            if is_modernbert_model(model_repo) {
195                return load_modernbert_cpu(model_repo);
196            }
197            #[cfg(feature = "cpu")]
198            {
199                let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
200                #[expect(
201                    clippy::needless_return,
202                    reason = "return needed before cfg(not) fallback"
203                )]
204                return Ok(Box::new(backend));
205            }
206            #[cfg(not(feature = "cpu"))]
207            Err(crate::Error::Other(anyhow::anyhow!(
208                "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
209            )))
210        }
211        #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
212        BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
213            "cpu backend requires building with: cargo build --features cpu"
214        ))),
215        #[cfg(feature = "metal")]
216        BackendKind::Metal => {
217            // All models route through the driver/arch system.
218            if is_modernbert_model(model_repo) {
219                return load_modernbert_metal(model_repo);
220            }
221            load_classic_metal(model_repo)
222        }
223        #[cfg(not(feature = "metal"))]
224        BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
225            "metal backend requires building with: cargo build --features metal"
226        ))),
227    }
228}
229
230/// Detect all available backends and load them.
231///
232/// Probes for GPU backends (CUDA, MLX) first, then falls back to CPU.
233/// Returns backends in priority order — the first entry is the primary
234/// (used for query embedding in interactive mode).
235///
236/// # Errors
237///
238/// Returns an error if no backends can be loaded (not even CPU).
239pub fn detect_backends(
240    #[cfg_attr(
241        not(any(
242            feature = "cuda",
243            feature = "mlx",
244            feature = "cpu",
245            feature = "cpu-accelerate",
246            feature = "metal"
247        )),
248        expect(unused_variables, reason = "used when backend features are enabled")
249    )]
250    model_repo: &str,
251) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
252    #[cfg_attr(
253        not(any(
254            feature = "cuda",
255            feature = "mlx",
256            feature = "cpu",
257            feature = "cpu-accelerate",
258            feature = "metal"
259        )),
260        expect(unused_mut, reason = "mut needed when backend features are enabled")
261    )]
262    let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
263
264    // Try CUDA (NVIDIA GPU)
265    #[cfg(feature = "cuda")]
266    {
267        if is_modernbert_model(model_repo) {
268            if let Ok(b) = load_modernbert_cuda(model_repo) {
269                backends.push(b);
270            }
271        } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
272            backends.push(Box::new(b));
273        }
274    }
275
276    // Try Metal (Apple Silicon GPU, preferred over MLX)
277    #[cfg(feature = "metal")]
278    {
279        // Route models through the driver/arch system by architecture.
280        if is_modernbert_model(model_repo) {
281            if let Ok(b) = load_modernbert_metal(model_repo) {
282                backends.push(b);
283            }
284        } else if let Ok(b) = load_classic_metal(model_repo) {
285            backends.push(b);
286        }
287    }
288
289    // Try MLX (Apple Silicon GPU, fallback if Metal unavailable)
290    #[cfg(feature = "mlx")]
291    if backends.is_empty()
292        && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
293    {
294        backends.push(Box::new(b));
295    }
296
297    // Add CPU as fallback only when no GPU backend was loaded.
298    // On Apple Silicon, running CPU + MLX concurrently is slower than
299    // MLX alone because they share the same physical cores and memory.
300    // On discrete GPU systems (CUDA), CPU would be a useful helper.
301    #[cfg_attr(
302        not(any(feature = "cpu", feature = "cpu-accelerate")),
303        expect(unused_variables, reason = "used when cpu feature is enabled")
304    )]
305    let has_gpu = backends.iter().any(|b| b.is_gpu());
306    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
307    if !has_gpu {
308        if is_modernbert_model(model_repo) {
309            if let Ok(b) = load_modernbert_cpu(model_repo) {
310                backends.push(b);
311            }
312        } else {
313            #[cfg(feature = "cpu")]
314            if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
315                backends.push(Box::new(b));
316            }
317        }
318    }
319
320    if backends.is_empty() {
321        return Err(crate::Error::Other(anyhow::anyhow!(
322            "no embedding backends available"
323        )));
324    }
325
326    Ok(backends)
327}
328
329// ---------------------------------------------------------------------------
330// ModernBERT loader (driver/arch system)
331// ---------------------------------------------------------------------------
332
333/// Load a `ModernBERT` model on the Metal GPU backend.
334///
335/// Downloads the model from Hugging Face Hub (cached after first download),
336/// memory-maps the safetensors weights, and builds a [`GenericBackend`]
337/// pairing a [`MetalDriver`](driver::metal::MetalDriver) with a
338/// [`ModernBertArch`](arch::modern_bert::ModernBertArch).
339///
340/// # Errors
341///
342/// Returns an error if no Metal device is available, the model cannot be
343/// downloaded, or weight loading fails.
344#[cfg(feature = "metal")]
345pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
346    use driver::metal::{MetalDriver, ModernBertConfig};
347    use generic::GenericBackend;
348    use hf_hub::api::sync::Api;
349
350    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
351    let repo = api.model(model_repo.to_string());
352
353    let config_path = repo
354        .get("config.json")
355        .map_err(|e| crate::Error::Download(e.to_string()))?;
356    let weights_path = repo
357        .get("model.safetensors")
358        .map_err(|e| crate::Error::Download(e.to_string()))?;
359
360    // Parse config.json
361    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
362        path: config_path.display().to_string(),
363        source: e,
364    })?;
365    let config_json: serde_json::Value = serde_json::from_str(&config_str)
366        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
367    let config = ModernBertConfig::from_json(&config_json)?;
368    let max_tokens = config.max_position_embeddings;
369
370    let driver = MetalDriver::new()?;
371    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
372
373    tracing::info!(
374        model_repo,
375        hidden = config.hidden_size,
376        layers = config.num_hidden_layers,
377        heads = config.num_attention_heads,
378        intermediate = config.intermediate_size,
379        max_tokens,
380        "ModernBERT loaded on Metal (driver/arch)"
381    );
382
383    Ok(Box::new(GenericBackend::new(
384        driver, arch, max_tokens, true, mmap,
385    )))
386}
387
388/// Load `ModernBERT` on CPU via the driver/arch system.
389#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
390pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
391    use driver::cpu::{CpuDriver, ModernBertConfig};
392    use generic::GenericBackend;
393    use hf_hub::api::sync::Api;
394
395    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
396    let repo = api.model(model_repo.to_string());
397
398    let config_path = repo
399        .get("config.json")
400        .map_err(|e| crate::Error::Download(e.to_string()))?;
401    let weights_path = repo
402        .get("model.safetensors")
403        .map_err(|e| crate::Error::Download(e.to_string()))?;
404
405    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
406        path: config_path.display().to_string(),
407        source: e,
408    })?;
409    let config_json: serde_json::Value = serde_json::from_str(&config_str)
410        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
411    let config = ModernBertConfig::from_json(&config_json)?;
412    let max_tokens = config.max_position_embeddings;
413
414    let driver = CpuDriver::new()?;
415    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
416
417    tracing::info!(
418        model_repo,
419        hidden = config.hidden_size,
420        layers = config.num_hidden_layers,
421        heads = config.num_attention_heads,
422        max_tokens,
423        "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
424    );
425
426    Ok(Box::new(GenericBackend::new_shared(
427        driver, arch, max_tokens, false, mmap,
428    )))
429}
430
431/// Load `ModernBERT` on CUDA via the driver/arch system.
432///
433/// Creates a [`CudaDriver`](driver::cuda::CudaDriver), loads safetensors weights
434/// onto the GPU, pre-converts GEMM weights to FP16, builds RoPE caches, and
435/// wraps the result in a [`GenericBackend`](generic::GenericBackend).
436///
437/// # Errors
438///
439/// Returns an error if no CUDA device is available, the model cannot be
440/// downloaded, or weight loading fails.
441#[cfg(feature = "cuda")]
442pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
443    use driver::cuda::{CudaDriver, ModernBertConfig};
444    use generic::GenericBackend;
445    use hf_hub::api::sync::Api;
446
447    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
448    let repo = api.model(model_repo.to_string());
449
450    let config_path = repo
451        .get("config.json")
452        .map_err(|e| crate::Error::Download(e.to_string()))?;
453    let weights_path = repo
454        .get("model.safetensors")
455        .map_err(|e| crate::Error::Download(e.to_string()))?;
456
457    // Parse config.json
458    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
459        path: config_path.display().to_string(),
460        source: e,
461    })?;
462    let config_json: serde_json::Value = serde_json::from_str(&config_str)
463        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
464    let config = ModernBertConfig::from_json(&config_json)?;
465    let max_tokens = config.max_position_embeddings;
466
467    let driver = CudaDriver::new()?;
468    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
469
470    tracing::info!(
471        model_repo,
472        hidden = config.hidden_size,
473        layers = config.num_hidden_layers,
474        heads = config.num_attention_heads,
475        intermediate = config.intermediate_size,
476        max_tokens,
477        "ModernBERT loaded on CUDA (driver/arch)"
478    );
479
480    // CUDA benefits from larger batches to saturate GPU SMs (128 SMs on RTX 4090).
481    // Metal uses 32 (AMX coprocessor-limited). CUDA can handle 128+ easily.
482    Ok(Box::new(GenericBackend::with_max_batch(
483        driver,
484        arch,
485        max_tokens,
486        true,
487        generic::MmapHolder::Owned(mmap),
488        32,
489    )))
490}
491
492/// Check whether a model repo uses the `ModernBERT` architecture.
493///
494/// Downloads and inspects `config.json` to check for `"model_type": "modernbert"`.
495/// Returns `false` on any download or parse error (fail-open for detection).
496#[cfg(any(
497    feature = "cuda",
498    feature = "metal",
499    feature = "cpu",
500    feature = "cpu-accelerate"
501))]
502fn is_modernbert_model(model_repo: &str) -> bool {
503    let Ok(api) = hf_hub::api::sync::Api::new() else {
504        return false;
505    };
506    let repo = api.model(model_repo.to_string());
507    let Ok(config_path) = repo.get("config.json") else {
508        return false;
509    };
510    let Ok(config_str) = std::fs::read_to_string(&config_path) else {
511        return false;
512    };
513    let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
514        return false;
515    };
516    json.get("model_type")
517        .and_then(serde_json::Value::as_str)
518        .is_some_and(|t| t == "modernbert")
519}
520
521// ---------------------------------------------------------------------------
522// ClassicBert loader (driver/arch system)
523// ---------------------------------------------------------------------------
524
525/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the Metal GPU backend.
526///
527/// Downloads the model from Hugging Face Hub (cached after first download),
528/// memory-maps the safetensors weights, fuses Q/K/V into a single tensor per
529/// layer, and builds a [`GenericBackend`] pairing a
530/// [`MetalDriver`](driver::metal::MetalDriver) with a
531/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
532///
533/// # Errors
534///
535/// Returns an error if no Metal device is available, the model cannot be
536/// downloaded, or weight loading fails.
537#[cfg(feature = "metal")]
538pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
539    use driver::metal::{ClassicBertConfig, MetalDriver};
540    use generic::GenericBackend;
541    use hf_hub::api::sync::Api;
542
543    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
544    let repo = api.model(model_repo.to_string());
545
546    let config_path = repo
547        .get("config.json")
548        .map_err(|e| crate::Error::Download(e.to_string()))?;
549    let weights_path = repo
550        .get("model.safetensors")
551        .map_err(|e| crate::Error::Download(e.to_string()))?;
552
553    // Parse config.json
554    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
555        path: config_path.display().to_string(),
556        source: e,
557    })?;
558    let config_json: serde_json::Value = serde_json::from_str(&config_str)
559        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
560    let config = ClassicBertConfig::from_json(&config_json)?;
561    let max_tokens = config.max_position_embeddings;
562
563    let driver = MetalDriver::new()?;
564    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
565
566    tracing::info!(
567        model_repo,
568        hidden = config.hidden_size,
569        layers = config.num_hidden_layers,
570        heads = config.num_attention_heads,
571        intermediate = config.intermediate_size,
572        max_tokens,
573        "ClassicBert loaded on Metal (driver/arch)"
574    );
575
576    Ok(Box::new(GenericBackend::new(
577        driver, arch, max_tokens, true, mmap,
578    )))
579}
580
581// ---------------------------------------------------------------------------
582// ClassicBert loader (CPU driver/arch system)
583// ---------------------------------------------------------------------------
584
585/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the CPU backend
586/// via the driver/arch system.
587///
588/// Downloads the model from Hugging Face Hub (cached after first download),
589/// reads safetensors weights into `Vec<f32>` tensors, fuses Q/K/V per layer,
590/// and builds a [`GenericBackend`] pairing a
591/// [`CpuDriver`](driver::cpu::CpuDriver) with a
592/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
593///
594/// # Errors
595///
596/// Returns an error if the model cannot be downloaded or weight loading fails.
597#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
598pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
599    use driver::cpu::{ClassicBertConfig, CpuDriver};
600    use generic::GenericBackend;
601    use hf_hub::api::sync::Api;
602
603    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
604    let repo = api.model(model_repo.to_string());
605
606    let config_path = repo
607        .get("config.json")
608        .map_err(|e| crate::Error::Download(e.to_string()))?;
609    let weights_path = repo
610        .get("model.safetensors")
611        .map_err(|e| crate::Error::Download(e.to_string()))?;
612
613    // Parse config.json
614    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
615        path: config_path.display().to_string(),
616        source: e,
617    })?;
618    let config_json: serde_json::Value = serde_json::from_str(&config_str)
619        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
620    let config = ClassicBertConfig::from_json(&config_json)?;
621    let max_tokens = config.max_position_embeddings;
622
623    let driver = CpuDriver::new()?;
624    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
625
626    tracing::info!(
627        model_repo,
628        hidden = config.hidden_size,
629        layers = config.num_hidden_layers,
630        heads = config.num_attention_heads,
631        intermediate = config.intermediate_size,
632        max_tokens,
633        "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
634    );
635
636    Ok(Box::new(GenericBackend::new_shared(
637        driver, arch, max_tokens, false, mmap,
638    )))
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644
645    /// Verify that `EmbedBackend` is object-safe by constructing a trait object type.
646    #[test]
647    fn trait_is_object_safe() {
648        // If this compiles, the trait is object-safe.
649        fn _assert_object_safe(_: &dyn EmbedBackend) {}
650    }
651
652    /// Verify that `Box<dyn EmbedBackend>` is `Send`.
653    #[test]
654    fn trait_object_is_send() {
655        fn assert_send<T: Send>() {}
656        assert_send::<Box<dyn EmbedBackend>>();
657    }
658
659    /// Verify that `Box<dyn EmbedBackend>` is `Sync` (needed for `&dyn` across threads).
660    #[test]
661    fn trait_object_is_sync() {
662        fn assert_sync<T: Sync>() {}
663        assert_sync::<Box<dyn EmbedBackend>>();
664    }
665
666    /// Verify that `Arc<dyn EmbedBackend>` is `Send` (needed for ring-buffer pipeline).
667    #[test]
668    fn arc_trait_object_is_send() {
669        fn assert_send<T: Send>() {}
670        assert_send::<std::sync::Arc<dyn EmbedBackend>>();
671    }
672
673    #[test]
674    fn encoding_construction() {
675        let enc = Encoding {
676            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
677            attention_mask: vec![1, 1, 1, 1, 1, 1],
678            token_type_ids: vec![0, 0, 0, 0, 0, 0],
679        };
680        assert_eq!(enc.input_ids.len(), 6);
681        assert_eq!(enc.attention_mask.len(), 6);
682        assert_eq!(enc.token_type_ids.len(), 6);
683    }
684
685    #[test]
686    fn encoding_clone() {
687        let enc = Encoding {
688            input_ids: vec![101, 102],
689            attention_mask: vec![1, 1],
690            token_type_ids: vec![0, 0],
691        };
692        let cloned = enc.clone();
693        assert_eq!(enc.input_ids, cloned.input_ids);
694    }
695
696    #[test]
697    fn backend_kind_default_is_cpu() {
698        assert_eq!(BackendKind::default(), BackendKind::Cpu);
699    }
700
701    #[test]
702    fn backend_kind_display() {
703        assert_eq!(BackendKind::Cuda.to_string(), "cuda");
704        assert_eq!(BackendKind::Mlx.to_string(), "mlx");
705        assert_eq!(BackendKind::Cpu.to_string(), "cpu");
706    }
707
708    #[cfg(not(feature = "mlx"))]
709    #[test]
710    fn load_backend_mlx_not_compiled() {
711        let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
712        assert!(result.is_err());
713    }
714
715    #[cfg(feature = "cpu")]
716    #[test]
717    fn detect_backends_returns_at_least_one() {
718        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
719        assert!(!backends.is_empty());
720    }
721
722    #[cfg(all(feature = "cpu", not(feature = "mlx")))]
723    #[test]
724    fn detect_backends_returns_at_least_one_backend() {
725        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
726        assert!(!backends.is_empty(), "should detect at least one backend");
727    }
728
729    /// Load `ModernBERT` on Metal and embed a short token sequence.
730    ///
731    /// Verifies that the full pipeline (weight loading, forward pass, pooling,
732    /// L2 normalization) produces a 768-dim unit vector.
733    #[cfg(feature = "metal")]
734    #[test]
735    #[ignore = "requires model download (~570MB)"]
736    fn modernbert_loads_and_embeds() {
737        use crate::backend::driver::Driver;
738
739        let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
740        assert!(backend.is_gpu(), "Metal backend should be GPU");
741
742        let enc = Encoding {
743            input_ids: vec![1, 100, 200, 300, 2],
744            attention_mask: vec![1; 5],
745            token_type_ids: vec![0; 5],
746        };
747
748        // Stage-by-stage diagnostic using the driver directly
749        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
750        let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
751
752        // Check: can we read back input_ids?
753        let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
754        eprintln!("input_ids: {:?}", &ids_host[0][..5]);
755
756        // Check: embedding lookup
757        // Need the tok_embeddings weight — load weights directly
758        let api = hf_hub::api::sync::Api::new().unwrap();
759        let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
760        let weights_path = repo.get("model.safetensors").unwrap();
761        let config_path = repo.get("config.json").unwrap();
762        let config_str = std::fs::read_to_string(&config_path).unwrap();
763        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
764        let config =
765            crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
766        let (arch, _mmap) = driver
767            .load_modern_bert_weights(&weights_path, &config)
768            .unwrap();
769
770        let hidden = driver
771            .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
772            .unwrap();
773        let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
774        let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
775        eprintln!(
776            "embedding: {nz}/{} nonzero, first 5: {:?}",
777            h[0].len(),
778            &h[0][..5]
779        );
780
781        // Stage-by-stage forward pass bisection
782        let total = 8; // padded seq
783        let hd = 768;
784        let nh = 12;
785        let head_dim = 64;
786
787        // After embedding LN
788        let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
789        let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
790        driver
791            .layer_norm(
792                &mut ln_out,
793                &emb_clone,
794                &arch.weights.emb_norm_weight,
795                &arch.weights.zero_bias,
796                total,
797                hd,
798                1e-5,
799            )
800            .unwrap();
801        let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
802        let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
803        eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
804
805        // Layer 0 QKV GEMM
806        let layer0 = &arch.weights.layers[0];
807        let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
808        driver
809            .gemm(
810                &ln_out,
811                &layer0.qkv_weight,
812                &mut qkv,
813                total,
814                3 * hd,
815                hd,
816                true,
817            )
818            .unwrap();
819        let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
820        let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
821        eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
822
823        // QKV split
824        let mut q = driver.alloc_zeros(total * hd).unwrap();
825        let mut k = driver.alloc_zeros(total * hd).unwrap();
826        let mut v = driver.alloc_zeros(total * hd).unwrap();
827        driver
828            .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
829            .unwrap();
830        let q_h = driver.to_host(&q, 1, total * hd).unwrap();
831        let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
832        eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
833
834        // Attention scores
835        let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
836        driver
837            .gemm_batched(
838                &q,
839                &k,
840                &mut scores,
841                8,
842                8,
843                head_dim,
844                true,
845                8 * head_dim,
846                8 * head_dim,
847                8 * 8,
848                nh,
849            )
850            .unwrap();
851        let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
852        let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
853        eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
854
855        // Full forward pass
856        use crate::backend::arch::ModelArch;
857        let enc2 = Encoding {
858            input_ids: vec![1, 100, 200, 300, 2],
859            attention_mask: vec![1; 5],
860            token_type_ids: vec![0; 5],
861        };
862
863        let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
864        let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
865        let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
866        eprintln!(
867            "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
868            &quick[0][..3]
869        );
870
871        // MRL truncation
872        eprintln!("\n=== ModernBERT MRL Truncation ===");
873        let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
874        let full_emb = &full[0];
875        for dims in [64, 128, 256, 384, 512, 768] {
876            let t: Vec<f32> = full_emb[..dims].to_vec();
877            let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
878            let f_norm: f32 = full_emb[..dims]
879                .iter()
880                .map(|x| x * x)
881                .sum::<f32>()
882                .sqrt()
883                .max(1e-12);
884            let cos: f32 = t
885                .iter()
886                .zip(&full_emb[..dims])
887                .map(|(a, b)| a * b)
888                .sum::<f32>()
889                / (t_norm * f_norm);
890            eprintln!("  dims={dims:>3}: cosine={cos:.6}");
891        }
892
893        // Throughput benchmark
894        eprintln!("\n=== ModernBERT Throughput ===");
895        // Build 32 encodings of varying length
896        let mut encs = Vec::new();
897        for i in 0..32 {
898            let len = 16 + (i * 4); // 16 to 140 tokens
899            let mut ids = vec![1_i64]; // CLS
900            for j in 1..len - 1 {
901                ids.push(100 + j as i64);
902            }
903            ids.push(2); // SEP
904            encs.push(Encoding {
905                input_ids: ids.clone(),
906                attention_mask: vec![1; ids.len()],
907                token_type_ids: vec![0; ids.len()],
908            });
909        }
910
911        // Warmup
912        let _ = arch.forward(&driver, &encs[..4]);
913
914        // Timed run
915        let t0 = std::time::Instant::now();
916        let result = arch.forward(&driver, &encs).unwrap();
917        let elapsed = t0.elapsed();
918        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
919        eprintln!(
920            "  batch={}, time={:.1}ms, throughput={:.1}/s",
921            encs.len(),
922            elapsed.as_secs_f64() * 1000.0,
923            throughput
924        );
925        assert_eq!(result.len(), 32);
926
927        // Batch=1 timing (critical — CLI query path)
928        let single = vec![encs[0].clone()];
929        let t1 = std::time::Instant::now();
930        let _ = arch.forward(&driver, &single).unwrap();
931        let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
932        eprintln!("  batch=1, time={single_ms:.1}ms");
933    }
934
935    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on Metal via the driver/arch system.
936    ///
937    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector,
938    /// compares against the CPU backend for numerical equivalence, and measures
939    /// throughput (target: >=308/s matching monolithic).
940    #[cfg(feature = "metal")]
941    #[test]
942    #[ignore = "requires model download (~33MB)"]
943    fn classic_bert_driver_arch() {
944        use crate::backend::arch::ModelArch;
945
946        let model_repo = "BAAI/bge-small-en-v1.5";
947
948        // Load via driver/arch system
949        let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
950        assert!(backend.is_gpu(), "Metal backend should be GPU");
951
952        let enc = Encoding {
953            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
954            attention_mask: vec![1, 1, 1, 1, 1, 1],
955            token_type_ids: vec![0, 0, 0, 0, 0, 0],
956        };
957
958        // Basic forward pass
959        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
960        assert_eq!(result.len(), 1);
961        assert_eq!(result[0].len(), 384);
962
963        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
964        eprintln!(
965            "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
966            &result[0][..3]
967        );
968        assert!(
969            (l2 - 1.0).abs() < 0.01,
970            "embedding should be L2-normalized, got L2={l2}"
971        );
972
973        // Compare against CPU backend (reliable reference)
974        #[cfg(feature = "cpu")]
975        {
976            let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
977                .expect("CPU load failed");
978            let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
979            eprintln!("CPU  first 5: {:?}", &cpu_result[0][..5]);
980            eprintln!("NEW  first 5: {:?}", &result[0][..5]);
981            let cosine: f32 = result[0]
982                .iter()
983                .zip(&cpu_result[0])
984                .map(|(a, b)| a * b)
985                .sum();
986            eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
987            assert!(
988                cosine > 0.95,
989                "cosine similarity vs CPU should be >0.95, got {cosine}"
990            );
991        }
992
993        // Throughput benchmark
994        eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
995        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
996        let config_path = {
997            let api = hf_hub::api::sync::Api::new().unwrap();
998            let repo = api.model(model_repo.to_string());
999            repo.get("config.json").unwrap()
1000        };
1001        let weights_path = {
1002            let api = hf_hub::api::sync::Api::new().unwrap();
1003            let repo = api.model(model_repo.to_string());
1004            repo.get("model.safetensors").unwrap()
1005        };
1006        let config_str = std::fs::read_to_string(&config_path).unwrap();
1007        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1008        let config =
1009            crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1010        let (arch, _mmap) = driver
1011            .load_classic_bert_weights(&weights_path, &config)
1012            .unwrap();
1013
1014        // Build 32 encodings of varying length
1015        let mut encs = Vec::new();
1016        for i in 0..32 {
1017            let len = 16 + (i * 4); // 16 to 140 tokens
1018            let mut ids = vec![101_i64]; // [CLS]
1019            for j in 1..len - 1 {
1020                ids.push(100 + j as i64);
1021            }
1022            ids.push(102); // [SEP]
1023            encs.push(Encoding {
1024                input_ids: ids.clone(),
1025                attention_mask: vec![1; ids.len()],
1026                token_type_ids: vec![0; ids.len()],
1027            });
1028        }
1029
1030        // Warmup
1031        let _ = arch.forward(&driver, &encs[..4]);
1032
1033        // Timed run
1034        let t0 = std::time::Instant::now();
1035        let bench_result = arch.forward(&driver, &encs).unwrap();
1036        let elapsed = t0.elapsed();
1037        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1038        eprintln!(
1039            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1040            encs.len(),
1041            elapsed.as_secs_f64() * 1000.0,
1042            throughput
1043        );
1044        assert_eq!(bench_result.len(), 32);
1045    }
1046
1047    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on CPU via the driver/arch system.
1048    ///
1049    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector
1050    /// and compares against the monolithic CPU backend for numerical equivalence.
1051    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1052    #[test]
1053    #[ignore = "requires model download (~33MB)"]
1054    fn classic_bert_cpu_driver_arch() {
1055        let model_repo = "BAAI/bge-small-en-v1.5";
1056
1057        // Load via new driver/arch system
1058        let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1059        assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1060
1061        let enc = Encoding {
1062            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1063            attention_mask: vec![1, 1, 1, 1, 1, 1],
1064            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1065        };
1066
1067        // Basic forward pass
1068        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1069        assert_eq!(result.len(), 1);
1070        assert_eq!(result[0].len(), 384);
1071
1072        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1073        eprintln!(
1074            "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1075            &result[0][..5]
1076        );
1077        assert!(
1078            (l2 - 1.0).abs() < 0.01,
1079            "embedding should be L2-normalized, got L2={l2}"
1080        );
1081
1082        // Compare against monolithic CPU backend (reference)
1083        #[cfg(feature = "cpu")]
1084        {
1085            let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1086                .expect("monolithic CPU load failed");
1087            let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1088            eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1089            eprintln!("New  first 5: {:?}", &result[0][..5]);
1090            let cosine: f32 = result[0]
1091                .iter()
1092                .zip(&cpu_result[0])
1093                .map(|(a, b)| a * b)
1094                .sum();
1095            eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1096            assert!(
1097                cosine > 0.999,
1098                "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1099            );
1100        }
1101    }
1102}