1pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19
20#[derive(Debug, Clone)]
26pub struct Encoding {
27 pub input_ids: Vec<i64>,
29 pub attention_mask: Vec<i64>,
31 pub token_type_ids: Vec<i64>,
33}
34
35pub trait EmbedBackend: Send + Sync {
48 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
58
59 fn supports_clone(&self) -> bool;
63
64 fn clone_backend(&self) -> Box<dyn EmbedBackend>;
71
72 fn is_gpu(&self) -> bool;
77
78 fn max_tokens(&self) -> usize {
83 512 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
89pub enum BackendKind {
90 Cuda,
92 Mlx,
94 #[default]
96 Cpu,
97 Metal,
99}
100
101impl std::fmt::Display for BackendKind {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 Self::Cuda => write!(f, "cuda"),
105 Self::Mlx => write!(f, "mlx"),
106 Self::Cpu => write!(f, "cpu"),
107 Self::Metal => write!(f, "metal"),
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, Default)]
117pub enum DeviceHint {
118 #[default]
120 Auto,
121 Cpu,
123 Gpu,
125}
126
127#[derive(Debug, Clone, Default)]
132pub struct InferenceOpts {}
133
134pub fn load_backend(
144 kind: BackendKind,
145 #[cfg_attr(
146 not(any(
147 feature = "cuda",
148 feature = "mlx",
149 feature = "cpu",
150 feature = "cpu-accelerate",
151 feature = "metal"
152 )),
153 expect(unused_variables, reason = "used when backend features are enabled")
154 )]
155 model_repo: &str,
156 #[cfg_attr(
157 not(any(
158 feature = "cuda",
159 feature = "mlx",
160 feature = "cpu",
161 feature = "cpu-accelerate",
162 feature = "metal"
163 )),
164 expect(unused_variables, reason = "used when backend features are enabled")
165 )]
166 device_hint: DeviceHint,
167) -> crate::Result<Box<dyn EmbedBackend>> {
168 match kind {
169 #[cfg(feature = "cuda")]
170 BackendKind::Cuda => {
171 if is_modernbert_model(model_repo) {
172 return load_modernbert_cuda(model_repo, max_layers);
173 }
174 let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
175 Ok(Box::new(backend))
176 }
177 #[cfg(not(feature = "cuda"))]
178 BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
179 "cuda backend requires building with: cargo build --features cuda"
180 ))),
181 #[cfg(feature = "mlx")]
182 BackendKind::Mlx => {
183 let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
184 Ok(Box::new(backend))
185 }
186 #[cfg(not(feature = "mlx"))]
187 BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
188 "mlx backend requires building with: cargo build --features mlx"
189 ))),
190 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
191 BackendKind::Cpu => {
192 if is_modernbert_model(model_repo) {
193 return load_modernbert_cpu(model_repo);
194 }
195 #[cfg(feature = "cpu")]
196 {
197 let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
198 return Ok(Box::new(backend));
199 }
200 #[cfg(not(feature = "cpu"))]
201 Err(crate::Error::Other(anyhow::anyhow!(
202 "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
203 )))
204 }
205 #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
206 BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
207 "cpu backend requires building with: cargo build --features cpu"
208 ))),
209 #[cfg(feature = "metal")]
210 BackendKind::Metal => {
211 if is_modernbert_model(model_repo) {
213 return load_modernbert_metal(model_repo);
214 }
215 load_classic_metal(model_repo)
216 }
217 #[cfg(not(feature = "metal"))]
218 BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
219 "metal backend requires building with: cargo build --features metal"
220 ))),
221 }
222}
223
224pub fn detect_backends(
234 #[cfg_attr(
235 not(any(
236 feature = "cuda",
237 feature = "mlx",
238 feature = "cpu",
239 feature = "cpu-accelerate",
240 feature = "metal"
241 )),
242 expect(unused_variables, reason = "used when backend features are enabled")
243 )]
244 model_repo: &str,
245) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
246 #[cfg_attr(
247 not(any(
248 feature = "cuda",
249 feature = "mlx",
250 feature = "cpu",
251 feature = "cpu-accelerate",
252 feature = "metal"
253 )),
254 expect(unused_mut, reason = "mut needed when backend features are enabled")
255 )]
256 let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
257
258 #[cfg(feature = "cuda")]
260 {
261 if is_modernbert_model(model_repo) {
262 if let Ok(b) = load_modernbert_cuda(model_repo, max_layers) {
263 backends.push(b);
264 }
265 } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
266 backends.push(Box::new(b));
267 }
268 }
269
270 #[cfg(feature = "metal")]
272 {
273 if is_modernbert_model(model_repo) {
275 if let Ok(b) = load_modernbert_metal(model_repo) {
276 backends.push(b);
277 }
278 } else if let Ok(b) = load_classic_metal(model_repo) {
279 backends.push(b);
280 }
281 }
282
283 #[cfg(feature = "mlx")]
285 if backends.is_empty()
286 && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
287 {
288 backends.push(Box::new(b));
289 }
290
291 #[cfg_attr(
296 not(any(feature = "cpu", feature = "cpu-accelerate")),
297 expect(unused_variables, reason = "used when cpu feature is enabled")
298 )]
299 let has_gpu = backends.iter().any(|b| b.is_gpu());
300 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
301 if !has_gpu {
302 if is_modernbert_model(model_repo) {
303 if let Ok(b) = load_modernbert_cpu(model_repo) {
304 backends.push(b);
305 }
306 } else {
307 #[cfg(feature = "cpu")]
308 if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
309 backends.push(Box::new(b));
310 }
311 }
312 }
313
314 if backends.is_empty() {
315 return Err(crate::Error::Other(anyhow::anyhow!(
316 "no embedding backends available"
317 )));
318 }
319
320 Ok(backends)
321}
322
323#[cfg(feature = "metal")]
339pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
340 use driver::metal::{MetalDriver, ModernBertConfig};
341 use generic::GenericBackend;
342 use hf_hub::api::sync::Api;
343
344 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
345 let repo = api.model(model_repo.to_string());
346
347 let config_path = repo
348 .get("config.json")
349 .map_err(|e| crate::Error::Download(e.to_string()))?;
350 let weights_path = repo
351 .get("model.safetensors")
352 .map_err(|e| crate::Error::Download(e.to_string()))?;
353
354 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
356 path: config_path.display().to_string(),
357 source: e,
358 })?;
359 let config_json: serde_json::Value = serde_json::from_str(&config_str)
360 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
361 let config = ModernBertConfig::from_json(&config_json)?;
362 let max_tokens = config.max_position_embeddings;
363
364 let driver = MetalDriver::new()?;
365 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
366
367 tracing::info!(
368 model_repo,
369 hidden = config.hidden_size,
370 layers = config.num_hidden_layers,
371 heads = config.num_attention_heads,
372 intermediate = config.intermediate_size,
373 max_tokens,
374 "ModernBERT loaded on Metal (driver/arch)"
375 );
376
377 Ok(Box::new(GenericBackend::new(
378 driver, arch, max_tokens, true, mmap,
379 )))
380}
381
382#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
384pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
385 use driver::cpu::{CpuDriver, ModernBertConfig};
386 use generic::GenericBackend;
387 use hf_hub::api::sync::Api;
388
389 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
390 let repo = api.model(model_repo.to_string());
391
392 let config_path = repo
393 .get("config.json")
394 .map_err(|e| crate::Error::Download(e.to_string()))?;
395 let weights_path = repo
396 .get("model.safetensors")
397 .map_err(|e| crate::Error::Download(e.to_string()))?;
398
399 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
400 path: config_path.display().to_string(),
401 source: e,
402 })?;
403 let config_json: serde_json::Value = serde_json::from_str(&config_str)
404 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
405 let config = ModernBertConfig::from_json(&config_json)?;
406 let max_tokens = config.max_position_embeddings;
407
408 let driver = CpuDriver::new()?;
409 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
410
411 tracing::info!(
412 model_repo,
413 hidden = config.hidden_size,
414 layers = config.num_hidden_layers,
415 heads = config.num_attention_heads,
416 max_tokens,
417 "ModernBERT loaded on CPU (driver/arch)"
418 );
419
420 Ok(Box::new(GenericBackend::new(
421 driver, arch, max_tokens, false, mmap,
422 )))
423}
424
425#[cfg(feature = "cuda")]
436pub fn load_modernbert_cuda(
437 model_repo: &str,
438 max_layers: Option<usize>,
439) -> crate::Result<Box<dyn EmbedBackend>> {
440 use driver::cuda::{CudaDriver, ModernBertConfig};
441 use generic::GenericBackend;
442 use hf_hub::api::sync::Api;
443
444 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
445 let repo = api.model(model_repo.to_string());
446
447 let config_path = repo
448 .get("config.json")
449 .map_err(|e| crate::Error::Download(e.to_string()))?;
450 let weights_path = repo
451 .get("model.safetensors")
452 .map_err(|e| crate::Error::Download(e.to_string()))?;
453
454 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
456 path: config_path.display().to_string(),
457 source: e,
458 })?;
459 let config_json: serde_json::Value = serde_json::from_str(&config_str)
460 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
461 let config = ModernBertConfig::from_json(&config_json)?;
462 let max_tokens = config.max_position_embeddings;
463
464 let driver = CudaDriver::new()?;
465 let (mut arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
466 arch.max_layers = max_layers;
467
468 tracing::info!(
469 model_repo,
470 hidden = config.hidden_size,
471 layers = config.num_hidden_layers,
472 heads = config.num_attention_heads,
473 intermediate = config.intermediate_size,
474 max_tokens,
475 "ModernBERT loaded on CUDA (driver/arch)"
476 );
477
478 Ok(Box::new(GenericBackend::with_max_batch(
481 driver, arch, max_tokens, true, mmap, 32,
482 )))
483}
484
485#[cfg(any(
490 feature = "cuda",
491 feature = "metal",
492 feature = "cpu",
493 feature = "cpu-accelerate"
494))]
495fn is_modernbert_model(model_repo: &str) -> bool {
496 let Ok(api) = hf_hub::api::sync::Api::new() else {
497 return false;
498 };
499 let repo = api.model(model_repo.to_string());
500 let Ok(config_path) = repo.get("config.json") else {
501 return false;
502 };
503 let Ok(config_str) = std::fs::read_to_string(&config_path) else {
504 return false;
505 };
506 let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
507 return false;
508 };
509 json.get("model_type")
510 .and_then(serde_json::Value::as_str)
511 .is_some_and(|t| t == "modernbert")
512}
513
514#[cfg(feature = "metal")]
531pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
532 use driver::metal::{ClassicBertConfig, MetalDriver};
533 use generic::GenericBackend;
534 use hf_hub::api::sync::Api;
535
536 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
537 let repo = api.model(model_repo.to_string());
538
539 let config_path = repo
540 .get("config.json")
541 .map_err(|e| crate::Error::Download(e.to_string()))?;
542 let weights_path = repo
543 .get("model.safetensors")
544 .map_err(|e| crate::Error::Download(e.to_string()))?;
545
546 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
548 path: config_path.display().to_string(),
549 source: e,
550 })?;
551 let config_json: serde_json::Value = serde_json::from_str(&config_str)
552 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
553 let config = ClassicBertConfig::from_json(&config_json)?;
554 let max_tokens = config.max_position_embeddings;
555
556 let driver = MetalDriver::new()?;
557 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
558
559 tracing::info!(
560 model_repo,
561 hidden = config.hidden_size,
562 layers = config.num_hidden_layers,
563 heads = config.num_attention_heads,
564 intermediate = config.intermediate_size,
565 max_tokens,
566 "ClassicBert loaded on Metal (driver/arch)"
567 );
568
569 Ok(Box::new(GenericBackend::new(
570 driver, arch, max_tokens, true, mmap,
571 )))
572}
573
574#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
591pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
592 use driver::cpu::{ClassicBertConfig, CpuDriver};
593 use generic::GenericBackend;
594 use hf_hub::api::sync::Api;
595
596 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
597 let repo = api.model(model_repo.to_string());
598
599 let config_path = repo
600 .get("config.json")
601 .map_err(|e| crate::Error::Download(e.to_string()))?;
602 let weights_path = repo
603 .get("model.safetensors")
604 .map_err(|e| crate::Error::Download(e.to_string()))?;
605
606 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
608 path: config_path.display().to_string(),
609 source: e,
610 })?;
611 let config_json: serde_json::Value = serde_json::from_str(&config_str)
612 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
613 let config = ClassicBertConfig::from_json(&config_json)?;
614 let max_tokens = config.max_position_embeddings;
615
616 let driver = CpuDriver::new()?;
617 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
618
619 tracing::info!(
620 model_repo,
621 hidden = config.hidden_size,
622 layers = config.num_hidden_layers,
623 heads = config.num_attention_heads,
624 intermediate = config.intermediate_size,
625 max_tokens,
626 "ClassicBert loaded on CPU (driver/arch)"
627 );
628
629 Ok(Box::new(GenericBackend::new(
630 driver, arch, max_tokens, false, mmap,
631 )))
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637
638 #[test]
640 fn trait_is_object_safe() {
641 fn _assert_object_safe(_: &dyn EmbedBackend) {}
643 }
644
645 #[test]
647 fn trait_object_is_send() {
648 fn assert_send<T: Send>() {}
649 assert_send::<Box<dyn EmbedBackend>>();
650 }
651
652 #[test]
654 fn trait_object_is_sync() {
655 fn assert_sync<T: Sync>() {}
656 assert_sync::<Box<dyn EmbedBackend>>();
657 }
658
659 #[test]
661 fn arc_trait_object_is_send() {
662 fn assert_send<T: Send>() {}
663 assert_send::<std::sync::Arc<dyn EmbedBackend>>();
664 }
665
666 #[test]
667 fn encoding_construction() {
668 let enc = Encoding {
669 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
670 attention_mask: vec![1, 1, 1, 1, 1, 1],
671 token_type_ids: vec![0, 0, 0, 0, 0, 0],
672 };
673 assert_eq!(enc.input_ids.len(), 6);
674 assert_eq!(enc.attention_mask.len(), 6);
675 assert_eq!(enc.token_type_ids.len(), 6);
676 }
677
678 #[test]
679 fn encoding_clone() {
680 let enc = Encoding {
681 input_ids: vec![101, 102],
682 attention_mask: vec![1, 1],
683 token_type_ids: vec![0, 0],
684 };
685 let cloned = enc.clone();
686 assert_eq!(enc.input_ids, cloned.input_ids);
687 }
688
689 #[test]
690 fn backend_kind_default_is_cpu() {
691 assert_eq!(BackendKind::default(), BackendKind::Cpu);
692 }
693
694 #[test]
695 fn backend_kind_display() {
696 assert_eq!(BackendKind::Cuda.to_string(), "cuda");
697 assert_eq!(BackendKind::Mlx.to_string(), "mlx");
698 assert_eq!(BackendKind::Cpu.to_string(), "cpu");
699 }
700
701 #[cfg(not(feature = "mlx"))]
702 #[test]
703 fn load_backend_mlx_not_compiled() {
704 let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
705 assert!(result.is_err());
706 }
707
708 #[cfg(feature = "cpu")]
709 #[test]
710 fn detect_backends_returns_at_least_one() {
711 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
712 assert!(!backends.is_empty());
713 }
714
715 #[cfg(all(feature = "cpu", not(feature = "mlx")))]
716 #[test]
717 fn detect_backends_returns_at_least_one_backend() {
718 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
719 assert!(!backends.is_empty(), "should detect at least one backend");
720 }
721
722 #[cfg(feature = "metal")]
727 #[test]
728 #[ignore = "requires model download (~570MB)"]
729 fn modernbert_loads_and_embeds() {
730 use crate::backend::driver::Driver;
731
732 let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
733 assert!(backend.is_gpu(), "Metal backend should be GPU");
734
735 let enc = Encoding {
736 input_ids: vec![1, 100, 200, 300, 2],
737 attention_mask: vec![1; 5],
738 token_type_ids: vec![0; 5],
739 };
740
741 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
743 let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
744
745 let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
747 eprintln!("input_ids: {:?}", &ids_host[0][..5]);
748
749 let api = hf_hub::api::sync::Api::new().unwrap();
752 let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
753 let weights_path = repo.get("model.safetensors").unwrap();
754 let config_path = repo.get("config.json").unwrap();
755 let config_str = std::fs::read_to_string(&config_path).unwrap();
756 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
757 let config =
758 crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
759 let (arch, _mmap) = driver
760 .load_modern_bert_weights(&weights_path, &config)
761 .unwrap();
762
763 let hidden = driver
764 .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
765 .unwrap();
766 let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
767 let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
768 eprintln!(
769 "embedding: {nz}/{} nonzero, first 5: {:?}",
770 h[0].len(),
771 &h[0][..5]
772 );
773
774 let total = 8; let hd = 768;
777 let nh = 12;
778 let head_dim = 64;
779
780 let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
782 let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
783 driver
784 .layer_norm(
785 &mut ln_out,
786 &emb_clone,
787 &arch.weights.emb_norm_weight,
788 &arch.weights.zero_bias,
789 total,
790 hd,
791 1e-5,
792 )
793 .unwrap();
794 let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
795 let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
796 eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
797
798 let layer0 = &arch.weights.layers[0];
800 let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
801 driver
802 .gemm(
803 &ln_out,
804 &layer0.qkv_weight,
805 &mut qkv,
806 total,
807 3 * hd,
808 hd,
809 true,
810 )
811 .unwrap();
812 let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
813 let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
814 eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
815
816 let mut q = driver.alloc_zeros(total * hd).unwrap();
818 let mut k = driver.alloc_zeros(total * hd).unwrap();
819 let mut v = driver.alloc_zeros(total * hd).unwrap();
820 driver
821 .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
822 .unwrap();
823 let q_h = driver.to_host(&q, 1, total * hd).unwrap();
824 let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
825 eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
826
827 let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
829 driver
830 .gemm_batched(
831 &q,
832 &k,
833 &mut scores,
834 8,
835 8,
836 head_dim,
837 true,
838 8 * head_dim,
839 8 * head_dim,
840 8 * 8,
841 nh,
842 )
843 .unwrap();
844 let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
845 let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
846 eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
847
848 use crate::backend::arch::ModelArch;
850 let enc2 = Encoding {
851 input_ids: vec![1, 100, 200, 300, 2],
852 attention_mask: vec![1; 5],
853 token_type_ids: vec![0; 5],
854 };
855
856 let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
857 let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
858 let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
859 eprintln!(
860 "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
861 &quick[0][..3]
862 );
863
864 eprintln!("\n=== ModernBERT MRL Truncation ===");
866 let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
867 let full_emb = &full[0];
868 for dims in [64, 128, 256, 384, 512, 768] {
869 let t: Vec<f32> = full_emb[..dims].to_vec();
870 let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
871 let f_norm: f32 = full_emb[..dims]
872 .iter()
873 .map(|x| x * x)
874 .sum::<f32>()
875 .sqrt()
876 .max(1e-12);
877 let cos: f32 = t
878 .iter()
879 .zip(&full_emb[..dims])
880 .map(|(a, b)| a * b)
881 .sum::<f32>()
882 / (t_norm * f_norm);
883 eprintln!(" dims={dims:>3}: cosine={cos:.6}");
884 }
885
886 eprintln!("\n=== ModernBERT Throughput ===");
888 let mut encs = Vec::new();
890 for i in 0..32 {
891 let len = 16 + (i * 4); let mut ids = vec![1_i64]; for j in 1..len - 1 {
894 ids.push(100 + j as i64);
895 }
896 ids.push(2); encs.push(Encoding {
898 input_ids: ids.clone(),
899 attention_mask: vec![1; ids.len()],
900 token_type_ids: vec![0; ids.len()],
901 });
902 }
903
904 let _ = arch.forward(&driver, &encs[..4]);
906
907 let t0 = std::time::Instant::now();
909 let result = arch.forward(&driver, &encs).unwrap();
910 let elapsed = t0.elapsed();
911 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
912 eprintln!(
913 " batch={}, time={:.1}ms, throughput={:.1}/s",
914 encs.len(),
915 elapsed.as_secs_f64() * 1000.0,
916 throughput
917 );
918 assert_eq!(result.len(), 32);
919
920 let single = vec![encs[0].clone()];
922 let t1 = std::time::Instant::now();
923 let _ = arch.forward(&driver, &single).unwrap();
924 let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
925 eprintln!(" batch=1, time={single_ms:.1}ms");
926 }
927
928 #[cfg(feature = "metal")]
934 #[test]
935 #[ignore = "requires model download (~33MB)"]
936 fn classic_bert_driver_arch() {
937 use crate::backend::arch::ModelArch;
938
939 let model_repo = "BAAI/bge-small-en-v1.5";
940
941 let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
943 assert!(backend.is_gpu(), "Metal backend should be GPU");
944
945 let enc = Encoding {
946 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
947 attention_mask: vec![1, 1, 1, 1, 1, 1],
948 token_type_ids: vec![0, 0, 0, 0, 0, 0],
949 };
950
951 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
953 assert_eq!(result.len(), 1);
954 assert_eq!(result[0].len(), 384);
955
956 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
957 eprintln!(
958 "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
959 &result[0][..3]
960 );
961 assert!(
962 (l2 - 1.0).abs() < 0.01,
963 "embedding should be L2-normalized, got L2={l2}"
964 );
965
966 #[cfg(feature = "cpu")]
968 {
969 let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
970 .expect("CPU load failed");
971 let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
972 eprintln!("CPU first 5: {:?}", &cpu_result[0][..5]);
973 eprintln!("NEW first 5: {:?}", &result[0][..5]);
974 let cosine: f32 = result[0]
975 .iter()
976 .zip(&cpu_result[0])
977 .map(|(a, b)| a * b)
978 .sum();
979 eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
980 assert!(
981 cosine > 0.95,
982 "cosine similarity vs CPU should be >0.95, got {cosine}"
983 );
984 }
985
986 eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
988 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
989 let config_path = {
990 let api = hf_hub::api::sync::Api::new().unwrap();
991 let repo = api.model(model_repo.to_string());
992 repo.get("config.json").unwrap()
993 };
994 let weights_path = {
995 let api = hf_hub::api::sync::Api::new().unwrap();
996 let repo = api.model(model_repo.to_string());
997 repo.get("model.safetensors").unwrap()
998 };
999 let config_str = std::fs::read_to_string(&config_path).unwrap();
1000 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1001 let config =
1002 crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1003 let (arch, _mmap) = driver
1004 .load_classic_bert_weights(&weights_path, &config)
1005 .unwrap();
1006
1007 let mut encs = Vec::new();
1009 for i in 0..32 {
1010 let len = 16 + (i * 4); let mut ids = vec![101_i64]; for j in 1..len - 1 {
1013 ids.push(100 + j as i64);
1014 }
1015 ids.push(102); encs.push(Encoding {
1017 input_ids: ids.clone(),
1018 attention_mask: vec![1; ids.len()],
1019 token_type_ids: vec![0; ids.len()],
1020 });
1021 }
1022
1023 let _ = arch.forward(&driver, &encs[..4]);
1025
1026 let t0 = std::time::Instant::now();
1028 let bench_result = arch.forward(&driver, &encs).unwrap();
1029 let elapsed = t0.elapsed();
1030 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1031 eprintln!(
1032 " batch={}, time={:.1}ms, throughput={:.1}/s",
1033 encs.len(),
1034 elapsed.as_secs_f64() * 1000.0,
1035 throughput
1036 );
1037 assert_eq!(bench_result.len(), 32);
1038 }
1039
1040 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1045 #[test]
1046 #[ignore = "requires model download (~33MB)"]
1047 fn classic_bert_cpu_driver_arch() {
1048 let model_repo = "BAAI/bge-small-en-v1.5";
1049
1050 let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1052 assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1053
1054 let enc = Encoding {
1055 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1056 attention_mask: vec![1, 1, 1, 1, 1, 1],
1057 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1058 };
1059
1060 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1062 assert_eq!(result.len(), 1);
1063 assert_eq!(result[0].len(), 384);
1064
1065 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1066 eprintln!(
1067 "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1068 &result[0][..5]
1069 );
1070 assert!(
1071 (l2 - 1.0).abs() < 0.01,
1072 "embedding should be L2-normalized, got L2={l2}"
1073 );
1074
1075 #[cfg(feature = "cpu")]
1077 {
1078 let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1079 .expect("monolithic CPU load failed");
1080 let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1081 eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1082 eprintln!("New first 5: {:?}", &result[0][..5]);
1083 let cosine: f32 = result[0]
1084 .iter()
1085 .zip(&cpu_result[0])
1086 .map(|(a, b)| a * b)
1087 .sum();
1088 eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1089 assert!(
1090 cosine > 0.999,
1091 "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1092 );
1093 }
1094 }
1095}