1use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17pub struct ModernBertLayerWeights<T> {
27 pub qkv_weight: T,
29 pub output_weight: T,
31 pub attn_norm_weight: Option<T>,
33 pub mlp_wi_weight: T,
36 pub mlp_wo_weight: T,
38 pub mlp_norm_weight: T,
40 pub is_global: bool,
42}
43
44pub struct ModernBertWeights<T> {
51 pub tok_embeddings: T,
53 pub emb_norm_weight: T,
55 pub final_norm_weight: T,
57 pub zero_bias: T,
62 pub layers: Vec<ModernBertLayerWeights<T>>,
64 pub num_heads: usize,
66 pub head_dim: usize,
68 pub hidden_dim: usize,
70 pub intermediate_dim: usize,
72 pub layer_norm_eps: f32,
74 pub local_window: usize,
76}
77
78pub struct RopeCache<T> {
83 pub cos: T,
85 pub sin: T,
87}
88
89pub struct ModernBertArch<T> {
95 pub weights: ModernBertWeights<T>,
97 pub global_rope: RopeCache<T>,
99 pub local_rope: RopeCache<T>,
101}
102
103struct EncoderGeometry {
109 batch: usize,
110 max_seq: usize,
111 total_tokens: usize,
113 padded_tokens: usize,
115 seq_lengths: Vec<usize>,
117 hidden: usize,
118 num_heads: usize,
119 head_dim: usize,
120 intermediate: usize,
121 local_window: usize,
122 scale: f32,
123 eps: f32,
124}
125
126fn attn_prenorm_qkv<D: Driver>(
134 driver: &D,
135 hidden_states: &D::Tensor,
136 layer: &ModernBertLayerWeights<D::Tensor>,
137 g: &EncoderGeometry,
138 zero_bias: &D::Tensor,
139 rope: &RopeCache<D::Tensor>,
140) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
141 let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
143 let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
144 driver.layer_norm(
145 &mut n,
146 hidden_states,
147 norm_w,
148 zero_bias,
149 g.total_tokens,
150 g.hidden,
151 g.eps,
152 )?;
153 n
154 } else {
155 driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
157 };
158
159 let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
162 driver.gemm(
163 &normed,
164 &layer.qkv_weight,
165 &mut qkv,
166 g.total_tokens,
167 3 * g.hidden,
168 g.hidden,
169 true,
170 )?;
171
172 let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
175 driver.pad_to_batch(
176 &qkv,
177 &mut qkv_padded,
178 &g.seq_lengths,
179 g.max_seq,
180 3 * g.hidden,
181 )?;
182
183 let padded = g.padded_tokens;
185 let mut q = driver.alloc_zeros(padded * g.hidden)?;
186 let mut k = driver.alloc_zeros(padded * g.hidden)?;
187 let mut v = driver.alloc_zeros(padded * g.hidden)?;
188 driver.qkv_split(
189 &mut q,
190 &mut k,
191 &mut v,
192 &qkv_padded,
193 g.batch,
194 g.max_seq,
195 g.hidden,
196 g.num_heads,
197 g.head_dim,
198 )?;
199
200 let num_rows = g.batch * g.num_heads * g.max_seq;
202 driver.apply_rope(
203 &mut q,
204 &rope.cos,
205 &rope.sin,
206 num_rows,
207 g.max_seq,
208 g.head_dim,
209 g.num_heads,
210 )?;
211 driver.apply_rope(
212 &mut k,
213 &rope.cos,
214 &rope.sin,
215 num_rows,
216 g.max_seq,
217 g.head_dim,
218 g.num_heads,
219 )?;
220
221 Ok((q, k, v))
222}
223
224#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
230fn attn_scores_residual<D: Driver>(
231 driver: &D,
232 q: &D::Tensor,
233 k: &D::Tensor,
234 v: &D::Tensor,
235 hidden_states: &D::Tensor,
236 layer: &ModernBertLayerWeights<D::Tensor>,
237 inputs: &BatchInputs<D::Tensor>,
238 g: &EncoderGeometry,
239) -> crate::Result<D::Tensor> {
240 let batch_heads = g.batch * g.num_heads;
241 let stride_qk = g.max_seq * g.head_dim;
242
243 let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
251 driver.gemm_batched(
252 q,
253 k,
254 &mut scores,
255 g.max_seq,
256 g.max_seq,
257 g.head_dim,
258 true,
259 stride_qk,
260 stride_qk,
261 g.max_seq * g.max_seq,
262 batch_heads,
263 )?;
264
265 if layer.is_global {
266 driver.fused_scale_mask_softmax(
267 &mut scores,
268 &inputs.float_mask,
269 g.batch,
270 g.num_heads,
271 g.max_seq,
272 g.scale,
273 )?;
274 } else {
275 driver.fused_scale_mask_softmax_windowed(
276 &mut scores,
277 &inputs.float_mask,
278 g.batch,
279 g.num_heads,
280 g.max_seq,
281 g.scale,
282 g.local_window,
283 )?;
284 }
285
286 let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
287 driver.gemm_batched(
288 &scores,
289 v,
290 &mut attn_out,
291 g.max_seq,
292 g.head_dim,
293 g.max_seq,
294 false,
295 g.max_seq * g.max_seq,
296 stride_qk,
297 stride_qk,
298 batch_heads,
299 )?;
300
301 let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
303 driver.attn_reshape(
304 &mut context,
305 &attn_out,
306 g.batch,
307 g.max_seq,
308 g.num_heads,
309 g.head_dim,
310 )?;
311
312 let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
316 driver.unpad_from_batch(
317 &context,
318 &mut context_unpacked,
319 &g.seq_lengths,
320 g.max_seq,
321 g.hidden,
322 )?;
323
324 let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
326 driver.gemm(
327 &context_unpacked,
328 &layer.output_weight,
329 &mut projected,
330 g.total_tokens,
331 g.hidden,
332 g.hidden,
333 true,
334 )?;
335
336 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
338 driver.residual_add(
339 &mut output,
340 &projected,
341 hidden_states,
342 g.total_tokens * g.hidden,
343 )?;
344 Ok(output)
345}
346
347fn ffn_sublayer<D: Driver>(
356 driver: &D,
357 attn_output: &D::Tensor,
358 layer: &ModernBertLayerWeights<D::Tensor>,
359 g: &EncoderGeometry,
360 zero_bias: &D::Tensor,
361) -> crate::Result<D::Tensor> {
362 let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
364 driver.layer_norm(
365 &mut mlp_normed,
366 attn_output,
367 &layer.mlp_norm_weight,
368 zero_bias,
369 g.total_tokens,
370 g.hidden,
371 g.eps,
372 )?;
373
374 let double_inter = 2 * g.intermediate;
376 let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
377 driver.gemm(
378 &mlp_normed,
379 &layer.mlp_wi_weight,
380 &mut wi_out,
381 g.total_tokens,
382 double_inter,
383 g.hidden,
384 true,
385 )?;
386
387 let n_elements = g.total_tokens * g.intermediate;
389 let mut value = driver.alloc_zeros(n_elements)?;
390 let mut gate = driver.alloc_zeros(n_elements)?;
391 driver.split_gate_value(
392 &mut value,
393 &mut gate,
394 &wi_out,
395 g.total_tokens,
396 g.intermediate,
397 )?;
398
399 let mut activated = driver.alloc_zeros(n_elements)?;
401 driver.geglu(&value, &gate, &mut activated, n_elements)?;
402
403 let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
405 driver.gemm(
406 &activated,
407 &layer.mlp_wo_weight,
408 &mut mlp_out,
409 g.total_tokens,
410 g.hidden,
411 g.intermediate,
412 true,
413 )?;
414
415 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
417 driver.residual_add(
418 &mut output,
419 &mlp_out,
420 attn_output,
421 g.total_tokens * g.hidden,
422 )?;
423 Ok(output)
424}
425
426fn debug_f16_tensor<D: Driver>(
431 driver: &D,
432 label: &str,
433 tensor: &D::Tensor,
434 rows: usize,
435 cols: usize,
436) -> crate::Result<()> {
437 let mut probe = driver.alloc_zeros(rows * cols)?;
438 driver.f16_to_f32(&mut probe, tensor, rows * cols)?;
439 driver.debug_tensor(label, &probe, rows, cols)
440}
441
442#[expect(
448 clippy::too_many_lines,
449 reason = "attention diagnostics intentionally keep stage probes adjacent to the operations"
450)]
451fn attn_prenorm_qkv_f16<D: Driver>(
452 driver: &D,
453 hidden_states: &D::Tensor,
454 layer: &ModernBertLayerWeights<D::Tensor>,
455 g: &EncoderGeometry,
456 zero_bias: &D::Tensor,
457 rope: &RopeCache<D::Tensor>,
458 layer_index: usize,
459 debug_tensors: bool,
460) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
461 let normed: Option<D::Tensor>;
464 let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
465 let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
466 driver.layer_norm_f16(
467 &mut n,
468 hidden_states,
469 norm_w,
470 zero_bias,
471 g.total_tokens,
472 g.hidden,
473 g.eps,
474 )?;
475 normed = Some(n);
476 normed.as_ref().unwrap()
477 } else {
478 hidden_states
480 };
481
482 let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
484 driver.gemm_f16(
485 normed_ref,
486 &layer.qkv_weight,
487 &mut qkv,
488 g.total_tokens,
489 3 * g.hidden,
490 g.hidden,
491 true,
492 )?;
493 if debug_tensors && layer_index == 0 {
494 debug_f16_tensor(
495 driver,
496 "modernbert.layer_0.qkv_f16_as_f32",
497 &qkv,
498 g.total_tokens,
499 3 * g.hidden,
500 )?;
501 }
502
503 let padded = g.padded_tokens;
506 let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
507 let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
508 let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
509 driver.fused_pad_qkv_split_f16(
510 &mut q,
511 &mut k,
512 &mut v,
513 &qkv,
514 &g.seq_lengths,
515 g.max_seq,
516 g.batch,
517 g.hidden,
518 g.num_heads,
519 g.head_dim,
520 )?;
521 if debug_tensors && layer_index == 0 {
522 let rows = g.batch * g.num_heads * g.max_seq;
523 debug_f16_tensor(
524 driver,
525 "modernbert.layer_0.q_after_split_f16_as_f32",
526 &q,
527 rows,
528 g.head_dim,
529 )?;
530 debug_f16_tensor(
531 driver,
532 "modernbert.layer_0.k_after_split_f16_as_f32",
533 &k,
534 rows,
535 g.head_dim,
536 )?;
537 debug_f16_tensor(
538 driver,
539 "modernbert.layer_0.v_after_split_f16_as_f32",
540 &v,
541 rows,
542 g.head_dim,
543 )?;
544 }
545
546 let num_rows = g.batch * g.num_heads * g.max_seq;
548 driver.rope_encode_f16(
549 &mut q,
550 &rope.cos,
551 &rope.sin,
552 num_rows,
553 g.max_seq,
554 g.head_dim,
555 g.num_heads,
556 )?;
557 driver.rope_encode_f16(
558 &mut k,
559 &rope.cos,
560 &rope.sin,
561 num_rows,
562 g.max_seq,
563 g.head_dim,
564 g.num_heads,
565 )?;
566 if debug_tensors && layer_index == 0 {
567 let rows = g.batch * g.num_heads * g.max_seq;
568 debug_f16_tensor(
569 driver,
570 "modernbert.layer_0.q_after_rope_f16_as_f32",
571 &q,
572 rows,
573 g.head_dim,
574 )?;
575 debug_f16_tensor(
576 driver,
577 "modernbert.layer_0.k_after_rope_f16_as_f32",
578 &k,
579 rows,
580 g.head_dim,
581 )?;
582 }
583
584 Ok((q, k, v))
585}
586
587#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
596#[expect(
597 clippy::too_many_lines,
598 reason = "attention diagnostics intentionally keep stage probes adjacent to the operations"
599)]
600fn attn_scores_residual_f16<D: Driver>(
601 driver: &D,
602 q: &D::Tensor,
603 k: &D::Tensor,
604 v: &D::Tensor,
605 hidden_states: &D::Tensor,
606 layer: &ModernBertLayerWeights<D::Tensor>,
607 inputs: &BatchInputs<D::Tensor>,
608 g: &EncoderGeometry,
609 layer_index: usize,
610 debug_tensors: bool,
611) -> crate::Result<D::Tensor> {
612 let batch_heads = g.batch * g.num_heads;
613 let stride_qk = g.max_seq * g.head_dim;
614
615 let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
617 driver.gemm_batched_f16(
618 q,
619 k,
620 &mut scores,
621 g.max_seq,
622 g.max_seq,
623 g.head_dim,
624 true,
625 stride_qk,
626 stride_qk,
627 g.max_seq * g.max_seq,
628 batch_heads,
629 )?;
630 if debug_tensors && layer_index == 0 {
631 debug_f16_tensor(
632 driver,
633 "modernbert.layer_0.attn_scores_before_softmax_f16_as_f32",
634 &scores,
635 batch_heads * g.max_seq,
636 g.max_seq,
637 )?;
638 }
639
640 if layer.is_global {
642 driver.fused_scale_mask_softmax_f16(
643 &mut scores,
644 &inputs.float_mask,
645 g.batch,
646 g.num_heads,
647 g.max_seq,
648 g.scale,
649 )?;
650 } else {
651 driver.fused_scale_mask_softmax_windowed_f16(
652 &mut scores,
653 &inputs.float_mask,
654 g.batch,
655 g.num_heads,
656 g.max_seq,
657 g.scale,
658 g.local_window,
659 )?;
660 }
661 if debug_tensors && layer_index == 0 {
662 debug_f16_tensor(
663 driver,
664 "modernbert.layer_0.attn_scores_after_softmax_f16_as_f32",
665 &scores,
666 batch_heads * g.max_seq,
667 g.max_seq,
668 )?;
669 }
670
671 let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
673 driver.gemm_batched_f16(
674 &scores,
675 v,
676 &mut attn_out,
677 g.max_seq,
678 g.head_dim,
679 g.max_seq,
680 false,
681 g.max_seq * g.max_seq,
682 stride_qk,
683 stride_qk,
684 batch_heads,
685 )?;
686 if debug_tensors && layer_index == 0 {
687 debug_f16_tensor(
688 driver,
689 "modernbert.layer_0.attn_heads_f16_as_f32",
690 &attn_out,
691 batch_heads * g.max_seq,
692 g.head_dim,
693 )?;
694 }
695
696 let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
699 driver.fused_reshape_unpad_f16(
700 &mut context_unpacked,
701 &attn_out,
702 &g.seq_lengths,
703 g.max_seq,
704 g.batch,
705 g.num_heads,
706 g.head_dim,
707 )?;
708 if debug_tensors && layer_index == 0 {
709 debug_f16_tensor(
710 driver,
711 "modernbert.layer_0.context_unpacked_f16_as_f32",
712 &context_unpacked,
713 g.total_tokens,
714 g.hidden,
715 )?;
716 }
717
718 let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
720 driver.gemm_f16(
721 &context_unpacked,
722 &layer.output_weight,
723 &mut projected,
724 g.total_tokens,
725 g.hidden,
726 g.hidden,
727 true,
728 )?;
729 if debug_tensors && layer_index == 0 {
730 debug_f16_tensor(
731 driver,
732 "modernbert.layer_0.attn_projected_f16_as_f32",
733 &projected,
734 g.total_tokens,
735 g.hidden,
736 )?;
737 }
738
739 let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
741 driver.residual_add_f16(
742 &mut output,
743 &projected,
744 hidden_states,
745 g.total_tokens * g.hidden,
746 )?;
747 Ok(output)
748}
749
750fn ffn_sublayer_f16<D: Driver>(
758 driver: &D,
759 attn_output: &D::Tensor,
760 layer: &ModernBertLayerWeights<D::Tensor>,
761 g: &EncoderGeometry,
762 zero_bias: &D::Tensor,
763 layer_index: usize,
764 debug_tensors: bool,
765) -> crate::Result<D::Tensor> {
766 let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
768 driver.layer_norm_f16(
769 &mut mlp_normed,
770 attn_output,
771 &layer.mlp_norm_weight,
772 zero_bias,
773 g.total_tokens,
774 g.hidden,
775 g.eps,
776 )?;
777 if debug_tensors && layer_index == 0 {
778 let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
779 driver.f16_to_f32(&mut probe, &mlp_normed, g.total_tokens * g.hidden)?;
780 driver.debug_tensor(
781 "modernbert.layer_0.ffn_mlp_normed_f16_as_f32",
782 &probe,
783 g.total_tokens,
784 g.hidden,
785 )?;
786 }
787
788 let double_inter = 2 * g.intermediate;
790 let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
791 driver.gemm_f16(
792 &mlp_normed,
793 &layer.mlp_wi_weight,
794 &mut wi_out,
795 g.total_tokens,
796 double_inter,
797 g.hidden,
798 true,
799 )?;
800 if debug_tensors && layer_index == 0 {
801 let mut probe = driver.alloc_zeros(g.total_tokens * double_inter)?;
802 driver.f16_to_f32(&mut probe, &wi_out, g.total_tokens * double_inter)?;
803 driver.debug_tensor(
804 "modernbert.layer_0.ffn_wi_out_f16_as_f32",
805 &probe,
806 g.total_tokens,
807 double_inter,
808 )?;
809 }
810
811 let n_elements = g.total_tokens * g.intermediate;
815 let mut activated = driver.alloc_zeros_f16(n_elements)?;
816 driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
817 if debug_tensors && layer_index == 0 {
818 let mut probe = driver.alloc_zeros(n_elements)?;
819 driver.f16_to_f32(&mut probe, &activated, n_elements)?;
820 driver.debug_tensor(
821 "modernbert.layer_0.ffn_activated_f16_as_f32",
822 &probe,
823 g.total_tokens,
824 g.intermediate,
825 )?;
826 }
827
828 let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
830 driver.gemm_f16(
831 &activated,
832 &layer.mlp_wo_weight,
833 &mut mlp_out,
834 g.total_tokens,
835 g.hidden,
836 g.intermediate,
837 true,
838 )?;
839 if debug_tensors && layer_index == 0 {
840 let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
841 driver.f16_to_f32(&mut probe, &mlp_out, g.total_tokens * g.hidden)?;
842 driver.debug_tensor(
843 "modernbert.layer_0.ffn_mlp_out_f16_as_f32",
844 &probe,
845 g.total_tokens,
846 g.hidden,
847 )?;
848 }
849
850 let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
852 driver.residual_add_f16(
853 &mut output,
854 &mlp_out,
855 attn_output,
856 g.total_tokens * g.hidden,
857 )?;
858 if debug_tensors && layer_index == 0 {
859 let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
860 driver.f16_to_f32(&mut probe, &output, g.total_tokens * g.hidden)?;
861 driver.debug_tensor(
862 "modernbert.layer_0.ffn_output_f16_as_f32",
863 &probe,
864 g.total_tokens,
865 g.hidden,
866 )?;
867 }
868 Ok(output)
869}
870
871impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
876 #[expect(
877 clippy::cast_precision_loss,
878 reason = "head_dim is small (64); sqrt is exact at this size"
879 )]
880 #[expect(
881 clippy::many_single_char_names,
882 reason = "w, g are standard geometry names; q, k, v are standard attention names"
883 )]
884 #[expect(
885 clippy::too_many_lines,
886 reason = "forward pass is a single logical unit"
887 )]
888 fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
889 let w = &self.weights;
890 let batch = encodings.len();
891 let hidden = w.hidden_dim;
892
893 let inputs = driver.prepare_batch_unpadded(encodings)?;
894 let max_seq = inputs.max_seq;
895 let total_tokens = inputs.total_tokens;
896
897 driver.begin_batch()?;
899
900 let mut hidden_states =
902 driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
903 let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
904 driver.layer_norm(
905 &mut hidden_states,
906 &emb_input,
907 &w.emb_norm_weight,
908 &w.zero_bias,
909 total_tokens,
910 hidden,
911 w.layer_norm_eps,
912 )?;
913 driver.debug_tensor(
914 "modernbert.embedding_layer_norm",
915 &hidden_states,
916 total_tokens,
917 hidden,
918 )?;
919
920 let g = EncoderGeometry {
921 batch,
922 max_seq,
923 total_tokens,
924 padded_tokens: batch * max_seq,
925 seq_lengths: inputs.seq_lengths.clone(),
926 hidden,
927 num_heads: w.num_heads,
928 head_dim: w.head_dim,
929 intermediate: w.intermediate_dim,
930 local_window: w.local_window,
931 scale: 1.0 / (w.head_dim as f32).sqrt(),
932 eps: w.layer_norm_eps,
933 };
934
935 let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
943 || std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
944 let use_f16 = if force_fp32 {
945 false
946 } else {
947 driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
948 };
949
950 if use_f16 {
951 let debug_tensors = driver.debug_tensors_enabled();
953
954 let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
956 driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
957 if debug_tensors {
958 let mut initial_probe = driver.alloc_zeros(total_tokens * hidden)?;
959 driver.f16_to_f32(&mut initial_probe, &hidden_f16, total_tokens * hidden)?;
960 driver.debug_tensor(
961 "modernbert.after_initial_f32_to_f16",
962 &initial_probe,
963 total_tokens,
964 hidden,
965 )?;
966 }
967
968 for (layer_index, layer) in w.layers.iter().enumerate() {
970 let saved = driver.save_pool_cursor();
971
972 let rope = if layer.is_global {
973 &self.global_rope
974 } else {
975 &self.local_rope
976 };
977
978 let (q, k, v) = attn_prenorm_qkv_f16(
979 driver,
980 &hidden_f16,
981 layer,
982 &g,
983 &w.zero_bias,
984 rope,
985 layer_index,
986 debug_tensors,
987 )?;
988 let attn_output = attn_scores_residual_f16(
989 driver,
990 &q,
991 &k,
992 &v,
993 &hidden_f16,
994 layer,
995 &inputs,
996 &g,
997 layer_index,
998 debug_tensors,
999 )?;
1000 if debug_tensors && layer_index == 0 {
1001 let mut probe = driver.alloc_zeros(total_tokens * hidden)?;
1002 driver.f16_to_f32(&mut probe, &attn_output, total_tokens * hidden)?;
1003 driver.debug_tensor(
1004 "modernbert.layer_0.attn_output_f16_as_f32",
1005 &probe,
1006 total_tokens,
1007 hidden,
1008 )?;
1009 }
1010 hidden_f16 = ffn_sublayer_f16(
1011 driver,
1012 &attn_output,
1013 layer,
1014 &g,
1015 &w.zero_bias,
1016 layer_index,
1017 debug_tensors,
1018 )?;
1019 driver.restore_pool_cursor(saved);
1020 if debug_tensors && (layer_index == 0 || layer_index + 1 == w.layers.len()) {
1021 let mut probe = driver.alloc_zeros(total_tokens * hidden)?;
1022 driver.f16_to_f32(&mut probe, &hidden_f16, total_tokens * hidden)?;
1023 driver.debug_tensor(
1024 &format!("modernbert.layer_{layer_index}.hidden_f16_as_f32"),
1025 &probe,
1026 total_tokens,
1027 hidden,
1028 )?;
1029 }
1030 }
1031
1032 let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
1034 driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
1035 hidden_states = hidden_f32;
1036 driver.debug_tensor(
1037 "modernbert.after_f16_to_f32",
1038 &hidden_states,
1039 total_tokens,
1040 hidden,
1041 )?;
1042 } else {
1043 for (layer_index, layer) in w.layers.iter().enumerate() {
1045 let saved = driver.save_pool_cursor();
1046
1047 let rope = if layer.is_global {
1048 &self.global_rope
1049 } else {
1050 &self.local_rope
1051 };
1052
1053 let (q, k, v) =
1054 attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
1055 let attn_output =
1056 attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
1057 hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
1058
1059 driver.restore_pool_cursor(saved);
1060 if layer_index == 0 || layer_index + 1 == w.layers.len() {
1061 driver.debug_tensor(
1062 &format!("modernbert.layer_{layer_index}.hidden_fp32"),
1063 &hidden_states,
1064 total_tokens,
1065 hidden,
1066 )?;
1067 }
1068 }
1069 }
1070
1071 let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
1073 driver.layer_norm(
1074 &mut hidden_states,
1075 &final_input,
1076 &w.final_norm_weight,
1077 &w.zero_bias,
1078 total_tokens,
1079 hidden,
1080 w.layer_norm_eps,
1081 )?;
1082 driver.debug_tensor(
1083 "modernbert.final_layer_norm",
1084 &hidden_states,
1085 total_tokens,
1086 hidden,
1087 )?;
1088
1089 let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
1091 driver.pad_to_batch(
1092 &hidden_states,
1093 &mut padded_for_pool,
1094 &inputs.seq_lengths,
1095 max_seq,
1096 hidden,
1097 )?;
1098
1099 let mut pooled = driver.alloc_zeros(batch * hidden)?;
1101 driver.mean_pool(
1102 &mut pooled,
1103 &padded_for_pool,
1104 &inputs.pooling_mask,
1105 batch,
1106 max_seq,
1107 hidden,
1108 )?;
1109 driver.debug_tensor("modernbert.mean_pool", &pooled, batch, hidden)?;
1110 driver.l2_normalize(&mut pooled, batch, hidden)?;
1111 driver.debug_tensor("modernbert.l2_normalize", &pooled, batch, hidden)?;
1112
1113 driver.end_batch()?;
1115
1116 driver.to_host(&pooled, batch, hidden)
1117 }
1118}