1pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19
20#[derive(Debug, Clone)]
26pub struct Encoding {
27 pub input_ids: Vec<i64>,
29 pub attention_mask: Vec<i64>,
31 pub token_type_ids: Vec<i64>,
33}
34
35pub trait EmbedBackend: Send + Sync {
48 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
58
59 fn supports_clone(&self) -> bool;
63
64 fn clone_backend(&self) -> Box<dyn EmbedBackend>;
71
72 fn is_gpu(&self) -> bool;
77
78 fn max_tokens(&self) -> usize {
83 512 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum BackendKind {
90 Cuda,
92 Mlx,
94 #[default]
96 Cpu,
97 Metal,
99}
100
101impl std::fmt::Display for BackendKind {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 Self::Cuda => write!(f, "cuda"),
105 Self::Mlx => write!(f, "mlx"),
106 Self::Cpu => write!(f, "cpu"),
107 Self::Metal => write!(f, "metal"),
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, Default)]
117pub enum DeviceHint {
118 #[default]
120 Auto,
121 Cpu,
123 Gpu,
125}
126
127#[derive(Debug, Clone, Default)]
132pub struct InferenceOpts {}
133
134pub fn load_backend(
144 kind: BackendKind,
145 #[cfg_attr(
146 not(any(
147 feature = "cuda",
148 feature = "mlx",
149 feature = "cpu",
150 feature = "cpu-accelerate",
151 feature = "metal"
152 )),
153 expect(unused_variables, reason = "used when backend features are enabled")
154 )]
155 model_repo: &str,
156 #[cfg_attr(
157 not(any(
158 feature = "cuda",
159 feature = "mlx",
160 feature = "cpu",
161 feature = "cpu-accelerate",
162 feature = "metal"
163 )),
164 expect(unused_variables, reason = "used when backend features are enabled")
165 )]
166 device_hint: DeviceHint,
167) -> crate::Result<Box<dyn EmbedBackend>> {
168 match kind {
169 #[cfg(feature = "cuda")]
170 BackendKind::Cuda => {
171 if is_modernbert_model(model_repo) {
172 return load_modernbert_cuda(model_repo);
173 }
174 let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
175 Ok(Box::new(backend))
176 }
177 #[cfg(not(feature = "cuda"))]
178 BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
179 "cuda backend requires building with: cargo build --features cuda"
180 ))),
181 #[cfg(feature = "mlx")]
182 BackendKind::Mlx => {
183 let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
184 Ok(Box::new(backend))
185 }
186 #[cfg(not(feature = "mlx"))]
187 BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
188 "mlx backend requires building with: cargo build --features mlx"
189 ))),
190 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
191 BackendKind::Cpu => {
192 if is_modernbert_model(model_repo) {
193 return load_modernbert_cpu(model_repo);
194 }
195 #[cfg(feature = "cpu")]
196 {
197 let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
198 #[expect(
199 clippy::needless_return,
200 reason = "return needed before cfg(not) fallback"
201 )]
202 return Ok(Box::new(backend));
203 }
204 #[cfg(not(feature = "cpu"))]
205 Err(crate::Error::Other(anyhow::anyhow!(
206 "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
207 )))
208 }
209 #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
210 BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
211 "cpu backend requires building with: cargo build --features cpu"
212 ))),
213 #[cfg(feature = "metal")]
214 BackendKind::Metal => {
215 if is_modernbert_model(model_repo) {
217 return load_modernbert_metal(model_repo);
218 }
219 load_classic_metal(model_repo)
220 }
221 #[cfg(not(feature = "metal"))]
222 BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
223 "metal backend requires building with: cargo build --features metal"
224 ))),
225 }
226}
227
228pub fn detect_backends(
238 #[cfg_attr(
239 not(any(
240 feature = "cuda",
241 feature = "mlx",
242 feature = "cpu",
243 feature = "cpu-accelerate",
244 feature = "metal"
245 )),
246 expect(unused_variables, reason = "used when backend features are enabled")
247 )]
248 model_repo: &str,
249) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
250 #[cfg_attr(
251 not(any(
252 feature = "cuda",
253 feature = "mlx",
254 feature = "cpu",
255 feature = "cpu-accelerate",
256 feature = "metal"
257 )),
258 expect(unused_mut, reason = "mut needed when backend features are enabled")
259 )]
260 let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
261
262 #[cfg(feature = "cuda")]
264 {
265 if is_modernbert_model(model_repo) {
266 if let Ok(b) = load_modernbert_cuda(model_repo) {
267 backends.push(b);
268 }
269 } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
270 backends.push(Box::new(b));
271 }
272 }
273
274 #[cfg(feature = "metal")]
276 {
277 if is_modernbert_model(model_repo) {
279 if let Ok(b) = load_modernbert_metal(model_repo) {
280 backends.push(b);
281 }
282 } else if let Ok(b) = load_classic_metal(model_repo) {
283 backends.push(b);
284 }
285 }
286
287 #[cfg(feature = "mlx")]
289 if backends.is_empty()
290 && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
291 {
292 backends.push(Box::new(b));
293 }
294
295 #[cfg_attr(
300 not(any(feature = "cpu", feature = "cpu-accelerate")),
301 expect(unused_variables, reason = "used when cpu feature is enabled")
302 )]
303 let has_gpu = backends.iter().any(|b| b.is_gpu());
304 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
305 if !has_gpu {
306 if is_modernbert_model(model_repo) {
307 if let Ok(b) = load_modernbert_cpu(model_repo) {
308 backends.push(b);
309 }
310 } else {
311 #[cfg(feature = "cpu")]
312 if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
313 backends.push(Box::new(b));
314 }
315 }
316 }
317
318 if backends.is_empty() {
319 return Err(crate::Error::Other(anyhow::anyhow!(
320 "no embedding backends available"
321 )));
322 }
323
324 Ok(backends)
325}
326
327#[cfg(feature = "metal")]
343pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
344 use driver::metal::{MetalDriver, ModernBertConfig};
345 use generic::GenericBackend;
346 use hf_hub::api::sync::Api;
347
348 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
349 let repo = api.model(model_repo.to_string());
350
351 let config_path = repo
352 .get("config.json")
353 .map_err(|e| crate::Error::Download(e.to_string()))?;
354 let weights_path = repo
355 .get("model.safetensors")
356 .map_err(|e| crate::Error::Download(e.to_string()))?;
357
358 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
360 path: config_path.display().to_string(),
361 source: e,
362 })?;
363 let config_json: serde_json::Value = serde_json::from_str(&config_str)
364 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
365 let config = ModernBertConfig::from_json(&config_json)?;
366 let max_tokens = config.max_position_embeddings;
367
368 let driver = MetalDriver::new()?;
369 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
370
371 tracing::info!(
372 model_repo,
373 hidden = config.hidden_size,
374 layers = config.num_hidden_layers,
375 heads = config.num_attention_heads,
376 intermediate = config.intermediate_size,
377 max_tokens,
378 "ModernBERT loaded on Metal (driver/arch)"
379 );
380
381 Ok(Box::new(GenericBackend::new(
382 driver, arch, max_tokens, true, mmap,
383 )))
384}
385
386#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
388pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
389 use driver::cpu::{CpuDriver, ModernBertConfig};
390 use generic::GenericBackend;
391 use hf_hub::api::sync::Api;
392
393 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
394 let repo = api.model(model_repo.to_string());
395
396 let config_path = repo
397 .get("config.json")
398 .map_err(|e| crate::Error::Download(e.to_string()))?;
399 let weights_path = repo
400 .get("model.safetensors")
401 .map_err(|e| crate::Error::Download(e.to_string()))?;
402
403 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
404 path: config_path.display().to_string(),
405 source: e,
406 })?;
407 let config_json: serde_json::Value = serde_json::from_str(&config_str)
408 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
409 let config = ModernBertConfig::from_json(&config_json)?;
410 let max_tokens = config.max_position_embeddings;
411
412 let driver = CpuDriver::new()?;
413 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
414
415 tracing::info!(
416 model_repo,
417 hidden = config.hidden_size,
418 layers = config.num_hidden_layers,
419 heads = config.num_attention_heads,
420 max_tokens,
421 "ModernBERT loaded on CPU (driver/arch)"
422 );
423
424 Ok(Box::new(GenericBackend::new(
425 driver, arch, max_tokens, false, mmap,
426 )))
427}
428
429#[cfg(feature = "cuda")]
440pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
441 use driver::cuda::{CudaDriver, ModernBertConfig};
442 use generic::GenericBackend;
443 use hf_hub::api::sync::Api;
444
445 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
446 let repo = api.model(model_repo.to_string());
447
448 let config_path = repo
449 .get("config.json")
450 .map_err(|e| crate::Error::Download(e.to_string()))?;
451 let weights_path = repo
452 .get("model.safetensors")
453 .map_err(|e| crate::Error::Download(e.to_string()))?;
454
455 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
457 path: config_path.display().to_string(),
458 source: e,
459 })?;
460 let config_json: serde_json::Value = serde_json::from_str(&config_str)
461 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
462 let config = ModernBertConfig::from_json(&config_json)?;
463 let max_tokens = config.max_position_embeddings;
464
465 let driver = CudaDriver::new()?;
466 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
467
468 tracing::info!(
469 model_repo,
470 hidden = config.hidden_size,
471 layers = config.num_hidden_layers,
472 heads = config.num_attention_heads,
473 intermediate = config.intermediate_size,
474 max_tokens,
475 "ModernBERT loaded on CUDA (driver/arch)"
476 );
477
478 Ok(Box::new(GenericBackend::with_max_batch(
481 driver, arch, max_tokens, true, mmap, 32,
482 )))
483}
484
485#[cfg(any(
490 feature = "cuda",
491 feature = "metal",
492 feature = "cpu",
493 feature = "cpu-accelerate"
494))]
495fn is_modernbert_model(model_repo: &str) -> bool {
496 let Ok(api) = hf_hub::api::sync::Api::new() else {
497 return false;
498 };
499 let repo = api.model(model_repo.to_string());
500 let Ok(config_path) = repo.get("config.json") else {
501 return false;
502 };
503 let Ok(config_str) = std::fs::read_to_string(&config_path) else {
504 return false;
505 };
506 let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
507 return false;
508 };
509 json.get("model_type")
510 .and_then(serde_json::Value::as_str)
511 .is_some_and(|t| t == "modernbert")
512}
513
514#[cfg(feature = "metal")]
531pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
532 use driver::metal::{ClassicBertConfig, MetalDriver};
533 use generic::GenericBackend;
534 use hf_hub::api::sync::Api;
535
536 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
537 let repo = api.model(model_repo.to_string());
538
539 let config_path = repo
540 .get("config.json")
541 .map_err(|e| crate::Error::Download(e.to_string()))?;
542 let weights_path = repo
543 .get("model.safetensors")
544 .map_err(|e| crate::Error::Download(e.to_string()))?;
545
546 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
548 path: config_path.display().to_string(),
549 source: e,
550 })?;
551 let config_json: serde_json::Value = serde_json::from_str(&config_str)
552 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
553 let config = ClassicBertConfig::from_json(&config_json)?;
554 let max_tokens = config.max_position_embeddings;
555
556 let driver = MetalDriver::new()?;
557 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
558
559 tracing::info!(
560 model_repo,
561 hidden = config.hidden_size,
562 layers = config.num_hidden_layers,
563 heads = config.num_attention_heads,
564 intermediate = config.intermediate_size,
565 max_tokens,
566 "ClassicBert loaded on Metal (driver/arch)"
567 );
568
569 Ok(Box::new(GenericBackend::new(
570 driver, arch, max_tokens, true, mmap,
571 )))
572}
573
574#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
591pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
592 use driver::cpu::{ClassicBertConfig, CpuDriver};
593 use generic::GenericBackend;
594 use hf_hub::api::sync::Api;
595
596 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
597 let repo = api.model(model_repo.to_string());
598
599 let config_path = repo
600 .get("config.json")
601 .map_err(|e| crate::Error::Download(e.to_string()))?;
602 let weights_path = repo
603 .get("model.safetensors")
604 .map_err(|e| crate::Error::Download(e.to_string()))?;
605
606 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
608 path: config_path.display().to_string(),
609 source: e,
610 })?;
611 let config_json: serde_json::Value = serde_json::from_str(&config_str)
612 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
613 let config = ClassicBertConfig::from_json(&config_json)?;
614 let max_tokens = config.max_position_embeddings;
615
616 let driver = CpuDriver::new()?;
617 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
618
619 tracing::info!(
620 model_repo,
621 hidden = config.hidden_size,
622 layers = config.num_hidden_layers,
623 heads = config.num_attention_heads,
624 intermediate = config.intermediate_size,
625 max_tokens,
626 "ClassicBert loaded on CPU (driver/arch)"
627 );
628
629 Ok(Box::new(GenericBackend::new(
630 driver, arch, max_tokens, false, mmap,
631 )))
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637
638 #[test]
640 fn trait_is_object_safe() {
641 fn _assert_object_safe(_: &dyn EmbedBackend) {}
643 }
644
645 #[test]
647 fn trait_object_is_send() {
648 fn assert_send<T: Send>() {}
649 assert_send::<Box<dyn EmbedBackend>>();
650 }
651
652 #[test]
654 fn trait_object_is_sync() {
655 fn assert_sync<T: Sync>() {}
656 assert_sync::<Box<dyn EmbedBackend>>();
657 }
658
659 #[test]
661 fn arc_trait_object_is_send() {
662 fn assert_send<T: Send>() {}
663 assert_send::<std::sync::Arc<dyn EmbedBackend>>();
664 }
665
666 #[test]
667 fn encoding_construction() {
668 let enc = Encoding {
669 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
670 attention_mask: vec![1, 1, 1, 1, 1, 1],
671 token_type_ids: vec![0, 0, 0, 0, 0, 0],
672 };
673 assert_eq!(enc.input_ids.len(), 6);
674 assert_eq!(enc.attention_mask.len(), 6);
675 assert_eq!(enc.token_type_ids.len(), 6);
676 }
677
678 #[test]
679 fn encoding_clone() {
680 let enc = Encoding {
681 input_ids: vec![101, 102],
682 attention_mask: vec![1, 1],
683 token_type_ids: vec![0, 0],
684 };
685 let cloned = enc.clone();
686 assert_eq!(enc.input_ids, cloned.input_ids);
687 }
688
689 #[test]
690 fn backend_kind_default_is_cpu() {
691 assert_eq!(BackendKind::default(), BackendKind::Cpu);
692 }
693
694 #[test]
695 fn backend_kind_display() {
696 assert_eq!(BackendKind::Cuda.to_string(), "cuda");
697 assert_eq!(BackendKind::Mlx.to_string(), "mlx");
698 assert_eq!(BackendKind::Cpu.to_string(), "cpu");
699 }
700
701 #[cfg(not(feature = "mlx"))]
702 #[test]
703 fn load_backend_mlx_not_compiled() {
704 let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
705 assert!(result.is_err());
706 }
707
708 #[cfg(feature = "cpu")]
709 #[test]
710 fn detect_backends_returns_at_least_one() {
711 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
712 assert!(!backends.is_empty());
713 }
714
715 #[cfg(all(feature = "cpu", not(feature = "mlx")))]
716 #[test]
717 fn detect_backends_returns_at_least_one_backend() {
718 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
719 assert!(!backends.is_empty(), "should detect at least one backend");
720 }
721
722 #[cfg(feature = "metal")]
727 #[test]
728 #[ignore = "requires model download (~570MB)"]
729 fn modernbert_loads_and_embeds() {
730 use crate::backend::driver::Driver;
731
732 let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
733 assert!(backend.is_gpu(), "Metal backend should be GPU");
734
735 let enc = Encoding {
736 input_ids: vec![1, 100, 200, 300, 2],
737 attention_mask: vec![1; 5],
738 token_type_ids: vec![0; 5],
739 };
740
741 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
743 let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
744
745 let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
747 eprintln!("input_ids: {:?}", &ids_host[0][..5]);
748
749 let api = hf_hub::api::sync::Api::new().unwrap();
752 let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
753 let weights_path = repo.get("model.safetensors").unwrap();
754 let config_path = repo.get("config.json").unwrap();
755 let config_str = std::fs::read_to_string(&config_path).unwrap();
756 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
757 let config =
758 crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
759 let (arch, _mmap) = driver
760 .load_modern_bert_weights(&weights_path, &config)
761 .unwrap();
762
763 let hidden = driver
764 .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
765 .unwrap();
766 let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
767 let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
768 eprintln!(
769 "embedding: {nz}/{} nonzero, first 5: {:?}",
770 h[0].len(),
771 &h[0][..5]
772 );
773
774 let total = 8; let hd = 768;
777 let nh = 12;
778 let head_dim = 64;
779
780 let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
782 let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
783 driver
784 .layer_norm(
785 &mut ln_out,
786 &emb_clone,
787 &arch.weights.emb_norm_weight,
788 &arch.weights.zero_bias,
789 total,
790 hd,
791 1e-5,
792 )
793 .unwrap();
794 let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
795 let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
796 eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
797
798 let layer0 = &arch.weights.layers[0];
800 let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
801 driver
802 .gemm(
803 &ln_out,
804 &layer0.qkv_weight,
805 &mut qkv,
806 total,
807 3 * hd,
808 hd,
809 true,
810 )
811 .unwrap();
812 let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
813 let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
814 eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
815
816 let mut q = driver.alloc_zeros(total * hd).unwrap();
818 let mut k = driver.alloc_zeros(total * hd).unwrap();
819 let mut v = driver.alloc_zeros(total * hd).unwrap();
820 driver
821 .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
822 .unwrap();
823 let q_h = driver.to_host(&q, 1, total * hd).unwrap();
824 let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
825 eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
826
827 let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
829 driver
830 .gemm_batched(
831 &q,
832 &k,
833 &mut scores,
834 8,
835 8,
836 head_dim,
837 true,
838 8 * head_dim,
839 8 * head_dim,
840 8 * 8,
841 nh,
842 )
843 .unwrap();
844 let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
845 let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
846 eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
847
848 use crate::backend::arch::ModelArch;
850 let enc2 = Encoding {
851 input_ids: vec![1, 100, 200, 300, 2],
852 attention_mask: vec![1; 5],
853 token_type_ids: vec![0; 5],
854 };
855
856 let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
857 let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
858 let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
859 eprintln!(
860 "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
861 &quick[0][..3]
862 );
863
864 eprintln!("\n=== ModernBERT MRL Truncation ===");
866 let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
867 let full_emb = &full[0];
868 for dims in [64, 128, 256, 384, 512, 768] {
869 let t: Vec<f32> = full_emb[..dims].to_vec();
870 let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
871 let f_norm: f32 = full_emb[..dims]
872 .iter()
873 .map(|x| x * x)
874 .sum::<f32>()
875 .sqrt()
876 .max(1e-12);
877 let cos: f32 = t
878 .iter()
879 .zip(&full_emb[..dims])
880 .map(|(a, b)| a * b)
881 .sum::<f32>()
882 / (t_norm * f_norm);
883 eprintln!(" dims={dims:>3}: cosine={cos:.6}");
884 }
885
886 eprintln!("\n=== ModernBERT Throughput ===");
888 let mut encs = Vec::new();
890 for i in 0..32 {
891 let len = 16 + (i * 4); let mut ids = vec![1_i64]; for j in 1..len - 1 {
894 ids.push(100 + j as i64);
895 }
896 ids.push(2); encs.push(Encoding {
898 input_ids: ids.clone(),
899 attention_mask: vec![1; ids.len()],
900 token_type_ids: vec![0; ids.len()],
901 });
902 }
903
904 let _ = arch.forward(&driver, &encs[..4]);
906
907 let t0 = std::time::Instant::now();
909 let result = arch.forward(&driver, &encs).unwrap();
910 let elapsed = t0.elapsed();
911 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
912 eprintln!(
913 " batch={}, time={:.1}ms, throughput={:.1}/s",
914 encs.len(),
915 elapsed.as_secs_f64() * 1000.0,
916 throughput
917 );
918 assert_eq!(result.len(), 32);
919
920 let single = vec![encs[0].clone()];
922 let t1 = std::time::Instant::now();
923 let _ = arch.forward(&driver, &single).unwrap();
924 let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
925 eprintln!(" batch=1, time={single_ms:.1}ms");
926 }
927
928 #[cfg(feature = "metal")]
934 #[test]
935 #[ignore = "requires model download (~33MB)"]
936 fn classic_bert_driver_arch() {
937 use crate::backend::arch::ModelArch;
938
939 let model_repo = "BAAI/bge-small-en-v1.5";
940
941 let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
943 assert!(backend.is_gpu(), "Metal backend should be GPU");
944
945 let enc = Encoding {
946 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
947 attention_mask: vec![1, 1, 1, 1, 1, 1],
948 token_type_ids: vec![0, 0, 0, 0, 0, 0],
949 };
950
951 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
953 assert_eq!(result.len(), 1);
954 assert_eq!(result[0].len(), 384);
955
956 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
957 eprintln!(
958 "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
959 &result[0][..3]
960 );
961 assert!(
962 (l2 - 1.0).abs() < 0.01,
963 "embedding should be L2-normalized, got L2={l2}"
964 );
965
966 #[cfg(feature = "cpu")]
968 {
969 let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
970 .expect("CPU load failed");
971 let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
972 eprintln!("CPU first 5: {:?}", &cpu_result[0][..5]);
973 eprintln!("NEW first 5: {:?}", &result[0][..5]);
974 let cosine: f32 = result[0]
975 .iter()
976 .zip(&cpu_result[0])
977 .map(|(a, b)| a * b)
978 .sum();
979 eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
980 assert!(
981 cosine > 0.95,
982 "cosine similarity vs CPU should be >0.95, got {cosine}"
983 );
984 }
985
986 eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
988 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
989 let config_path = {
990 let api = hf_hub::api::sync::Api::new().unwrap();
991 let repo = api.model(model_repo.to_string());
992 repo.get("config.json").unwrap()
993 };
994 let weights_path = {
995 let api = hf_hub::api::sync::Api::new().unwrap();
996 let repo = api.model(model_repo.to_string());
997 repo.get("model.safetensors").unwrap()
998 };
999 let config_str = std::fs::read_to_string(&config_path).unwrap();
1000 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1001 let config =
1002 crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1003 let (arch, _mmap) = driver
1004 .load_classic_bert_weights(&weights_path, &config)
1005 .unwrap();
1006
1007 let mut encs = Vec::new();
1009 for i in 0..32 {
1010 let len = 16 + (i * 4); let mut ids = vec![101_i64]; for j in 1..len - 1 {
1013 ids.push(100 + j as i64);
1014 }
1015 ids.push(102); encs.push(Encoding {
1017 input_ids: ids.clone(),
1018 attention_mask: vec![1; ids.len()],
1019 token_type_ids: vec![0; ids.len()],
1020 });
1021 }
1022
1023 let _ = arch.forward(&driver, &encs[..4]);
1025
1026 let t0 = std::time::Instant::now();
1028 let bench_result = arch.forward(&driver, &encs).unwrap();
1029 let elapsed = t0.elapsed();
1030 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1031 eprintln!(
1032 " batch={}, time={:.1}ms, throughput={:.1}/s",
1033 encs.len(),
1034 elapsed.as_secs_f64() * 1000.0,
1035 throughput
1036 );
1037 assert_eq!(bench_result.len(), 32);
1038 }
1039
1040 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1045 #[test]
1046 #[ignore = "requires model download (~33MB)"]
1047 fn classic_bert_cpu_driver_arch() {
1048 let model_repo = "BAAI/bge-small-en-v1.5";
1049
1050 let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1052 assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1053
1054 let enc = Encoding {
1055 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1056 attention_mask: vec![1, 1, 1, 1, 1, 1],
1057 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1058 };
1059
1060 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1062 assert_eq!(result.len(), 1);
1063 assert_eq!(result[0].len(), 384);
1064
1065 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1066 eprintln!(
1067 "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1068 &result[0][..5]
1069 );
1070 assert!(
1071 (l2 - 1.0).abs() < 0.01,
1072 "embedding should be L2-normalized, got L2={l2}"
1073 );
1074
1075 #[cfg(feature = "cpu")]
1077 {
1078 let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1079 .expect("monolithic CPU load failed");
1080 let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1081 eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1082 eprintln!("New first 5: {:?}", &result[0][..5]);
1083 let cosine: f32 = result[0]
1084 .iter()
1085 .zip(&cpu_result[0])
1086 .map(|(a, b)| a * b)
1087 .sum();
1088 eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1089 assert!(
1090 cosine > 0.999,
1091 "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1092 );
1093 }
1094 }
1095}