1use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17pub struct ModernBertLayerWeights<T> {
27 pub qkv_weight: T,
29 pub output_weight: T,
31 pub attn_norm_weight: Option<T>,
33 pub mlp_wi_weight: T,
36 pub mlp_wo_weight: T,
38 pub mlp_norm_weight: T,
40 pub is_global: bool,
42}
43
44pub struct ModernBertWeights<T> {
51 pub tok_embeddings: T,
53 pub emb_norm_weight: T,
55 pub final_norm_weight: T,
57 pub zero_bias: T,
62 pub layers: Vec<ModernBertLayerWeights<T>>,
64 pub num_heads: usize,
66 pub head_dim: usize,
68 pub hidden_dim: usize,
70 pub intermediate_dim: usize,
72 pub layer_norm_eps: f32,
74 pub local_window: usize,
76}
77
78pub struct RopeCache<T> {
83 pub cos: T,
85 pub sin: T,
87}
88
89pub struct ModernBertArch<T> {
95 pub weights: ModernBertWeights<T>,
97 pub global_rope: RopeCache<T>,
99 pub local_rope: RopeCache<T>,
101}
102
103struct EncoderGeometry {
109 batch: usize,
110 max_seq: usize,
111 total_tokens: usize,
113 padded_tokens: usize,
115 seq_lengths: Vec<usize>,
117 hidden: usize,
118 num_heads: usize,
119 head_dim: usize,
120 intermediate: usize,
121 local_window: usize,
122 scale: f32,
123 eps: f32,
124}
125
126fn attn_prenorm_qkv<D: Driver>(
134 driver: &D,
135 hidden_states: &D::Tensor,
136 layer: &ModernBertLayerWeights<D::Tensor>,
137 g: &EncoderGeometry,
138 zero_bias: &D::Tensor,
139 rope: &RopeCache<D::Tensor>,
140) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
141 let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
143 let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
144 driver.layer_norm(
145 &mut n,
146 hidden_states,
147 norm_w,
148 zero_bias,
149 g.total_tokens,
150 g.hidden,
151 g.eps,
152 )?;
153 n
154 } else {
155 driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
157 };
158
159 let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
162 driver.gemm(
163 &normed,
164 &layer.qkv_weight,
165 &mut qkv,
166 g.total_tokens,
167 3 * g.hidden,
168 g.hidden,
169 true,
170 )?;
171
172 let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
175 driver.pad_to_batch(
176 &qkv,
177 &mut qkv_padded,
178 &g.seq_lengths,
179 g.max_seq,
180 3 * g.hidden,
181 )?;
182
183 let padded = g.padded_tokens;
185 let mut q = driver.alloc_zeros(padded * g.hidden)?;
186 let mut k = driver.alloc_zeros(padded * g.hidden)?;
187 let mut v = driver.alloc_zeros(padded * g.hidden)?;
188 driver.qkv_split(
189 &mut q,
190 &mut k,
191 &mut v,
192 &qkv_padded,
193 g.batch,
194 g.max_seq,
195 g.hidden,
196 g.num_heads,
197 g.head_dim,
198 )?;
199
200 let num_rows = g.batch * g.num_heads * g.max_seq;
202 driver.apply_rope(
203 &mut q,
204 &rope.cos,
205 &rope.sin,
206 num_rows,
207 g.max_seq,
208 g.head_dim,
209 g.num_heads,
210 )?;
211 driver.apply_rope(
212 &mut k,
213 &rope.cos,
214 &rope.sin,
215 num_rows,
216 g.max_seq,
217 g.head_dim,
218 g.num_heads,
219 )?;
220
221 Ok((q, k, v))
222}
223
224#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
230fn attn_scores_residual<D: Driver>(
231 driver: &D,
232 q: &D::Tensor,
233 k: &D::Tensor,
234 v: &D::Tensor,
235 hidden_states: &D::Tensor,
236 layer: &ModernBertLayerWeights<D::Tensor>,
237 inputs: &BatchInputs<D::Tensor>,
238 g: &EncoderGeometry,
239) -> crate::Result<D::Tensor> {
240 let batch_heads = g.batch * g.num_heads;
241 let stride_qk = g.max_seq * g.head_dim;
242
243 let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
251 driver.gemm_batched(
252 q,
253 k,
254 &mut scores,
255 g.max_seq,
256 g.max_seq,
257 g.head_dim,
258 true,
259 stride_qk,
260 stride_qk,
261 g.max_seq * g.max_seq,
262 batch_heads,
263 )?;
264
265 if layer.is_global {
266 driver.fused_scale_mask_softmax(
267 &mut scores,
268 &inputs.float_mask,
269 g.batch,
270 g.num_heads,
271 g.max_seq,
272 g.scale,
273 )?;
274 } else {
275 driver.fused_scale_mask_softmax_windowed(
276 &mut scores,
277 &inputs.float_mask,
278 g.batch,
279 g.num_heads,
280 g.max_seq,
281 g.scale,
282 g.local_window,
283 )?;
284 }
285
286 let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
287 driver.gemm_batched(
288 &scores,
289 v,
290 &mut attn_out,
291 g.max_seq,
292 g.head_dim,
293 g.max_seq,
294 false,
295 g.max_seq * g.max_seq,
296 stride_qk,
297 stride_qk,
298 batch_heads,
299 )?;
300
301 let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
303 driver.attn_reshape(
304 &mut context,
305 &attn_out,
306 g.batch,
307 g.max_seq,
308 g.num_heads,
309 g.head_dim,
310 )?;
311
312 let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
316 driver.unpad_from_batch(
317 &context,
318 &mut context_unpacked,
319 &g.seq_lengths,
320 g.max_seq,
321 g.hidden,
322 )?;
323
324 let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
326 driver.gemm(
327 &context_unpacked,
328 &layer.output_weight,
329 &mut projected,
330 g.total_tokens,
331 g.hidden,
332 g.hidden,
333 true,
334 )?;
335
336 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
338 driver.residual_add(
339 &mut output,
340 &projected,
341 hidden_states,
342 g.total_tokens * g.hidden,
343 )?;
344 Ok(output)
345}
346
347fn ffn_sublayer<D: Driver>(
356 driver: &D,
357 attn_output: &D::Tensor,
358 layer: &ModernBertLayerWeights<D::Tensor>,
359 g: &EncoderGeometry,
360 zero_bias: &D::Tensor,
361) -> crate::Result<D::Tensor> {
362 let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
364 driver.layer_norm(
365 &mut mlp_normed,
366 attn_output,
367 &layer.mlp_norm_weight,
368 zero_bias,
369 g.total_tokens,
370 g.hidden,
371 g.eps,
372 )?;
373
374 let double_inter = 2 * g.intermediate;
376 let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
377 driver.gemm(
378 &mlp_normed,
379 &layer.mlp_wi_weight,
380 &mut wi_out,
381 g.total_tokens,
382 double_inter,
383 g.hidden,
384 true,
385 )?;
386
387 let n_elements = g.total_tokens * g.intermediate;
389 let mut value = driver.alloc_zeros(n_elements)?;
390 let mut gate = driver.alloc_zeros(n_elements)?;
391 driver.split_gate_value(
392 &mut value,
393 &mut gate,
394 &wi_out,
395 g.total_tokens,
396 g.intermediate,
397 )?;
398
399 let mut activated = driver.alloc_zeros(n_elements)?;
401 driver.geglu(&value, &gate, &mut activated, n_elements)?;
402
403 let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
405 driver.gemm(
406 &activated,
407 &layer.mlp_wo_weight,
408 &mut mlp_out,
409 g.total_tokens,
410 g.hidden,
411 g.intermediate,
412 true,
413 )?;
414
415 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
417 driver.residual_add(
418 &mut output,
419 &mlp_out,
420 attn_output,
421 g.total_tokens * g.hidden,
422 )?;
423 Ok(output)
424}
425
426fn attn_prenorm_qkv_f16<D: Driver>(
436 driver: &D,
437 hidden_states: &D::Tensor,
438 layer: &ModernBertLayerWeights<D::Tensor>,
439 g: &EncoderGeometry,
440 zero_bias: &D::Tensor,
441 rope: &RopeCache<D::Tensor>,
442) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
443 let normed: Option<D::Tensor>;
446 let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
447 let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
448 driver.layer_norm_f16(
449 &mut n,
450 hidden_states,
451 norm_w,
452 zero_bias,
453 g.total_tokens,
454 g.hidden,
455 g.eps,
456 )?;
457 normed = Some(n);
458 normed.as_ref().unwrap()
459 } else {
460 hidden_states
462 };
463
464 let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
466 driver.gemm_f16(
467 normed_ref,
468 &layer.qkv_weight,
469 &mut qkv,
470 g.total_tokens,
471 3 * g.hidden,
472 g.hidden,
473 true,
474 )?;
475
476 let padded = g.padded_tokens;
479 let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
480 let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
481 let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
482 driver.fused_pad_qkv_split_f16(
483 &mut q,
484 &mut k,
485 &mut v,
486 &qkv,
487 &g.seq_lengths,
488 g.max_seq,
489 g.batch,
490 g.hidden,
491 g.num_heads,
492 g.head_dim,
493 )?;
494
495 let num_rows = g.batch * g.num_heads * g.max_seq;
497 driver.rope_encode_f16(
498 &mut q,
499 &rope.cos,
500 &rope.sin,
501 num_rows,
502 g.max_seq,
503 g.head_dim,
504 g.num_heads,
505 )?;
506 driver.rope_encode_f16(
507 &mut k,
508 &rope.cos,
509 &rope.sin,
510 num_rows,
511 g.max_seq,
512 g.head_dim,
513 g.num_heads,
514 )?;
515
516 Ok((q, k, v))
517}
518
519#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
528fn attn_scores_residual_f16<D: Driver>(
529 driver: &D,
530 q: &D::Tensor,
531 k: &D::Tensor,
532 v: &D::Tensor,
533 hidden_states: &D::Tensor,
534 layer: &ModernBertLayerWeights<D::Tensor>,
535 inputs: &BatchInputs<D::Tensor>,
536 g: &EncoderGeometry,
537) -> crate::Result<D::Tensor> {
538 let batch_heads = g.batch * g.num_heads;
539 let stride_qk = g.max_seq * g.head_dim;
540
541 let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
543 driver.gemm_batched_f16(
544 q,
545 k,
546 &mut scores,
547 g.max_seq,
548 g.max_seq,
549 g.head_dim,
550 true,
551 stride_qk,
552 stride_qk,
553 g.max_seq * g.max_seq,
554 batch_heads,
555 )?;
556
557 if layer.is_global {
559 driver.fused_scale_mask_softmax_f16(
560 &mut scores,
561 &inputs.float_mask,
562 g.batch,
563 g.num_heads,
564 g.max_seq,
565 g.scale,
566 )?;
567 } else {
568 driver.fused_scale_mask_softmax_windowed_f16(
569 &mut scores,
570 &inputs.float_mask,
571 g.batch,
572 g.num_heads,
573 g.max_seq,
574 g.scale,
575 g.local_window,
576 )?;
577 }
578
579 let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
581 driver.gemm_batched_f16(
582 &scores,
583 v,
584 &mut attn_out,
585 g.max_seq,
586 g.head_dim,
587 g.max_seq,
588 false,
589 g.max_seq * g.max_seq,
590 stride_qk,
591 stride_qk,
592 batch_heads,
593 )?;
594
595 let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
598 driver.fused_reshape_unpad_f16(
599 &mut context_unpacked,
600 &attn_out,
601 &g.seq_lengths,
602 g.max_seq,
603 g.batch,
604 g.num_heads,
605 g.head_dim,
606 )?;
607
608 let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
610 driver.gemm_f16(
611 &context_unpacked,
612 &layer.output_weight,
613 &mut projected,
614 g.total_tokens,
615 g.hidden,
616 g.hidden,
617 true,
618 )?;
619
620 let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
622 driver.residual_add_f16(
623 &mut output,
624 &projected,
625 hidden_states,
626 g.total_tokens * g.hidden,
627 )?;
628 Ok(output)
629}
630
631fn ffn_sublayer_f16<D: Driver>(
639 driver: &D,
640 attn_output: &D::Tensor,
641 layer: &ModernBertLayerWeights<D::Tensor>,
642 g: &EncoderGeometry,
643 zero_bias: &D::Tensor,
644) -> crate::Result<D::Tensor> {
645 let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
647 driver.layer_norm_f16(
648 &mut mlp_normed,
649 attn_output,
650 &layer.mlp_norm_weight,
651 zero_bias,
652 g.total_tokens,
653 g.hidden,
654 g.eps,
655 )?;
656
657 let double_inter = 2 * g.intermediate;
659 let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
660 driver.gemm_f16(
661 &mlp_normed,
662 &layer.mlp_wi_weight,
663 &mut wi_out,
664 g.total_tokens,
665 double_inter,
666 g.hidden,
667 true,
668 )?;
669
670 let n_elements = g.total_tokens * g.intermediate;
674 let mut activated = driver.alloc_zeros_f16(n_elements)?;
675 driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
676
677 let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
679 driver.gemm_f16(
680 &activated,
681 &layer.mlp_wo_weight,
682 &mut mlp_out,
683 g.total_tokens,
684 g.hidden,
685 g.intermediate,
686 true,
687 )?;
688
689 let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
691 driver.residual_add_f16(
692 &mut output,
693 &mlp_out,
694 attn_output,
695 g.total_tokens * g.hidden,
696 )?;
697 Ok(output)
698}
699
700impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
705 #[expect(
706 clippy::cast_precision_loss,
707 reason = "head_dim is small (64); sqrt is exact at this size"
708 )]
709 #[expect(
710 clippy::many_single_char_names,
711 reason = "w, g are standard geometry names; q, k, v are standard attention names"
712 )]
713 #[expect(
714 clippy::too_many_lines,
715 reason = "forward pass is a single logical unit"
716 )]
717 fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
718 let w = &self.weights;
719 let batch = encodings.len();
720 let hidden = w.hidden_dim;
721
722 let inputs = driver.prepare_batch_unpadded(encodings)?;
723 let max_seq = inputs.max_seq;
724 let total_tokens = inputs.total_tokens;
725
726 driver.begin_batch()?;
728
729 let mut hidden_states =
731 driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
732 let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
733 driver.layer_norm(
734 &mut hidden_states,
735 &emb_input,
736 &w.emb_norm_weight,
737 &w.zero_bias,
738 total_tokens,
739 hidden,
740 w.layer_norm_eps,
741 )?;
742
743 let g = EncoderGeometry {
744 batch,
745 max_seq,
746 total_tokens,
747 padded_tokens: batch * max_seq,
748 seq_lengths: inputs.seq_lengths.clone(),
749 hidden,
750 num_heads: w.num_heads,
751 head_dim: w.head_dim,
752 intermediate: w.intermediate_dim,
753 local_window: w.local_window,
754 scale: 1.0 / (w.head_dim as f32).sqrt(),
755 eps: w.layer_norm_eps,
756 };
757
758 let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
766 || std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
767 let use_f16 = if force_fp32 {
768 false
769 } else {
770 driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
771 };
772
773 if use_f16 {
774 let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
778 driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
779
780 for layer in &w.layers {
782 let saved = driver.save_pool_cursor();
783
784 let rope = if layer.is_global {
785 &self.global_rope
786 } else {
787 &self.local_rope
788 };
789
790 let (q, k, v) =
791 attn_prenorm_qkv_f16(driver, &hidden_f16, layer, &g, &w.zero_bias, rope)?;
792 let attn_output =
793 attn_scores_residual_f16(driver, &q, &k, &v, &hidden_f16, layer, &inputs, &g)?;
794 hidden_f16 = ffn_sublayer_f16(driver, &attn_output, layer, &g, &w.zero_bias)?;
795 driver.restore_pool_cursor(saved);
796 }
797
798 let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
800 driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
801 hidden_states = hidden_f32;
802 } else {
803 for layer in &w.layers {
805 let saved = driver.save_pool_cursor();
806
807 let rope = if layer.is_global {
808 &self.global_rope
809 } else {
810 &self.local_rope
811 };
812
813 let (q, k, v) =
814 attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
815 let attn_output =
816 attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
817 hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
818
819 driver.restore_pool_cursor(saved);
820 }
821 }
822
823 let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
825 driver.layer_norm(
826 &mut hidden_states,
827 &final_input,
828 &w.final_norm_weight,
829 &w.zero_bias,
830 total_tokens,
831 hidden,
832 w.layer_norm_eps,
833 )?;
834
835 let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
837 driver.pad_to_batch(
838 &hidden_states,
839 &mut padded_for_pool,
840 &inputs.seq_lengths,
841 max_seq,
842 hidden,
843 )?;
844
845 let mut pooled = driver.alloc_zeros(batch * hidden)?;
847 driver.mean_pool(
848 &mut pooled,
849 &padded_for_pool,
850 &inputs.pooling_mask,
851 batch,
852 max_seq,
853 hidden,
854 )?;
855 driver.l2_normalize(&mut pooled, batch, hidden)?;
856
857 driver.end_batch()?;
859
860 driver.to_host(&pooled, batch, hidden)
861 }
862}