1pub mod arch;
8pub mod blas_info;
9#[cfg(feature = "cpu")]
10pub mod cpu;
11#[cfg(feature = "cuda")]
12pub mod cuda;
13pub mod driver;
14pub mod generic;
15#[cfg(feature = "metal")]
16pub mod metal_kernels;
17#[cfg(feature = "mlx")]
18pub mod mlx;
19#[cfg(feature = "cuda")]
20pub mod nvrtc_cubin;
21
22#[derive(Debug, Clone)]
28pub struct Encoding {
29 pub input_ids: Vec<i64>,
31 pub attention_mask: Vec<i64>,
33 pub token_type_ids: Vec<i64>,
35}
36
37pub trait EmbedBackend: Send + Sync {
50 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
60
61 fn supports_clone(&self) -> bool;
65
66 fn clone_backend(&self) -> Box<dyn EmbedBackend>;
73
74 fn is_gpu(&self) -> bool;
79
80 fn max_tokens(&self) -> usize {
85 512 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
91pub enum BackendKind {
92 Cuda,
94 Mlx,
96 #[default]
98 Cpu,
99 Metal,
101}
102
103impl std::fmt::Display for BackendKind {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::Cuda => write!(f, "cuda"),
107 Self::Mlx => write!(f, "mlx"),
108 Self::Cpu => write!(f, "cpu"),
109 Self::Metal => write!(f, "metal"),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default)]
119pub enum DeviceHint {
120 #[default]
122 Auto,
123 Cpu,
125 Gpu,
127}
128
129#[derive(Debug, Clone, Default)]
134pub struct InferenceOpts {}
135
136pub fn load_backend(
146 kind: BackendKind,
147 #[cfg_attr(
148 not(any(
149 feature = "cuda",
150 feature = "mlx",
151 feature = "cpu",
152 feature = "cpu-accelerate",
153 feature = "metal"
154 )),
155 expect(unused_variables, reason = "used when backend features are enabled")
156 )]
157 model_repo: &str,
158 #[cfg_attr(
159 not(any(
160 feature = "cuda",
161 feature = "mlx",
162 feature = "cpu",
163 feature = "cpu-accelerate",
164 feature = "metal"
165 )),
166 expect(unused_variables, reason = "used when backend features are enabled")
167 )]
168 device_hint: DeviceHint,
169) -> crate::Result<Box<dyn EmbedBackend>> {
170 match kind {
171 #[cfg(feature = "cuda")]
172 BackendKind::Cuda => {
173 if is_modernbert_model(model_repo) {
174 return load_modernbert_cuda(model_repo);
175 }
176 let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
177 Ok(Box::new(backend))
178 }
179 #[cfg(not(feature = "cuda"))]
180 BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
181 "cuda backend requires building with: cargo build --features cuda"
182 ))),
183 #[cfg(feature = "mlx")]
184 BackendKind::Mlx => {
185 let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
186 Ok(Box::new(backend))
187 }
188 #[cfg(not(feature = "mlx"))]
189 BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
190 "mlx backend requires building with: cargo build --features mlx"
191 ))),
192 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
193 BackendKind::Cpu => {
194 if is_modernbert_model(model_repo) {
195 return load_modernbert_cpu(model_repo);
196 }
197 #[cfg(feature = "cpu")]
198 {
199 let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
200 #[expect(
201 clippy::needless_return,
202 reason = "return needed before cfg(not) fallback"
203 )]
204 return Ok(Box::new(backend));
205 }
206 #[cfg(not(feature = "cpu"))]
207 Err(crate::Error::Other(anyhow::anyhow!(
208 "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
209 )))
210 }
211 #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
212 BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
213 "cpu backend requires building with: cargo build --features cpu"
214 ))),
215 #[cfg(feature = "metal")]
216 BackendKind::Metal => {
217 if is_modernbert_model(model_repo) {
219 return load_modernbert_metal(model_repo);
220 }
221 load_classic_metal(model_repo)
222 }
223 #[cfg(not(feature = "metal"))]
224 BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
225 "metal backend requires building with: cargo build --features metal"
226 ))),
227 }
228}
229
230pub fn detect_backends(
240 #[cfg_attr(
241 not(any(
242 feature = "cuda",
243 feature = "mlx",
244 feature = "cpu",
245 feature = "cpu-accelerate",
246 feature = "metal"
247 )),
248 expect(unused_variables, reason = "used when backend features are enabled")
249 )]
250 model_repo: &str,
251) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
252 #[cfg_attr(
253 not(any(
254 feature = "cuda",
255 feature = "mlx",
256 feature = "cpu",
257 feature = "cpu-accelerate",
258 feature = "metal"
259 )),
260 expect(unused_mut, reason = "mut needed when backend features are enabled")
261 )]
262 let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
263
264 #[cfg(feature = "cuda")]
266 {
267 if is_modernbert_model(model_repo) {
268 if let Ok(b) = load_modernbert_cuda(model_repo) {
269 backends.push(b);
270 }
271 } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
272 backends.push(Box::new(b));
273 }
274 }
275
276 #[cfg(feature = "metal")]
278 {
279 if is_modernbert_model(model_repo) {
281 if let Ok(b) = load_modernbert_metal(model_repo) {
282 backends.push(b);
283 }
284 } else if let Ok(b) = load_classic_metal(model_repo) {
285 backends.push(b);
286 }
287 }
288
289 #[cfg(feature = "mlx")]
291 if backends.is_empty()
292 && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
293 {
294 backends.push(Box::new(b));
295 }
296
297 #[cfg_attr(
302 not(any(feature = "cpu", feature = "cpu-accelerate")),
303 expect(unused_variables, reason = "used when cpu feature is enabled")
304 )]
305 let has_gpu = backends.iter().any(|b| b.is_gpu());
306 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
307 if !has_gpu {
308 if is_modernbert_model(model_repo) {
309 if let Ok(b) = load_modernbert_cpu(model_repo) {
310 backends.push(b);
311 }
312 } else {
313 #[cfg(feature = "cpu")]
314 if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
315 backends.push(Box::new(b));
316 }
317 }
318 }
319
320 if backends.is_empty() {
321 return Err(crate::Error::Other(anyhow::anyhow!(
322 "no embedding backends available"
323 )));
324 }
325
326 Ok(backends)
327}
328
329#[cfg(feature = "metal")]
345pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
346 use driver::metal::{MetalDriver, ModernBertConfig};
347 use generic::GenericBackend;
348 use hf_hub::api::sync::Api;
349
350 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
351 let repo = api.model(model_repo.to_string());
352
353 let config_path = repo
354 .get("config.json")
355 .map_err(|e| crate::Error::Download(e.to_string()))?;
356 let weights_path = repo
357 .get("model.safetensors")
358 .map_err(|e| crate::Error::Download(e.to_string()))?;
359
360 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
362 path: config_path.display().to_string(),
363 source: e,
364 })?;
365 let config_json: serde_json::Value = serde_json::from_str(&config_str)
366 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
367 let config = ModernBertConfig::from_json(&config_json)?;
368 let max_tokens = config.max_position_embeddings;
369
370 let driver = MetalDriver::new()?;
371 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
372
373 tracing::info!(
374 model_repo,
375 hidden = config.hidden_size,
376 layers = config.num_hidden_layers,
377 heads = config.num_attention_heads,
378 intermediate = config.intermediate_size,
379 max_tokens,
380 "ModernBERT loaded on Metal (driver/arch)"
381 );
382
383 Ok(Box::new(GenericBackend::new(
384 driver, arch, max_tokens, true, mmap,
385 )))
386}
387
388#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
390pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
391 use driver::cpu::{CpuDriver, ModernBertConfig};
392 use generic::GenericBackend;
393 use hf_hub::api::sync::Api;
394
395 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
396 let repo = api.model(model_repo.to_string());
397
398 let config_path = repo
399 .get("config.json")
400 .map_err(|e| crate::Error::Download(e.to_string()))?;
401 let weights_path = repo
402 .get("model.safetensors")
403 .map_err(|e| crate::Error::Download(e.to_string()))?;
404
405 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
406 path: config_path.display().to_string(),
407 source: e,
408 })?;
409 let config_json: serde_json::Value = serde_json::from_str(&config_str)
410 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
411 let config = ModernBertConfig::from_json(&config_json)?;
412 let max_tokens = config.max_position_embeddings;
413
414 let driver = CpuDriver::new()?;
415 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
416
417 tracing::info!(
418 model_repo,
419 hidden = config.hidden_size,
420 layers = config.num_hidden_layers,
421 heads = config.num_attention_heads,
422 max_tokens,
423 "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
424 );
425
426 Ok(Box::new(GenericBackend::new_shared(
427 driver, arch, max_tokens, false, mmap,
428 )))
429}
430
431#[cfg(feature = "cuda")]
442pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
443 use driver::cuda::{CudaDriver, ModernBertConfig};
444 use generic::GenericBackend;
445 use hf_hub::api::sync::Api;
446
447 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
448 let repo = api.model(model_repo.to_string());
449
450 let config_path = repo
451 .get("config.json")
452 .map_err(|e| crate::Error::Download(e.to_string()))?;
453 let weights_path = repo
454 .get("model.safetensors")
455 .map_err(|e| crate::Error::Download(e.to_string()))?;
456
457 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
459 path: config_path.display().to_string(),
460 source: e,
461 })?;
462 let config_json: serde_json::Value = serde_json::from_str(&config_str)
463 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
464 let config = ModernBertConfig::from_json(&config_json)?;
465 let max_tokens = config.max_position_embeddings;
466
467 let driver = CudaDriver::new()?;
468 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
469
470 tracing::info!(
471 model_repo,
472 hidden = config.hidden_size,
473 layers = config.num_hidden_layers,
474 heads = config.num_attention_heads,
475 intermediate = config.intermediate_size,
476 max_tokens,
477 "ModernBERT loaded on CUDA (driver/arch)"
478 );
479
480 Ok(Box::new(GenericBackend::with_max_batch(
483 driver,
484 arch,
485 max_tokens,
486 true,
487 generic::MmapHolder::Owned(mmap),
488 32,
489 )))
490}
491
492#[cfg(any(
497 feature = "cuda",
498 feature = "metal",
499 feature = "cpu",
500 feature = "cpu-accelerate"
501))]
502fn is_modernbert_model(model_repo: &str) -> bool {
503 let Ok(api) = hf_hub::api::sync::Api::new() else {
504 return false;
505 };
506 let repo = api.model(model_repo.to_string());
507 let Ok(config_path) = repo.get("config.json") else {
508 return false;
509 };
510 let Ok(config_str) = std::fs::read_to_string(&config_path) else {
511 return false;
512 };
513 let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
514 return false;
515 };
516 json.get("model_type")
517 .and_then(serde_json::Value::as_str)
518 .is_some_and(|t| t == "modernbert")
519}
520
521#[cfg(feature = "metal")]
538pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
539 use driver::metal::{ClassicBertConfig, MetalDriver};
540 use generic::GenericBackend;
541 use hf_hub::api::sync::Api;
542
543 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
544 let repo = api.model(model_repo.to_string());
545
546 let config_path = repo
547 .get("config.json")
548 .map_err(|e| crate::Error::Download(e.to_string()))?;
549 let weights_path = repo
550 .get("model.safetensors")
551 .map_err(|e| crate::Error::Download(e.to_string()))?;
552
553 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
555 path: config_path.display().to_string(),
556 source: e,
557 })?;
558 let config_json: serde_json::Value = serde_json::from_str(&config_str)
559 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
560 let config = ClassicBertConfig::from_json(&config_json)?;
561 let max_tokens = config.max_position_embeddings;
562
563 let driver = MetalDriver::new()?;
564 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
565
566 tracing::info!(
567 model_repo,
568 hidden = config.hidden_size,
569 layers = config.num_hidden_layers,
570 heads = config.num_attention_heads,
571 intermediate = config.intermediate_size,
572 max_tokens,
573 "ClassicBert loaded on Metal (driver/arch)"
574 );
575
576 Ok(Box::new(GenericBackend::new(
577 driver, arch, max_tokens, true, mmap,
578 )))
579}
580
581#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
598pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
599 use driver::cpu::{ClassicBertConfig, CpuDriver};
600 use generic::GenericBackend;
601 use hf_hub::api::sync::Api;
602
603 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
604 let repo = api.model(model_repo.to_string());
605
606 let config_path = repo
607 .get("config.json")
608 .map_err(|e| crate::Error::Download(e.to_string()))?;
609 let weights_path = repo
610 .get("model.safetensors")
611 .map_err(|e| crate::Error::Download(e.to_string()))?;
612
613 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
615 path: config_path.display().to_string(),
616 source: e,
617 })?;
618 let config_json: serde_json::Value = serde_json::from_str(&config_str)
619 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
620 let config = ClassicBertConfig::from_json(&config_json)?;
621 let max_tokens = config.max_position_embeddings;
622
623 let driver = CpuDriver::new()?;
624 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
625
626 tracing::info!(
627 model_repo,
628 hidden = config.hidden_size,
629 layers = config.num_hidden_layers,
630 heads = config.num_attention_heads,
631 intermediate = config.intermediate_size,
632 max_tokens,
633 "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
634 );
635
636 Ok(Box::new(GenericBackend::new_shared(
637 driver, arch, max_tokens, false, mmap,
638 )))
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644
645 #[test]
647 fn trait_is_object_safe() {
648 fn _assert_object_safe(_: &dyn EmbedBackend) {}
650 }
651
652 #[test]
654 fn trait_object_is_send() {
655 fn assert_send<T: Send>() {}
656 assert_send::<Box<dyn EmbedBackend>>();
657 }
658
659 #[test]
661 fn trait_object_is_sync() {
662 fn assert_sync<T: Sync>() {}
663 assert_sync::<Box<dyn EmbedBackend>>();
664 }
665
666 #[test]
668 fn arc_trait_object_is_send() {
669 fn assert_send<T: Send>() {}
670 assert_send::<std::sync::Arc<dyn EmbedBackend>>();
671 }
672
673 #[test]
674 fn encoding_construction() {
675 let enc = Encoding {
676 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
677 attention_mask: vec![1, 1, 1, 1, 1, 1],
678 token_type_ids: vec![0, 0, 0, 0, 0, 0],
679 };
680 assert_eq!(enc.input_ids.len(), 6);
681 assert_eq!(enc.attention_mask.len(), 6);
682 assert_eq!(enc.token_type_ids.len(), 6);
683 }
684
685 #[test]
686 fn encoding_clone() {
687 let enc = Encoding {
688 input_ids: vec![101, 102],
689 attention_mask: vec![1, 1],
690 token_type_ids: vec![0, 0],
691 };
692 let cloned = enc.clone();
693 assert_eq!(enc.input_ids, cloned.input_ids);
694 }
695
696 #[test]
697 fn backend_kind_default_is_cpu() {
698 assert_eq!(BackendKind::default(), BackendKind::Cpu);
699 }
700
701 #[test]
702 fn backend_kind_display() {
703 assert_eq!(BackendKind::Cuda.to_string(), "cuda");
704 assert_eq!(BackendKind::Mlx.to_string(), "mlx");
705 assert_eq!(BackendKind::Cpu.to_string(), "cpu");
706 }
707
708 #[cfg(not(feature = "mlx"))]
709 #[test]
710 fn load_backend_mlx_not_compiled() {
711 let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
712 assert!(result.is_err());
713 }
714
715 #[cfg(feature = "cpu")]
716 #[test]
717 fn detect_backends_returns_at_least_one() {
718 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
719 assert!(!backends.is_empty());
720 }
721
722 #[cfg(all(feature = "cpu", not(feature = "mlx")))]
723 #[test]
724 fn detect_backends_returns_at_least_one_backend() {
725 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
726 assert!(!backends.is_empty(), "should detect at least one backend");
727 }
728
729 #[cfg(feature = "metal")]
734 #[test]
735 #[ignore = "requires model download (~570MB)"]
736 fn modernbert_loads_and_embeds() {
737 use crate::backend::driver::Driver;
738
739 let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
740 assert!(backend.is_gpu(), "Metal backend should be GPU");
741
742 let enc = Encoding {
743 input_ids: vec![1, 100, 200, 300, 2],
744 attention_mask: vec![1; 5],
745 token_type_ids: vec![0; 5],
746 };
747
748 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
750 let inputs = driver.prepare_batch(&[enc.clone()], 8).unwrap();
751
752 let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
754 eprintln!("input_ids: {:?}", &ids_host[0][..5]);
755
756 let api = hf_hub::api::sync::Api::new().unwrap();
759 let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
760 let weights_path = repo.get("model.safetensors").unwrap();
761 let config_path = repo.get("config.json").unwrap();
762 let config_str = std::fs::read_to_string(&config_path).unwrap();
763 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
764 let config =
765 crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
766 let (arch, _mmap) = driver
767 .load_modern_bert_weights(&weights_path, &config)
768 .unwrap();
769
770 let hidden = driver
771 .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
772 .unwrap();
773 let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
774 let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
775 eprintln!(
776 "embedding: {nz}/{} nonzero, first 5: {:?}",
777 h[0].len(),
778 &h[0][..5]
779 );
780
781 let total = 8; let hd = 768;
784 let nh = 12;
785 let head_dim = 64;
786
787 let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
789 let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
790 driver
791 .layer_norm(
792 &mut ln_out,
793 &emb_clone,
794 &arch.weights.emb_norm_weight,
795 &arch.weights.zero_bias,
796 total,
797 hd,
798 1e-5,
799 )
800 .unwrap();
801 let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
802 let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
803 eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
804
805 let layer0 = &arch.weights.layers[0];
807 let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
808 driver
809 .gemm(
810 &ln_out,
811 &layer0.qkv_weight,
812 &mut qkv,
813 total,
814 3 * hd,
815 hd,
816 true,
817 )
818 .unwrap();
819 let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
820 let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
821 eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
822
823 let mut q = driver.alloc_zeros(total * hd).unwrap();
825 let mut k = driver.alloc_zeros(total * hd).unwrap();
826 let mut v = driver.alloc_zeros(total * hd).unwrap();
827 driver
828 .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
829 .unwrap();
830 let q_h = driver.to_host(&q, 1, total * hd).unwrap();
831 let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
832 eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
833
834 let mut scores = driver.alloc_zeros(1 * nh * 8 * 8).unwrap();
836 driver
837 .gemm_batched(
838 &q,
839 &k,
840 &mut scores,
841 8,
842 8,
843 head_dim,
844 true,
845 8 * head_dim,
846 8 * head_dim,
847 8 * 8,
848 nh,
849 )
850 .unwrap();
851 let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
852 let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
853 eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
854
855 use crate::backend::arch::ModelArch;
857 let enc2 = Encoding {
858 input_ids: vec![1, 100, 200, 300, 2],
859 attention_mask: vec![1; 5],
860 token_type_ids: vec![0; 5],
861 };
862
863 let quick = arch.forward(&driver, &[enc2.clone()]).unwrap();
864 let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
865 let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
866 eprintln!(
867 "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
868 &quick[0][..3]
869 );
870
871 eprintln!("\n=== ModernBERT MRL Truncation ===");
873 let full = arch.forward(&driver, &[enc2.clone()]).unwrap();
874 let full_emb = &full[0];
875 for dims in [64, 128, 256, 384, 512, 768] {
876 let t: Vec<f32> = full_emb[..dims].to_vec();
877 let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
878 let f_norm: f32 = full_emb[..dims]
879 .iter()
880 .map(|x| x * x)
881 .sum::<f32>()
882 .sqrt()
883 .max(1e-12);
884 let cos: f32 = t
885 .iter()
886 .zip(&full_emb[..dims])
887 .map(|(a, b)| a * b)
888 .sum::<f32>()
889 / (t_norm * f_norm);
890 eprintln!(" dims={dims:>3}: cosine={cos:.6}");
891 }
892
893 eprintln!("\n=== ModernBERT Throughput ===");
895 let mut encs = Vec::new();
897 for i in 0..32 {
898 let len = 16 + (i * 4); let mut ids = vec![1_i64]; for j in 1..len - 1 {
901 ids.push(100 + j as i64);
902 }
903 ids.push(2); encs.push(Encoding {
905 input_ids: ids.clone(),
906 attention_mask: vec![1; ids.len()],
907 token_type_ids: vec![0; ids.len()],
908 });
909 }
910
911 let _ = arch.forward(&driver, &encs[..4]);
913
914 let t0 = std::time::Instant::now();
916 let result = arch.forward(&driver, &encs).unwrap();
917 let elapsed = t0.elapsed();
918 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
919 eprintln!(
920 " batch={}, time={:.1}ms, throughput={:.1}/s",
921 encs.len(),
922 elapsed.as_secs_f64() * 1000.0,
923 throughput
924 );
925 assert_eq!(result.len(), 32);
926
927 let single = vec![encs[0].clone()];
929 let t1 = std::time::Instant::now();
930 let _ = arch.forward(&driver, &single).unwrap();
931 let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
932 eprintln!(" batch=1, time={single_ms:.1}ms");
933 }
934
935 #[cfg(feature = "metal")]
941 #[test]
942 #[ignore = "requires model download (~33MB)"]
943 fn classic_bert_driver_arch() {
944 use crate::backend::arch::ModelArch;
945
946 let model_repo = "BAAI/bge-small-en-v1.5";
947
948 let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
950 assert!(backend.is_gpu(), "Metal backend should be GPU");
951
952 let enc = Encoding {
953 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
954 attention_mask: vec![1, 1, 1, 1, 1, 1],
955 token_type_ids: vec![0, 0, 0, 0, 0, 0],
956 };
957
958 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
960 assert_eq!(result.len(), 1);
961 assert_eq!(result[0].len(), 384);
962
963 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
964 eprintln!(
965 "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
966 &result[0][..3]
967 );
968 assert!(
969 (l2 - 1.0).abs() < 0.01,
970 "embedding should be L2-normalized, got L2={l2}"
971 );
972
973 #[cfg(feature = "cpu")]
975 {
976 let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
977 .expect("CPU load failed");
978 let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
979 eprintln!("CPU first 5: {:?}", &cpu_result[0][..5]);
980 eprintln!("NEW first 5: {:?}", &result[0][..5]);
981 let cosine: f32 = result[0]
982 .iter()
983 .zip(&cpu_result[0])
984 .map(|(a, b)| a * b)
985 .sum();
986 eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
987 assert!(
988 cosine > 0.95,
989 "cosine similarity vs CPU should be >0.95, got {cosine}"
990 );
991 }
992
993 eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
995 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
996 let config_path = {
997 let api = hf_hub::api::sync::Api::new().unwrap();
998 let repo = api.model(model_repo.to_string());
999 repo.get("config.json").unwrap()
1000 };
1001 let weights_path = {
1002 let api = hf_hub::api::sync::Api::new().unwrap();
1003 let repo = api.model(model_repo.to_string());
1004 repo.get("model.safetensors").unwrap()
1005 };
1006 let config_str = std::fs::read_to_string(&config_path).unwrap();
1007 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1008 let config =
1009 crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1010 let (arch, _mmap) = driver
1011 .load_classic_bert_weights(&weights_path, &config)
1012 .unwrap();
1013
1014 let mut encs = Vec::new();
1016 for i in 0..32 {
1017 let len = 16 + (i * 4); let mut ids = vec![101_i64]; for j in 1..len - 1 {
1020 ids.push(100 + j as i64);
1021 }
1022 ids.push(102); encs.push(Encoding {
1024 input_ids: ids.clone(),
1025 attention_mask: vec![1; ids.len()],
1026 token_type_ids: vec![0; ids.len()],
1027 });
1028 }
1029
1030 let _ = arch.forward(&driver, &encs[..4]);
1032
1033 let t0 = std::time::Instant::now();
1035 let bench_result = arch.forward(&driver, &encs).unwrap();
1036 let elapsed = t0.elapsed();
1037 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1038 eprintln!(
1039 " batch={}, time={:.1}ms, throughput={:.1}/s",
1040 encs.len(),
1041 elapsed.as_secs_f64() * 1000.0,
1042 throughput
1043 );
1044 assert_eq!(bench_result.len(), 32);
1045 }
1046
1047 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1052 #[test]
1053 #[ignore = "requires model download (~33MB)"]
1054 fn classic_bert_cpu_driver_arch() {
1055 let model_repo = "BAAI/bge-small-en-v1.5";
1056
1057 let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1059 assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1060
1061 let enc = Encoding {
1062 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1063 attention_mask: vec![1, 1, 1, 1, 1, 1],
1064 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1065 };
1066
1067 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1069 assert_eq!(result.len(), 1);
1070 assert_eq!(result[0].len(), 384);
1071
1072 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1073 eprintln!(
1074 "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1075 &result[0][..5]
1076 );
1077 assert!(
1078 (l2 - 1.0).abs() < 0.01,
1079 "embedding should be L2-normalized, got L2={l2}"
1080 );
1081
1082 #[cfg(feature = "cpu")]
1084 {
1085 let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1086 .expect("monolithic CPU load failed");
1087 let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1088 eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1089 eprintln!("New first 5: {:?}", &result[0][..5]);
1090 let cosine: f32 = result[0]
1091 .iter()
1092 .zip(&cpu_result[0])
1093 .map(|(a, b)| a * b)
1094 .sum();
1095 eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1096 assert!(
1097 cosine > 0.999,
1098 "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1099 );
1100 }
1101 }
1102}