1use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17pub struct ModernBertLayerWeights<T> {
27 pub qkv_weight: T,
29 pub output_weight: T,
31 pub attn_norm_weight: Option<T>,
33 pub mlp_wi_weight: T,
36 pub mlp_wo_weight: T,
38 pub mlp_norm_weight: T,
40 pub is_global: bool,
42}
43
44pub struct ModernBertWeights<T> {
51 pub tok_embeddings: T,
53 pub emb_norm_weight: T,
55 pub final_norm_weight: T,
57 pub zero_bias: T,
62 pub layers: Vec<ModernBertLayerWeights<T>>,
64 pub num_heads: usize,
66 pub head_dim: usize,
68 pub hidden_dim: usize,
70 pub intermediate_dim: usize,
72 pub layer_norm_eps: f32,
74 pub local_window: usize,
76}
77
78pub struct RopeCache<T> {
83 pub cos: T,
85 pub sin: T,
87}
88
89pub struct ModernBertArch<T> {
95 pub weights: ModernBertWeights<T>,
97 pub global_rope: RopeCache<T>,
99 pub local_rope: RopeCache<T>,
101}
102
103struct EncoderGeometry {
109 batch: usize,
110 max_seq: usize,
111 total_tokens: usize,
113 padded_tokens: usize,
115 seq_lengths: Vec<usize>,
117 hidden: usize,
118 num_heads: usize,
119 head_dim: usize,
120 intermediate: usize,
121 local_window: usize,
122 scale: f32,
123 eps: f32,
124}
125
126fn attn_prenorm_qkv<D: Driver>(
134 driver: &D,
135 hidden_states: &D::Tensor,
136 layer: &ModernBertLayerWeights<D::Tensor>,
137 g: &EncoderGeometry,
138 zero_bias: &D::Tensor,
139 rope: &RopeCache<D::Tensor>,
140) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
141 let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
143 let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
144 driver.layer_norm(
145 &mut n,
146 hidden_states,
147 norm_w,
148 zero_bias,
149 g.total_tokens,
150 g.hidden,
151 g.eps,
152 )?;
153 n
154 } else {
155 driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
157 };
158
159 let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
162 driver.gemm(
163 &normed,
164 &layer.qkv_weight,
165 &mut qkv,
166 g.total_tokens,
167 3 * g.hidden,
168 g.hidden,
169 true,
170 )?;
171
172 let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
175 driver.pad_to_batch(
176 &qkv,
177 &mut qkv_padded,
178 &g.seq_lengths,
179 g.max_seq,
180 3 * g.hidden,
181 )?;
182
183 let padded = g.padded_tokens;
185 let mut q = driver.alloc_zeros(padded * g.hidden)?;
186 let mut k = driver.alloc_zeros(padded * g.hidden)?;
187 let mut v = driver.alloc_zeros(padded * g.hidden)?;
188 driver.qkv_split(
189 &mut q,
190 &mut k,
191 &mut v,
192 &qkv_padded,
193 g.batch,
194 g.max_seq,
195 g.hidden,
196 g.num_heads,
197 g.head_dim,
198 )?;
199
200 let num_rows = g.batch * g.num_heads * g.max_seq;
202 driver.apply_rope(
203 &mut q,
204 &rope.cos,
205 &rope.sin,
206 num_rows,
207 g.max_seq,
208 g.head_dim,
209 g.num_heads,
210 )?;
211 driver.apply_rope(
212 &mut k,
213 &rope.cos,
214 &rope.sin,
215 num_rows,
216 g.max_seq,
217 g.head_dim,
218 g.num_heads,
219 )?;
220
221 Ok((q, k, v))
222}
223
224#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
230fn attn_scores_residual<D: Driver>(
231 driver: &D,
232 q: &D::Tensor,
233 k: &D::Tensor,
234 v: &D::Tensor,
235 hidden_states: &D::Tensor,
236 layer: &ModernBertLayerWeights<D::Tensor>,
237 inputs: &BatchInputs<D::Tensor>,
238 g: &EncoderGeometry,
239) -> crate::Result<D::Tensor> {
240 let batch_heads = g.batch * g.num_heads;
241 let stride_qk = g.max_seq * g.head_dim;
242
243 let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
251 driver.gemm_batched(
252 q,
253 k,
254 &mut scores,
255 g.max_seq,
256 g.max_seq,
257 g.head_dim,
258 true,
259 stride_qk,
260 stride_qk,
261 g.max_seq * g.max_seq,
262 batch_heads,
263 )?;
264
265 if layer.is_global {
266 driver.fused_scale_mask_softmax(
267 &mut scores,
268 &inputs.float_mask,
269 g.batch,
270 g.num_heads,
271 g.max_seq,
272 g.scale,
273 )?;
274 } else {
275 driver.fused_scale_mask_softmax_windowed(
276 &mut scores,
277 &inputs.float_mask,
278 g.batch,
279 g.num_heads,
280 g.max_seq,
281 g.scale,
282 g.local_window,
283 )?;
284 }
285
286 let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
287 driver.gemm_batched(
288 &scores,
289 v,
290 &mut attn_out,
291 g.max_seq,
292 g.head_dim,
293 g.max_seq,
294 false,
295 g.max_seq * g.max_seq,
296 stride_qk,
297 stride_qk,
298 batch_heads,
299 )?;
300
301 let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
303 driver.attn_reshape(
304 &mut context,
305 &attn_out,
306 g.batch,
307 g.max_seq,
308 g.num_heads,
309 g.head_dim,
310 )?;
311
312 let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
316 driver.unpad_from_batch(
317 &context,
318 &mut context_unpacked,
319 &g.seq_lengths,
320 g.max_seq,
321 g.hidden,
322 )?;
323
324 let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
326 driver.gemm(
327 &context_unpacked,
328 &layer.output_weight,
329 &mut projected,
330 g.total_tokens,
331 g.hidden,
332 g.hidden,
333 true,
334 )?;
335
336 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
338 driver.residual_add(
339 &mut output,
340 &projected,
341 hidden_states,
342 g.total_tokens * g.hidden,
343 )?;
344 Ok(output)
345}
346
347fn ffn_sublayer<D: Driver>(
356 driver: &D,
357 attn_output: &D::Tensor,
358 layer: &ModernBertLayerWeights<D::Tensor>,
359 g: &EncoderGeometry,
360 zero_bias: &D::Tensor,
361) -> crate::Result<D::Tensor> {
362 let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
364 driver.layer_norm(
365 &mut mlp_normed,
366 attn_output,
367 &layer.mlp_norm_weight,
368 zero_bias,
369 g.total_tokens,
370 g.hidden,
371 g.eps,
372 )?;
373
374 let double_inter = 2 * g.intermediate;
376 let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
377 driver.gemm(
378 &mlp_normed,
379 &layer.mlp_wi_weight,
380 &mut wi_out,
381 g.total_tokens,
382 double_inter,
383 g.hidden,
384 true,
385 )?;
386
387 let n_elements = g.total_tokens * g.intermediate;
389 let mut value = driver.alloc_zeros(n_elements)?;
390 let mut gate = driver.alloc_zeros(n_elements)?;
391 driver.split_gate_value(
392 &mut value,
393 &mut gate,
394 &wi_out,
395 g.total_tokens,
396 g.intermediate,
397 )?;
398
399 let mut activated = driver.alloc_zeros(n_elements)?;
401 driver.geglu(&value, &gate, &mut activated, n_elements)?;
402
403 let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
405 driver.gemm(
406 &activated,
407 &layer.mlp_wo_weight,
408 &mut mlp_out,
409 g.total_tokens,
410 g.hidden,
411 g.intermediate,
412 true,
413 )?;
414
415 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
417 driver.residual_add(
418 &mut output,
419 &mlp_out,
420 attn_output,
421 g.total_tokens * g.hidden,
422 )?;
423 Ok(output)
424}
425
426fn attn_prenorm_qkv_f16<D: Driver>(
436 driver: &D,
437 hidden_states: &D::Tensor,
438 layer: &ModernBertLayerWeights<D::Tensor>,
439 g: &EncoderGeometry,
440 zero_bias: &D::Tensor,
441 rope: &RopeCache<D::Tensor>,
442) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
443 let normed: Option<D::Tensor>;
446 let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
447 let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
448 driver.layer_norm_f16(
449 &mut n,
450 hidden_states,
451 norm_w,
452 zero_bias,
453 g.total_tokens,
454 g.hidden,
455 g.eps,
456 )?;
457 normed = Some(n);
458 normed.as_ref().unwrap()
459 } else {
460 hidden_states
462 };
463
464 let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
466 driver.gemm_f16(
467 normed_ref,
468 &layer.qkv_weight,
469 &mut qkv,
470 g.total_tokens,
471 3 * g.hidden,
472 g.hidden,
473 true,
474 )?;
475
476 let padded = g.padded_tokens;
479 let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
480 let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
481 let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
482 driver.fused_pad_qkv_split_f16(
483 &mut q,
484 &mut k,
485 &mut v,
486 &qkv,
487 &g.seq_lengths,
488 g.max_seq,
489 g.batch,
490 g.hidden,
491 g.num_heads,
492 g.head_dim,
493 )?;
494
495 let num_rows = g.batch * g.num_heads * g.max_seq;
497 driver.rope_encode_f16(
498 &mut q,
499 &rope.cos,
500 &rope.sin,
501 num_rows,
502 g.max_seq,
503 g.head_dim,
504 g.num_heads,
505 )?;
506 driver.rope_encode_f16(
507 &mut k,
508 &rope.cos,
509 &rope.sin,
510 num_rows,
511 g.max_seq,
512 g.head_dim,
513 g.num_heads,
514 )?;
515
516 Ok((q, k, v))
517}
518
519#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
528fn attn_scores_residual_f16<D: Driver>(
529 driver: &D,
530 q: &D::Tensor,
531 k: &D::Tensor,
532 v: &D::Tensor,
533 hidden_states: &D::Tensor,
534 layer: &ModernBertLayerWeights<D::Tensor>,
535 inputs: &BatchInputs<D::Tensor>,
536 g: &EncoderGeometry,
537) -> crate::Result<D::Tensor> {
538 let batch_heads = g.batch * g.num_heads;
539 let stride_qk = g.max_seq * g.head_dim;
540
541 let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
543 driver.gemm_batched_f16(
544 q,
545 k,
546 &mut scores,
547 g.max_seq,
548 g.max_seq,
549 g.head_dim,
550 true,
551 stride_qk,
552 stride_qk,
553 g.max_seq * g.max_seq,
554 batch_heads,
555 )?;
556
557 if layer.is_global {
559 driver.fused_scale_mask_softmax_f16(
560 &mut scores,
561 &inputs.float_mask,
562 g.batch,
563 g.num_heads,
564 g.max_seq,
565 g.scale,
566 )?;
567 } else {
568 driver.fused_scale_mask_softmax_windowed_f16(
569 &mut scores,
570 &inputs.float_mask,
571 g.batch,
572 g.num_heads,
573 g.max_seq,
574 g.scale,
575 g.local_window,
576 )?;
577 }
578
579 let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
581 driver.gemm_batched_f16(
582 &scores,
583 v,
584 &mut attn_out,
585 g.max_seq,
586 g.head_dim,
587 g.max_seq,
588 false,
589 g.max_seq * g.max_seq,
590 stride_qk,
591 stride_qk,
592 batch_heads,
593 )?;
594
595 let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
598 driver.fused_reshape_unpad_f16(
599 &mut context_unpacked,
600 &attn_out,
601 &g.seq_lengths,
602 g.max_seq,
603 g.batch,
604 g.num_heads,
605 g.head_dim,
606 )?;
607
608 let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
610 driver.gemm_f16(
611 &context_unpacked,
612 &layer.output_weight,
613 &mut projected,
614 g.total_tokens,
615 g.hidden,
616 g.hidden,
617 true,
618 )?;
619
620 let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
622 driver.residual_add_f16(
623 &mut output,
624 &projected,
625 hidden_states,
626 g.total_tokens * g.hidden,
627 )?;
628 Ok(output)
629}
630
631fn ffn_sublayer_f16<D: Driver>(
639 driver: &D,
640 attn_output: &D::Tensor,
641 layer: &ModernBertLayerWeights<D::Tensor>,
642 g: &EncoderGeometry,
643 zero_bias: &D::Tensor,
644) -> crate::Result<D::Tensor> {
645 let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
647 driver.layer_norm_f16(
648 &mut mlp_normed,
649 attn_output,
650 &layer.mlp_norm_weight,
651 zero_bias,
652 g.total_tokens,
653 g.hidden,
654 g.eps,
655 )?;
656
657 let double_inter = 2 * g.intermediate;
659 let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
660 driver.gemm_f16(
661 &mlp_normed,
662 &layer.mlp_wi_weight,
663 &mut wi_out,
664 g.total_tokens,
665 double_inter,
666 g.hidden,
667 true,
668 )?;
669
670 let n_elements = g.total_tokens * g.intermediate;
674 let mut activated = driver.alloc_zeros_f16(n_elements)?;
675 driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
676
677 let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
679 driver.gemm_f16(
680 &activated,
681 &layer.mlp_wo_weight,
682 &mut mlp_out,
683 g.total_tokens,
684 g.hidden,
685 g.intermediate,
686 true,
687 )?;
688
689 let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
691 driver.residual_add_f16(
692 &mut output,
693 &mlp_out,
694 attn_output,
695 g.total_tokens * g.hidden,
696 )?;
697 Ok(output)
698}
699
700impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
705 #[expect(
706 clippy::cast_precision_loss,
707 reason = "head_dim is small (64); sqrt is exact at this size"
708 )]
709 #[expect(
710 clippy::many_single_char_names,
711 reason = "w, g are standard geometry names; q, k, v are standard attention names"
712 )]
713 fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
714 let w = &self.weights;
715 let batch = encodings.len();
716 let hidden = w.hidden_dim;
717
718 let inputs = driver.prepare_batch_unpadded(encodings)?;
719 let max_seq = inputs.max_seq;
720 let total_tokens = inputs.total_tokens;
721
722 driver.begin_batch()?;
724
725 let mut hidden_states =
727 driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
728 let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
729 driver.layer_norm(
730 &mut hidden_states,
731 &emb_input,
732 &w.emb_norm_weight,
733 &w.zero_bias,
734 total_tokens,
735 hidden,
736 w.layer_norm_eps,
737 )?;
738
739 let g = EncoderGeometry {
740 batch,
741 max_seq,
742 total_tokens,
743 padded_tokens: batch * max_seq,
744 seq_lengths: inputs.seq_lengths.clone(),
745 hidden,
746 num_heads: w.num_heads,
747 head_dim: w.head_dim,
748 intermediate: w.intermediate_dim,
749 local_window: w.local_window,
750 scale: 1.0 / (w.head_dim as f32).sqrt(),
751 eps: w.layer_norm_eps,
752 };
753
754 let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
762 || std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
763 let use_f16 = if force_fp32 {
764 false
765 } else {
766 driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
767 };
768
769 if use_f16 {
770 let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
774 driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
775
776 for layer in &w.layers {
778 let saved = driver.save_pool_cursor();
779
780 let rope = if layer.is_global {
781 &self.global_rope
782 } else {
783 &self.local_rope
784 };
785
786 let (q, k, v) =
787 attn_prenorm_qkv_f16(driver, &hidden_f16, layer, &g, &w.zero_bias, rope)?;
788 let attn_output =
789 attn_scores_residual_f16(driver, &q, &k, &v, &hidden_f16, layer, &inputs, &g)?;
790 hidden_f16 = ffn_sublayer_f16(driver, &attn_output, layer, &g, &w.zero_bias)?;
791 driver.restore_pool_cursor(saved);
792 }
793
794 let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
796 driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
797 hidden_states = hidden_f32;
798 } else {
799 for layer in &w.layers {
801 let saved = driver.save_pool_cursor();
802
803 let rope = if layer.is_global {
804 &self.global_rope
805 } else {
806 &self.local_rope
807 };
808
809 let (q, k, v) =
810 attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
811 let attn_output =
812 attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
813 hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
814
815 driver.restore_pool_cursor(saved);
816 }
817 }
818
819 let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
821 driver.layer_norm(
822 &mut hidden_states,
823 &final_input,
824 &w.final_norm_weight,
825 &w.zero_bias,
826 total_tokens,
827 hidden,
828 w.layer_norm_eps,
829 )?;
830
831 let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
833 driver.pad_to_batch(
834 &hidden_states,
835 &mut padded_for_pool,
836 &inputs.seq_lengths,
837 max_seq,
838 hidden,
839 )?;
840
841 let mut pooled = driver.alloc_zeros(batch * hidden)?;
843 driver.mean_pool(
844 &mut pooled,
845 &padded_for_pool,
846 &inputs.pooling_mask,
847 batch,
848 max_seq,
849 hidden,
850 )?;
851 driver.l2_normalize(&mut pooled, batch, hidden)?;
852
853 driver.end_batch()?;
855
856 driver.to_host(&pooled, batch, hidden)
857 }
858}