1pub mod arch;
8pub mod blas_info;
9#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
15pub mod cpu;
16#[cfg(feature = "cuda")]
17pub mod cuda;
18pub mod driver;
19pub mod generic;
20#[cfg(feature = "metal")]
21pub mod metal_kernels;
22#[cfg(feature = "mlx")]
23pub mod mlx;
24#[cfg(feature = "cuda")]
25pub mod nvrtc_cubin;
26
27#[derive(Debug, Clone)]
33pub struct Encoding {
34 pub input_ids: Vec<i64>,
36 pub attention_mask: Vec<i64>,
38 pub token_type_ids: Vec<i64>,
40}
41
42pub trait EmbedBackend: Send + Sync {
55 fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>>;
65
66 fn supports_clone(&self) -> bool;
70
71 fn clone_backend(&self) -> Box<dyn EmbedBackend>;
78
79 fn is_gpu(&self) -> bool;
84
85 fn max_tokens(&self) -> usize {
90 512 }
92
93 fn name(&self) -> &'static str {
97 if self.is_gpu() { "GPU" } else { "CPU" }
98 }
99}
100
101pub trait RerankBackend: Send + Sync {
119 fn score_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<f32>>;
131
132 fn max_tokens(&self) -> usize {
134 512
135 }
136
137 fn is_gpu(&self) -> bool;
139
140 fn name(&self) -> &'static str {
142 if self.is_gpu() { "GPU" } else { "CPU" }
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
148pub enum BackendKind {
149 Cuda,
151 Mlx,
153 #[default]
155 Cpu,
156 Metal,
158}
159
160impl std::fmt::Display for BackendKind {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 Self::Cuda => write!(f, "cuda"),
164 Self::Mlx => write!(f, "mlx"),
165 Self::Cpu => write!(f, "cpu"),
166 Self::Metal => write!(f, "metal"),
167 }
168 }
169}
170
171#[derive(Debug, Clone, Copy, Default)]
176pub enum DeviceHint {
177 #[default]
179 Auto,
180 Cpu,
182 Gpu,
184}
185
186#[derive(Debug, Clone, Default)]
191pub struct InferenceOpts {}
192
193pub fn load_backend(
203 kind: BackendKind,
204 #[cfg_attr(
205 not(any(
206 feature = "cuda",
207 feature = "mlx",
208 feature = "cpu",
209 feature = "cpu-accelerate",
210 feature = "metal"
211 )),
212 expect(unused_variables, reason = "used when backend features are enabled")
213 )]
214 model_repo: &str,
215 #[cfg_attr(
216 not(any(
217 feature = "cuda",
218 feature = "mlx",
219 feature = "cpu",
220 feature = "cpu-accelerate",
221 feature = "metal"
222 )),
223 expect(unused_variables, reason = "used when backend features are enabled")
224 )]
225 device_hint: DeviceHint,
226) -> crate::Result<Box<dyn EmbedBackend>> {
227 match kind {
228 #[cfg(feature = "cuda")]
229 BackendKind::Cuda => {
230 if is_modernbert_model(model_repo) {
231 return load_modernbert_cuda(model_repo);
232 }
233 let backend = cuda::CudaBackend::load(model_repo, &device_hint)?;
234 Ok(Box::new(backend))
235 }
236 #[cfg(not(feature = "cuda"))]
237 BackendKind::Cuda => Err(crate::Error::Other(anyhow::anyhow!(
238 "cuda backend requires building with: cargo build --features cuda"
239 ))),
240 #[cfg(feature = "mlx")]
241 BackendKind::Mlx => {
242 let backend = mlx::MlxBackend::load(model_repo, &device_hint)?;
243 Ok(Box::new(backend))
244 }
245 #[cfg(not(feature = "mlx"))]
246 BackendKind::Mlx => Err(crate::Error::Other(anyhow::anyhow!(
247 "mlx backend requires building with: cargo build --features mlx"
248 ))),
249 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
250 BackendKind::Cpu => {
251 if is_modernbert_model(model_repo) {
252 return load_modernbert_cpu(model_repo);
253 }
254 #[cfg(feature = "cpu")]
255 {
256 let backend = cpu::CpuBackend::load(model_repo, &device_hint)?;
257 #[expect(
258 clippy::needless_return,
259 reason = "return needed before cfg(not) fallback"
260 )]
261 return Ok(Box::new(backend));
262 }
263 #[cfg(not(feature = "cpu"))]
264 Err(crate::Error::Other(anyhow::anyhow!(
265 "ClassicBert CPU backend requires feature 'cpu'; only ModernBERT is available with 'cpu-accelerate'"
266 )))
267 }
268 #[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
269 BackendKind::Cpu => Err(crate::Error::Other(anyhow::anyhow!(
270 "cpu backend requires building with: cargo build --features cpu"
271 ))),
272 #[cfg(feature = "metal")]
273 BackendKind::Metal => {
274 if is_modernbert_model(model_repo) {
276 return load_modernbert_metal(model_repo);
277 }
278 load_classic_metal(model_repo)
279 }
280 #[cfg(not(feature = "metal"))]
281 BackendKind::Metal => Err(crate::Error::Other(anyhow::anyhow!(
282 "metal backend requires building with: cargo build --features metal"
283 ))),
284 }
285}
286
287pub fn detect_backends(
297 #[cfg_attr(
298 not(any(
299 feature = "cuda",
300 feature = "mlx",
301 feature = "cpu",
302 feature = "cpu-accelerate",
303 feature = "metal"
304 )),
305 expect(unused_variables, reason = "used when backend features are enabled")
306 )]
307 model_repo: &str,
308) -> crate::Result<Vec<Box<dyn EmbedBackend>>> {
309 #[cfg_attr(
310 not(any(
311 feature = "cuda",
312 feature = "mlx",
313 feature = "cpu",
314 feature = "cpu-accelerate",
315 feature = "metal"
316 )),
317 expect(unused_mut, reason = "mut needed when backend features are enabled")
318 )]
319 let mut backends: Vec<Box<dyn EmbedBackend>> = Vec::new();
320
321 #[cfg(feature = "cuda")]
323 {
324 if is_modernbert_model(model_repo) {
325 if let Ok(b) = load_modernbert_cuda(model_repo) {
326 backends.push(b);
327 }
328 } else if let Ok(b) = cuda::CudaBackend::load(model_repo, &DeviceHint::Gpu) {
329 backends.push(Box::new(b));
330 }
331 }
332
333 #[cfg(feature = "metal")]
335 {
336 if is_modernbert_model(model_repo) {
338 if let Ok(b) = load_modernbert_metal(model_repo) {
339 backends.push(b);
340 }
341 } else if let Ok(b) = load_classic_metal(model_repo) {
342 backends.push(b);
343 }
344 }
345
346 #[cfg(feature = "mlx")]
348 if backends.is_empty()
349 && let Ok(b) = mlx::MlxBackend::load(model_repo, &DeviceHint::Auto)
350 {
351 backends.push(Box::new(b));
352 }
353
354 #[cfg_attr(
359 not(any(feature = "cpu", feature = "cpu-accelerate")),
360 expect(unused_variables, reason = "used when cpu feature is enabled")
361 )]
362 let has_gpu = backends.iter().any(|b| b.is_gpu());
363 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
364 if !has_gpu {
365 if is_modernbert_model(model_repo) {
366 if let Ok(b) = load_modernbert_cpu(model_repo) {
367 backends.push(b);
368 }
369 } else {
370 #[cfg(feature = "cpu")]
371 if let Ok(b) = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu) {
372 backends.push(Box::new(b));
373 }
374 }
375 }
376
377 if backends.is_empty() {
378 return Err(crate::Error::Other(anyhow::anyhow!(
379 "no embedding backends available"
380 )));
381 }
382
383 Ok(backends)
384}
385
386#[cfg(feature = "metal")]
402pub fn load_modernbert_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
403 use driver::metal::{MetalDriver, ModernBertConfig};
404 use generic::GenericBackend;
405 use hf_hub::api::sync::Api;
406
407 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
408 let repo = api.model(model_repo.to_string());
409
410 let config_path = repo
411 .get("config.json")
412 .map_err(|e| crate::Error::Download(e.to_string()))?;
413 let weights_path = repo
414 .get("model.safetensors")
415 .map_err(|e| crate::Error::Download(e.to_string()))?;
416
417 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
419 path: config_path.display().to_string(),
420 source: e,
421 })?;
422 let config_json: serde_json::Value = serde_json::from_str(&config_str)
423 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
424 let config = ModernBertConfig::from_json(&config_json)?;
425 let max_tokens = config.max_position_embeddings;
426
427 let driver = MetalDriver::new()?;
428 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
429
430 tracing::info!(
431 model_repo,
432 hidden = config.hidden_size,
433 layers = config.num_hidden_layers,
434 heads = config.num_attention_heads,
435 intermediate = config.intermediate_size,
436 max_tokens,
437 "ModernBERT loaded on Metal (driver/arch)"
438 );
439
440 Ok(Box::new(GenericBackend::new(
441 driver, arch, max_tokens, true, mmap,
442 )))
443}
444
445#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
447pub fn load_modernbert_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
448 use driver::cpu::{CpuDriver, ModernBertConfig};
449 use generic::GenericBackend;
450 use hf_hub::api::sync::Api;
451
452 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
453 let repo = api.model(model_repo.to_string());
454
455 let config_path = repo
456 .get("config.json")
457 .map_err(|e| crate::Error::Download(e.to_string()))?;
458 let weights_path = repo
459 .get("model.safetensors")
460 .map_err(|e| crate::Error::Download(e.to_string()))?;
461
462 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
463 path: config_path.display().to_string(),
464 source: e,
465 })?;
466 let config_json: serde_json::Value = serde_json::from_str(&config_str)
467 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
468 let config = ModernBertConfig::from_json(&config_json)?;
469 let max_tokens = config.max_position_embeddings;
470
471 let driver = CpuDriver::new()?;
472 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
473
474 tracing::info!(
475 model_repo,
476 hidden = config.hidden_size,
477 layers = config.num_hidden_layers,
478 heads = config.num_attention_heads,
479 max_tokens,
480 "ModernBERT loaded on CPU (driver/arch, zero-copy mmap)"
481 );
482
483 Ok(Box::new(GenericBackend::new_shared(
484 driver, arch, max_tokens, false, mmap,
485 )))
486}
487
488#[cfg(feature = "cuda")]
499pub fn load_modernbert_cuda(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
500 use driver::cuda::{CudaDriver, ModernBertConfig};
501 use generic::GenericBackend;
502 use hf_hub::api::sync::Api;
503
504 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
505 let repo = api.model(model_repo.to_string());
506
507 let config_path = repo
508 .get("config.json")
509 .map_err(|e| crate::Error::Download(e.to_string()))?;
510 let weights_path = repo
511 .get("model.safetensors")
512 .map_err(|e| crate::Error::Download(e.to_string()))?;
513
514 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
516 path: config_path.display().to_string(),
517 source: e,
518 })?;
519 let config_json: serde_json::Value = serde_json::from_str(&config_str)
520 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
521 let config = ModernBertConfig::from_json(&config_json)?;
522 let max_tokens = config.max_position_embeddings;
523
524 let driver = CudaDriver::new()?;
525 let (arch, mmap) = driver.load_modern_bert_weights(&weights_path, &config)?;
526
527 tracing::info!(
528 model_repo,
529 hidden = config.hidden_size,
530 layers = config.num_hidden_layers,
531 heads = config.num_attention_heads,
532 intermediate = config.intermediate_size,
533 max_tokens,
534 "ModernBERT loaded on CUDA (driver/arch)"
535 );
536
537 Ok(Box::new(GenericBackend::with_max_batch(
540 driver,
541 arch,
542 max_tokens,
543 true,
544 generic::MmapHolder::Owned(mmap),
545 32,
546 )))
547}
548
549#[cfg(any(
554 feature = "cuda",
555 feature = "metal",
556 feature = "cpu",
557 feature = "cpu-accelerate"
558))]
559fn is_modernbert_model(model_repo: &str) -> bool {
560 let Ok(api) = hf_hub::api::sync::Api::new() else {
561 return false;
562 };
563 let repo = api.model(model_repo.to_string());
564 let Ok(config_path) = repo.get("config.json") else {
565 return false;
566 };
567 let Ok(config_str) = std::fs::read_to_string(&config_path) else {
568 return false;
569 };
570 let Ok(json) = serde_json::from_str::<serde_json::Value>(&config_str) else {
571 return false;
572 };
573 json.get("model_type")
574 .and_then(serde_json::Value::as_str)
575 .is_some_and(|t| t == "modernbert")
576}
577
578#[cfg(feature = "metal")]
595pub fn load_classic_metal(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
596 use driver::metal::{ClassicBertConfig, MetalDriver};
597 use generic::GenericBackend;
598 use hf_hub::api::sync::Api;
599
600 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
601 let repo = api.model(model_repo.to_string());
602
603 let config_path = repo
604 .get("config.json")
605 .map_err(|e| crate::Error::Download(e.to_string()))?;
606 let weights_path = repo
607 .get("model.safetensors")
608 .map_err(|e| crate::Error::Download(e.to_string()))?;
609
610 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
612 path: config_path.display().to_string(),
613 source: e,
614 })?;
615 let config_json: serde_json::Value = serde_json::from_str(&config_str)
616 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
617 let config = ClassicBertConfig::from_json(&config_json)?;
618 let max_tokens = config.max_position_embeddings;
619
620 let driver = MetalDriver::new()?;
621 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
622
623 tracing::info!(
624 model_repo,
625 hidden = config.hidden_size,
626 layers = config.num_hidden_layers,
627 heads = config.num_attention_heads,
628 intermediate = config.intermediate_size,
629 max_tokens,
630 "ClassicBert loaded on Metal (driver/arch)"
631 );
632
633 Ok(Box::new(GenericBackend::new(
634 driver, arch, max_tokens, true, mmap,
635 )))
636}
637
638#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
655pub fn load_classic_cpu(model_repo: &str) -> crate::Result<Box<dyn EmbedBackend>> {
656 use driver::cpu::{ClassicBertConfig, CpuDriver};
657 use generic::GenericBackend;
658 use hf_hub::api::sync::Api;
659
660 let api = Api::new().map_err(|e| crate::Error::Download(e.to_string()))?;
661 let repo = api.model(model_repo.to_string());
662
663 let config_path = repo
664 .get("config.json")
665 .map_err(|e| crate::Error::Download(e.to_string()))?;
666 let weights_path = repo
667 .get("model.safetensors")
668 .map_err(|e| crate::Error::Download(e.to_string()))?;
669
670 let config_str = std::fs::read_to_string(&config_path).map_err(|e| crate::Error::Io {
672 path: config_path.display().to_string(),
673 source: e,
674 })?;
675 let config_json: serde_json::Value = serde_json::from_str(&config_str)
676 .map_err(|e| crate::Error::Other(anyhow::anyhow!("config parse error: {e}")))?;
677 let config = ClassicBertConfig::from_json(&config_json)?;
678 let max_tokens = config.max_position_embeddings;
679
680 let driver = CpuDriver::new()?;
681 let (arch, mmap) = driver.load_classic_bert_weights(&weights_path, &config)?;
682
683 tracing::info!(
684 model_repo,
685 hidden = config.hidden_size,
686 layers = config.num_hidden_layers,
687 heads = config.num_attention_heads,
688 intermediate = config.intermediate_size,
689 max_tokens,
690 "ClassicBert loaded on CPU (driver/arch, zero-copy mmap)"
691 );
692
693 Ok(Box::new(GenericBackend::new_shared(
694 driver, arch, max_tokens, false, mmap,
695 )))
696}
697
698#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
718pub fn load_reranker_cpu(model_repo: &str) -> crate::Result<Box<dyn RerankBackend>> {
719 let backend = cpu::CpuRerankBackend::load(model_repo)?;
720 Ok(Box::new(backend))
721}
722
723#[cfg(not(any(feature = "cpu", feature = "cpu-accelerate")))]
724pub fn load_reranker_cpu(_model_repo: &str) -> crate::Result<Box<dyn RerankBackend>> {
725 Err(crate::Error::Other(anyhow::anyhow!(
726 "cross-encoder rerank requires building with --features cpu \
727 or --features cpu-accelerate"
728 )))
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[test]
737 fn trait_is_object_safe() {
738 fn _assert_object_safe(_: &dyn EmbedBackend) {}
740 }
741
742 #[test]
744 fn trait_object_is_send() {
745 fn assert_send<T: Send>() {}
746 assert_send::<Box<dyn EmbedBackend>>();
747 }
748
749 #[test]
751 fn trait_object_is_sync() {
752 fn assert_sync<T: Sync>() {}
753 assert_sync::<Box<dyn EmbedBackend>>();
754 }
755
756 #[test]
758 fn arc_trait_object_is_send() {
759 fn assert_send<T: Send>() {}
760 assert_send::<std::sync::Arc<dyn EmbedBackend>>();
761 }
762
763 #[test]
764 fn encoding_construction() {
765 let enc = Encoding {
766 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
767 attention_mask: vec![1, 1, 1, 1, 1, 1],
768 token_type_ids: vec![0, 0, 0, 0, 0, 0],
769 };
770 assert_eq!(enc.input_ids.len(), 6);
771 assert_eq!(enc.attention_mask.len(), 6);
772 assert_eq!(enc.token_type_ids.len(), 6);
773 }
774
775 #[test]
776 fn encoding_clone() {
777 let enc = Encoding {
778 input_ids: vec![101, 102],
779 attention_mask: vec![1, 1],
780 token_type_ids: vec![0, 0],
781 };
782 let cloned = enc.clone();
783 assert_eq!(enc.input_ids, cloned.input_ids);
784 }
785
786 #[test]
787 fn backend_kind_default_is_cpu() {
788 assert_eq!(BackendKind::default(), BackendKind::Cpu);
789 }
790
791 #[test]
792 fn backend_kind_display() {
793 assert_eq!(BackendKind::Cuda.to_string(), "cuda");
794 assert_eq!(BackendKind::Mlx.to_string(), "mlx");
795 assert_eq!(BackendKind::Cpu.to_string(), "cpu");
796 }
797
798 #[cfg(not(feature = "mlx"))]
799 #[test]
800 fn load_backend_mlx_not_compiled() {
801 let result = load_backend(BackendKind::Mlx, "test/model", DeviceHint::Cpu);
802 assert!(result.is_err());
803 }
804
805 #[cfg(feature = "cpu")]
806 #[test]
807 fn detect_backends_returns_at_least_one() {
808 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
809 assert!(!backends.is_empty());
810 }
811
812 #[cfg(all(feature = "cpu", not(feature = "mlx")))]
813 #[test]
814 fn detect_backends_returns_at_least_one_backend() {
815 let backends = detect_backends("BAAI/bge-small-en-v1.5").unwrap();
816 assert!(!backends.is_empty(), "should detect at least one backend");
817 }
818
819 #[cfg(feature = "metal")]
824 #[test]
825 #[ignore = "requires model download (~570MB)"]
826 #[expect(clippy::too_many_lines, reason = "end-to-end backend diagnostic test")]
827 fn modernbert_loads_and_embeds() {
828 use crate::backend::arch::ModelArch;
829 use crate::backend::driver::Driver;
830
831 let backend = load_modernbert_metal("nomic-ai/modernbert-embed-base").expect("load failed");
832 assert!(backend.is_gpu(), "Metal backend should be GPU");
833
834 let enc = Encoding {
835 input_ids: vec![1, 100, 200, 300, 2],
836 attention_mask: vec![1; 5],
837 token_type_ids: vec![0; 5],
838 };
839
840 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
842 let inputs = driver.prepare_batch(std::slice::from_ref(&enc), 8).unwrap();
843
844 let ids_host = driver.to_host(&inputs.input_ids, 1, 8).unwrap();
846 eprintln!("input_ids: {:?}", &ids_host[0][..5]);
847
848 let api = hf_hub::api::sync::Api::new().unwrap();
851 let repo = api.model("nomic-ai/modernbert-embed-base".to_string());
852 let weights_path = repo.get("model.safetensors").unwrap();
853 let config_path = repo.get("config.json").unwrap();
854 let config_str = std::fs::read_to_string(&config_path).unwrap();
855 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
856 let config =
857 crate::backend::driver::metal::ModernBertConfig::from_json(&config_json).unwrap();
858 let (arch, _mmap) = driver
859 .load_modern_bert_weights(&weights_path, &config)
860 .unwrap();
861
862 let hidden = driver
863 .embedding_lookup(&inputs.input_ids, &arch.weights.tok_embeddings, 8, 768)
864 .unwrap();
865 let h = driver.to_host(&hidden, 1, 8 * 768).unwrap();
866 let nz = h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
867 eprintln!(
868 "embedding: {nz}/{} nonzero, first 5: {:?}",
869 h[0].len(),
870 &h[0][..5]
871 );
872
873 let total = 8; let hd = 768;
876 let nh = 12;
877 let head_dim = 64;
878
879 let emb_clone = driver.clone_tensor(&hidden, total * hd).unwrap();
881 let mut ln_out = driver.alloc_zeros(total * hd).unwrap();
882 driver
883 .layer_norm(
884 &mut ln_out,
885 &emb_clone,
886 &arch.weights.emb_norm_weight,
887 &arch.weights.zero_bias,
888 total,
889 hd,
890 1e-5,
891 )
892 .unwrap();
893 let ln_h = driver.to_host(&ln_out, 1, total * hd).unwrap();
894 let nz = ln_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
895 eprintln!("STAGE 1 - emb+LN: {nz}/{} nonzero", total * hd);
896
897 let layer0 = &arch.weights.layers[0];
899 let mut qkv = driver.alloc_zeros(total * 3 * hd).unwrap();
900 driver
901 .gemm(
902 &ln_out,
903 &layer0.qkv_weight,
904 &mut qkv,
905 total,
906 3 * hd,
907 hd,
908 true,
909 )
910 .unwrap();
911 let qkv_h = driver.to_host(&qkv, 1, total * 3 * hd).unwrap();
912 let nz = qkv_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
913 eprintln!("STAGE 2 - QKV GEMM: {nz}/{} nonzero", total * 3 * hd);
914
915 let mut q = driver.alloc_zeros(total * hd).unwrap();
917 let mut k = driver.alloc_zeros(total * hd).unwrap();
918 let mut v = driver.alloc_zeros(total * hd).unwrap();
919 driver
920 .qkv_split(&mut q, &mut k, &mut v, &qkv, 1, 8, hd, nh, head_dim)
921 .unwrap();
922 let q_h = driver.to_host(&q, 1, total * hd).unwrap();
923 let nz = q_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
924 eprintln!("STAGE 3 - Q after split: {nz}/{} nonzero", total * hd);
925
926 let mut scores = driver.alloc_zeros(nh * 8 * 8).unwrap();
928 driver
929 .gemm_batched(
930 &q,
931 &k,
932 &mut scores,
933 8,
934 8,
935 head_dim,
936 true,
937 8 * head_dim,
938 8 * head_dim,
939 8 * 8,
940 nh,
941 )
942 .unwrap();
943 let s_h = driver.to_host(&scores, 1, nh * 8 * 8).unwrap();
944 let nz = s_h[0].iter().filter(|&&v| v.abs() > 1e-10).count();
945 eprintln!("STAGE 4 - scores: {nz}/{} nonzero", nh * 8 * 8);
946
947 let enc2 = Encoding {
949 input_ids: vec![1, 100, 200, 300, 2],
950 attention_mask: vec![1; 5],
951 token_type_ids: vec![0; 5],
952 };
953
954 let quick = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
955 let l2: f32 = quick[0].iter().map(|x| x * x).sum::<f32>().sqrt();
956 let nz = quick[0].iter().filter(|&&v| v.abs() > 1e-10).count();
957 eprintln!(
958 "BATCHED forward: L2={l2:.4}, nz={nz}/768, first 3: {:?}",
959 &quick[0][..3]
960 );
961
962 eprintln!("\n=== ModernBERT MRL Truncation ===");
964 let full = arch.forward(&driver, std::slice::from_ref(&enc2)).unwrap();
965 let full_emb = &full[0];
966 for dims in [64, 128, 256, 384, 512, 768] {
967 let t: Vec<f32> = full_emb[..dims].to_vec();
968 let t_norm: f32 = t.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
969 let f_norm: f32 = full_emb[..dims]
970 .iter()
971 .map(|x| x * x)
972 .sum::<f32>()
973 .sqrt()
974 .max(1e-12);
975 let cos: f32 = t
976 .iter()
977 .zip(&full_emb[..dims])
978 .map(|(a, b)| a * b)
979 .sum::<f32>()
980 / (t_norm * f_norm);
981 eprintln!(" dims={dims:>3}: cosine={cos:.6}");
982 }
983
984 eprintln!("\n=== ModernBERT Throughput ===");
986 let mut encs = Vec::new();
988 for i in 0..32 {
989 let len = 16 + (i * 4); let mut ids = vec![1_i64]; for j in 1..len - 1 {
992 ids.push(100 + i64::from(j));
993 }
994 ids.push(2); encs.push(Encoding {
996 input_ids: ids.clone(),
997 attention_mask: vec![1; ids.len()],
998 token_type_ids: vec![0; ids.len()],
999 });
1000 }
1001
1002 let _ = arch.forward(&driver, &encs[..4]);
1004
1005 let t0 = std::time::Instant::now();
1007 let result = arch.forward(&driver, &encs).unwrap();
1008 let elapsed = t0.elapsed();
1009 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1010 eprintln!(
1011 " batch={}, time={:.1}ms, throughput={:.1}/s",
1012 encs.len(),
1013 elapsed.as_secs_f64() * 1000.0,
1014 throughput
1015 );
1016 assert_eq!(result.len(), 32);
1017
1018 let single = vec![encs[0].clone()];
1020 let t1 = std::time::Instant::now();
1021 let _ = arch.forward(&driver, &single).unwrap();
1022 let single_ms = t1.elapsed().as_secs_f64() * 1000.0;
1023 eprintln!(" batch=1, time={single_ms:.1}ms");
1024 }
1025
1026 #[cfg(feature = "metal")]
1032 #[test]
1033 #[ignore = "requires model download (~33MB)"]
1034 fn classic_bert_driver_arch() {
1035 use crate::backend::arch::ModelArch;
1036
1037 let model_repo = "BAAI/bge-small-en-v1.5";
1038
1039 let backend = load_classic_metal(model_repo).expect("load_classic_metal failed");
1041 assert!(backend.is_gpu(), "Metal backend should be GPU");
1042
1043 let enc = Encoding {
1044 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1045 attention_mask: vec![1, 1, 1, 1, 1, 1],
1046 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1047 };
1048
1049 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1051 assert_eq!(result.len(), 1);
1052 assert_eq!(result[0].len(), 384);
1053
1054 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1055 eprintln!(
1056 "ClassicBert driver/arch: L2={l2:.4}, first 3: {:?}",
1057 &result[0][..3]
1058 );
1059 assert!(
1060 (l2 - 1.0).abs() < 0.01,
1061 "embedding should be L2-normalized, got L2={l2}"
1062 );
1063
1064 #[cfg(feature = "cpu")]
1066 {
1067 let cpu = load_backend(BackendKind::Cpu, model_repo, DeviceHint::Cpu)
1068 .expect("CPU load failed");
1069 let cpu_result = cpu.embed_batch(std::slice::from_ref(&enc)).unwrap();
1070 eprintln!("CPU first 5: {:?}", &cpu_result[0][..5]);
1071 eprintln!("NEW first 5: {:?}", &result[0][..5]);
1072 let cosine: f32 = result[0]
1073 .iter()
1074 .zip(&cpu_result[0])
1075 .map(|(a, b)| a * b)
1076 .sum();
1077 eprintln!("cosine(driver/arch, CPU) = {cosine:.6}");
1078 assert!(
1079 cosine > 0.95,
1080 "cosine similarity vs CPU should be >0.95, got {cosine}"
1081 );
1082 }
1083
1084 eprintln!("\n=== ClassicBert Driver/Arch Throughput ===");
1086 let driver = crate::backend::driver::metal::MetalDriver::new().unwrap();
1087 let config_path = {
1088 let api = hf_hub::api::sync::Api::new().unwrap();
1089 let repo = api.model(model_repo.to_string());
1090 repo.get("config.json").unwrap()
1091 };
1092 let weights_path = {
1093 let api = hf_hub::api::sync::Api::new().unwrap();
1094 let repo = api.model(model_repo.to_string());
1095 repo.get("model.safetensors").unwrap()
1096 };
1097 let config_str = std::fs::read_to_string(&config_path).unwrap();
1098 let config_json: serde_json::Value = serde_json::from_str(&config_str).unwrap();
1099 let config =
1100 crate::backend::driver::metal::ClassicBertConfig::from_json(&config_json).unwrap();
1101 let (arch, _mmap) = driver
1102 .load_classic_bert_weights(&weights_path, &config)
1103 .unwrap();
1104
1105 let mut encs = Vec::new();
1107 for i in 0..32 {
1108 let len = 16 + (i * 4); let mut ids = vec![101_i64]; for j in 1..len - 1 {
1111 ids.push(100 + i64::from(j));
1112 }
1113 ids.push(102); encs.push(Encoding {
1115 input_ids: ids.clone(),
1116 attention_mask: vec![1; ids.len()],
1117 token_type_ids: vec![0; ids.len()],
1118 });
1119 }
1120
1121 let _ = arch.forward(&driver, &encs[..4]);
1123
1124 let t0 = std::time::Instant::now();
1126 let bench_result = arch.forward(&driver, &encs).unwrap();
1127 let elapsed = t0.elapsed();
1128 let throughput = encs.len() as f64 / elapsed.as_secs_f64();
1129 eprintln!(
1130 " batch={}, time={:.1}ms, throughput={:.1}/s",
1131 encs.len(),
1132 elapsed.as_secs_f64() * 1000.0,
1133 throughput
1134 );
1135 assert_eq!(bench_result.len(), 32);
1136 }
1137
1138 #[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
1143 #[test]
1144 #[ignore = "requires model download (~33MB)"]
1145 fn classic_bert_cpu_driver_arch() {
1146 let model_repo = "BAAI/bge-small-en-v1.5";
1147
1148 let backend = load_classic_cpu(model_repo).expect("load_classic_cpu failed");
1150 assert!(!backend.is_gpu(), "CPU backend should not be GPU");
1151
1152 let enc = Encoding {
1153 input_ids: vec![101, 2023, 2003, 1037, 3231, 102],
1154 attention_mask: vec![1, 1, 1, 1, 1, 1],
1155 token_type_ids: vec![0, 0, 0, 0, 0, 0],
1156 };
1157
1158 let result = backend.embed_batch(std::slice::from_ref(&enc)).unwrap();
1160 assert_eq!(result.len(), 1);
1161 assert_eq!(result[0].len(), 384);
1162
1163 let l2: f32 = result[0].iter().map(|x| x * x).sum::<f32>().sqrt();
1164 eprintln!(
1165 "ClassicBert CPU driver/arch: L2={l2:.4}, first 5: {:?}",
1166 &result[0][..5]
1167 );
1168 assert!(
1169 (l2 - 1.0).abs() < 0.01,
1170 "embedding should be L2-normalized, got L2={l2}"
1171 );
1172
1173 #[cfg(feature = "cpu")]
1175 {
1176 let cpu_mono = cpu::CpuBackend::load(model_repo, &DeviceHint::Cpu)
1177 .expect("monolithic CPU load failed");
1178 let cpu_result = cpu_mono.embed_batch(&[enc]).unwrap();
1179 eprintln!("Mono first 5: {:?}", &cpu_result[0][..5]);
1180 eprintln!("New first 5: {:?}", &result[0][..5]);
1181 let cosine: f32 = result[0]
1182 .iter()
1183 .zip(&cpu_result[0])
1184 .map(|(a, b)| a * b)
1185 .sum();
1186 eprintln!("cosine(driver/arch, monolithic) = {cosine:.6}");
1187 assert!(
1188 cosine > 0.999,
1189 "cosine similarity vs monolithic CPU should be >0.999, got {cosine}"
1190 );
1191 }
1192 }
1193}