1pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19#[cfg(feature = "cuda")]
20pub mod nvrtc_cubin;
21
22#[derive(Debug, Clone)]
28pub struct Encoding {
29 pub input_ids: Vec<i64>,
31 pub attention_mask: Vec<i64>,
33 pub token_type_ids: Vec<i64>,
35}
36
37pub trait EmbedBackend: Send + Sync {
50 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
60
61 fn supports_clone(&self) -> bool;
65
66 fn clone_backend(&self) -> Box<dyn EmbedBackend>;
73
74 fn is_gpu(&self) -> bool;
79
80 fn max_tokens(&self) -> usize {
85 512 }
87
88 fn name(&self) -> &'static str {
92 if self.is_gpu() { "GPU" } else { "CPU" }
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
98pub enum BackendKind {
99 Cuda,
101 Mlx,
103 #[default]
105 Cpu,
106 Metal,
108}
109
110impl std::fmt::Display for BackendKind {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match self {
113 Self::Cuda => write!(f, "cuda"),
114 Self::Mlx => write!(f, "mlx"),
115 Self::Cpu => write!(f, "cpu"),
116 Self::Metal => write!(f, "metal"),
117 }
118 }
119}
120
121#[derive(Debug, Clone, Copy, Default)]
126pub enum DeviceHint {
127 #[default]
129 Auto,
130 Cpu,
132 Gpu,
134}
135
136#[derive(Debug, Clone, Default)]
141pub struct InferenceOpts {}
142
143pub fn load_backend(
153 kind: BackendKind,
154 #[cfg_attr(
155 not(any(
156 feature = "cuda",
157 feature = "mlx",
158 feature = "cpu",
159 feature = "cpu-accelerate",
160 feature = "metal"
161 )),
162 expect(unused_variables, reason = "used when backend features are enabled")
163 )]
164 model_repo: &str,
165 #[cfg_attr(
166 not(any(
167 feature = "cuda",
168 feature = "mlx",
169 feature = "cpu",
170 feature = "cpu-accelerate",
171 feature = "metal"
172 )),
173 expect(unused_variables, reason = "used when backend features are enabled")
174 )]
175 device_hint: DeviceHint,
176) -> crate::Result<Box<dyn EmbedBackend>> {
177 match kind {
178 #[cfg(feature = "cuda")]
179 BackendKind::Cuda => {
180 if is_modernbert_model(model_repo) {
181 return load_modernbert_cuda(model_repo);
182 }
183 let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
184 Ok(Box::new(backend))
185 }
186 #[cfg(not(feature = "cuda"))]
187 BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
188 "cuda backend requires building with: cargo build --features cuda"
189 ))),
190 #[cfg(feature = "mlx")]
191 BackendKind::Mlx => {
192 let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
193 Ok(Box::new(backend))
194 }
195 #[cfg(not(feature = "mlx"))]
196 BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
197 "mlx backend requires building with: cargo build --features mlx"
198 ))),
199 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
200 BackendKind::Cpu => {
201 if is_modernbert_model(model_repo) {
202 return load_modernbert_cpu(model_repo);
203 }
204 #[cfg(feature = "cpu")]
205 {
206 let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
207 #[expect(
208 clippy::needless_return,
209 reason = "return needed before cfg(not) fallback"
210 )]
211 return Ok(Box::new(backend));
212 }
213 #[cfg(not(feature = "cpu"))]
214 Err(crate::Error::Other(anyhow::anyhow!(
215 "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
216 )))
217 }
218 #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
219 BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
220 "cpu backend requires building with: cargo build --features cpu"
221 ))),
222 #[cfg(feature = "metal")]
223 BackendKind::Metal => {
224 if is_modernbert_model(model_repo) {
226 return load_modernbert_metal(model_repo);
227 }
228 load_classic_metal(model_repo)
229 }
230 #[cfg(not(feature = "metal"))]
231 BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
232 "metal backend requires building with: cargo build --features metal"
233 ))),
234 }
235}
236
237pub fn detect_backends(
247 #[cfg_attr(
248 not(any(
249 feature = "cuda",
250 feature = "mlx",
251 feature = "cpu",
252 feature = "cpu-accelerate",
253 feature = "metal"
254 )),
255 expect(unused_variables, reason = "used when backend features are enabled")
256 )]
257 model_repo: &str,
258) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
259 #[cfg_attr(
260 not(any(
261 feature = "cuda",
262 feature = "mlx",
263 feature = "cpu",
264 feature = "cpu-accelerate",
265 feature = "metal"
266 )),
267 expect(unused_mut, reason = "mut needed when backend features are enabled")
268 )]
269 let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
270
271 #[cfg(feature = "cuda")]
273 {
274 if is_modernbert_model(model_repo) {
275 if let Ok(b) = load_modernbert_cuda(model_repo) {
276 backends.push(b);
277 }
278 } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
279 backends.push(Box::new(b));
280 }
281 }
282
283 #[cfg(feature = "metal")]
285 {
286 if is_modernbert_model(model_repo) {
288 if let Ok(b) = load_modernbert_metal(model_repo) {
289 backends.push(b);
290 }
291 } else if let Ok(b) = load_classic_metal(model_repo) {
292 backends.push(b);
293 }
294 }
295
296 #[cfg(feature = "mlx")]
298 if backends.is_empty()
299 && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
300 {
301 backends.push(Box::new(b));
302 }
303
304 #[cfg_attr(
309 not(any(feature = "cpu", feature = "cpu-accelerate")),
310 expect(unused_variables, reason = "used when cpu feature is enabled")
311 )]
312 let has_gpu = backends.iter().any(|b| b.is_gpu());
313 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
314 if !has_gpu {
315 if is_modernbert_model(model_repo) {
316 if let Ok(b) = load_modernbert_cpu(model_repo) {
317 backends.push(b);
318 }
319 } else {
320 #[cfg(feature = "cpu")]
321 if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
322 backends.push(Box::new(b));
323 }
324 }
325 }
326
327 if backends.is_empty() {
328 return Err(crate::Error::Other(anyhow::anyhow!(
329 "no embedding backends available"
330 )));
331 }
332
333 Ok(backends)
334}
335
336#[cfg(feature = "metal")]
352pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
353 use driver::metal::{MetalDriver, ModernBertConfig};
354 use generic::GenericBackend;
355 use hf_hub::api::sync::Api;
356
357 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
358 let repo = api.model(model_repo.to_string());
359
360 let config_path = repo
361 .get("config.json")
362 .map_err(|e| crate::Error::Download(e.to_string()))?;
363 let weights_path = repo
364 .get("model.safetensors")
365 .map_err(|e| crate::Error::Download(e.to_string()))?;
366
367 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
369 path: config_path.display().to_string(),
370 source: e,
371 })?;
372 let config_json: serde_json::Value = serde_json::from_str(&config_str)
373 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
374 let config = ModernBertConfig::from_json(&config_json)?;
375 let max_tokens = config.max_position_embeddings;
376
377 let driver = MetalDriver::new()?;
378 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
379
380 tracing::info!(
381 model_repo,
382 hidden = config.hidden_size,
383 layers = config.num_hidden_layers,
384 heads = config.num_attention_heads,
385 intermediate = config.intermediate_size,
386 max_tokens,
387 "ModernBERT loaded on Metal (driver/arch)"
388 );
389
390 Ok(Box::new(GenericBackend::new(
391 driver, arch, max_tokens, true, mmap,
392 )))
393}
394
395#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
397pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
398 use driver::cpu::{CpuDriver, ModernBertConfig};
399 use generic::GenericBackend;
400 use hf_hub::api::sync::Api;
401
402 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
403 let repo = api.model(model_repo.to_string());
404
405 let config_path = repo
406 .get("config.json")
407 .map_err(|e| crate::Error::Download(e.to_string()))?;
408 let weights_path = repo
409 .get("model.safetensors")
410 .map_err(|e| crate::Error::Download(e.to_string()))?;
411
412 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
413 path: config_path.display().to_string(),
414 source: e,
415 })?;
416 let config_json: serde_json::Value = serde_json::from_str(&config_str)
417 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
418 let config = ModernBertConfig::from_json(&config_json)?;
419 let max_tokens = config.max_position_embeddings;
420
421 let driver = CpuDriver::new()?;
422 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
423
424 tracing::info!(
425 model_repo,
426 hidden = config.hidden_size,
427 layers = config.num_hidden_layers,
428 heads = config.num_attention_heads,
429 max_tokens,
430 "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
431 );
432
433 Ok(Box::new(GenericBackend::new_shared(
434 driver, arch, max_tokens, false, mmap,
435 )))
436}
437
438#[cfg(feature = "cuda")]
449pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
450 use driver::cuda::{CudaDriver, ModernBertConfig};
451 use generic::GenericBackend;
452 use hf_hub::api::sync::Api;
453
454 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
455 let repo = api.model(model_repo.to_string());
456
457 let config_path = repo
458 .get("config.json")
459 .map_err(|e| crate::Error::Download(e.to_string()))?;
460 let weights_path = repo
461 .get("model.safetensors")
462 .map_err(|e| crate::Error::Download(e.to_string()))?;
463
464 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
466 path: config_path.display().to_string(),
467 source: e,
468 })?;
469 let config_json: serde_json::Value = serde_json::from_str(&config_str)
470 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
471 let config = ModernBertConfig::from_json(&config_json)?;
472 let max_tokens = config.max_position_embeddings;
473
474 let driver = CudaDriver::new()?;
475 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
476
477 tracing::info!(
478 model_repo,
479 hidden = config.hidden_size,
480 layers = config.num_hidden_layers,
481 heads = config.num_attention_heads,
482 intermediate = config.intermediate_size,
483 max_tokens,
484 "ModernBERT loaded on CUDA (driver/arch)"
485 );
486
487 Ok(Box::new(GenericBackend::with_max_batch(
490 driver,
491 arch,
492 max_tokens,
493 true,
494 generic::MmapHolder::Owned(mmap),
495 32,
496 )))
497}
498
499#[cfg(any(
504 feature = "cuda",
505 feature = "metal",
506 feature = "cpu",
507 feature = "cpu-accelerate"
508))]
509fn is_modernbert_model(model_repo: &str) -> bool {
510 let Ok(api) = hf_hub::api::sync::Api::new() else {
511 return false;
512 };
513 let repo = api.model(model_repo.to_string());
514 let Ok(config_path) = repo.get("config.json") else {
515 return false;
516 };
517 let Ok(config_str) = std::fs::read_to_string(&config_path) else {
518 return false;
519 };
520 let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
521 return false;
522 };
523 json.get("model_type")
524 .and_then(serde_json::Value::as_str)
525 .is_some_and(|t| t == "modernbert")
526}
527
528#[cfg(feature = "metal")]
545pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
546 use driver::metal::{ClassicBertConfig, MetalDriver};
547 use generic::GenericBackend;
548 use hf_hub::api::sync::Api;
549
550 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
551 let repo = api.model(model_repo.to_string());
552
553 let config_path = repo
554 .get("config.json")
555 .map_err(|e| crate::Error::Download(e.to_string()))?;
556 let weights_path = repo
557 .get("model.safetensors")
558 .map_err(|e| crate::Error::Download(e.to_string()))?;
559
560 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
562 path: config_path.display().to_string(),
563 source: e,
564 })?;
565 let config_json: serde_json::Value = serde_json::from_str(&config_str)
566 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
567 let config = ClassicBertConfig::from_json(&config_json)?;
568 let max_tokens = config.max_position_embeddings;
569
570 let driver = MetalDriver::new()?;
571 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
572
573 tracing::info!(
574 model_repo,
575 hidden = config.hidden_size,
576 layers = config.num_hidden_layers,
577 heads = config.num_attention_heads,
578 intermediate = config.intermediate_size,
579 max_tokens,
580 "ClassicBert loaded on Metal (driver/arch)"
581 );
582
583 Ok(Box::new(GenericBackend::new(
584 driver, arch, max_tokens, true, mmap,
585 )))
586}
587
588#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
605pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
606 use driver::cpu::{ClassicBertConfig, CpuDriver};
607 use generic::GenericBackend;
608 use hf_hub::api::sync::Api;
609
610 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
611 let repo = api.model(model_repo.to_string());
612
613 let config_path = repo
614 .get("config.json")
615 .map_err(|e| crate::Error::Download(e.to_string()))?;
616 let weights_path = repo
617 .get("model.safetensors")
618 .map_err(|e| crate::Error::Download(e.to_string()))?;
619
620 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
622 path: config_path.display().to_string(),
623 source: e,
624 })?;
625 let config_json: serde_json::Value = serde_json::from_str(&config_str)
626 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
627 let config = ClassicBertConfig::from_json(&config_json)?;
628 let max_tokens = config.max_position_embeddings;
629
630 let driver = CpuDriver::new()?;
631 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
632
633 tracing::info!(
634 model_repo,
635 hidden = config.hidden_size,
636 layers = config.num_hidden_layers,
637 heads = config.num_attention_heads,
638 intermediate = config.intermediate_size,
639 max_tokens,
640 "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
641 );
642
643 Ok(Box::new(GenericBackend::new_shared(
644 driver, arch, max_tokens, false, mmap,
645 )))
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651
652 #[test]
654 fn trait_is_object_safe() {
655 fn _assert_object_safe(_: &dyn EmbedBackend) {}
657 }
658
659 #[test]
661 fn trait_object_is_send() {
662 fn assert_send<T: Send>() {}
663 assert_send::<Box<dyn EmbedBackend>>();
664 }
665
666 #[test]
668 fn trait_object_is_sync() {
669 fn assert_sync<T: Sync>() {}
670 assert_sync::<Box<dyn EmbedBackend>>();
671 }
672
673 #[test]
675 fn arc_trait_object_is_send() {
676 fn assert_send<T: Send>() {}
677 assert_send::<std::sync::Arc<dyn EmbedBackend>>();
678 }
679
680 #[test]
681 fn encoding_construction() {
682 let enc = Encoding {
683 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
684 attention_mask: vec![1, 1, 1, 1, 1, 1],
685 token_type_ids: vec![0, 0, 0, 0, 0, 0],
686 };
687 assert_eq!(enc.input_ids.len(), 6);
688 assert_eq!(enc.attention_mask.len(), 6);
689 assert_eq!(enc.token_type_ids.len(), 6);
690 }
691
692 #[test]
693 fn encoding_clone() {
694 let enc = Encoding {
695 input_ids: vec![101, 102],
696 attention_mask: vec![1, 1],
697 token_type_ids: vec![0, 0],
698 };
699 let cloned = enc.clone();
700 assert_eq!(enc.input_ids, cloned.input_ids);
701 }
702
703 #[test]
704 fn backend_kind_default_is_cpu() {
705 assert_eq!(BackendKind::default(), BackendKind::Cpu);
706 }
707
708 #[test]
709 fn backend_kind_display() {
710 assert_eq!(BackendKind::Cuda.to_string(), "cuda");
711 assert_eq!(BackendKind::Mlx.to_string(), "mlx");
712 assert_eq!(BackendKind::Cpu.to_string(), "cpu");
713 }
714
715 #[cfg(not(feature = "mlx"))]
716 #[test]
717 fn load_backend_mlx_not_compiled() {
718 let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
719 assert!(result.is_err());
720 }
721
722 #[cfg(feature = "cpu")]
723 #[test]
724 fn detect_backends_returns_at_least_one() {
725 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
726 assert!(!backends.is_empty());
727 }
728
729 #[cfg(all(feature = "cpu", not(feature = "mlx")))]
730 #[test]
731 fn detect_backends_returns_at_least_one_backend() {
732 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
733 assert!(!backends.is_empty(), "should detect at least one backend");
734 }
735
736 #[cfg(feature = "metal")]
741 #[test]
742 #[ignore = "requires model download (~570MB)"]
743 #[expect(clippy::too_many_lines, reason = "end-to-end backend diagnostic test")]
744 fn modernbert_loads_and_embeds() {
745 use crate::backend::arch::ModelArch;
746 use crate::backend::driver::Driver;
747
748 let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
749 assert!(backend.is_gpu(), "Metal backend should be GPU");
750
751 let enc = Encoding {
752 input_ids: vec![1, 100, 200, 300, 2],
753 attention_mask: vec![1; 5],
754 token_type_ids: vec![0; 5],
755 };
756
757 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
759 let inputs = driver.prepare_batch(std::slice::from_ref(&enc), 8).unwrap();
760
761 let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
763 eprintln!("input_ids: {:?}", &ids_host[0][..5]);
764
765 let api = hf_hub::api::sync::Api::new().unwrap();
768 let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
769 let weights_path = repo.get("model.safetensors").unwrap();
770 let config_path = repo.get("config.json").unwrap();
771 let config_str = std::fs::read_to_string(&config_path).unwrap();
772 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
773 let config =
774 crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
775 let (arch, _mmap) = driver
776 .load_modern_bert_weights(&weights_path, &config)
777 .unwrap();
778
779 let hidden = driver
780 .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
781 .unwrap();
782 let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
783 let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
784 eprintln!(
785 "embedding: {nz}/{} nonzero, first 5: {:?}",
786 h[0].len(),
787 &h[0][..5]
788 );
789
790 let total = 8; let hd = 768;
793 let nh = 12;
794 let head_dim = 64;
795
796 let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
798 let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
799 driver
800 .layer_norm(
801 &mut ln_out,
802 &emb_clone,
803 &arch.weights.emb_norm_weight,
804 &arch.weights.zero_bias,
805 total,
806 hd,
807 1e-5,
808 )
809 .unwrap();
810 let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
811 let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
812 eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
813
814 let layer0 = &arch.weights.layers[0];
816 let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
817 driver
818 .gemm(
819 &ln_out,
820 &layer0.qkv_weight,
821 &mut qkv,
822 total,
823 3 * hd,
824 hd,
825 true,
826 )
827 .unwrap();
828 let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
829 let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
830 eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
831
832 let mut q = driver.alloc_zeros(total * hd).unwrap();
834 let mut k = driver.alloc_zeros(total * hd).unwrap();
835 let mut v = driver.alloc_zeros(total * hd).unwrap();
836 driver
837 .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
838 .unwrap();
839 let q_h = driver.to_host(&q, 1, total * hd).unwrap();
840 let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
841 eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
842
843 let mut scores = driver.alloc_zeros(nh * 8 * 8).unwrap();
845 driver
846 .gemm_batched(
847 &q,
848 &k,
849 &mut scores,
850 8,
851 8,
852 head_dim,
853 true,
854 8 * head_dim,
855 8 * head_dim,
856 8 * 8,
857 nh,
858 )
859 .unwrap();
860 let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
861 let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
862 eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
863
864 let enc2 = Encoding {
866 input_ids: vec![1, 100, 200, 300, 2],
867 attention_mask: vec![1; 5],
868 token_type_ids: vec![0; 5],
869 };
870
871 let quick = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
872 let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
873 let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
874 eprintln!(
875 "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
876 &quick[0][..3]
877 );
878
879 eprintln!("\n=== ModernBERT MRL Truncation ===");
881 let full = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
882 let full_emb = &full[0];
883 for dims in [64, 128, 256, 384, 512, 768] {
884 let t: Vec<f32> = full_emb[..dims].to_vec();
885 let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
886 let f_norm: f32 = full_emb[..dims]
887 .iter()
888 .map(|x| x * x)
889 .sum::<f32>()
890 .sqrt()
891 .max(1e-12);
892 let cos: f32 = t
893 .iter()
894 .zip(&full_emb[..dims])
895 .map(|(a, b)| a * b)
896 .sum::<f32>()
897 / (t_norm * f_norm);
898 eprintln!(" dims={dims:>3}: cosine={cos:.6}");
899 }
900
901 eprintln!("\n=== ModernBERT Throughput ===");
903 let mut encs = Vec::new();
905 for i in 0..32 {
906 let len = 16 + (i * 4); let mut ids = vec![1_i64]; for j in 1..len - 1 {
909 ids.push(100 + i64::from(j));
910 }
911 ids.push(2); encs.push(Encoding {
913 input_ids: ids.clone(),
914 attention_mask: vec![1; ids.len()],
915 token_type_ids: vec![0; ids.len()],
916 });
917 }
918
919 let _ = arch.forward(&driver, &encs[..4]);
921
922 let t0 = std::time::Instant::now();
924 let result = arch.forward(&driver, &encs).unwrap();
925 let elapsed = t0.elapsed();
926 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
927 eprintln!(
928 " batch={}, time={:.1}ms, throughput={:.1}/s",
929 encs.len(),
930 elapsed.as_secs_f64() * 1000.0,
931 throughput
932 );
933 assert_eq!(result.len(), 32);
934
935 let single = vec![encs[0].clone()];
937 let t1 = std::time::Instant::now();
938 let _ = arch.forward(&driver, &single).unwrap();
939 let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
940 eprintln!(" batch=1, time={single_ms:.1}ms");
941 }
942
943 #[cfg(feature = "metal")]
949 #[test]
950 #[ignore = "requires model download (~33MB)"]
951 fn classic_bert_driver_arch() {
952 use crate::backend::arch::ModelArch;
953
954 let model_repo = "BAAI/bge-small-en-v1.5";
955
956 let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
958 assert!(backend.is_gpu(), "Metal backend should be GPU");
959
960 let enc = Encoding {
961 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
962 attention_mask: vec![1, 1, 1, 1, 1, 1],
963 token_type_ids: vec![0, 0, 0, 0, 0, 0],
964 };
965
966 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
968 assert_eq!(result.len(), 1);
969 assert_eq!(result[0].len(), 384);
970
971 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
972 eprintln!(
973 "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
974 &result[0][..3]
975 );
976 assert!(
977 (l2 - 1.0).abs() < 0.01,
978 "embedding should be L2-normalized, got L2={l2}"
979 );
980
981 #[cfg(feature = "cpu")]
983 {
984 let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
985 .expect("CPU load failed");
986 let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
987 eprintln!("CPU first 5: {:?}", &cpu_result[0][..5]);
988 eprintln!("NEW first 5: {:?}", &result[0][..5]);
989 let cosine: f32 = result[0]
990 .iter()
991 .zip(&cpu_result[0])
992 .map(|(a, b)| a * b)
993 .sum();
994 eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
995 assert!(
996 cosine > 0.95,
997 "cosine similarity vs CPU should be >0.95, got {cosine}"
998 );
999 }
1000
1001 eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
1003 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
1004 let config_path = {
1005 let api = hf_hub::api::sync::Api::new().unwrap();
1006 let repo = api.model(model_repo.to_string());
1007 repo.get("config.json").unwrap()
1008 };
1009 let weights_path = {
1010 let api = hf_hub::api::sync::Api::new().unwrap();
1011 let repo = api.model(model_repo.to_string());
1012 repo.get("model.safetensors").unwrap()
1013 };
1014 let config_str = std::fs::read_to_string(&config_path).unwrap();
1015 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1016 let config =
1017 crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1018 let (arch, _mmap) = driver
1019 .load_classic_bert_weights(&weights_path, &config)
1020 .unwrap();
1021
1022 let mut encs = Vec::new();
1024 for i in 0..32 {
1025 let len = 16 + (i * 4); let mut ids = vec![101_i64]; for j in 1..len - 1 {
1028 ids.push(100 + i64::from(j));
1029 }
1030 ids.push(102); encs.push(Encoding {
1032 input_ids: ids.clone(),
1033 attention_mask: vec![1; ids.len()],
1034 token_type_ids: vec![0; ids.len()],
1035 });
1036 }
1037
1038 let _ = arch.forward(&driver, &encs[..4]);
1040
1041 let t0 = std::time::Instant::now();
1043 let bench_result = arch.forward(&driver, &encs).unwrap();
1044 let elapsed = t0.elapsed();
1045 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1046 eprintln!(
1047 " batch={}, time={:.1}ms, throughput={:.1}/s",
1048 encs.len(),
1049 elapsed.as_secs_f64() * 1000.0,
1050 throughput
1051 );
1052 assert_eq!(bench_result.len(), 32);
1053 }
1054
1055 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1060 #[test]
1061 #[ignore = "requires model download (~33MB)"]
1062 fn classic_bert_cpu_driver_arch() {
1063 let model_repo = "BAAI/bge-small-en-v1.5";
1064
1065 let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1067 assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1068
1069 let enc = Encoding {
1070 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1071 attention_mask: vec![1, 1, 1, 1, 1, 1],
1072 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1073 };
1074
1075 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1077 assert_eq!(result.len(), 1);
1078 assert_eq!(result[0].len(), 384);
1079
1080 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1081 eprintln!(
1082 "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1083 &result[0][..5]
1084 );
1085 assert!(
1086 (l2 - 1.0).abs() < 0.01,
1087 "embedding should be L2-normalized, got L2={l2}"
1088 );
1089
1090 #[cfg(feature = "cpu")]
1092 {
1093 let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1094 .expect("monolithic CPU load failed");
1095 let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1096 eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1097 eprintln!("New first 5: {:?}", &result[0][..5]);
1098 let cosine: f32 = result[0]
1099 .iter()
1100 .zip(&cpu_result[0])
1101 .map(|(a, b)| a * b)
1102 .sum();
1103 eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1104 assert!(
1105 cosine > 0.999,
1106 "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1107 );
1108 }
1109 }
1110}