Skip to main content

ripvec_core/backend/
mod.rs

1//! Embedding backend abstraction layer.
2//!
3//! Defines the [`EmbedBackend`] trait that all embedding backends (CPU, CUDA,
4//! Metal, MLX) implement, plus the [`Encoding`] input type and [`BackendKind`]
5//! discriminant. Use [`load_backend`] to construct a backend by kind.
6
7pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19
20/// Pre-tokenized encoding ready for inference.
21///
22/// Token IDs, attention mask, and token type IDs must all have the same length.
23/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
24/// reaching the backend.
25#[derive(Debug, Clone)]
26pub struct Encoding {
27    /// Token IDs produced by the tokenizer.
28    pub input_ids: Vec<i64>,
29    /// Attention mask (1 for real tokens, 0 for padding).
30    pub attention_mask: Vec<i64>,
31    /// Token type IDs (0 for single-sequence models).
32    pub token_type_ids: Vec<i64>,
33}
34
35/// Trait for embedding backends.
36///
37/// Implementations must be [`Send`] so they can be moved across threads (e.g.
38/// into a ring-buffer pipeline). The trait is object-safe — callers use
39/// `&dyn EmbedBackend` or `Box<dyn EmbedBackend>`.
40///
41/// # GPU vs CPU scheduling
42///
43/// - **CPU backends** (`is_gpu() == false`): cloned per rayon thread via
44///   [`clone_backend`](EmbedBackend::clone_backend).
45/// - **GPU backends** (`is_gpu() == true`): use a ring-buffer pipeline with
46///   `RING_SIZE = 4` for bounded memory.
47pub trait EmbedBackend: Send + Sync {
48    /// Embed a batch of pre-tokenized inputs, returning L2-normalized vectors.
49    ///
50    /// Each inner `Vec<f32>` is the embedding for the corresponding
51    /// [`Encoding`]. Errors **must** propagate — never silently return
52    /// defaults.
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if tensor construction or the forward pass fails.
57    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
58
59    /// Whether this backend supports cheap cloning for per-thread instances.
60    ///
61    /// CPU backends return `true`; GPU backends typically return `false`.
62    fn supports_clone(&self) -> bool;
63
64    /// Create a cheap clone of this backend for per-thread use in rayon.
65    ///
66    /// # Panics
67    ///
68    /// May panic if [`supports_clone`](EmbedBackend::supports_clone) returns
69    /// `false`. Callers must check `supports_clone()` first.
70    fn clone_backend(&self) -> Box<dyn EmbedBackend>;
71
72    /// Whether this backend runs on a GPU.
73    ///
74    /// GPU backends use a ring-buffer pipelined scheduler (`RING_SIZE = 4`)
75    /// for bounded memory usage.
76    fn is_gpu(&self) -> bool;
77
78    /// Maximum token count this model supports (position embedding limit).
79    ///
80    /// `ClassicBert`: 512. `ModernBERT`: up to model config. Tokens beyond this
81    /// are truncated during tokenization.
82    fn max_tokens(&self) -> usize {
83        512 // default for classic BERT models
84    }
85}
86
87/// Available embedding backend implementations.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum BackendKind {
90    /// CUDA (cudarc, NVIDIA GPUs via cuBLAS + custom kernels).
91    Cuda,
92    /// MLX (Apple Silicon, macOS only).
93    Mlx,
94    /// CPU (ndarray + system BLAS).
95    #[default]
96    Cpu,
97    /// Metal (Apple Silicon, macOS only, direct Metal GPU).
98    Metal,
99}
100
101impl std::fmt::Display for BackendKind {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            Self::Cuda => write!(f, "cuda"),
105            Self::Mlx => write!(f, "mlx"),
106            Self::Cpu => write!(f, "cpu"),
107            Self::Metal => write!(f, "metal"),
108        }
109    }
110}
111
112/// Device hint passed to [`load_backend`].
113///
114/// Backends map this to their native device type. Not all backends support
115/// all devices — unsupported combinations return an error.
116#[derive(Debug, Clone, Copy, Default)]
117pub enum DeviceHint {
118    /// Automatically select the best available device.
119    #[default]
120    Auto,
121    /// Force CPU inference.
122    Cpu,
123    /// Force GPU inference (Metal on macOS, CUDA on Linux/Windows).
124    Gpu,
125}
126
127/// Inference optimization parameters passed through to model loading.
128///
129/// Currently empty — reserved for future optimizations (quantization,
130/// distillation, etc.) that need to be configured at load time.
131#[derive(Debug, Clone, Default)]
132pub struct InferenceOpts {}
133
134/// Construct an embedding backend of the given kind.
135///
136/// Downloads model weights on first use via `hf-hub`. The `device_hint`
137/// is advisory — backends that don't support GPU fall back to CPU.
138///
139/// # Errors
140///
141/// Returns an error if the requested backend was not compiled in (missing
142/// feature flag) or if model loading fails.
143pub fn load_backend(
144    kind: BackendKind,
145    #[cfg_attr(
146        not(any(
147            feature = "cuda",
148            feature = "mlx",
149            feature = "cpu",
150            feature = "cpu-accelerate",
151            feature = "metal"
152        )),
153        expect(unused_variables, reason = "used when backend features are enabled")
154    )]
155    model_repo: &str,
156    #[cfg_attr(
157        not(any(
158            feature = "cuda",
159            feature = "mlx",
160            feature = "cpu",
161            feature = "cpu-accelerate",
162            feature = "metal"
163        )),
164        expect(unused_variables, reason = "used when backend features are enabled")
165    )]
166    device_hint: DeviceHint,
167) -> crate::Result<Box<dyn EmbedBackend>> {
168    match kind {
169        #[cfg(feature = "cuda")]
170        BackendKind::Cuda => {
171            if is_modernbert_model(model_repo) {
172                return load_modernbert_cuda(model_repo);
173            }
174            let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
175            Ok(Box::new(backend))
176        }
177        #[cfg(not(feature = "cuda"))]
178        BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
179            "cuda backend requires building with: cargo build --features cuda"
180        ))),
181        #[cfg(feature = "mlx")]
182        BackendKind::Mlx => {
183            let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
184            Ok(Box::new(backend))
185        }
186        #[cfg(not(feature = "mlx"))]
187        BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
188            "mlx backend requires building with: cargo build --features mlx"
189        ))),
190        #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
191        BackendKind::Cpu => {
192            if is_modernbert_model(model_repo) {
193                return load_modernbert_cpu(model_repo);
194            }
195            #[cfg(feature = "cpu")]
196            {
197                let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
198                #[expect(
199                    clippy::needless_return,
200                    reason = "return needed before cfg(not) fallback"
201                )]
202                return Ok(Box::new(backend));
203            }
204            #[cfg(not(feature = "cpu"))]
205            Err(crate::Error::Other(anyhow::anyhow!(
206                "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
207            )))
208        }
209        #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
210        BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
211            "cpu backend requires building with: cargo build --features cpu"
212        ))),
213        #[cfg(feature = "metal")]
214        BackendKind::Metal => {
215            // All models route through the driver/arch system.
216            if is_modernbert_model(model_repo) {
217                return load_modernbert_metal(model_repo);
218            }
219            load_classic_metal(model_repo)
220        }
221        #[cfg(not(feature = "metal"))]
222        BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
223            "metal backend requires building with: cargo build --features metal"
224        ))),
225    }
226}
227
228/// Detect all available backends and load them.
229///
230/// Probes for GPU backends (CUDA, MLX) first, then falls back to CPU.
231/// Returns backends in priority order — the first entry is the primary
232/// (used for query embedding in interactive mode).
233///
234/// # Errors
235///
236/// Returns an error if no backends can be loaded (not even CPU).
237pub fn detect_backends(
238    #[cfg_attr(
239        not(any(
240            feature = "cuda",
241            feature = "mlx",
242            feature = "cpu",
243            feature = "cpu-accelerate",
244            feature = "metal"
245        )),
246        expect(unused_variables, reason = "used when backend features are enabled")
247    )]
248    model_repo: &str,
249) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
250    #[cfg_attr(
251        not(any(
252            feature = "cuda",
253            feature = "mlx",
254            feature = "cpu",
255            feature = "cpu-accelerate",
256            feature = "metal"
257        )),
258        expect(unused_mut, reason = "mut needed when backend features are enabled")
259    )]
260    let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
261
262    // Try CUDA (NVIDIA GPU)
263    #[cfg(feature = "cuda")]
264    {
265        if is_modernbert_model(model_repo) {
266            if let Ok(b) = load_modernbert_cuda(model_repo) {
267                backends.push(b);
268            }
269        } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
270            backends.push(Box::new(b));
271        }
272    }
273
274    // Try Metal (Apple Silicon GPU, preferred over MLX)
275    #[cfg(feature = "metal")]
276    {
277        // Route models through the driver/arch system by architecture.
278        if is_modernbert_model(model_repo) {
279            if let Ok(b) = load_modernbert_metal(model_repo) {
280                backends.push(b);
281            }
282        } else if let Ok(b) = load_classic_metal(model_repo) {
283            backends.push(b);
284        }
285    }
286
287    // Try MLX (Apple Silicon GPU, fallback if Metal unavailable)
288    #[cfg(feature = "mlx")]
289    if backends.is_empty()
290        && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
291    {
292        backends.push(Box::new(b));
293    }
294
295    // Add CPU as fallback only when no GPU backend was loaded.
296    // On Apple Silicon, running CPU + MLX concurrently is slower than
297    // MLX alone because they share the same physical cores and memory.
298    // On discrete GPU systems (CUDA), CPU would be a useful helper.
299    #[cfg_attr(
300        not(any(feature = "cpu", feature = "cpu-accelerate")),
301        expect(unused_variables, reason = "used when cpu feature is enabled")
302    )]
303    let has_gpu = backends.iter().any(|b| b.is_gpu());
304    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
305    if !has_gpu {
306        if is_modernbert_model(model_repo) {
307            if let Ok(b) = load_modernbert_cpu(model_repo) {
308                backends.push(b);
309            }
310        } else {
311            #[cfg(feature = "cpu")]
312            if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
313                backends.push(Box::new(b));
314            }
315        }
316    }
317
318    if backends.is_empty() {
319        return Err(crate::Error::Other(anyhow::anyhow!(
320            "no embedding backends available"
321        )));
322    }
323
324    Ok(backends)
325}
326
327// ---------------------------------------------------------------------------
328// ModernBERT loader (driver/arch system)
329// ---------------------------------------------------------------------------
330
331/// Load a `ModernBERT` model on the Metal GPU backend.
332///
333/// Downloads the model from Hugging Face Hub (cached after first download),
334/// memory-maps the safetensors weights, and builds a [`GenericBackend`]
335/// pairing a [`MetalDriver`](driver::metal::MetalDriver) with a
336/// [`ModernBertArch`](arch::modern_bert::ModernBertArch).
337///
338/// # Errors
339///
340/// Returns an error if no Metal device is available, the model cannot be
341/// downloaded, or weight loading fails.
342#[cfg(feature = "metal")]
343pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
344    use driver::metal::{MetalDriver, ModernBertConfig};
345    use generic::GenericBackend;
346    use hf_hub::api::sync::Api;
347
348    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
349    let repo = api.model(model_repo.to_string());
350
351    let config_path = repo
352        .get("config.json")
353        .map_err(|e| crate::Error::Download(e.to_string()))?;
354    let weights_path = repo
355        .get("model.safetensors")
356        .map_err(|e| crate::Error::Download(e.to_string()))?;
357
358    // Parse config.json
359    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
360        path: config_path.display().to_string(),
361        source: e,
362    })?;
363    let config_json: serde_json::Value = serde_json::from_str(&config_str)
364        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
365    let config = ModernBertConfig::from_json(&config_json)?;
366    let max_tokens = config.max_position_embeddings;
367
368    let driver = MetalDriver::new()?;
369    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
370
371    tracing::info!(
372        model_repo,
373        hidden = config.hidden_size,
374        layers = config.num_hidden_layers,
375        heads = config.num_attention_heads,
376        intermediate = config.intermediate_size,
377        max_tokens,
378        "ModernBERT loaded on Metal (driver/arch)"
379    );
380
381    Ok(Box::new(GenericBackend::new(
382        driver, arch, max_tokens, true, mmap,
383    )))
384}
385
386/// Load `ModernBERT` on CPU via the driver/arch system.
387#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
388pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
389    use driver::cpu::{CpuDriver, ModernBertConfig};
390    use generic::GenericBackend;
391    use hf_hub::api::sync::Api;
392
393    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
394    let repo = api.model(model_repo.to_string());
395
396    let config_path = repo
397        .get("config.json")
398        .map_err(|e| crate::Error::Download(e.to_string()))?;
399    let weights_path = repo
400        .get("model.safetensors")
401        .map_err(|e| crate::Error::Download(e.to_string()))?;
402
403    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
404        path: config_path.display().to_string(),
405        source: e,
406    })?;
407    let config_json: serde_json::Value = serde_json::from_str(&config_str)
408        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
409    let config = ModernBertConfig::from_json(&config_json)?;
410    let max_tokens = config.max_position_embeddings;
411
412    let driver = CpuDriver::new()?;
413    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
414
415    tracing::info!(
416        model_repo,
417        hidden = config.hidden_size,
418        layers = config.num_hidden_layers,
419        heads = config.num_attention_heads,
420        max_tokens,
421        "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
422    );
423
424    Ok(Box::new(GenericBackend::new_shared(
425        driver, arch, max_tokens, false, mmap,
426    )))
427}
428
429/// Load `ModernBERT` on CUDA via the driver/arch system.
430///
431/// Creates a [`CudaDriver`](driver::cuda::CudaDriver), loads safetensors weights
432/// onto the GPU, pre-converts GEMM weights to FP16, builds RoPE caches, and
433/// wraps the result in a [`GenericBackend`](generic::GenericBackend).
434///
435/// # Errors
436///
437/// Returns an error if no CUDA device is available, the model cannot be
438/// downloaded, or weight loading fails.
439#[cfg(feature = "cuda")]
440pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
441    use driver::cuda::{CudaDriver, ModernBertConfig};
442    use generic::GenericBackend;
443    use hf_hub::api::sync::Api;
444
445    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
446    let repo = api.model(model_repo.to_string());
447
448    let config_path = repo
449        .get("config.json")
450        .map_err(|e| crate::Error::Download(e.to_string()))?;
451    let weights_path = repo
452        .get("model.safetensors")
453        .map_err(|e| crate::Error::Download(e.to_string()))?;
454
455    // Parse config.json
456    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
457        path: config_path.display().to_string(),
458        source: e,
459    })?;
460    let config_json: serde_json::Value = serde_json::from_str(&config_str)
461        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
462    let config = ModernBertConfig::from_json(&config_json)?;
463    let max_tokens = config.max_position_embeddings;
464
465    let driver = CudaDriver::new()?;
466    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
467
468    tracing::info!(
469        model_repo,
470        hidden = config.hidden_size,
471        layers = config.num_hidden_layers,
472        heads = config.num_attention_heads,
473        intermediate = config.intermediate_size,
474        max_tokens,
475        "ModernBERT loaded on CUDA (driver/arch)"
476    );
477
478    // CUDA benefits from larger batches to saturate GPU SMs (128 SMs on RTX 4090).
479    // Metal uses 32 (AMX coprocessor-limited). CUDA can handle 128+ easily.
480    Ok(Box::new(GenericBackend::with_max_batch(
481        driver,
482        arch,
483        max_tokens,
484        true,
485        generic::MmapHolder::Owned(mmap),
486        32,
487    )))
488}
489
490/// Check whether a model repo uses the `ModernBERT` architecture.
491///
492/// Downloads and inspects `config.json` to check for `"model_type": "modernbert"`.
493/// Returns `false` on any download or parse error (fail-open for detection).
494#[cfg(any(
495    feature = "cuda",
496    feature = "metal",
497    feature = "cpu",
498    feature = "cpu-accelerate"
499))]
500fn is_modernbert_model(model_repo: &str) -> bool {
501    let Ok(api) = hf_hub::api::sync::Api::new() else {
502        return false;
503    };
504    let repo = api.model(model_repo.to_string());
505    let Ok(config_path) = repo.get("config.json") else {
506        return false;
507    };
508    let Ok(config_str) = std::fs::read_to_string(&config_path) else {
509        return false;
510    };
511    let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
512        return false;
513    };
514    json.get("model_type")
515        .and_then(serde_json::Value::as_str)
516        .is_some_and(|t| t == "modernbert")
517}
518
519// ---------------------------------------------------------------------------
520// ClassicBert loader (driver/arch system)
521// ---------------------------------------------------------------------------
522
523/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the Metal GPU backend.
524///
525/// Downloads the model from Hugging Face Hub (cached after first download),
526/// memory-maps the safetensors weights, fuses Q/K/V into a single tensor per
527/// layer, and builds a [`GenericBackend`] pairing a
528/// [`MetalDriver`](driver::metal::MetalDriver) with a
529/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
530///
531/// # Errors
532///
533/// Returns an error if no Metal device is available, the model cannot be
534/// downloaded, or weight loading fails.
535#[cfg(feature = "metal")]
536pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
537    use driver::metal::{ClassicBertConfig, MetalDriver};
538    use generic::GenericBackend;
539    use hf_hub::api::sync::Api;
540
541    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
542    let repo = api.model(model_repo.to_string());
543
544    let config_path = repo
545        .get("config.json")
546        .map_err(|e| crate::Error::Download(e.to_string()))?;
547    let weights_path = repo
548        .get("model.safetensors")
549        .map_err(|e| crate::Error::Download(e.to_string()))?;
550
551    // Parse config.json
552    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
553        path: config_path.display().to_string(),
554        source: e,
555    })?;
556    let config_json: serde_json::Value = serde_json::from_str(&config_str)
557        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
558    let config = ClassicBertConfig::from_json(&config_json)?;
559    let max_tokens = config.max_position_embeddings;
560
561    let driver = MetalDriver::new()?;
562    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
563
564    tracing::info!(
565        model_repo,
566        hidden = config.hidden_size,
567        layers = config.num_hidden_layers,
568        heads = config.num_attention_heads,
569        intermediate = config.intermediate_size,
570        max_tokens,
571        "ClassicBert loaded on Metal (driver/arch)"
572    );
573
574    Ok(Box::new(GenericBackend::new(
575        driver, arch, max_tokens, true, mmap,
576    )))
577}
578
579// ---------------------------------------------------------------------------
580// ClassicBert loader (CPU driver/arch system)
581// ---------------------------------------------------------------------------
582
583/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the CPU backend
584/// via the driver/arch system.
585///
586/// Downloads the model from Hugging Face Hub (cached after first download),
587/// reads safetensors weights into `Vec<f32>` tensors, fuses Q/K/V per layer,
588/// and builds a [`GenericBackend`] pairing a
589/// [`CpuDriver`](driver::cpu::CpuDriver) with a
590/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
591///
592/// # Errors
593///
594/// Returns an error if the model cannot be downloaded or weight loading fails.
595#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
596pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
597    use driver::cpu::{ClassicBertConfig, CpuDriver};
598    use generic::GenericBackend;
599    use hf_hub::api::sync::Api;
600
601    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
602    let repo = api.model(model_repo.to_string());
603
604    let config_path = repo
605        .get("config.json")
606        .map_err(|e| crate::Error::Download(e.to_string()))?;
607    let weights_path = repo
608        .get("model.safetensors")
609        .map_err(|e| crate::Error::Download(e.to_string()))?;
610
611    // Parse config.json
612    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
613        path: config_path.display().to_string(),
614        source: e,
615    })?;
616    let config_json: serde_json::Value = serde_json::from_str(&config_str)
617        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
618    let config = ClassicBertConfig::from_json(&config_json)?;
619    let max_tokens = config.max_position_embeddings;
620
621    let driver = CpuDriver::new()?;
622    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
623
624    tracing::info!(
625        model_repo,
626        hidden = config.hidden_size,
627        layers = config.num_hidden_layers,
628        heads = config.num_attention_heads,
629        intermediate = config.intermediate_size,
630        max_tokens,
631        "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
632    );
633
634    Ok(Box::new(GenericBackend::new_shared(
635        driver, arch, max_tokens, false, mmap,
636    )))
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642
643    /// Verify that `EmbedBackend` is object-safe by constructing a trait object type.
644    #[test]
645    fn trait_is_object_safe() {
646        // If this compiles, the trait is object-safe.
647        fn _assert_object_safe(_: &dyn EmbedBackend) {}
648    }
649
650    /// Verify that `Box<dyn EmbedBackend>` is `Send`.
651    #[test]
652    fn trait_object_is_send() {
653        fn assert_send<T: Send>() {}
654        assert_send::<Box<dyn EmbedBackend>>();
655    }
656
657    /// Verify that `Box<dyn EmbedBackend>` is `Sync` (needed for `&dyn` across threads).
658    #[test]
659    fn trait_object_is_sync() {
660        fn assert_sync<T: Sync>() {}
661        assert_sync::<Box<dyn EmbedBackend>>();
662    }
663
664    /// Verify that `Arc<dyn EmbedBackend>` is `Send` (needed for ring-buffer pipeline).
665    #[test]
666    fn arc_trait_object_is_send() {
667        fn assert_send<T: Send>() {}
668        assert_send::<std::sync::Arc<dyn EmbedBackend>>();
669    }
670
671    #[test]
672    fn encoding_construction() {
673        let enc = Encoding {
674            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
675            attention_mask: vec![1, 1, 1, 1, 1, 1],
676            token_type_ids: vec![0, 0, 0, 0, 0, 0],
677        };
678        assert_eq!(enc.input_ids.len(), 6);
679        assert_eq!(enc.attention_mask.len(), 6);
680        assert_eq!(enc.token_type_ids.len(), 6);
681    }
682
683    #[test]
684    fn encoding_clone() {
685        let enc = Encoding {
686            input_ids: vec![101, 102],
687            attention_mask: vec![1, 1],
688            token_type_ids: vec![0, 0],
689        };
690        let cloned = enc.clone();
691        assert_eq!(enc.input_ids, cloned.input_ids);
692    }
693
694    #[test]
695    fn backend_kind_default_is_cpu() {
696        assert_eq!(BackendKind::default(), BackendKind::Cpu);
697    }
698
699    #[test]
700    fn backend_kind_display() {
701        assert_eq!(BackendKind::Cuda.to_string(), "cuda");
702        assert_eq!(BackendKind::Mlx.to_string(), "mlx");
703        assert_eq!(BackendKind::Cpu.to_string(), "cpu");
704    }
705
706    #[cfg(not(feature = "mlx"))]
707    #[test]
708    fn load_backend_mlx_not_compiled() {
709        let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
710        assert!(result.is_err());
711    }
712
713    #[cfg(feature = "cpu")]
714    #[test]
715    fn detect_backends_returns_at_least_one() {
716        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
717        assert!(!backends.is_empty());
718    }
719
720    #[cfg(all(feature = "cpu", not(feature = "mlx")))]
721    #[test]
722    fn detect_backends_returns_at_least_one_backend() {
723        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
724        assert!(!backends.is_empty(), "should detect at least one backend");
725    }
726
727    /// Load `ModernBERT` on Metal and embed a short token sequence.
728    ///
729    /// Verifies that the full pipeline (weight loading, forward pass, pooling,
730    /// L2 normalization) produces a 768-dim unit vector.
731    #[cfg(feature = "metal")]
732    #[test]
733    #[ignore = "requires model download (~570MB)"]
734    fn modernbert_loads_and_embeds() {
735        use crate::backend::driver::Driver;
736
737        let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
738        assert!(backend.is_gpu(), "Metal backend should be GPU");
739
740        let enc = Encoding {
741            input_ids: vec![1, 100, 200, 300, 2],
742            attention_mask: vec![1; 5],
743            token_type_ids: vec![0; 5],
744        };
745
746        // Stage-by-stage diagnostic using the driver directly
747        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
748        let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
749
750        // Check: can we read back input_ids?
751        let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
752        eprintln!("input_ids: {:?}", &ids_host[0][..5]);
753
754        // Check: embedding lookup
755        // Need the tok_embeddings weight — load weights directly
756        let api = hf_hub::api::sync::Api::new().unwrap();
757        let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
758        let weights_path = repo.get("model.safetensors").unwrap();
759        let config_path = repo.get("config.json").unwrap();
760        let config_str = std::fs::read_to_string(&config_path).unwrap();
761        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
762        let config =
763            crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
764        let (arch, _mmap) = driver
765            .load_modern_bert_weights(&weights_path, &config)
766            .unwrap();
767
768        let hidden = driver
769            .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
770            .unwrap();
771        let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
772        let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
773        eprintln!(
774            "embedding: {nz}/{} nonzero, first 5: {:?}",
775            h[0].len(),
776            &h[0][..5]
777        );
778
779        // Stage-by-stage forward pass bisection
780        let total = 8; // padded seq
781        let hd = 768;
782        let nh = 12;
783        let head_dim = 64;
784
785        // After embedding LN
786        let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
787        let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
788        driver
789            .layer_norm(
790                &mut ln_out,
791                &emb_clone,
792                &arch.weights.emb_norm_weight,
793                &arch.weights.zero_bias,
794                total,
795                hd,
796                1e-5,
797            )
798            .unwrap();
799        let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
800        let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
801        eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
802
803        // Layer 0 QKV GEMM
804        let layer0 = &arch.weights.layers[0];
805        let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
806        driver
807            .gemm(
808                &ln_out,
809                &layer0.qkv_weight,
810                &mut qkv,
811                total,
812                3 * hd,
813                hd,
814                true,
815            )
816            .unwrap();
817        let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
818        let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
819        eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
820
821        // QKV split
822        let mut q = driver.alloc_zeros(total * hd).unwrap();
823        let mut k = driver.alloc_zeros(total * hd).unwrap();
824        let mut v = driver.alloc_zeros(total * hd).unwrap();
825        driver
826            .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
827            .unwrap();
828        let q_h = driver.to_host(&q, 1, total * hd).unwrap();
829        let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
830        eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
831
832        // Attention scores
833        let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
834        driver
835            .gemm_batched(
836                &q,
837                &k,
838                &mut scores,
839                8,
840                8,
841                head_dim,
842                true,
843                8 * head_dim,
844                8 * head_dim,
845                8 * 8,
846                nh,
847            )
848            .unwrap();
849        let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
850        let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
851        eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
852
853        // Full forward pass
854        use crate::backend::arch::ModelArch;
855        let enc2 = Encoding {
856            input_ids: vec![1, 100, 200, 300, 2],
857            attention_mask: vec![1; 5],
858            token_type_ids: vec![0; 5],
859        };
860
861        let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
862        let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
863        let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
864        eprintln!(
865            "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
866            &quick[0][..3]
867        );
868
869        // MRL truncation
870        eprintln!("\n=== ModernBERT MRL Truncation ===");
871        let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
872        let full_emb = &full[0];
873        for dims in [64, 128, 256, 384, 512, 768] {
874            let t: Vec<f32> = full_emb[..dims].to_vec();
875            let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
876            let f_norm: f32 = full_emb[..dims]
877                .iter()
878                .map(|x| x * x)
879                .sum::<f32>()
880                .sqrt()
881                .max(1e-12);
882            let cos: f32 = t
883                .iter()
884                .zip(&full_emb[..dims])
885                .map(|(a, b)| a * b)
886                .sum::<f32>()
887                / (t_norm * f_norm);
888            eprintln!("  dims={dims:>3}: cosine={cos:.6}");
889        }
890
891        // Throughput benchmark
892        eprintln!("\n=== ModernBERT Throughput ===");
893        // Build 32 encodings of varying length
894        let mut encs = Vec::new();
895        for i in 0..32 {
896            let len = 16 + (i * 4); // 16 to 140 tokens
897            let mut ids = vec![1_i64]; // CLS
898            for j in 1..len - 1 {
899                ids.push(100 + j as i64);
900            }
901            ids.push(2); // SEP
902            encs.push(Encoding {
903                input_ids: ids.clone(),
904                attention_mask: vec![1; ids.len()],
905                token_type_ids: vec![0; ids.len()],
906            });
907        }
908
909        // Warmup
910        let _ = arch.forward(&driver, &encs[..4]);
911
912        // Timed run
913        let t0 = std::time::Instant::now();
914        let result = arch.forward(&driver, &encs).unwrap();
915        let elapsed = t0.elapsed();
916        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
917        eprintln!(
918            "  batch={}, time={:.1}ms, throughput={:.1}/s",
919            encs.len(),
920            elapsed.as_secs_f64() * 1000.0,
921            throughput
922        );
923        assert_eq!(result.len(), 32);
924
925        // Batch=1 timing (critical — CLI query path)
926        let single = vec![encs[0].clone()];
927        let t1 = std::time::Instant::now();
928        let _ = arch.forward(&driver, &single).unwrap();
929        let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
930        eprintln!("  batch=1, time={single_ms:.1}ms");
931    }
932
933    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on Metal via the driver/arch system.
934    ///
935    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector,
936    /// compares against the CPU backend for numerical equivalence, and measures
937    /// throughput (target: >=308/s matching monolithic).
938    #[cfg(feature = "metal")]
939    #[test]
940    #[ignore = "requires model download (~33MB)"]
941    fn classic_bert_driver_arch() {
942        use crate::backend::arch::ModelArch;
943
944        let model_repo = "BAAI/bge-small-en-v1.5";
945
946        // Load via driver/arch system
947        let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
948        assert!(backend.is_gpu(), "Metal backend should be GPU");
949
950        let enc = Encoding {
951            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
952            attention_mask: vec![1, 1, 1, 1, 1, 1],
953            token_type_ids: vec![0, 0, 0, 0, 0, 0],
954        };
955
956        // Basic forward pass
957        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
958        assert_eq!(result.len(), 1);
959        assert_eq!(result[0].len(), 384);
960
961        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
962        eprintln!(
963            "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
964            &result[0][..3]
965        );
966        assert!(
967            (l2 - 1.0).abs() < 0.01,
968            "embedding should be L2-normalized, got L2={l2}"
969        );
970
971        // Compare against CPU backend (reliable reference)
972        #[cfg(feature = "cpu")]
973        {
974            let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
975                .expect("CPU load failed");
976            let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
977            eprintln!("CPU  first 5: {:?}", &cpu_result[0][..5]);
978            eprintln!("NEW  first 5: {:?}", &result[0][..5]);
979            let cosine: f32 = result[0]
980                .iter()
981                .zip(&cpu_result[0])
982                .map(|(a, b)| a * b)
983                .sum();
984            eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
985            assert!(
986                cosine > 0.95,
987                "cosine similarity vs CPU should be >0.95, got {cosine}"
988            );
989        }
990
991        // Throughput benchmark
992        eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
993        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
994        let config_path = {
995            let api = hf_hub::api::sync::Api::new().unwrap();
996            let repo = api.model(model_repo.to_string());
997            repo.get("config.json").unwrap()
998        };
999        let weights_path = {
1000            let api = hf_hub::api::sync::Api::new().unwrap();
1001            let repo = api.model(model_repo.to_string());
1002            repo.get("model.safetensors").unwrap()
1003        };
1004        let config_str = std::fs::read_to_string(&config_path).unwrap();
1005        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1006        let config =
1007            crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1008        let (arch, _mmap) = driver
1009            .load_classic_bert_weights(&weights_path, &config)
1010            .unwrap();
1011
1012        // Build 32 encodings of varying length
1013        let mut encs = Vec::new();
1014        for i in 0..32 {
1015            let len = 16 + (i * 4); // 16 to 140 tokens
1016            let mut ids = vec![101_i64]; // [CLS]
1017            for j in 1..len - 1 {
1018                ids.push(100 + j as i64);
1019            }
1020            ids.push(102); // [SEP]
1021            encs.push(Encoding {
1022                input_ids: ids.clone(),
1023                attention_mask: vec![1; ids.len()],
1024                token_type_ids: vec![0; ids.len()],
1025            });
1026        }
1027
1028        // Warmup
1029        let _ = arch.forward(&driver, &encs[..4]);
1030
1031        // Timed run
1032        let t0 = std::time::Instant::now();
1033        let bench_result = arch.forward(&driver, &encs).unwrap();
1034        let elapsed = t0.elapsed();
1035        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1036        eprintln!(
1037            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1038            encs.len(),
1039            elapsed.as_secs_f64() * 1000.0,
1040            throughput
1041        );
1042        assert_eq!(bench_result.len(), 32);
1043    }
1044
1045    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on CPU via the driver/arch system.
1046    ///
1047    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector
1048    /// and compares against the monolithic CPU backend for numerical equivalence.
1049    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1050    #[test]
1051    #[ignore = "requires model download (~33MB)"]
1052    fn classic_bert_cpu_driver_arch() {
1053        let model_repo = "BAAI/bge-small-en-v1.5";
1054
1055        // Load via new driver/arch system
1056        let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1057        assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1058
1059        let enc = Encoding {
1060            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1061            attention_mask: vec![1, 1, 1, 1, 1, 1],
1062            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1063        };
1064
1065        // Basic forward pass
1066        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1067        assert_eq!(result.len(), 1);
1068        assert_eq!(result[0].len(), 384);
1069
1070        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1071        eprintln!(
1072            "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1073            &result[0][..5]
1074        );
1075        assert!(
1076            (l2 - 1.0).abs() < 0.01,
1077            "embedding should be L2-normalized, got L2={l2}"
1078        );
1079
1080        // Compare against monolithic CPU backend (reference)
1081        #[cfg(feature = "cpu")]
1082        {
1083            let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1084                .expect("monolithic CPU load failed");
1085            let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1086            eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1087            eprintln!("New  first 5: {:?}", &result[0][..5]);
1088            let cosine: f32 = result[0]
1089                .iter()
1090                .zip(&cpu_result[0])
1091                .map(|(a, b)| a * b)
1092                .sum();
1093            eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1094            assert!(
1095                cosine > 0.999,
1096                "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1097            );
1098        }
1099    }
1100}