Skip to main content

ripvec_core/backend/
mod.rs

1//! Embedding backend abstraction layer.
2//!
3//! Defines the [`EmbedBackend`] trait that all embedding backends (CPU, CUDA,
4//! Metal, MLX) implement, plus the [`Encoding`] input type and [`BackendKind`]
5//! discriminant. Use [`load_backend`] to construct a backend by kind.
6
7pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19
20/// Pre-tokenized encoding ready for inference.
21///
22/// Token IDs, attention mask, and token type IDs must all have the same length.
23/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
24/// reaching the backend.
25#[derive(Debug, Clone)]
26pub struct Encoding {
27    /// Token IDs produced by the tokenizer.
28    pub input_ids: Vec<i64>,
29    /// Attention mask (1 for real tokens, 0 for padding).
30    pub attention_mask: Vec<i64>,
31    /// Token type IDs (0 for single-sequence models).
32    pub token_type_ids: Vec<i64>,
33}
34
35/// Trait for embedding backends.
36///
37/// Implementations must be [`Send`] so they can be moved across threads (e.g.
38/// into a ring-buffer pipeline). The trait is object-safe — callers use
39/// `&dyn EmbedBackend` or `Box<dyn EmbedBackend>`.
40///
41/// # GPU vs CPU scheduling
42///
43/// - **CPU backends** (`is_gpu() == false`): cloned per rayon thread via
44///   [`clone_backend`](EmbedBackend::clone_backend).
45/// - **GPU backends** (`is_gpu() == true`): use a ring-buffer pipeline with
46///   `RING_SIZE = 4` for bounded memory.
47pub trait EmbedBackend: Send + Sync {
48    /// Embed a batch of pre-tokenized inputs, returning L2-normalized vectors.
49    ///
50    /// Each inner `Vec<f32>` is the embedding for the corresponding
51    /// [`Encoding`]. Errors **must** propagate — never silently return
52    /// defaults.
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if tensor construction or the forward pass fails.
57    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
58
59    /// Whether this backend supports cheap cloning for per-thread instances.
60    ///
61    /// CPU backends return `true`; GPU backends typically return `false`.
62    fn supports_clone(&self) -> bool;
63
64    /// Create a cheap clone of this backend for per-thread use in rayon.
65    ///
66    /// # Panics
67    ///
68    /// May panic if [`supports_clone`](EmbedBackend::supports_clone) returns
69    /// `false`. Callers must check `supports_clone()` first.
70    fn clone_backend(&self) -> Box<dyn EmbedBackend>;
71
72    /// Whether this backend runs on a GPU.
73    ///
74    /// GPU backends use a ring-buffer pipelined scheduler (`RING_SIZE = 4`)
75    /// for bounded memory usage.
76    fn is_gpu(&self) -> bool;
77
78    /// Maximum token count this model supports (position embedding limit).
79    ///
80    /// `ClassicBert`: 512. `ModernBERT`: up to model config. Tokens beyond this
81    /// are truncated during tokenization.
82    fn max_tokens(&self) -> usize {
83        512 // default for classic BERT models
84    }
85}
86
87/// Available embedding backend implementations.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum BackendKind {
90    /// CUDA (cudarc, NVIDIA GPUs via cuBLAS + custom kernels).
91    Cuda,
92    /// MLX (Apple Silicon, macOS only).
93    Mlx,
94    /// CPU (ndarray + system BLAS).
95    #[default]
96    Cpu,
97    /// Metal (Apple Silicon, macOS only, direct Metal GPU).
98    Metal,
99}
100
101impl std::fmt::Display for BackendKind {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            Self::Cuda => write!(f, "cuda"),
105            Self::Mlx => write!(f, "mlx"),
106            Self::Cpu => write!(f, "cpu"),
107            Self::Metal => write!(f, "metal"),
108        }
109    }
110}
111
112/// Device hint passed to [`load_backend`].
113///
114/// Backends map this to their native device type. Not all backends support
115/// all devices — unsupported combinations return an error.
116#[derive(Debug, Clone, Copy, Default)]
117pub enum DeviceHint {
118    /// Automatically select the best available device.
119    #[default]
120    Auto,
121    /// Force CPU inference.
122    Cpu,
123    /// Force GPU inference (Metal on macOS, CUDA on Linux/Windows).
124    Gpu,
125}
126
127/// Inference optimization parameters passed through to model loading.
128///
129/// Currently empty — reserved for future optimizations (quantization,
130/// distillation, etc.) that need to be configured at load time.
131#[derive(Debug, Clone, Default)]
132pub struct InferenceOpts {}
133
134/// Construct an embedding backend of the given kind.
135///
136/// Downloads model weights on first use via `hf-hub`. The `device_hint`
137/// is advisory — backends that don't support GPU fall back to CPU.
138///
139/// # Errors
140///
141/// Returns an error if the requested backend was not compiled in (missing
142/// feature flag) or if model loading fails.
143pub fn load_backend(
144    kind: BackendKind,
145    #[cfg_attr(
146        not(any(
147            feature = "cuda",
148            feature = "mlx",
149            feature = "cpu",
150            feature = "cpu-accelerate",
151            feature = "metal"
152        )),
153        expect(unused_variables, reason = "used when backend features are enabled")
154    )]
155    model_repo: &str,
156    #[cfg_attr(
157        not(any(
158            feature = "cuda",
159            feature = "mlx",
160            feature = "cpu",
161            feature = "cpu-accelerate",
162            feature = "metal"
163        )),
164        expect(unused_variables, reason = "used when backend features are enabled")
165    )]
166    device_hint: DeviceHint,
167) -> crate::Result<Box<dyn EmbedBackend>> {
168    match kind {
169        #[cfg(feature = "cuda")]
170        BackendKind::Cuda => {
171            if is_modernbert_model(model_repo) {
172                return load_modernbert_cuda(model_repo);
173            }
174            let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
175            Ok(Box::new(backend))
176        }
177        #[cfg(not(feature = "cuda"))]
178        BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
179            "cuda backend requires building with: cargo build --features cuda"
180        ))),
181        #[cfg(feature = "mlx")]
182        BackendKind::Mlx => {
183            let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
184            Ok(Box::new(backend))
185        }
186        #[cfg(not(feature = "mlx"))]
187        BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
188            "mlx backend requires building with: cargo build --features mlx"
189        ))),
190        #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
191        BackendKind::Cpu => {
192            if is_modernbert_model(model_repo) {
193                return load_modernbert_cpu(model_repo);
194            }
195            #[cfg(feature = "cpu")]
196            {
197                let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
198                #[expect(
199                    clippy::needless_return,
200                    reason = "return needed before cfg(not) fallback"
201                )]
202                return Ok(Box::new(backend));
203            }
204            #[cfg(not(feature = "cpu"))]
205            Err(crate::Error::Other(anyhow::anyhow!(
206                "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
207            )))
208        }
209        #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
210        BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
211            "cpu backend requires building with: cargo build --features cpu"
212        ))),
213        #[cfg(feature = "metal")]
214        BackendKind::Metal => {
215            // All models route through the driver/arch system.
216            if is_modernbert_model(model_repo) {
217                return load_modernbert_metal(model_repo);
218            }
219            load_classic_metal(model_repo)
220        }
221        #[cfg(not(feature = "metal"))]
222        BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
223            "metal backend requires building with: cargo build --features metal"
224        ))),
225    }
226}
227
228/// Detect all available backends and load them.
229///
230/// Probes for GPU backends (CUDA, MLX) first, then falls back to CPU.
231/// Returns backends in priority order — the first entry is the primary
232/// (used for query embedding in interactive mode).
233///
234/// # Errors
235///
236/// Returns an error if no backends can be loaded (not even CPU).
237pub fn detect_backends(
238    #[cfg_attr(
239        not(any(
240            feature = "cuda",
241            feature = "mlx",
242            feature = "cpu",
243            feature = "cpu-accelerate",
244            feature = "metal"
245        )),
246        expect(unused_variables, reason = "used when backend features are enabled")
247    )]
248    model_repo: &str,
249) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
250    #[cfg_attr(
251        not(any(
252            feature = "cuda",
253            feature = "mlx",
254            feature = "cpu",
255            feature = "cpu-accelerate",
256            feature = "metal"
257        )),
258        expect(unused_mut, reason = "mut needed when backend features are enabled")
259    )]
260    let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
261
262    // Try CUDA (NVIDIA GPU)
263    #[cfg(feature = "cuda")]
264    {
265        if is_modernbert_model(model_repo) {
266            if let Ok(b) = load_modernbert_cuda(model_repo) {
267                backends.push(b);
268            }
269        } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
270            backends.push(Box::new(b));
271        }
272    }
273
274    // Try Metal (Apple Silicon GPU, preferred over MLX)
275    #[cfg(feature = "metal")]
276    {
277        // Route models through the driver/arch system by architecture.
278        if is_modernbert_model(model_repo) {
279            if let Ok(b) = load_modernbert_metal(model_repo) {
280                backends.push(b);
281            }
282        } else if let Ok(b) = load_classic_metal(model_repo) {
283            backends.push(b);
284        }
285    }
286
287    // Try MLX (Apple Silicon GPU, fallback if Metal unavailable)
288    #[cfg(feature = "mlx")]
289    if backends.is_empty()
290        && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
291    {
292        backends.push(Box::new(b));
293    }
294
295    // Add CPU as fallback only when no GPU backend was loaded.
296    // On Apple Silicon, running CPU + MLX concurrently is slower than
297    // MLX alone because they share the same physical cores and memory.
298    // On discrete GPU systems (CUDA), CPU would be a useful helper.
299    #[cfg_attr(
300        not(any(feature = "cpu", feature = "cpu-accelerate")),
301        expect(unused_variables, reason = "used when cpu feature is enabled")
302    )]
303    let has_gpu = backends.iter().any(|b| b.is_gpu());
304    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
305    if !has_gpu {
306        if is_modernbert_model(model_repo) {
307            if let Ok(b) = load_modernbert_cpu(model_repo) {
308                backends.push(b);
309            }
310        } else {
311            #[cfg(feature = "cpu")]
312            if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
313                backends.push(Box::new(b));
314            }
315        }
316    }
317
318    if backends.is_empty() {
319        return Err(crate::Error::Other(anyhow::anyhow!(
320            "no embedding backends available"
321        )));
322    }
323
324    Ok(backends)
325}
326
327// ---------------------------------------------------------------------------
328// ModernBERT loader (driver/arch system)
329// ---------------------------------------------------------------------------
330
331/// Load a `ModernBERT` model on the Metal GPU backend.
332///
333/// Downloads the model from Hugging Face Hub (cached after first download),
334/// memory-maps the safetensors weights, and builds a [`GenericBackend`]
335/// pairing a [`MetalDriver`](driver::metal::MetalDriver) with a
336/// [`ModernBertArch`](arch::modern_bert::ModernBertArch).
337///
338/// # Errors
339///
340/// Returns an error if no Metal device is available, the model cannot be
341/// downloaded, or weight loading fails.
342#[cfg(feature = "metal")]
343pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
344    use driver::metal::{MetalDriver, ModernBertConfig};
345    use generic::GenericBackend;
346    use hf_hub::api::sync::Api;
347
348    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
349    let repo = api.model(model_repo.to_string());
350
351    let config_path = repo
352        .get("config.json")
353        .map_err(|e| crate::Error::Download(e.to_string()))?;
354    let weights_path = repo
355        .get("model.safetensors")
356        .map_err(|e| crate::Error::Download(e.to_string()))?;
357
358    // Parse config.json
359    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
360        path: config_path.display().to_string(),
361        source: e,
362    })?;
363    let config_json: serde_json::Value = serde_json::from_str(&config_str)
364        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
365    let config = ModernBertConfig::from_json(&config_json)?;
366    let max_tokens = config.max_position_embeddings;
367
368    let driver = MetalDriver::new()?;
369    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
370
371    tracing::info!(
372        model_repo,
373        hidden = config.hidden_size,
374        layers = config.num_hidden_layers,
375        heads = config.num_attention_heads,
376        intermediate = config.intermediate_size,
377        max_tokens,
378        "ModernBERT loaded on Metal (driver/arch)"
379    );
380
381    Ok(Box::new(GenericBackend::new(
382        driver, arch, max_tokens, true, mmap,
383    )))
384}
385
386/// Load `ModernBERT` on CPU via the driver/arch system.
387#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
388pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
389    use driver::cpu::{CpuDriver, ModernBertConfig};
390    use generic::GenericBackend;
391    use hf_hub::api::sync::Api;
392
393    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
394    let repo = api.model(model_repo.to_string());
395
396    let config_path = repo
397        .get("config.json")
398        .map_err(|e| crate::Error::Download(e.to_string()))?;
399    let weights_path = repo
400        .get("model.safetensors")
401        .map_err(|e| crate::Error::Download(e.to_string()))?;
402
403    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
404        path: config_path.display().to_string(),
405        source: e,
406    })?;
407    let config_json: serde_json::Value = serde_json::from_str(&config_str)
408        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
409    let config = ModernBertConfig::from_json(&config_json)?;
410    let max_tokens = config.max_position_embeddings;
411
412    let driver = CpuDriver::new()?;
413    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
414
415    tracing::info!(
416        model_repo,
417        hidden = config.hidden_size,
418        layers = config.num_hidden_layers,
419        heads = config.num_attention_heads,
420        max_tokens,
421        "ModernBERT loaded on CPU (driver/arch)"
422    );
423
424    Ok(Box::new(GenericBackend::new(
425        driver, arch, max_tokens, false, mmap,
426    )))
427}
428
429/// Load `ModernBERT` on CUDA via the driver/arch system.
430///
431/// Creates a [`CudaDriver`](driver::cuda::CudaDriver), loads safetensors weights
432/// onto the GPU, pre-converts GEMM weights to FP16, builds RoPE caches, and
433/// wraps the result in a [`GenericBackend`](generic::GenericBackend).
434///
435/// # Errors
436///
437/// Returns an error if no CUDA device is available, the model cannot be
438/// downloaded, or weight loading fails.
439#[cfg(feature = "cuda")]
440pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
441    use driver::cuda::{CudaDriver, ModernBertConfig};
442    use generic::GenericBackend;
443    use hf_hub::api::sync::Api;
444
445    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
446    let repo = api.model(model_repo.to_string());
447
448    let config_path = repo
449        .get("config.json")
450        .map_err(|e| crate::Error::Download(e.to_string()))?;
451    let weights_path = repo
452        .get("model.safetensors")
453        .map_err(|e| crate::Error::Download(e.to_string()))?;
454
455    // Parse config.json
456    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
457        path: config_path.display().to_string(),
458        source: e,
459    })?;
460    let config_json: serde_json::Value = serde_json::from_str(&config_str)
461        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
462    let config = ModernBertConfig::from_json(&config_json)?;
463    let max_tokens = config.max_position_embeddings;
464
465    let driver = CudaDriver::new()?;
466    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
467
468    tracing::info!(
469        model_repo,
470        hidden = config.hidden_size,
471        layers = config.num_hidden_layers,
472        heads = config.num_attention_heads,
473        intermediate = config.intermediate_size,
474        max_tokens,
475        "ModernBERT loaded on CUDA (driver/arch)"
476    );
477
478    // CUDA benefits from larger batches to saturate GPU SMs (128 SMs on RTX 4090).
479    // Metal uses 32 (AMX coprocessor-limited). CUDA can handle 128+ easily.
480    Ok(Box::new(GenericBackend::with_max_batch(
481        driver, arch, max_tokens, true, mmap, 32,
482    )))
483}
484
485/// Check whether a model repo uses the `ModernBERT` architecture.
486///
487/// Downloads and inspects `config.json` to check for `"model_type": "modernbert"`.
488/// Returns `false` on any download or parse error (fail-open for detection).
489#[cfg(any(
490    feature = "cuda",
491    feature = "metal",
492    feature = "cpu",
493    feature = "cpu-accelerate"
494))]
495fn is_modernbert_model(model_repo: &str) -> bool {
496    let Ok(api) = hf_hub::api::sync::Api::new() else {
497        return false;
498    };
499    let repo = api.model(model_repo.to_string());
500    let Ok(config_path) = repo.get("config.json") else {
501        return false;
502    };
503    let Ok(config_str) = std::fs::read_to_string(&config_path) else {
504        return false;
505    };
506    let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
507        return false;
508    };
509    json.get("model_type")
510        .and_then(serde_json::Value::as_str)
511        .is_some_and(|t| t == "modernbert")
512}
513
514// ---------------------------------------------------------------------------
515// ClassicBert loader (driver/arch system)
516// ---------------------------------------------------------------------------
517
518/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the Metal GPU backend.
519///
520/// Downloads the model from Hugging Face Hub (cached after first download),
521/// memory-maps the safetensors weights, fuses Q/K/V into a single tensor per
522/// layer, and builds a [`GenericBackend`] pairing a
523/// [`MetalDriver`](driver::metal::MetalDriver) with a
524/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
525///
526/// # Errors
527///
528/// Returns an error if no Metal device is available, the model cannot be
529/// downloaded, or weight loading fails.
530#[cfg(feature = "metal")]
531pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
532    use driver::metal::{ClassicBertConfig, MetalDriver};
533    use generic::GenericBackend;
534    use hf_hub::api::sync::Api;
535
536    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
537    let repo = api.model(model_repo.to_string());
538
539    let config_path = repo
540        .get("config.json")
541        .map_err(|e| crate::Error::Download(e.to_string()))?;
542    let weights_path = repo
543        .get("model.safetensors")
544        .map_err(|e| crate::Error::Download(e.to_string()))?;
545
546    // Parse config.json
547    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
548        path: config_path.display().to_string(),
549        source: e,
550    })?;
551    let config_json: serde_json::Value = serde_json::from_str(&config_str)
552        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
553    let config = ClassicBertConfig::from_json(&config_json)?;
554    let max_tokens = config.max_position_embeddings;
555
556    let driver = MetalDriver::new()?;
557    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
558
559    tracing::info!(
560        model_repo,
561        hidden = config.hidden_size,
562        layers = config.num_hidden_layers,
563        heads = config.num_attention_heads,
564        intermediate = config.intermediate_size,
565        max_tokens,
566        "ClassicBert loaded on Metal (driver/arch)"
567    );
568
569    Ok(Box::new(GenericBackend::new(
570        driver, arch, max_tokens, true, mmap,
571    )))
572}
573
574// ---------------------------------------------------------------------------
575// ClassicBert loader (CPU driver/arch system)
576// ---------------------------------------------------------------------------
577
578/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the CPU backend
579/// via the driver/arch system.
580///
581/// Downloads the model from Hugging Face Hub (cached after first download),
582/// reads safetensors weights into `Vec<f32>` tensors, fuses Q/K/V per layer,
583/// and builds a [`GenericBackend`] pairing a
584/// [`CpuDriver`](driver::cpu::CpuDriver) with a
585/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
586///
587/// # Errors
588///
589/// Returns an error if the model cannot be downloaded or weight loading fails.
590#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
591pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
592    use driver::cpu::{ClassicBertConfig, CpuDriver};
593    use generic::GenericBackend;
594    use hf_hub::api::sync::Api;
595
596    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
597    let repo = api.model(model_repo.to_string());
598
599    let config_path = repo
600        .get("config.json")
601        .map_err(|e| crate::Error::Download(e.to_string()))?;
602    let weights_path = repo
603        .get("model.safetensors")
604        .map_err(|e| crate::Error::Download(e.to_string()))?;
605
606    // Parse config.json
607    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
608        path: config_path.display().to_string(),
609        source: e,
610    })?;
611    let config_json: serde_json::Value = serde_json::from_str(&config_str)
612        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
613    let config = ClassicBertConfig::from_json(&config_json)?;
614    let max_tokens = config.max_position_embeddings;
615
616    let driver = CpuDriver::new()?;
617    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
618
619    tracing::info!(
620        model_repo,
621        hidden = config.hidden_size,
622        layers = config.num_hidden_layers,
623        heads = config.num_attention_heads,
624        intermediate = config.intermediate_size,
625        max_tokens,
626        "ClassicBert loaded on CPU (driver/arch)"
627    );
628
629    Ok(Box::new(GenericBackend::new(
630        driver, arch, max_tokens, false, mmap,
631    )))
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637
638    /// Verify that `EmbedBackend` is object-safe by constructing a trait object type.
639    #[test]
640    fn trait_is_object_safe() {
641        // If this compiles, the trait is object-safe.
642        fn _assert_object_safe(_: &dyn EmbedBackend) {}
643    }
644
645    /// Verify that `Box<dyn EmbedBackend>` is `Send`.
646    #[test]
647    fn trait_object_is_send() {
648        fn assert_send<T: Send>() {}
649        assert_send::<Box<dyn EmbedBackend>>();
650    }
651
652    /// Verify that `Box<dyn EmbedBackend>` is `Sync` (needed for `&dyn` across threads).
653    #[test]
654    fn trait_object_is_sync() {
655        fn assert_sync<T: Sync>() {}
656        assert_sync::<Box<dyn EmbedBackend>>();
657    }
658
659    /// Verify that `Arc<dyn EmbedBackend>` is `Send` (needed for ring-buffer pipeline).
660    #[test]
661    fn arc_trait_object_is_send() {
662        fn assert_send<T: Send>() {}
663        assert_send::<std::sync::Arc<dyn EmbedBackend>>();
664    }
665
666    #[test]
667    fn encoding_construction() {
668        let enc = Encoding {
669            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
670            attention_mask: vec![1, 1, 1, 1, 1, 1],
671            token_type_ids: vec![0, 0, 0, 0, 0, 0],
672        };
673        assert_eq!(enc.input_ids.len(), 6);
674        assert_eq!(enc.attention_mask.len(), 6);
675        assert_eq!(enc.token_type_ids.len(), 6);
676    }
677
678    #[test]
679    fn encoding_clone() {
680        let enc = Encoding {
681            input_ids: vec![101, 102],
682            attention_mask: vec![1, 1],
683            token_type_ids: vec![0, 0],
684        };
685        let cloned = enc.clone();
686        assert_eq!(enc.input_ids, cloned.input_ids);
687    }
688
689    #[test]
690    fn backend_kind_default_is_cpu() {
691        assert_eq!(BackendKind::default(), BackendKind::Cpu);
692    }
693
694    #[test]
695    fn backend_kind_display() {
696        assert_eq!(BackendKind::Cuda.to_string(), "cuda");
697        assert_eq!(BackendKind::Mlx.to_string(), "mlx");
698        assert_eq!(BackendKind::Cpu.to_string(), "cpu");
699    }
700
701    #[cfg(not(feature = "mlx"))]
702    #[test]
703    fn load_backend_mlx_not_compiled() {
704        let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
705        assert!(result.is_err());
706    }
707
708    #[cfg(feature = "cpu")]
709    #[test]
710    fn detect_backends_returns_at_least_one() {
711        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
712        assert!(!backends.is_empty());
713    }
714
715    #[cfg(all(feature = "cpu", not(feature = "mlx")))]
716    #[test]
717    fn detect_backends_returns_at_least_one_backend() {
718        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
719        assert!(!backends.is_empty(), "should detect at least one backend");
720    }
721
722    /// Load `ModernBERT` on Metal and embed a short token sequence.
723    ///
724    /// Verifies that the full pipeline (weight loading, forward pass, pooling,
725    /// L2 normalization) produces a 768-dim unit vector.
726    #[cfg(feature = "metal")]
727    #[test]
728    #[ignore = "requires model download (~570MB)"]
729    fn modernbert_loads_and_embeds() {
730        use crate::backend::driver::Driver;
731
732        let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
733        assert!(backend.is_gpu(), "Metal backend should be GPU");
734
735        let enc = Encoding {
736            input_ids: vec![1, 100, 200, 300, 2],
737            attention_mask: vec![1; 5],
738            token_type_ids: vec![0; 5],
739        };
740
741        // Stage-by-stage diagnostic using the driver directly
742        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
743        let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
744
745        // Check: can we read back input_ids?
746        let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
747        eprintln!("input_ids: {:?}", &ids_host[0][..5]);
748
749        // Check: embedding lookup
750        // Need the tok_embeddings weight — load weights directly
751        let api = hf_hub::api::sync::Api::new().unwrap();
752        let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
753        let weights_path = repo.get("model.safetensors").unwrap();
754        let config_path = repo.get("config.json").unwrap();
755        let config_str = std::fs::read_to_string(&config_path).unwrap();
756        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
757        let config =
758            crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
759        let (arch, _mmap) = driver
760            .load_modern_bert_weights(&weights_path, &config)
761            .unwrap();
762
763        let hidden = driver
764            .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
765            .unwrap();
766        let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
767        let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
768        eprintln!(
769            "embedding: {nz}/{} nonzero, first 5: {:?}",
770            h[0].len(),
771            &h[0][..5]
772        );
773
774        // Stage-by-stage forward pass bisection
775        let total = 8; // padded seq
776        let hd = 768;
777        let nh = 12;
778        let head_dim = 64;
779
780        // After embedding LN
781        let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
782        let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
783        driver
784            .layer_norm(
785                &mut ln_out,
786                &emb_clone,
787                &arch.weights.emb_norm_weight,
788                &arch.weights.zero_bias,
789                total,
790                hd,
791                1e-5,
792            )
793            .unwrap();
794        let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
795        let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
796        eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
797
798        // Layer 0 QKV GEMM
799        let layer0 = &arch.weights.layers[0];
800        let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
801        driver
802            .gemm(
803                &ln_out,
804                &layer0.qkv_weight,
805                &mut qkv,
806                total,
807                3 * hd,
808                hd,
809                true,
810            )
811            .unwrap();
812        let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
813        let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
814        eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
815
816        // QKV split
817        let mut q = driver.alloc_zeros(total * hd).unwrap();
818        let mut k = driver.alloc_zeros(total * hd).unwrap();
819        let mut v = driver.alloc_zeros(total * hd).unwrap();
820        driver
821            .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
822            .unwrap();
823        let q_h = driver.to_host(&q, 1, total * hd).unwrap();
824        let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
825        eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
826
827        // Attention scores
828        let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
829        driver
830            .gemm_batched(
831                &q,
832                &k,
833                &mut scores,
834                8,
835                8,
836                head_dim,
837                true,
838                8 * head_dim,
839                8 * head_dim,
840                8 * 8,
841                nh,
842            )
843            .unwrap();
844        let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
845        let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
846        eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
847
848        // Full forward pass
849        use crate::backend::arch::ModelArch;
850        let enc2 = Encoding {
851            input_ids: vec![1, 100, 200, 300, 2],
852            attention_mask: vec![1; 5],
853            token_type_ids: vec![0; 5],
854        };
855
856        let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
857        let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
858        let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
859        eprintln!(
860            "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
861            &quick[0][..3]
862        );
863
864        // MRL truncation
865        eprintln!("\n=== ModernBERT MRL Truncation ===");
866        let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
867        let full_emb = &full[0];
868        for dims in [64, 128, 256, 384, 512, 768] {
869            let t: Vec<f32> = full_emb[..dims].to_vec();
870            let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
871            let f_norm: f32 = full_emb[..dims]
872                .iter()
873                .map(|x| x * x)
874                .sum::<f32>()
875                .sqrt()
876                .max(1e-12);
877            let cos: f32 = t
878                .iter()
879                .zip(&full_emb[..dims])
880                .map(|(a, b)| a * b)
881                .sum::<f32>()
882                / (t_norm * f_norm);
883            eprintln!("  dims={dims:>3}: cosine={cos:.6}");
884        }
885
886        // Throughput benchmark
887        eprintln!("\n=== ModernBERT Throughput ===");
888        // Build 32 encodings of varying length
889        let mut encs = Vec::new();
890        for i in 0..32 {
891            let len = 16 + (i * 4); // 16 to 140 tokens
892            let mut ids = vec![1_i64]; // CLS
893            for j in 1..len - 1 {
894                ids.push(100 + j as i64);
895            }
896            ids.push(2); // SEP
897            encs.push(Encoding {
898                input_ids: ids.clone(),
899                attention_mask: vec![1; ids.len()],
900                token_type_ids: vec![0; ids.len()],
901            });
902        }
903
904        // Warmup
905        let _ = arch.forward(&driver, &encs[..4]);
906
907        // Timed run
908        let t0 = std::time::Instant::now();
909        let result = arch.forward(&driver, &encs).unwrap();
910        let elapsed = t0.elapsed();
911        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
912        eprintln!(
913            "  batch={}, time={:.1}ms, throughput={:.1}/s",
914            encs.len(),
915            elapsed.as_secs_f64() * 1000.0,
916            throughput
917        );
918        assert_eq!(result.len(), 32);
919
920        // Batch=1 timing (critical — CLI query path)
921        let single = vec![encs[0].clone()];
922        let t1 = std::time::Instant::now();
923        let _ = arch.forward(&driver, &single).unwrap();
924        let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
925        eprintln!("  batch=1, time={single_ms:.1}ms");
926    }
927
928    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on Metal via the driver/arch system.
929    ///
930    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector,
931    /// compares against the CPU backend for numerical equivalence, and measures
932    /// throughput (target: >=308/s matching monolithic).
933    #[cfg(feature = "metal")]
934    #[test]
935    #[ignore = "requires model download (~33MB)"]
936    fn classic_bert_driver_arch() {
937        use crate::backend::arch::ModelArch;
938
939        let model_repo = "BAAI/bge-small-en-v1.5";
940
941        // Load via driver/arch system
942        let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
943        assert!(backend.is_gpu(), "Metal backend should be GPU");
944
945        let enc = Encoding {
946            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
947            attention_mask: vec![1, 1, 1, 1, 1, 1],
948            token_type_ids: vec![0, 0, 0, 0, 0, 0],
949        };
950
951        // Basic forward pass
952        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
953        assert_eq!(result.len(), 1);
954        assert_eq!(result[0].len(), 384);
955
956        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
957        eprintln!(
958            "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
959            &result[0][..3]
960        );
961        assert!(
962            (l2 - 1.0).abs() < 0.01,
963            "embedding should be L2-normalized, got L2={l2}"
964        );
965
966        // Compare against CPU backend (reliable reference)
967        #[cfg(feature = "cpu")]
968        {
969            let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
970                .expect("CPU load failed");
971            let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
972            eprintln!("CPU  first 5: {:?}", &cpu_result[0][..5]);
973            eprintln!("NEW  first 5: {:?}", &result[0][..5]);
974            let cosine: f32 = result[0]
975                .iter()
976                .zip(&cpu_result[0])
977                .map(|(a, b)| a * b)
978                .sum();
979            eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
980            assert!(
981                cosine > 0.95,
982                "cosine similarity vs CPU should be >0.95, got {cosine}"
983            );
984        }
985
986        // Throughput benchmark
987        eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
988        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
989        let config_path = {
990            let api = hf_hub::api::sync::Api::new().unwrap();
991            let repo = api.model(model_repo.to_string());
992            repo.get("config.json").unwrap()
993        };
994        let weights_path = {
995            let api = hf_hub::api::sync::Api::new().unwrap();
996            let repo = api.model(model_repo.to_string());
997            repo.get("model.safetensors").unwrap()
998        };
999        let config_str = std::fs::read_to_string(&config_path).unwrap();
1000        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1001        let config =
1002            crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1003        let (arch, _mmap) = driver
1004            .load_classic_bert_weights(&weights_path, &config)
1005            .unwrap();
1006
1007        // Build 32 encodings of varying length
1008        let mut encs = Vec::new();
1009        for i in 0..32 {
1010            let len = 16 + (i * 4); // 16 to 140 tokens
1011            let mut ids = vec![101_i64]; // [CLS]
1012            for j in 1..len - 1 {
1013                ids.push(100 + j as i64);
1014            }
1015            ids.push(102); // [SEP]
1016            encs.push(Encoding {
1017                input_ids: ids.clone(),
1018                attention_mask: vec![1; ids.len()],
1019                token_type_ids: vec![0; ids.len()],
1020            });
1021        }
1022
1023        // Warmup
1024        let _ = arch.forward(&driver, &encs[..4]);
1025
1026        // Timed run
1027        let t0 = std::time::Instant::now();
1028        let bench_result = arch.forward(&driver, &encs).unwrap();
1029        let elapsed = t0.elapsed();
1030        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1031        eprintln!(
1032            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1033            encs.len(),
1034            elapsed.as_secs_f64() * 1000.0,
1035            throughput
1036        );
1037        assert_eq!(bench_result.len(), 32);
1038    }
1039
1040    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on CPU via the driver/arch system.
1041    ///
1042    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector
1043    /// and compares against the monolithic CPU backend for numerical equivalence.
1044    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1045    #[test]
1046    #[ignore = "requires model download (~33MB)"]
1047    fn classic_bert_cpu_driver_arch() {
1048        let model_repo = "BAAI/bge-small-en-v1.5";
1049
1050        // Load via new driver/arch system
1051        let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1052        assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1053
1054        let enc = Encoding {
1055            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1056            attention_mask: vec![1, 1, 1, 1, 1, 1],
1057            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1058        };
1059
1060        // Basic forward pass
1061        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1062        assert_eq!(result.len(), 1);
1063        assert_eq!(result[0].len(), 384);
1064
1065        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1066        eprintln!(
1067            "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1068            &result[0][..5]
1069        );
1070        assert!(
1071            (l2 - 1.0).abs() < 0.01,
1072            "embedding should be L2-normalized, got L2={l2}"
1073        );
1074
1075        // Compare against monolithic CPU backend (reference)
1076        #[cfg(feature = "cpu")]
1077        {
1078            let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1079                .expect("monolithic CPU load failed");
1080            let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1081            eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1082            eprintln!("New  first 5: {:?}", &result[0][..5]);
1083            let cosine: f32 = result[0]
1084                .iter()
1085                .zip(&cpu_result[0])
1086                .map(|(a, b)| a * b)
1087                .sum();
1088            eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1089            assert!(
1090                cosine > 0.999,
1091                "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1092            );
1093        }
1094    }
1095}