Skip to main content

ripvec_core/backend/
mod.rs

1//! Embedding backend abstraction layer.
2//!
3//! Defines the [`EmbedBackend`] trait that all embedding backends (CPU, CUDA,
4//! MLX) implement, plus the [`Encoding`] input type and [`BackendKind`]
5//! discriminant. Use [`load_backend`] to construct a backend by kind.
6
7pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19
20/// Pre-tokenized encoding ready for inference.
21///
22/// Token IDs, attention mask, and token type IDs must all have the same length.
23/// Token count is capped at `MODEL_MAX_TOKENS` (512) by the tokenizer before
24/// reaching the backend.
25#[derive(Debug, Clone)]
26pub struct Encoding {
27    /// Token IDs produced by the tokenizer.
28    pub input_ids: Vec<i64>,
29    /// Attention mask (1 for real tokens, 0 for padding).
30    pub attention_mask: Vec<i64>,
31    /// Token type IDs (0 for single-sequence models).
32    pub token_type_ids: Vec<i64>,
33}
34
35/// Trait for embedding backends.
36///
37/// Implementations must be [`Send`] so they can be moved across threads (e.g.
38/// into a ring-buffer pipeline). The trait is object-safe — callers use
39/// `&dyn EmbedBackend` or `Box<dyn EmbedBackend>`.
40///
41/// # GPU vs CPU scheduling
42///
43/// - **CPU backends** (`is_gpu() == false`): cloned per rayon thread via
44///   [`clone_backend`](EmbedBackend::clone_backend).
45/// - **GPU backends** (`is_gpu() == true`): use a ring-buffer pipeline with
46///   `RING_SIZE = 4` for bounded memory.
47pub trait EmbedBackend: Send + Sync {
48    /// Embed a batch of pre-tokenized inputs, returning L2-normalized vectors.
49    ///
50    /// Each inner `Vec<f32>` is the embedding for the corresponding
51    /// [`Encoding`]. Errors **must** propagate — never silently return
52    /// defaults.
53    ///
54    /// # Errors
55    ///
56    /// Returns an error if tensor construction or the forward pass fails.
57    fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
58
59    /// Whether this backend supports cheap cloning for per-thread instances.
60    ///
61    /// CPU backends return `true`; GPU backends typically return `false`.
62    fn supports_clone(&self) -> bool;
63
64    /// Create a cheap clone of this backend for per-thread use in rayon.
65    ///
66    /// # Panics
67    ///
68    /// May panic if [`supports_clone`](EmbedBackend::supports_clone) returns
69    /// `false`. Callers must check `supports_clone()` first.
70    fn clone_backend(&self) -> Box<dyn EmbedBackend>;
71
72    /// Whether this backend runs on a GPU.
73    ///
74    /// GPU backends use a ring-buffer pipelined scheduler (`RING_SIZE = 4`)
75    /// for bounded memory usage.
76    fn is_gpu(&self) -> bool;
77
78    /// Maximum token count this model supports (position embedding limit).
79    ///
80    /// `ClassicBert`: 512. `ModernBERT`: up to model config. Tokens beyond this
81    /// are truncated during tokenization.
82    fn max_tokens(&self) -> usize {
83        512 // default for classic BERT models
84    }
85}
86
87/// Available embedding backend implementations.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum BackendKind {
90    /// CUDA (cudarc, NVIDIA GPUs via cuBLAS + custom kernels).
91    Cuda,
92    /// MLX (Apple Silicon, macOS only).
93    Mlx,
94    /// CPU (ndarray + system BLAS).
95    #[default]
96    Cpu,
97    /// Metal (Apple Silicon, macOS only, direct Metal GPU).
98    Metal,
99}
100
101impl std::fmt::Display for BackendKind {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            Self::Cuda => write!(f, "cuda"),
105            Self::Mlx => write!(f, "mlx"),
106            Self::Cpu => write!(f, "cpu"),
107            Self::Metal => write!(f, "metal"),
108        }
109    }
110}
111
112/// Device hint passed to [`load_backend`].
113///
114/// Backends map this to their native device type. Not all backends support
115/// all devices — unsupported combinations return an error.
116#[derive(Debug, Clone, Copy, Default)]
117pub enum DeviceHint {
118    /// Automatically select the best available device.
119    #[default]
120    Auto,
121    /// Force CPU inference.
122    Cpu,
123    /// Force GPU inference (Metal on macOS, CUDA on Linux/Windows).
124    Gpu,
125}
126
127/// Inference optimization parameters passed through to model loading.
128///
129/// Currently empty — reserved for future optimizations (quantization,
130/// distillation, etc.) that need to be configured at load time.
131#[derive(Debug, Clone, Default)]
132pub struct InferenceOpts {}
133
134/// Construct an embedding backend of the given kind.
135///
136/// Downloads model weights on first use via `hf-hub`. The `device_hint`
137/// is advisory — backends that don't support GPU fall back to CPU.
138///
139/// # Errors
140///
141/// Returns an error if the requested backend was not compiled in (missing
142/// feature flag) or if model loading fails.
143pub fn load_backend(
144    kind: BackendKind,
145    #[cfg_attr(
146        not(any(
147            feature = "cuda",
148            feature = "mlx",
149            feature = "cpu",
150            feature = "cpu-accelerate",
151            feature = "metal"
152        )),
153        expect(unused_variables, reason = "used when backend features are enabled")
154    )]
155    model_repo: &str,
156    #[cfg_attr(
157        not(any(
158            feature = "cuda",
159            feature = "mlx",
160            feature = "cpu",
161            feature = "cpu-accelerate",
162            feature = "metal"
163        )),
164        expect(unused_variables, reason = "used when backend features are enabled")
165    )]
166    device_hint: DeviceHint,
167) -> crate::Result<Box<dyn EmbedBackend>> {
168    match kind {
169        #[cfg(feature = "cuda")]
170        BackendKind::Cuda => {
171            if is_modernbert_model(model_repo) {
172                return load_modernbert_cuda(model_repo, max_layers);
173            }
174            let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
175            Ok(Box::new(backend))
176        }
177        #[cfg(not(feature = "cuda"))]
178        BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
179            "cuda backend requires building with: cargo build --features cuda"
180        ))),
181        #[cfg(feature = "mlx")]
182        BackendKind::Mlx => {
183            let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
184            Ok(Box::new(backend))
185        }
186        #[cfg(not(feature = "mlx"))]
187        BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
188            "mlx backend requires building with: cargo build --features mlx"
189        ))),
190        #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
191        BackendKind::Cpu => {
192            if is_modernbert_model(model_repo) {
193                return load_modernbert_cpu(model_repo);
194            }
195            #[cfg(feature = "cpu")]
196            {
197                let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
198                return Ok(Box::new(backend));
199            }
200            #[cfg(not(feature = "cpu"))]
201            Err(crate::Error::Other(anyhow::anyhow!(
202                "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
203            )))
204        }
205        #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
206        BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
207            "cpu backend requires building with: cargo build --features cpu"
208        ))),
209        #[cfg(feature = "metal")]
210        BackendKind::Metal => {
211            // All models route through the driver/arch system.
212            if is_modernbert_model(model_repo) {
213                return load_modernbert_metal(model_repo);
214            }
215            load_classic_metal(model_repo)
216        }
217        #[cfg(not(feature = "metal"))]
218        BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
219            "metal backend requires building with: cargo build --features metal"
220        ))),
221    }
222}
223
224/// Detect all available backends and load them.
225///
226/// Probes for GPU backends (CUDA, MLX) first, then falls back to CPU.
227/// Returns backends in priority order — the first entry is the primary
228/// (used for query embedding in interactive mode).
229///
230/// # Errors
231///
232/// Returns an error if no backends can be loaded (not even CPU).
233pub fn detect_backends(
234    #[cfg_attr(
235        not(any(
236            feature = "cuda",
237            feature = "mlx",
238            feature = "cpu",
239            feature = "cpu-accelerate",
240            feature = "metal"
241        )),
242        expect(unused_variables, reason = "used when backend features are enabled")
243    )]
244    model_repo: &str,
245) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
246    #[cfg_attr(
247        not(any(
248            feature = "cuda",
249            feature = "mlx",
250            feature = "cpu",
251            feature = "cpu-accelerate",
252            feature = "metal"
253        )),
254        expect(unused_mut, reason = "mut needed when backend features are enabled")
255    )]
256    let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
257
258    // Try CUDA (NVIDIA GPU)
259    #[cfg(feature = "cuda")]
260    {
261        if is_modernbert_model(model_repo) {
262            if let Ok(b) = load_modernbert_cuda(model_repo, max_layers) {
263                backends.push(b);
264            }
265        } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
266            backends.push(Box::new(b));
267        }
268    }
269
270    // Try Metal (Apple Silicon GPU, preferred over MLX)
271    #[cfg(feature = "metal")]
272    {
273        // Route models through the driver/arch system by architecture.
274        if is_modernbert_model(model_repo) {
275            if let Ok(b) = load_modernbert_metal(model_repo) {
276                backends.push(b);
277            }
278        } else if let Ok(b) = load_classic_metal(model_repo) {
279            backends.push(b);
280        }
281    }
282
283    // Try MLX (Apple Silicon GPU, fallback if Metal unavailable)
284    #[cfg(feature = "mlx")]
285    if backends.is_empty()
286        && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
287    {
288        backends.push(Box::new(b));
289    }
290
291    // Add CPU as fallback only when no GPU backend was loaded.
292    // On Apple Silicon, running CPU + MLX concurrently is slower than
293    // MLX alone because they share the same physical cores and memory.
294    // On discrete GPU systems (CUDA), CPU would be a useful helper.
295    #[cfg_attr(
296        not(any(feature = "cpu", feature = "cpu-accelerate")),
297        expect(unused_variables, reason = "used when cpu feature is enabled")
298    )]
299    let has_gpu = backends.iter().any(|b| b.is_gpu());
300    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
301    if !has_gpu {
302        if is_modernbert_model(model_repo) {
303            if let Ok(b) = load_modernbert_cpu(model_repo) {
304                backends.push(b);
305            }
306        } else {
307            #[cfg(feature = "cpu")]
308            if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
309                backends.push(Box::new(b));
310            }
311        }
312    }
313
314    if backends.is_empty() {
315        return Err(crate::Error::Other(anyhow::anyhow!(
316            "no embedding backends available"
317        )));
318    }
319
320    Ok(backends)
321}
322
323// ---------------------------------------------------------------------------
324// ModernBERT loader (driver/arch system)
325// ---------------------------------------------------------------------------
326
327/// Load a `ModernBERT` model on the Metal GPU backend.
328///
329/// Downloads the model from Hugging Face Hub (cached after first download),
330/// memory-maps the safetensors weights, and builds a [`GenericBackend`]
331/// pairing a [`MetalDriver`](driver::metal::MetalDriver) with a
332/// [`ModernBertArch`](arch::modern_bert::ModernBertArch).
333///
334/// # Errors
335///
336/// Returns an error if no Metal device is available, the model cannot be
337/// downloaded, or weight loading fails.
338#[cfg(feature = "metal")]
339pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
340    use driver::metal::{MetalDriver, ModernBertConfig};
341    use generic::GenericBackend;
342    use hf_hub::api::sync::Api;
343
344    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
345    let repo = api.model(model_repo.to_string());
346
347    let config_path = repo
348        .get("config.json")
349        .map_err(|e| crate::Error::Download(e.to_string()))?;
350    let weights_path = repo
351        .get("model.safetensors")
352        .map_err(|e| crate::Error::Download(e.to_string()))?;
353
354    // Parse config.json
355    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
356        path: config_path.display().to_string(),
357        source: e,
358    })?;
359    let config_json: serde_json::Value = serde_json::from_str(&config_str)
360        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
361    let config = ModernBertConfig::from_json(&config_json)?;
362    let max_tokens = config.max_position_embeddings;
363
364    let driver = MetalDriver::new()?;
365    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
366
367    tracing::info!(
368        model_repo,
369        hidden = config.hidden_size,
370        layers = config.num_hidden_layers,
371        heads = config.num_attention_heads,
372        intermediate = config.intermediate_size,
373        max_tokens,
374        "ModernBERT loaded on Metal (driver/arch)"
375    );
376
377    Ok(Box::new(GenericBackend::new(
378        driver, arch, max_tokens, true, mmap,
379    )))
380}
381
382/// Load `ModernBERT` on CPU via the driver/arch system.
383#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
384pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
385    use driver::cpu::{CpuDriver, ModernBertConfig};
386    use generic::GenericBackend;
387    use hf_hub::api::sync::Api;
388
389    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
390    let repo = api.model(model_repo.to_string());
391
392    let config_path = repo
393        .get("config.json")
394        .map_err(|e| crate::Error::Download(e.to_string()))?;
395    let weights_path = repo
396        .get("model.safetensors")
397        .map_err(|e| crate::Error::Download(e.to_string()))?;
398
399    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
400        path: config_path.display().to_string(),
401        source: e,
402    })?;
403    let config_json: serde_json::Value = serde_json::from_str(&config_str)
404        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
405    let config = ModernBertConfig::from_json(&config_json)?;
406    let max_tokens = config.max_position_embeddings;
407
408    let driver = CpuDriver::new()?;
409    let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
410
411    tracing::info!(
412        model_repo,
413        hidden = config.hidden_size,
414        layers = config.num_hidden_layers,
415        heads = config.num_attention_heads,
416        max_tokens,
417        "ModernBERT loaded on CPU (driver/arch)"
418    );
419
420    Ok(Box::new(GenericBackend::new(
421        driver, arch, max_tokens, false, mmap,
422    )))
423}
424
425/// Load `ModernBERT` on CUDA via the driver/arch system.
426///
427/// Creates a [`CudaDriver`](driver::cuda::CudaDriver), loads safetensors weights
428/// onto the GPU, pre-converts GEMM weights to FP16, builds RoPE caches, and
429/// wraps the result in a [`GenericBackend`](generic::GenericBackend).
430///
431/// # Errors
432///
433/// Returns an error if no CUDA device is available, the model cannot be
434/// downloaded, or weight loading fails.
435#[cfg(feature = "cuda")]
436pub fn load_modernbert_cuda(
437    model_repo: &str,
438    max_layers: Option<usize>,
439) -> crate::Result<Box<dyn EmbedBackend>> {
440    use driver::cuda::{CudaDriver, ModernBertConfig};
441    use generic::GenericBackend;
442    use hf_hub::api::sync::Api;
443
444    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
445    let repo = api.model(model_repo.to_string());
446
447    let config_path = repo
448        .get("config.json")
449        .map_err(|e| crate::Error::Download(e.to_string()))?;
450    let weights_path = repo
451        .get("model.safetensors")
452        .map_err(|e| crate::Error::Download(e.to_string()))?;
453
454    // Parse config.json
455    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
456        path: config_path.display().to_string(),
457        source: e,
458    })?;
459    let config_json: serde_json::Value = serde_json::from_str(&config_str)
460        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
461    let config = ModernBertConfig::from_json(&config_json)?;
462    let max_tokens = config.max_position_embeddings;
463
464    let driver = CudaDriver::new()?;
465    let (mut arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
466    arch.max_layers = max_layers;
467
468    tracing::info!(
469        model_repo,
470        hidden = config.hidden_size,
471        layers = config.num_hidden_layers,
472        heads = config.num_attention_heads,
473        intermediate = config.intermediate_size,
474        max_tokens,
475        "ModernBERT loaded on CUDA (driver/arch)"
476    );
477
478    // CUDA benefits from larger batches to saturate GPU SMs (128 SMs on RTX 4090).
479    // Metal uses 32 (AMX coprocessor-limited). CUDA can handle 128+ easily.
480    Ok(Box::new(GenericBackend::with_max_batch(
481        driver, arch, max_tokens, true, mmap, 32,
482    )))
483}
484
485/// Check whether a model repo uses the `ModernBERT` architecture.
486///
487/// Downloads and inspects `config.json` to check for `"model_type": "modernbert"`.
488/// Returns `false` on any download or parse error (fail-open for detection).
489#[cfg(any(
490    feature = "cuda",
491    feature = "metal",
492    feature = "cpu",
493    feature = "cpu-accelerate"
494))]
495fn is_modernbert_model(model_repo: &str) -> bool {
496    let Ok(api) = hf_hub::api::sync::Api::new() else {
497        return false;
498    };
499    let repo = api.model(model_repo.to_string());
500    let Ok(config_path) = repo.get("config.json") else {
501        return false;
502    };
503    let Ok(config_str) = std::fs::read_to_string(&config_path) else {
504        return false;
505    };
506    let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
507        return false;
508    };
509    json.get("model_type")
510        .and_then(serde_json::Value::as_str)
511        .is_some_and(|t| t == "modernbert")
512}
513
514// ---------------------------------------------------------------------------
515// ClassicBert loader (driver/arch system)
516// ---------------------------------------------------------------------------
517
518/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the Metal GPU backend.
519///
520/// Downloads the model from Hugging Face Hub (cached after first download),
521/// memory-maps the safetensors weights, fuses Q/K/V into a single tensor per
522/// layer, and builds a [`GenericBackend`] pairing a
523/// [`MetalDriver`](driver::metal::MetalDriver) with a
524/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
525///
526/// # Errors
527///
528/// Returns an error if no Metal device is available, the model cannot be
529/// downloaded, or weight loading fails.
530#[cfg(feature = "metal")]
531pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
532    use driver::metal::{ClassicBertConfig, MetalDriver};
533    use generic::GenericBackend;
534    use hf_hub::api::sync::Api;
535
536    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
537    let repo = api.model(model_repo.to_string());
538
539    let config_path = repo
540        .get("config.json")
541        .map_err(|e| crate::Error::Download(e.to_string()))?;
542    let weights_path = repo
543        .get("model.safetensors")
544        .map_err(|e| crate::Error::Download(e.to_string()))?;
545
546    // Parse config.json
547    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
548        path: config_path.display().to_string(),
549        source: e,
550    })?;
551    let config_json: serde_json::Value = serde_json::from_str(&config_str)
552        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
553    let config = ClassicBertConfig::from_json(&config_json)?;
554    let max_tokens = config.max_position_embeddings;
555
556    let driver = MetalDriver::new()?;
557    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
558
559    tracing::info!(
560        model_repo,
561        hidden = config.hidden_size,
562        layers = config.num_hidden_layers,
563        heads = config.num_attention_heads,
564        intermediate = config.intermediate_size,
565        max_tokens,
566        "ClassicBert loaded on Metal (driver/arch)"
567    );
568
569    Ok(Box::new(GenericBackend::new(
570        driver, arch, max_tokens, true, mmap,
571    )))
572}
573
574// ---------------------------------------------------------------------------
575// ClassicBert loader (CPU driver/arch system)
576// ---------------------------------------------------------------------------
577
578/// Load a `ClassicBert` model (e.g. `BAAI/bge-small-en-v1.5`) on the CPU backend
579/// via the driver/arch system.
580///
581/// Downloads the model from Hugging Face Hub (cached after first download),
582/// reads safetensors weights into `Vec<f32>` tensors, fuses Q/K/V per layer,
583/// and builds a [`GenericBackend`] pairing a
584/// [`CpuDriver`](driver::cpu::CpuDriver) with a
585/// [`ClassicBertArch`](arch::classic_bert::ClassicBertArch).
586///
587/// # Errors
588///
589/// Returns an error if the model cannot be downloaded or weight loading fails.
590#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
591pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
592    use driver::cpu::{ClassicBertConfig, CpuDriver};
593    use generic::GenericBackend;
594    use hf_hub::api::sync::Api;
595
596    let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
597    let repo = api.model(model_repo.to_string());
598
599    let config_path = repo
600        .get("config.json")
601        .map_err(|e| crate::Error::Download(e.to_string()))?;
602    let weights_path = repo
603        .get("model.safetensors")
604        .map_err(|e| crate::Error::Download(e.to_string()))?;
605
606    // Parse config.json
607    let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
608        path: config_path.display().to_string(),
609        source: e,
610    })?;
611    let config_json: serde_json::Value = serde_json::from_str(&config_str)
612        .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
613    let config = ClassicBertConfig::from_json(&config_json)?;
614    let max_tokens = config.max_position_embeddings;
615
616    let driver = CpuDriver::new()?;
617    let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
618
619    tracing::info!(
620        model_repo,
621        hidden = config.hidden_size,
622        layers = config.num_hidden_layers,
623        heads = config.num_attention_heads,
624        intermediate = config.intermediate_size,
625        max_tokens,
626        "ClassicBert loaded on CPU (driver/arch)"
627    );
628
629    Ok(Box::new(GenericBackend::new(
630        driver, arch, max_tokens, false, mmap,
631    )))
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637
638    /// Verify that `EmbedBackend` is object-safe by constructing a trait object type.
639    #[test]
640    fn trait_is_object_safe() {
641        // If this compiles, the trait is object-safe.
642        fn _assert_object_safe(_: &dyn EmbedBackend) {}
643    }
644
645    /// Verify that `Box<dyn EmbedBackend>` is `Send`.
646    #[test]
647    fn trait_object_is_send() {
648        fn assert_send<T: Send>() {}
649        assert_send::<Box<dyn EmbedBackend>>();
650    }
651
652    /// Verify that `Box<dyn EmbedBackend>` is `Sync` (needed for `&dyn` across threads).
653    #[test]
654    fn trait_object_is_sync() {
655        fn assert_sync<T: Sync>() {}
656        assert_sync::<Box<dyn EmbedBackend>>();
657    }
658
659    /// Verify that `Arc<dyn EmbedBackend>` is `Send` (needed for ring-buffer pipeline).
660    #[test]
661    fn arc_trait_object_is_send() {
662        fn assert_send<T: Send>() {}
663        assert_send::<std::sync::Arc<dyn EmbedBackend>>();
664    }
665
666    #[test]
667    fn encoding_construction() {
668        let enc = Encoding {
669            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
670            attention_mask: vec![1, 1, 1, 1, 1, 1],
671            token_type_ids: vec![0, 0, 0, 0, 0, 0],
672        };
673        assert_eq!(enc.input_ids.len(), 6);
674        assert_eq!(enc.attention_mask.len(), 6);
675        assert_eq!(enc.token_type_ids.len(), 6);
676    }
677
678    #[test]
679    fn encoding_clone() {
680        let enc = Encoding {
681            input_ids: vec![101, 102],
682            attention_mask: vec![1, 1],
683            token_type_ids: vec![0, 0],
684        };
685        let cloned = enc.clone();
686        assert_eq!(enc.input_ids, cloned.input_ids);
687    }
688
689    #[test]
690    fn backend_kind_default_is_cpu() {
691        assert_eq!(BackendKind::default(), BackendKind::Cpu);
692    }
693
694    #[test]
695    fn backend_kind_display() {
696        assert_eq!(BackendKind::Cuda.to_string(), "cuda");
697        assert_eq!(BackendKind::Mlx.to_string(), "mlx");
698        assert_eq!(BackendKind::Cpu.to_string(), "cpu");
699    }
700
701    #[cfg(not(feature = "mlx"))]
702    #[test]
703    fn load_backend_mlx_not_compiled() {
704        let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
705        assert!(result.is_err());
706    }
707
708    #[cfg(feature = "cpu")]
709    #[test]
710    fn detect_backends_returns_at_least_one() {
711        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
712        assert!(!backends.is_empty());
713    }
714
715    #[cfg(all(feature = "cpu", not(feature = "mlx")))]
716    #[test]
717    fn detect_backends_returns_at_least_one_backend() {
718        let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
719        assert!(!backends.is_empty(), "should detect at least one backend");
720    }
721
722    /// Load `ModernBERT` on Metal and embed a short token sequence.
723    ///
724    /// Verifies that the full pipeline (weight loading, forward pass, pooling,
725    /// L2 normalization) produces a 768-dim unit vector.
726    #[cfg(feature = "metal")]
727    #[test]
728    #[ignore = "requires model download (~570MB)"]
729    fn modernbert_loads_and_embeds() {
730        use crate::backend::driver::Driver;
731
732        let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
733        assert!(backend.is_gpu(), "Metal backend should be GPU");
734
735        let enc = Encoding {
736            input_ids: vec![1, 100, 200, 300, 2],
737            attention_mask: vec![1; 5],
738            token_type_ids: vec![0; 5],
739        };
740
741        // Stage-by-stage diagnostic using the driver directly
742        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
743        let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
744
745        // Check: can we read back input_ids?
746        let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
747        eprintln!("input_ids: {:?}", &ids_host[0][..5]);
748
749        // Check: embedding lookup
750        // Need the tok_embeddings weight — load weights directly
751        let api = hf_hub::api::sync::Api::new().unwrap();
752        let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
753        let weights_path = repo.get("model.safetensors").unwrap();
754        let config_path = repo.get("config.json").unwrap();
755        let config_str = std::fs::read_to_string(&config_path).unwrap();
756        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
757        let config =
758            crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
759        let (arch, _mmap) = driver
760            .load_modern_bert_weights(&weights_path, &config)
761            .unwrap();
762
763        let hidden = driver
764            .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
765            .unwrap();
766        let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
767        let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
768        eprintln!(
769            "embedding: {nz}/{} nonzero, first 5: {:?}",
770            h[0].len(),
771            &h[0][..5]
772        );
773
774        // Stage-by-stage forward pass bisection
775        let total = 8; // padded seq
776        let hd = 768;
777        let nh = 12;
778        let head_dim = 64;
779
780        // After embedding LN
781        let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
782        let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
783        driver
784            .layer_norm(
785                &mut ln_out,
786                &emb_clone,
787                &arch.weights.emb_norm_weight,
788                &arch.weights.zero_bias,
789                total,
790                hd,
791                1e-5,
792            )
793            .unwrap();
794        let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
795        let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
796        eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
797
798        // Layer 0 QKV GEMM
799        let layer0 = &arch.weights.layers[0];
800        let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
801        driver
802            .gemm(
803                &ln_out,
804                &layer0.qkv_weight,
805                &mut qkv,
806                total,
807                3 * hd,
808                hd,
809                true,
810            )
811            .unwrap();
812        let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
813        let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
814        eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
815
816        // QKV split
817        let mut q = driver.alloc_zeros(total * hd).unwrap();
818        let mut k = driver.alloc_zeros(total * hd).unwrap();
819        let mut v = driver.alloc_zeros(total * hd).unwrap();
820        driver
821            .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
822            .unwrap();
823        let q_h = driver.to_host(&q, 1, total * hd).unwrap();
824        let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
825        eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
826
827        // Attention scores
828        let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
829        driver
830            .gemm_batched(
831                &q,
832                &k,
833                &mut scores,
834                8,
835                8,
836                head_dim,
837                true,
838                8 * head_dim,
839                8 * head_dim,
840                8 * 8,
841                nh,
842            )
843            .unwrap();
844        let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
845        let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
846        eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
847
848        // Full forward pass
849        use crate::backend::arch::ModelArch;
850        let enc2 = Encoding {
851            input_ids: vec![1, 100, 200, 300, 2],
852            attention_mask: vec![1; 5],
853            token_type_ids: vec![0; 5],
854        };
855
856        let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
857        let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
858        let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
859        eprintln!(
860            "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
861            &quick[0][..3]
862        );
863
864        // MRL truncation
865        eprintln!("\n=== ModernBERT MRL Truncation ===");
866        let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
867        let full_emb = &full[0];
868        for dims in [64, 128, 256, 384, 512, 768] {
869            let t: Vec<f32> = full_emb[..dims].to_vec();
870            let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
871            let f_norm: f32 = full_emb[..dims]
872                .iter()
873                .map(|x| x * x)
874                .sum::<f32>()
875                .sqrt()
876                .max(1e-12);
877            let cos: f32 = t
878                .iter()
879                .zip(&full_emb[..dims])
880                .map(|(a, b)| a * b)
881                .sum::<f32>()
882                / (t_norm * f_norm);
883            eprintln!("  dims={dims:>3}: cosine={cos:.6}");
884        }
885
886        // Throughput benchmark
887        eprintln!("\n=== ModernBERT Throughput ===");
888        // Build 32 encodings of varying length
889        let mut encs = Vec::new();
890        for i in 0..32 {
891            let len = 16 + (i * 4); // 16 to 140 tokens
892            let mut ids = vec![1_i64]; // CLS
893            for j in 1..len - 1 {
894                ids.push(100 + j as i64);
895            }
896            ids.push(2); // SEP
897            encs.push(Encoding {
898                input_ids: ids.clone(),
899                attention_mask: vec![1; ids.len()],
900                token_type_ids: vec![0; ids.len()],
901            });
902        }
903
904        // Warmup
905        let _ = arch.forward(&driver, &encs[..4]);
906
907        // Timed run
908        let t0 = std::time::Instant::now();
909        let result = arch.forward(&driver, &encs).unwrap();
910        let elapsed = t0.elapsed();
911        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
912        eprintln!(
913            "  batch={}, time={:.1}ms, throughput={:.1}/s",
914            encs.len(),
915            elapsed.as_secs_f64() * 1000.0,
916            throughput
917        );
918        assert_eq!(result.len(), 32);
919
920        // Batch=1 timing (critical — CLI query path)
921        let single = vec![encs[0].clone()];
922        let t1 = std::time::Instant::now();
923        let _ = arch.forward(&driver, &single).unwrap();
924        let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
925        eprintln!("  batch=1, time={single_ms:.1}ms");
926    }
927
928    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on Metal via the driver/arch system.
929    ///
930    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector,
931    /// compares against the CPU backend for numerical equivalence, and measures
932    /// throughput (target: >=308/s matching monolithic).
933    #[cfg(feature = "metal")]
934    #[test]
935    #[ignore = "requires model download (~33MB)"]
936    fn classic_bert_driver_arch() {
937        use crate::backend::arch::ModelArch;
938
939        let model_repo = "BAAI/bge-small-en-v1.5";
940
941        // Load via driver/arch system
942        let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
943        assert!(backend.is_gpu(), "Metal backend should be GPU");
944
945        let enc = Encoding {
946            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
947            attention_mask: vec![1, 1, 1, 1, 1, 1],
948            token_type_ids: vec![0, 0, 0, 0, 0, 0],
949        };
950
951        // Basic forward pass
952        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
953        assert_eq!(result.len(), 1);
954        assert_eq!(result[0].len(), 384);
955
956        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
957        eprintln!(
958            "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
959            &result[0][..3]
960        );
961        assert!(
962            (l2 - 1.0).abs() < 0.01,
963            "embedding should be L2-normalized, got L2={l2}"
964        );
965
966        // Compare against CPU backend (reliable reference)
967        #[cfg(feature = "cpu")]
968        {
969            let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
970                .expect("CPU load failed");
971            let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
972            eprintln!("CPU  first 5: {:?}", &cpu_result[0][..5]);
973            eprintln!("NEW  first 5: {:?}", &result[0][..5]);
974            let cosine: f32 = result[0]
975                .iter()
976                .zip(&cpu_result[0])
977                .map(|(a, b)| a * b)
978                .sum();
979            eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
980            assert!(
981                cosine > 0.95,
982                "cosine similarity vs CPU should be >0.95, got {cosine}"
983            );
984        }
985
986        // Throughput benchmark
987        eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
988        let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
989        let config_path = {
990            let api = hf_hub::api::sync::Api::new().unwrap();
991            let repo = api.model(model_repo.to_string());
992            repo.get("config.json").unwrap()
993        };
994        let weights_path = {
995            let api = hf_hub::api::sync::Api::new().unwrap();
996            let repo = api.model(model_repo.to_string());
997            repo.get("model.safetensors").unwrap()
998        };
999        let config_str = std::fs::read_to_string(&config_path).unwrap();
1000        let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1001        let config =
1002            crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1003        let (arch, _mmap) = driver
1004            .load_classic_bert_weights(&weights_path, &config)
1005            .unwrap();
1006
1007        // Build 32 encodings of varying length
1008        let mut encs = Vec::new();
1009        for i in 0..32 {
1010            let len = 16 + (i * 4); // 16 to 140 tokens
1011            let mut ids = vec![101_i64]; // [CLS]
1012            for j in 1..len - 1 {
1013                ids.push(100 + j as i64);
1014            }
1015            ids.push(102); // [SEP]
1016            encs.push(Encoding {
1017                input_ids: ids.clone(),
1018                attention_mask: vec![1; ids.len()],
1019                token_type_ids: vec![0; ids.len()],
1020            });
1021        }
1022
1023        // Warmup
1024        let _ = arch.forward(&driver, &encs[..4]);
1025
1026        // Timed run
1027        let t0 = std::time::Instant::now();
1028        let bench_result = arch.forward(&driver, &encs).unwrap();
1029        let elapsed = t0.elapsed();
1030        let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1031        eprintln!(
1032            "  batch={}, time={:.1}ms, throughput={:.1}/s",
1033            encs.len(),
1034            elapsed.as_secs_f64() * 1000.0,
1035            throughput
1036        );
1037        assert_eq!(bench_result.len(), 32);
1038    }
1039
1040    /// Load `ClassicBert` (`BAAI/bge-small-en-v1.5`) on CPU via the driver/arch system.
1041    ///
1042    /// Verifies that the full pipeline produces a 384-dim L2-normalized vector
1043    /// and compares against the monolithic CPU backend for numerical equivalence.
1044    #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1045    #[test]
1046    #[ignore = "requires model download (~33MB)"]
1047    fn classic_bert_cpu_driver_arch() {
1048        let model_repo = "BAAI/bge-small-en-v1.5";
1049
1050        // Load via new driver/arch system
1051        let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1052        assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1053
1054        let enc = Encoding {
1055            input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1056            attention_mask: vec![1, 1, 1, 1, 1, 1],
1057            token_type_ids: vec![0, 0, 0, 0, 0, 0],
1058        };
1059
1060        // Basic forward pass
1061        let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1062        assert_eq!(result.len(), 1);
1063        assert_eq!(result[0].len(), 384);
1064
1065        let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1066        eprintln!(
1067            "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1068            &result[0][..5]
1069        );
1070        assert!(
1071            (l2 - 1.0).abs() < 0.01,
1072            "embedding should be L2-normalized, got L2={l2}"
1073        );
1074
1075        // Compare against monolithic CPU backend (reference)
1076        #[cfg(feature = "cpu")]
1077        {
1078            let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1079                .expect("monolithic CPU load failed");
1080            let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1081            eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1082            eprintln!("New  first 5: {:?}", &result[0][..5]);
1083            let cosine: f32 = result[0]
1084                .iter()
1085                .zip(&cpu_result[0])
1086                .map(|(a, b)| a * b)
1087                .sum();
1088            eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1089            assert!(
1090                cosine > 0.999,
1091                "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1092            );
1093        }
1094    }
1095}