1pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19
20#[derive(Debug, Clone)]
26pub struct Encoding {
27 pub input_ids: Vec<i64>,
29 pub attention_mask: Vec<i64>,
31 pub token_type_ids: Vec<i64>,
33}
34
35pub trait EmbedBackend: Send + Sync {
48 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
58
59 fn supports_clone(&self) -> bool;
63
64 fn clone_backend(&self) -> Box<dyn EmbedBackend>;
71
72 fn is_gpu(&self) -> bool;
77
78 fn max_tokens(&self) -> usize {
83 512 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum BackendKind {
90 Cuda,
92 Mlx,
94 #[default]
96 Cpu,
97 Metal,
99}
100
101impl std::fmt::Display for BackendKind {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 Self::Cuda => write!(f, "cuda"),
105 Self::Mlx => write!(f, "mlx"),
106 Self::Cpu => write!(f, "cpu"),
107 Self::Metal => write!(f, "metal"),
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, Default)]
117pub enum DeviceHint {
118 #[default]
120 Auto,
121 Cpu,
123 Gpu,
125}
126
127#[derive(Debug, Clone, Default)]
132pub struct InferenceOpts {}
133
134pub fn load_backend(
144 kind: BackendKind,
145 #[cfg_attr(
146 not(any(
147 feature = "cuda",
148 feature = "mlx",
149 feature = "cpu",
150 feature = "cpu-accelerate",
151 feature = "metal"
152 )),
153 expect(unused_variables, reason = "used when backend features are enabled")
154 )]
155 model_repo: &str,
156 #[cfg_attr(
157 not(any(
158 feature = "cuda",
159 feature = "mlx",
160 feature = "cpu",
161 feature = "cpu-accelerate",
162 feature = "metal"
163 )),
164 expect(unused_variables, reason = "used when backend features are enabled")
165 )]
166 device_hint: DeviceHint,
167) -> crate::Result<Box<dyn EmbedBackend>> {
168 match kind {
169 #[cfg(feature = "cuda")]
170 BackendKind::Cuda => {
171 if is_modernbert_model(model_repo) {
172 return load_modernbert_cuda(model_repo);
173 }
174 let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
175 Ok(Box::new(backend))
176 }
177 #[cfg(not(feature = "cuda"))]
178 BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
179 "cuda backend requires building with: cargo build --features cuda"
180 ))),
181 #[cfg(feature = "mlx")]
182 BackendKind::Mlx => {
183 let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
184 Ok(Box::new(backend))
185 }
186 #[cfg(not(feature = "mlx"))]
187 BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
188 "mlx backend requires building with: cargo build --features mlx"
189 ))),
190 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
191 BackendKind::Cpu => {
192 if is_modernbert_model(model_repo) {
193 return load_modernbert_cpu(model_repo);
194 }
195 #[cfg(feature = "cpu")]
196 {
197 let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
198 #[expect(
199 clippy::needless_return,
200 reason = "return needed before cfg(not) fallback"
201 )]
202 return Ok(Box::new(backend));
203 }
204 #[cfg(not(feature = "cpu"))]
205 Err(crate::Error::Other(anyhow::anyhow!(
206 "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
207 )))
208 }
209 #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
210 BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
211 "cpu backend requires building with: cargo build --features cpu"
212 ))),
213 #[cfg(feature = "metal")]
214 BackendKind::Metal => {
215 if is_modernbert_model(model_repo) {
217 return load_modernbert_metal(model_repo);
218 }
219 load_classic_metal(model_repo)
220 }
221 #[cfg(not(feature = "metal"))]
222 BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
223 "metal backend requires building with: cargo build --features metal"
224 ))),
225 }
226}
227
228pub fn detect_backends(
238 #[cfg_attr(
239 not(any(
240 feature = "cuda",
241 feature = "mlx",
242 feature = "cpu",
243 feature = "cpu-accelerate",
244 feature = "metal"
245 )),
246 expect(unused_variables, reason = "used when backend features are enabled")
247 )]
248 model_repo: &str,
249) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
250 #[cfg_attr(
251 not(any(
252 feature = "cuda",
253 feature = "mlx",
254 feature = "cpu",
255 feature = "cpu-accelerate",
256 feature = "metal"
257 )),
258 expect(unused_mut, reason = "mut needed when backend features are enabled")
259 )]
260 let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
261
262 #[cfg(feature = "cuda")]
264 {
265 if is_modernbert_model(model_repo) {
266 if let Ok(b) = load_modernbert_cuda(model_repo) {
267 backends.push(b);
268 }
269 } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
270 backends.push(Box::new(b));
271 }
272 }
273
274 #[cfg(feature = "metal")]
276 {
277 if is_modernbert_model(model_repo) {
279 if let Ok(b) = load_modernbert_metal(model_repo) {
280 backends.push(b);
281 }
282 } else if let Ok(b) = load_classic_metal(model_repo) {
283 backends.push(b);
284 }
285 }
286
287 #[cfg(feature = "mlx")]
289 if backends.is_empty()
290 && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
291 {
292 backends.push(Box::new(b));
293 }
294
295 #[cfg_attr(
300 not(any(feature = "cpu", feature = "cpu-accelerate")),
301 expect(unused_variables, reason = "used when cpu feature is enabled")
302 )]
303 let has_gpu = backends.iter().any(|b| b.is_gpu());
304 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
305 if !has_gpu {
306 if is_modernbert_model(model_repo) {
307 if let Ok(b) = load_modernbert_cpu(model_repo) {
308 backends.push(b);
309 }
310 } else {
311 #[cfg(feature = "cpu")]
312 if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
313 backends.push(Box::new(b));
314 }
315 }
316 }
317
318 if backends.is_empty() {
319 return Err(crate::Error::Other(anyhow::anyhow!(
320 "no embedding backends available"
321 )));
322 }
323
324 Ok(backends)
325}
326
327#[cfg(feature = "metal")]
343pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
344 use driver::metal::{MetalDriver, ModernBertConfig};
345 use generic::GenericBackend;
346 use hf_hub::api::sync::Api;
347
348 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
349 let repo = api.model(model_repo.to_string());
350
351 let config_path = repo
352 .get("config.json")
353 .map_err(|e| crate::Error::Download(e.to_string()))?;
354 let weights_path = repo
355 .get("model.safetensors")
356 .map_err(|e| crate::Error::Download(e.to_string()))?;
357
358 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
360 path: config_path.display().to_string(),
361 source: e,
362 })?;
363 let config_json: serde_json::Value = serde_json::from_str(&config_str)
364 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
365 let config = ModernBertConfig::from_json(&config_json)?;
366 let max_tokens = config.max_position_embeddings;
367
368 let driver = MetalDriver::new()?;
369 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
370
371 tracing::info!(
372 model_repo,
373 hidden = config.hidden_size,
374 layers = config.num_hidden_layers,
375 heads = config.num_attention_heads,
376 intermediate = config.intermediate_size,
377 max_tokens,
378 "ModernBERT loaded on Metal (driver/arch)"
379 );
380
381 Ok(Box::new(GenericBackend::new(
382 driver, arch, max_tokens, true, mmap,
383 )))
384}
385
386#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
388pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
389 use driver::cpu::{CpuDriver, ModernBertConfig};
390 use generic::GenericBackend;
391 use hf_hub::api::sync::Api;
392
393 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
394 let repo = api.model(model_repo.to_string());
395
396 let config_path = repo
397 .get("config.json")
398 .map_err(|e| crate::Error::Download(e.to_string()))?;
399 let weights_path = repo
400 .get("model.safetensors")
401 .map_err(|e| crate::Error::Download(e.to_string()))?;
402
403 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
404 path: config_path.display().to_string(),
405 source: e,
406 })?;
407 let config_json: serde_json::Value = serde_json::from_str(&config_str)
408 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
409 let config = ModernBertConfig::from_json(&config_json)?;
410 let max_tokens = config.max_position_embeddings;
411
412 let driver = CpuDriver::new()?;
413 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
414
415 tracing::info!(
416 model_repo,
417 hidden = config.hidden_size,
418 layers = config.num_hidden_layers,
419 heads = config.num_attention_heads,
420 max_tokens,
421 "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
422 );
423
424 Ok(Box::new(GenericBackend::new_shared(
425 driver, arch, max_tokens, false, mmap,
426 )))
427}
428
429#[cfg(feature = "cuda")]
440pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
441 use driver::cuda::{CudaDriver, ModernBertConfig};
442 use generic::GenericBackend;
443 use hf_hub::api::sync::Api;
444
445 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
446 let repo = api.model(model_repo.to_string());
447
448 let config_path = repo
449 .get("config.json")
450 .map_err(|e| crate::Error::Download(e.to_string()))?;
451 let weights_path = repo
452 .get("model.safetensors")
453 .map_err(|e| crate::Error::Download(e.to_string()))?;
454
455 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
457 path: config_path.display().to_string(),
458 source: e,
459 })?;
460 let config_json: serde_json::Value = serde_json::from_str(&config_str)
461 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
462 let config = ModernBertConfig::from_json(&config_json)?;
463 let max_tokens = config.max_position_embeddings;
464
465 let driver = CudaDriver::new()?;
466 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
467
468 tracing::info!(
469 model_repo,
470 hidden = config.hidden_size,
471 layers = config.num_hidden_layers,
472 heads = config.num_attention_heads,
473 intermediate = config.intermediate_size,
474 max_tokens,
475 "ModernBERT loaded on CUDA (driver/arch)"
476 );
477
478 Ok(Box::new(GenericBackend::with_max_batch(
481 driver,
482 arch,
483 max_tokens,
484 true,
485 generic::MmapHolder::Owned(mmap),
486 32,
487 )))
488}
489
490#[cfg(any(
495 feature = "cuda",
496 feature = "metal",
497 feature = "cpu",
498 feature = "cpu-accelerate"
499))]
500fn is_modernbert_model(model_repo: &str) -> bool {
501 let Ok(api) = hf_hub::api::sync::Api::new() else {
502 return false;
503 };
504 let repo = api.model(model_repo.to_string());
505 let Ok(config_path) = repo.get("config.json") else {
506 return false;
507 };
508 let Ok(config_str) = std::fs::read_to_string(&config_path) else {
509 return false;
510 };
511 let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
512 return false;
513 };
514 json.get("model_type")
515 .and_then(serde_json::Value::as_str)
516 .is_some_and(|t| t == "modernbert")
517}
518
519#[cfg(feature = "metal")]
536pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
537 use driver::metal::{ClassicBertConfig, MetalDriver};
538 use generic::GenericBackend;
539 use hf_hub::api::sync::Api;
540
541 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
542 let repo = api.model(model_repo.to_string());
543
544 let config_path = repo
545 .get("config.json")
546 .map_err(|e| crate::Error::Download(e.to_string()))?;
547 let weights_path = repo
548 .get("model.safetensors")
549 .map_err(|e| crate::Error::Download(e.to_string()))?;
550
551 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
553 path: config_path.display().to_string(),
554 source: e,
555 })?;
556 let config_json: serde_json::Value = serde_json::from_str(&config_str)
557 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
558 let config = ClassicBertConfig::from_json(&config_json)?;
559 let max_tokens = config.max_position_embeddings;
560
561 let driver = MetalDriver::new()?;
562 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
563
564 tracing::info!(
565 model_repo,
566 hidden = config.hidden_size,
567 layers = config.num_hidden_layers,
568 heads = config.num_attention_heads,
569 intermediate = config.intermediate_size,
570 max_tokens,
571 "ClassicBert loaded on Metal (driver/arch)"
572 );
573
574 Ok(Box::new(GenericBackend::new(
575 driver, arch, max_tokens, true, mmap,
576 )))
577}
578
579#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
596pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
597 use driver::cpu::{ClassicBertConfig, CpuDriver};
598 use generic::GenericBackend;
599 use hf_hub::api::sync::Api;
600
601 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
602 let repo = api.model(model_repo.to_string());
603
604 let config_path = repo
605 .get("config.json")
606 .map_err(|e| crate::Error::Download(e.to_string()))?;
607 let weights_path = repo
608 .get("model.safetensors")
609 .map_err(|e| crate::Error::Download(e.to_string()))?;
610
611 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
613 path: config_path.display().to_string(),
614 source: e,
615 })?;
616 let config_json: serde_json::Value = serde_json::from_str(&config_str)
617 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
618 let config = ClassicBertConfig::from_json(&config_json)?;
619 let max_tokens = config.max_position_embeddings;
620
621 let driver = CpuDriver::new()?;
622 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
623
624 tracing::info!(
625 model_repo,
626 hidden = config.hidden_size,
627 layers = config.num_hidden_layers,
628 heads = config.num_attention_heads,
629 intermediate = config.intermediate_size,
630 max_tokens,
631 "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
632 );
633
634 Ok(Box::new(GenericBackend::new_shared(
635 driver, arch, max_tokens, false, mmap,
636 )))
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642
643 #[test]
645 fn trait_is_object_safe() {
646 fn _assert_object_safe(_: &dyn EmbedBackend) {}
648 }
649
650 #[test]
652 fn trait_object_is_send() {
653 fn assert_send<T: Send>() {}
654 assert_send::<Box<dyn EmbedBackend>>();
655 }
656
657 #[test]
659 fn trait_object_is_sync() {
660 fn assert_sync<T: Sync>() {}
661 assert_sync::<Box<dyn EmbedBackend>>();
662 }
663
664 #[test]
666 fn arc_trait_object_is_send() {
667 fn assert_send<T: Send>() {}
668 assert_send::<std::sync::Arc<dyn EmbedBackend>>();
669 }
670
671 #[test]
672 fn encoding_construction() {
673 let enc = Encoding {
674 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
675 attention_mask: vec![1, 1, 1, 1, 1, 1],
676 token_type_ids: vec![0, 0, 0, 0, 0, 0],
677 };
678 assert_eq!(enc.input_ids.len(), 6);
679 assert_eq!(enc.attention_mask.len(), 6);
680 assert_eq!(enc.token_type_ids.len(), 6);
681 }
682
683 #[test]
684 fn encoding_clone() {
685 let enc = Encoding {
686 input_ids: vec![101, 102],
687 attention_mask: vec![1, 1],
688 token_type_ids: vec![0, 0],
689 };
690 let cloned = enc.clone();
691 assert_eq!(enc.input_ids, cloned.input_ids);
692 }
693
694 #[test]
695 fn backend_kind_default_is_cpu() {
696 assert_eq!(BackendKind::default(), BackendKind::Cpu);
697 }
698
699 #[test]
700 fn backend_kind_display() {
701 assert_eq!(BackendKind::Cuda.to_string(), "cuda");
702 assert_eq!(BackendKind::Mlx.to_string(), "mlx");
703 assert_eq!(BackendKind::Cpu.to_string(), "cpu");
704 }
705
706 #[cfg(not(feature = "mlx"))]
707 #[test]
708 fn load_backend_mlx_not_compiled() {
709 let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
710 assert!(result.is_err());
711 }
712
713 #[cfg(feature = "cpu")]
714 #[test]
715 fn detect_backends_returns_at_least_one() {
716 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
717 assert!(!backends.is_empty());
718 }
719
720 #[cfg(all(feature = "cpu", not(feature = "mlx")))]
721 #[test]
722 fn detect_backends_returns_at_least_one_backend() {
723 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
724 assert!(!backends.is_empty(), "should detect at least one backend");
725 }
726
727 #[cfg(feature = "metal")]
732 #[test]
733 #[ignore = "requires model download (~570MB)"]
734 fn modernbert_loads_and_embeds() {
735 use crate::backend::driver::Driver;
736
737 let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
738 assert!(backend.is_gpu(), "Metal backend should be GPU");
739
740 let enc = Encoding {
741 input_ids: vec![1, 100, 200, 300, 2],
742 attention_mask: vec![1; 5],
743 token_type_ids: vec![0; 5],
744 };
745
746 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
748 let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
749
750 let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
752 eprintln!("input_ids: {:?}", &ids_host[0][..5]);
753
754 let api = hf_hub::api::sync::Api::new().unwrap();
757 let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
758 let weights_path = repo.get("model.safetensors").unwrap();
759 let config_path = repo.get("config.json").unwrap();
760 let config_str = std::fs::read_to_string(&config_path).unwrap();
761 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
762 let config =
763 crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
764 let (arch, _mmap) = driver
765 .load_modern_bert_weights(&weights_path, &config)
766 .unwrap();
767
768 let hidden = driver
769 .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
770 .unwrap();
771 let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
772 let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
773 eprintln!(
774 "embedding: {nz}/{} nonzero, first 5: {:?}",
775 h[0].len(),
776 &h[0][..5]
777 );
778
779 let total = 8; let hd = 768;
782 let nh = 12;
783 let head_dim = 64;
784
785 let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
787 let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
788 driver
789 .layer_norm(
790 &mut ln_out,
791 &emb_clone,
792 &arch.weights.emb_norm_weight,
793 &arch.weights.zero_bias,
794 total,
795 hd,
796 1e-5,
797 )
798 .unwrap();
799 let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
800 let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
801 eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
802
803 let layer0 = &arch.weights.layers[0];
805 let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
806 driver
807 .gemm(
808 &ln_out,
809 &layer0.qkv_weight,
810 &mut qkv,
811 total,
812 3 * hd,
813 hd,
814 true,
815 )
816 .unwrap();
817 let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
818 let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
819 eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
820
821 let mut q = driver.alloc_zeros(total * hd).unwrap();
823 let mut k = driver.alloc_zeros(total * hd).unwrap();
824 let mut v = driver.alloc_zeros(total * hd).unwrap();
825 driver
826 .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
827 .unwrap();
828 let q_h = driver.to_host(&q, 1, total * hd).unwrap();
829 let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
830 eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
831
832 let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
834 driver
835 .gemm_batched(
836 &q,
837 &k,
838 &mut scores,
839 8,
840 8,
841 head_dim,
842 true,
843 8 * head_dim,
844 8 * head_dim,
845 8 * 8,
846 nh,
847 )
848 .unwrap();
849 let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
850 let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
851 eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
852
853 use crate::backend::arch::ModelArch;
855 let enc2 = Encoding {
856 input_ids: vec![1, 100, 200, 300, 2],
857 attention_mask: vec![1; 5],
858 token_type_ids: vec![0; 5],
859 };
860
861 let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
862 let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
863 let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
864 eprintln!(
865 "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
866 &quick[0][..3]
867 );
868
869 eprintln!("\n=== ModernBERT MRL Truncation ===");
871 let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
872 let full_emb = &full[0];
873 for dims in [64, 128, 256, 384, 512, 768] {
874 let t: Vec<f32> = full_emb[..dims].to_vec();
875 let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
876 let f_norm: f32 = full_emb[..dims]
877 .iter()
878 .map(|x| x * x)
879 .sum::<f32>()
880 .sqrt()
881 .max(1e-12);
882 let cos: f32 = t
883 .iter()
884 .zip(&full_emb[..dims])
885 .map(|(a, b)| a * b)
886 .sum::<f32>()
887 / (t_norm * f_norm);
888 eprintln!(" dims={dims:>3}: cosine={cos:.6}");
889 }
890
891 eprintln!("\n=== ModernBERT Throughput ===");
893 let mut encs = Vec::new();
895 for i in 0..32 {
896 let len = 16 + (i * 4); let mut ids = vec![1_i64]; for j in 1..len - 1 {
899 ids.push(100 + j as i64);
900 }
901 ids.push(2); encs.push(Encoding {
903 input_ids: ids.clone(),
904 attention_mask: vec![1; ids.len()],
905 token_type_ids: vec![0; ids.len()],
906 });
907 }
908
909 let _ = arch.forward(&driver, &encs[..4]);
911
912 let t0 = std::time::Instant::now();
914 let result = arch.forward(&driver, &encs).unwrap();
915 let elapsed = t0.elapsed();
916 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
917 eprintln!(
918 " batch={}, time={:.1}ms, throughput={:.1}/s",
919 encs.len(),
920 elapsed.as_secs_f64() * 1000.0,
921 throughput
922 );
923 assert_eq!(result.len(), 32);
924
925 let single = vec![encs[0].clone()];
927 let t1 = std::time::Instant::now();
928 let _ = arch.forward(&driver, &single).unwrap();
929 let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
930 eprintln!(" batch=1, time={single_ms:.1}ms");
931 }
932
933 #[cfg(feature = "metal")]
939 #[test]
940 #[ignore = "requires model download (~33MB)"]
941 fn classic_bert_driver_arch() {
942 use crate::backend::arch::ModelArch;
943
944 let model_repo = "BAAI/bge-small-en-v1.5";
945
946 let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
948 assert!(backend.is_gpu(), "Metal backend should be GPU");
949
950 let enc = Encoding {
951 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
952 attention_mask: vec![1, 1, 1, 1, 1, 1],
953 token_type_ids: vec![0, 0, 0, 0, 0, 0],
954 };
955
956 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
958 assert_eq!(result.len(), 1);
959 assert_eq!(result[0].len(), 384);
960
961 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
962 eprintln!(
963 "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
964 &result[0][..3]
965 );
966 assert!(
967 (l2 - 1.0).abs() < 0.01,
968 "embedding should be L2-normalized, got L2={l2}"
969 );
970
971 #[cfg(feature = "cpu")]
973 {
974 let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
975 .expect("CPU load failed");
976 let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
977 eprintln!("CPU first 5: {:?}", &cpu_result[0][..5]);
978 eprintln!("NEW first 5: {:?}", &result[0][..5]);
979 let cosine: f32 = result[0]
980 .iter()
981 .zip(&cpu_result[0])
982 .map(|(a, b)| a * b)
983 .sum();
984 eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
985 assert!(
986 cosine > 0.95,
987 "cosine similarity vs CPU should be >0.95, got {cosine}"
988 );
989 }
990
991 eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
993 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
994 let config_path = {
995 let api = hf_hub::api::sync::Api::new().unwrap();
996 let repo = api.model(model_repo.to_string());
997 repo.get("config.json").unwrap()
998 };
999 let weights_path = {
1000 let api = hf_hub::api::sync::Api::new().unwrap();
1001 let repo = api.model(model_repo.to_string());
1002 repo.get("model.safetensors").unwrap()
1003 };
1004 let config_str = std::fs::read_to_string(&config_path).unwrap();
1005 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1006 let config =
1007 crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1008 let (arch, _mmap) = driver
1009 .load_classic_bert_weights(&weights_path, &config)
1010 .unwrap();
1011
1012 let mut encs = Vec::new();
1014 for i in 0..32 {
1015 let len = 16 + (i * 4); let mut ids = vec![101_i64]; for j in 1..len - 1 {
1018 ids.push(100 + j as i64);
1019 }
1020 ids.push(102); encs.push(Encoding {
1022 input_ids: ids.clone(),
1023 attention_mask: vec![1; ids.len()],
1024 token_type_ids: vec![0; ids.len()],
1025 });
1026 }
1027
1028 let _ = arch.forward(&driver, &encs[..4]);
1030
1031 let t0 = std::time::Instant::now();
1033 let bench_result = arch.forward(&driver, &encs).unwrap();
1034 let elapsed = t0.elapsed();
1035 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1036 eprintln!(
1037 " batch={}, time={:.1}ms, throughput={:.1}/s",
1038 encs.len(),
1039 elapsed.as_secs_f64() * 1000.0,
1040 throughput
1041 );
1042 assert_eq!(bench_result.len(), 32);
1043 }
1044
1045 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1050 #[test]
1051 #[ignore = "requires model download (~33MB)"]
1052 fn classic_bert_cpu_driver_arch() {
1053 let model_repo = "BAAI/bge-small-en-v1.5";
1054
1055 let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1057 assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1058
1059 let enc = Encoding {
1060 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1061 attention_mask: vec![1, 1, 1, 1, 1, 1],
1062 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1063 };
1064
1065 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1067 assert_eq!(result.len(), 1);
1068 assert_eq!(result[0].len(), 384);
1069
1070 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1071 eprintln!(
1072 "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1073 &result[0][..5]
1074 );
1075 assert!(
1076 (l2 - 1.0).abs() < 0.01,
1077 "embedding should be L2-normalized, got L2={l2}"
1078 );
1079
1080 #[cfg(feature = "cpu")]
1082 {
1083 let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1084 .expect("monolithic CPU load failed");
1085 let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1086 eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1087 eprintln!("New first 5: {:?}", &result[0][..5]);
1088 let cosine: f32 = result[0]
1089 .iter()
1090 .zip(&cpu_result[0])
1091 .map(|(a, b)| a * b)
1092 .sum();
1093 eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1094 assert!(
1095 cosine > 0.999,
1096 "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1097 );
1098 }
1099 }
1100}