Skip to main content

ripvec_core/backend/arch/
modern_bert.rs

1//! `ModernBERT` architecture (`nomic-ai/modernbert-embed-base`).
2//!
3//! 22-layer transformer with alternating local/global attention, gated GELU
4//! (`GeGLU`) MLP, two `RoPE` frequency caches, and pre-norm layer structure.
5//! No biases anywhere, no position embeddings (`RoPE` only), mean pooling.
6//!
7//! Weight structures are generic over the tensor type `T`, which is
8//! [`Driver::Tensor`] when wired to a
9//! backend. The [`ModelArch`] implementation composes
10//! [`Driver`] primitives into the full forward
11//! pass.
12
13use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17// ---------------------------------------------------------------------------
18// Weight structures
19// ---------------------------------------------------------------------------
20
21/// Weights for one `ModernBERT` encoder layer.
22///
23/// All projections are bias-free. The QKV weight is a fused `[3*hidden, hidden]`
24/// matrix. The MLP `Wi` is `[2*intermediate, hidden]` and chunks into value +
25/// gate halves for the `GeGLU` activation.
26pub struct ModernBertLayerWeights<T> {
27    /// Fused Q+K+V projection weight `[3*hidden, hidden]` -- no bias.
28    pub qkv_weight: T,
29    /// Attention output projection weight `[hidden, hidden]` -- no bias.
30    pub output_weight: T,
31    /// Pre-attention `LayerNorm` weight `[hidden]` -- `None` for layer 0 (identity).
32    pub attn_norm_weight: Option<T>,
33    /// Gated MLP input weight `[2*intermediate, hidden]` -- chunks to
34    /// `value[intermediate]` + `gate[intermediate]` for `GeGLU`.
35    pub mlp_wi_weight: T,
36    /// MLP output projection weight `[hidden, intermediate]` -- no bias.
37    pub mlp_wo_weight: T,
38    /// Pre-MLP `LayerNorm` weight `[hidden]`.
39    pub mlp_norm_weight: T,
40    /// Whether this layer uses global (full) or local (sliding window) attention.
41    pub is_global: bool,
42}
43
44/// Full `ModernBERT` model weights, generic over tensor type.
45///
46/// Includes embedding table, per-layer encoder weights, final norm, and model
47/// geometry. The tensor type `T` becomes
48/// [`Driver::Tensor`] when loaded onto a
49/// specific backend.
50pub struct ModernBertWeights<T> {
51    /// Word embedding table `[vocab_size, hidden]`.
52    pub tok_embeddings: T,
53    /// Post-embedding `LayerNorm` weight `[hidden]` (no bias).
54    pub emb_norm_weight: T,
55    /// Final `LayerNorm` weight `[hidden]` applied before pooling (no bias).
56    pub final_norm_weight: T,
57    /// A zero-filled tensor `[hidden]` used as dummy bias for `LayerNorm` calls.
58    ///
59    /// The [`Driver::layer_norm`] API requires a bias tensor; `ModernBERT` has
60    /// none, so we pass this zero buffer instead.
61    pub zero_bias: T,
62    /// Per-layer encoder weights.
63    pub layers: Vec<ModernBertLayerWeights<T>>,
64    /// Number of attention heads (12 for modernbert-embed-base).
65    pub num_heads: usize,
66    /// Dimension per attention head (`hidden / num_heads`, 64).
67    pub head_dim: usize,
68    /// Hidden dimension (768 for modernbert-embed-base).
69    pub hidden_dim: usize,
70    /// MLP intermediate dimension (1152 for modernbert-embed-base).
71    pub intermediate_dim: usize,
72    /// Layer normalization epsilon (1e-5).
73    pub layer_norm_eps: f32,
74    /// Sliding window size for local attention layers (128).
75    pub local_window: usize,
76}
77
78/// Pre-computed `RoPE` cos/sin cache for one frequency base.
79///
80/// `ModernBERT` uses two caches: one for global layers (theta=160000) and one
81/// for local layers (theta=10000).
82pub struct RopeCache<T> {
83    /// Cosine table `[max_seq, head_dim/2]`.
84    pub cos: T,
85    /// Sine table `[max_seq, head_dim/2]`.
86    pub sin: T,
87}
88
89/// `ModernBERT` architecture: `nomic-ai/modernbert-embed-base`.
90///
91/// 22 layers, alternating local/global attention, `GeGLU` MLP, `RoPE` (two
92/// theta values), no biases, mean pooling. Composes [`Driver`] primitives into
93/// the full forward pass.
94pub struct ModernBertArch<T> {
95    /// Model weights on device.
96    pub weights: ModernBertWeights<T>,
97    /// `RoPE` cos/sin cache for global attention layers (theta=160000).
98    pub global_rope: RopeCache<T>,
99    /// `RoPE` cos/sin cache for local attention layers (theta=10000).
100    pub local_rope: RopeCache<T>,
101}
102
103// ---------------------------------------------------------------------------
104// Encoder geometry
105// ---------------------------------------------------------------------------
106
107/// Encoder geometry passed to sublayer helpers to avoid repeating fields.
108struct EncoderGeometry {
109    batch: usize,
110    max_seq: usize,
111    /// Actual tokens across all sequences (no padding). Used for linear ops.
112    total_tokens: usize,
113    /// Padded total: `batch * max_seq`. Used for attention layout.
114    padded_tokens: usize,
115    /// Per-sequence lengths for pad/unpad.
116    seq_lengths: Vec<usize>,
117    hidden: usize,
118    num_heads: usize,
119    head_dim: usize,
120    intermediate: usize,
121    local_window: usize,
122    scale: f32,
123    eps: f32,
124}
125
126// ---------------------------------------------------------------------------
127// Attention sublayer — pre-norm + QKV + RoPE
128// ---------------------------------------------------------------------------
129
130/// Pre-norm + QKV projection + split + `RoPE`.
131///
132/// Returns `(q, k, v)` each `[batch*num_heads, seq, head_dim]`.
133fn attn_prenorm_qkv<D: Driver>(
134    driver: &D,
135    hidden_states: &D::Tensor,
136    layer: &ModernBertLayerWeights<D::Tensor>,
137    g: &EncoderGeometry,
138    zero_bias: &D::Tensor,
139    rope: &RopeCache<D::Tensor>,
140) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
141    // Pre-attention norm (identity for layer 0).
142    let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
143        let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
144        driver.layer_norm(
145            &mut n,
146            hidden_states,
147            norm_w,
148            zero_bias,
149            g.total_tokens,
150            g.hidden,
151            g.eps,
152        )?;
153        n
154    } else {
155        // Layer 0: identity -- just clone the input.
156        driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
157    };
158
159    // QKV projection: [total_tokens, hidden] @ [3*hidden, hidden]^T
160    // Uses total_tokens (unpadded) — no wasted compute on padding.
161    let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
162    driver.gemm(
163        &normed,
164        &layer.qkv_weight,
165        &mut qkv,
166        g.total_tokens,
167        3 * g.hidden,
168        g.hidden,
169        true,
170    )?;
171
172    // Pad QKV from [total_tokens, 3H] to [batch*max_seq, 3H] for attention.
173    // qkv_split needs the padded batch×seq layout to reshape into per-head tensors.
174    let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
175    driver.pad_to_batch(
176        &qkv,
177        &mut qkv_padded,
178        &g.seq_lengths,
179        g.max_seq,
180        3 * g.hidden,
181    )?;
182
183    // Split into Q, K, V each [batch * num_heads, seq, head_dim].
184    let padded = g.padded_tokens;
185    let mut q = driver.alloc_zeros(padded * g.hidden)?;
186    let mut k = driver.alloc_zeros(padded * g.hidden)?;
187    let mut v = driver.alloc_zeros(padded * g.hidden)?;
188    driver.qkv_split(
189        &mut q,
190        &mut k,
191        &mut v,
192        &qkv_padded,
193        g.batch,
194        g.max_seq,
195        g.hidden,
196        g.num_heads,
197        g.head_dim,
198    )?;
199
200    // Apply RoPE to Q and K with appropriate theta.
201    let num_rows = g.batch * g.num_heads * g.max_seq;
202    driver.apply_rope(
203        &mut q,
204        &rope.cos,
205        &rope.sin,
206        num_rows,
207        g.max_seq,
208        g.head_dim,
209        g.num_heads,
210    )?;
211    driver.apply_rope(
212        &mut k,
213        &rope.cos,
214        &rope.sin,
215        num_rows,
216        g.max_seq,
217        g.head_dim,
218        g.num_heads,
219    )?;
220
221    Ok((q, k, v))
222}
223
224// ---------------------------------------------------------------------------
225// Attention sublayer — scores + output projection + residual
226// ---------------------------------------------------------------------------
227
228/// Attention scores + output projection + residual add.
229#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
230fn attn_scores_residual<D: Driver>(
231    driver: &D,
232    q: &D::Tensor,
233    k: &D::Tensor,
234    v: &D::Tensor,
235    hidden_states: &D::Tensor,
236    layer: &ModernBertLayerWeights<D::Tensor>,
237    inputs: &BatchInputs<D::Tensor>,
238    g: &EncoderGeometry,
239) -> crate::Result<D::Tensor> {
240    let batch_heads = g.batch * g.num_heads;
241    let stride_qk = g.max_seq * g.head_dim;
242
243    // Full Q@K^T for all layers. Local layers mask via windowed softmax.
244    //
245    // Banded attention kernels (banded_qk/sv) compute 4× fewer elements for
246    // local layers but are scalar — 35% slower than hardware simdgroup GEMM.
247    // The GEMM + windowed mask approach wastes compute on masked positions but
248    // Apple's hardware matmul throughput more than compensates at seq≤512.
249    // TODO: banded attention wins at seq>2048 where O(seq²) dominates.
250    let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
251    driver.gemm_batched(
252        q,
253        k,
254        &mut scores,
255        g.max_seq,
256        g.max_seq,
257        g.head_dim,
258        true,
259        stride_qk,
260        stride_qk,
261        g.max_seq * g.max_seq,
262        batch_heads,
263    )?;
264
265    if layer.is_global {
266        driver.fused_scale_mask_softmax(
267            &mut scores,
268            &inputs.float_mask,
269            g.batch,
270            g.num_heads,
271            g.max_seq,
272            g.scale,
273        )?;
274    } else {
275        driver.fused_scale_mask_softmax_windowed(
276            &mut scores,
277            &inputs.float_mask,
278            g.batch,
279            g.num_heads,
280            g.max_seq,
281            g.scale,
282            g.local_window,
283        )?;
284    }
285
286    let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
287    driver.gemm_batched(
288        &scores,
289        v,
290        &mut attn_out,
291        g.max_seq,
292        g.head_dim,
293        g.max_seq,
294        false,
295        g.max_seq * g.max_seq,
296        stride_qk,
297        stride_qk,
298        batch_heads,
299    )?;
300
301    // Reshape heads back to [padded_tokens, hidden] (still padded).
302    let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
303    driver.attn_reshape(
304        &mut context,
305        &attn_out,
306        g.batch,
307        g.max_seq,
308        g.num_heads,
309        g.head_dim,
310    )?;
311
312    // Unpad FIRST: [padded_tokens, H] → [total_tokens, H].
313    // Output projection is per-token — unpadding before GEMM is valid and
314    // avoids processing batch*max_seq rows when only total_tokens are real.
315    let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
316    driver.unpad_from_batch(
317        &context,
318        &mut context_unpacked,
319        &g.seq_lengths,
320        g.max_seq,
321        g.hidden,
322    )?;
323
324    // Output projection on unpadded layout: [total_tokens, H] × [H, H].
325    let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
326    driver.gemm(
327        &context_unpacked,
328        &layer.output_weight,
329        &mut projected,
330        g.total_tokens,
331        g.hidden,
332        g.hidden,
333        true,
334    )?;
335
336    // Residual add (no bias in ModernBERT). Both are [total_tokens, H].
337    let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
338    driver.residual_add(
339        &mut output,
340        &projected,
341        hidden_states,
342        g.total_tokens * g.hidden,
343    )?;
344    Ok(output)
345}
346
347// ---------------------------------------------------------------------------
348// Feed-forward (GeGLU MLP) sublayer
349// ---------------------------------------------------------------------------
350
351/// Run the gated GELU MLP sublayer for one `ModernBERT` encoder layer.
352///
353/// Pre-MLP norm -> Wi projection -> split into value+gate -> `GeGLU` ->
354/// Wo projection -> residual add.
355fn ffn_sublayer<D: Driver>(
356    driver: &D,
357    attn_output: &D::Tensor,
358    layer: &ModernBertLayerWeights<D::Tensor>,
359    g: &EncoderGeometry,
360    zero_bias: &D::Tensor,
361) -> crate::Result<D::Tensor> {
362    // Pre-MLP LayerNorm.
363    let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
364    driver.layer_norm(
365        &mut mlp_normed,
366        attn_output,
367        &layer.mlp_norm_weight,
368        zero_bias,
369        g.total_tokens,
370        g.hidden,
371        g.eps,
372    )?;
373
374    // Wi projection: [total_tokens, hidden] @ [2*inter, hidden]^T → [total_tokens, 2*inter].
375    let double_inter = 2 * g.intermediate;
376    let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
377    driver.gemm(
378        &mlp_normed,
379        &layer.mlp_wi_weight,
380        &mut wi_out,
381        g.total_tokens,
382        double_inter,
383        g.hidden,
384        true,
385    )?;
386
387    // Split Wi output into value [total_tokens, inter] and gate [total_tokens, inter].
388    let n_elements = g.total_tokens * g.intermediate;
389    let mut value = driver.alloc_zeros(n_elements)?;
390    let mut gate = driver.alloc_zeros(n_elements)?;
391    driver.split_gate_value(
392        &mut value,
393        &mut gate,
394        &wi_out,
395        g.total_tokens,
396        g.intermediate,
397    )?;
398
399    // GeGLU: output = gelu(value) * gate
400    let mut activated = driver.alloc_zeros(n_elements)?;
401    driver.geglu(&value, &gate, &mut activated, n_elements)?;
402
403    // Wo projection: [total_tokens, inter] @ [hidden, inter]^T => [total_tokens, hidden]
404    let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
405    driver.gemm(
406        &activated,
407        &layer.mlp_wo_weight,
408        &mut mlp_out,
409        g.total_tokens,
410        g.hidden,
411        g.intermediate,
412        true,
413    )?;
414
415    // Residual add (no bias).
416    let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
417    driver.residual_add(
418        &mut output,
419        &mlp_out,
420        attn_output,
421        g.total_tokens * g.hidden,
422    )?;
423    Ok(output)
424}
425
426// ---------------------------------------------------------------------------
427// FP16 attention sublayer — pre-norm + QKV + RoPE (all half precision)
428// ---------------------------------------------------------------------------
429
430/// FP16 pre-norm + QKV projection + split + `RoPE`.
431///
432/// All tensors are half precision. RoPE cos/sin tables stay FP32 (the kernel
433/// reads half Q/K, does FP32 trig, writes half).
434/// Returns `(q, k, v)` each `[batch*num_heads, seq, head_dim]` in FP16.
435fn attn_prenorm_qkv_f16<D: Driver>(
436    driver: &D,
437    hidden_states: &D::Tensor,
438    layer: &ModernBertLayerWeights<D::Tensor>,
439    g: &EncoderGeometry,
440    zero_bias: &D::Tensor,
441    rope: &RopeCache<D::Tensor>,
442) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
443    // Pre-attention norm (identity for layer 0). FP16 in/out.
444    // Layer 0 uses hidden_states directly (GEMM is read-only, no clone needed).
445    let normed: Option<D::Tensor>;
446    let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
447        let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
448        driver.layer_norm_f16(
449            &mut n,
450            hidden_states,
451            norm_w,
452            zero_bias,
453            g.total_tokens,
454            g.hidden,
455            g.eps,
456        )?;
457        normed = Some(n);
458        normed.as_ref().unwrap()
459    } else {
460        // Layer 0: identity — pass through directly. GEMM reads, does not modify.
461        hidden_states
462    };
463
464    // QKV: [total_tokens, hidden] @ [3*hidden, hidden]^T — all FP16.
465    let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
466    driver.gemm_f16(
467        normed_ref,
468        &layer.qkv_weight,
469        &mut qkv,
470        g.total_tokens,
471        3 * g.hidden,
472        g.hidden,
473        true,
474    )?;
475
476    // Fused pad + QKV split: flat → Q, K, V in per-head layout directly.
477    // Eliminates the padded intermediate buffer and its 2 memory round-trips.
478    let padded = g.padded_tokens;
479    let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
480    let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
481    let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
482    driver.fused_pad_qkv_split_f16(
483        &mut q,
484        &mut k,
485        &mut v,
486        &qkv,
487        &g.seq_lengths,
488        g.max_seq,
489        g.batch,
490        g.hidden,
491        g.num_heads,
492        g.head_dim,
493    )?;
494
495    // RoPE: half Q/K, float cos/sin tables.
496    let num_rows = g.batch * g.num_heads * g.max_seq;
497    driver.rope_encode_f16(
498        &mut q,
499        &rope.cos,
500        &rope.sin,
501        num_rows,
502        g.max_seq,
503        g.head_dim,
504        g.num_heads,
505    )?;
506    driver.rope_encode_f16(
507        &mut k,
508        &rope.cos,
509        &rope.sin,
510        num_rows,
511        g.max_seq,
512        g.head_dim,
513        g.num_heads,
514    )?;
515
516    Ok((q, k, v))
517}
518
519// ---------------------------------------------------------------------------
520// FP16 attention sublayer — scores + output projection + residual
521// ---------------------------------------------------------------------------
522
523/// FP16 attention scores + output projection + residual add.
524///
525/// All tensors FP16. The softmax kernel uses FP32 accumulators internally.
526/// The `float_mask` from `BatchInputs` stays FP32 (softmax kernel reads it).
527#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
528fn attn_scores_residual_f16<D: Driver>(
529    driver: &D,
530    q: &D::Tensor,
531    k: &D::Tensor,
532    v: &D::Tensor,
533    hidden_states: &D::Tensor,
534    layer: &ModernBertLayerWeights<D::Tensor>,
535    inputs: &BatchInputs<D::Tensor>,
536    g: &EncoderGeometry,
537) -> crate::Result<D::Tensor> {
538    let batch_heads = g.batch * g.num_heads;
539    let stride_qk = g.max_seq * g.head_dim;
540
541    // Q@K^T — FP16 batched GEMM.
542    let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
543    driver.gemm_batched_f16(
544        q,
545        k,
546        &mut scores,
547        g.max_seq,
548        g.max_seq,
549        g.head_dim,
550        true,
551        stride_qk,
552        stride_qk,
553        g.max_seq * g.max_seq,
554        batch_heads,
555    )?;
556
557    // Softmax — FP16 scores, FP32 mask, FP32 accumulators inside kernel.
558    if layer.is_global {
559        driver.fused_scale_mask_softmax_f16(
560            &mut scores,
561            &inputs.float_mask,
562            g.batch,
563            g.num_heads,
564            g.max_seq,
565            g.scale,
566        )?;
567    } else {
568        driver.fused_scale_mask_softmax_windowed_f16(
569            &mut scores,
570            &inputs.float_mask,
571            g.batch,
572            g.num_heads,
573            g.max_seq,
574            g.scale,
575            g.local_window,
576        )?;
577    }
578
579    // scores @ V — FP16 batched GEMM.
580    let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
581    driver.gemm_batched_f16(
582        &scores,
583        v,
584        &mut attn_out,
585        g.max_seq,
586        g.head_dim,
587        g.max_seq,
588        false,
589        g.max_seq * g.max_seq,
590        stride_qk,
591        stride_qk,
592        batch_heads,
593    )?;
594
595    // Fused reshape + unpad: [batch*heads, max_seq, head_dim] → [total_tokens, hidden].
596    // Eliminates the padded context intermediate buffer.
597    let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
598    driver.fused_reshape_unpad_f16(
599        &mut context_unpacked,
600        &attn_out,
601        &g.seq_lengths,
602        g.max_seq,
603        g.batch,
604        g.num_heads,
605        g.head_dim,
606    )?;
607
608    // Output projection on unpadded — FP16: [total_tokens, H] × [H, H].
609    let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
610    driver.gemm_f16(
611        &context_unpacked,
612        &layer.output_weight,
613        &mut projected,
614        g.total_tokens,
615        g.hidden,
616        g.hidden,
617        true,
618    )?;
619
620    // Residual add — FP16.
621    let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
622    driver.residual_add_f16(
623        &mut output,
624        &projected,
625        hidden_states,
626        g.total_tokens * g.hidden,
627    )?;
628    Ok(output)
629}
630
631// ---------------------------------------------------------------------------
632// FP16 feed-forward (GeGLU MLP) sublayer
633// ---------------------------------------------------------------------------
634
635/// FP16 gated GELU MLP sublayer.
636///
637/// All tensors FP16. `GeGLU` kernel uses FP32 GELU compute internally.
638fn ffn_sublayer_f16<D: Driver>(
639    driver: &D,
640    attn_output: &D::Tensor,
641    layer: &ModernBertLayerWeights<D::Tensor>,
642    g: &EncoderGeometry,
643    zero_bias: &D::Tensor,
644) -> crate::Result<D::Tensor> {
645    // Pre-MLP LayerNorm — FP16.
646    let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
647    driver.layer_norm_f16(
648        &mut mlp_normed,
649        attn_output,
650        &layer.mlp_norm_weight,
651        zero_bias,
652        g.total_tokens,
653        g.hidden,
654        g.eps,
655    )?;
656
657    // Wi projection — FP16 GEMM.
658    let double_inter = 2 * g.intermediate;
659    let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
660    driver.gemm_f16(
661        &mlp_normed,
662        &layer.mlp_wi_weight,
663        &mut wi_out,
664        g.total_tokens,
665        double_inter,
666        g.hidden,
667        true,
668    )?;
669
670    // Fused split + GeGLU — FP16.
671    // Reads [total_tokens, 2*intermediate], writes [total_tokens, intermediate].
672    // Eliminates two intermediate buffers and halves HBM bandwidth.
673    let n_elements = g.total_tokens * g.intermediate;
674    let mut activated = driver.alloc_zeros_f16(n_elements)?;
675    driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
676
677    // Wo projection — FP16 GEMM.
678    let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
679    driver.gemm_f16(
680        &activated,
681        &layer.mlp_wo_weight,
682        &mut mlp_out,
683        g.total_tokens,
684        g.hidden,
685        g.intermediate,
686        true,
687    )?;
688
689    // Residual add — FP16.
690    let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
691    driver.residual_add_f16(
692        &mut output,
693        &mlp_out,
694        attn_output,
695        g.total_tokens * g.hidden,
696    )?;
697    Ok(output)
698}
699
700// ---------------------------------------------------------------------------
701// ModelArch implementation
702// ---------------------------------------------------------------------------
703
704impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
705    #[expect(
706        clippy::cast_precision_loss,
707        reason = "head_dim is small (64); sqrt is exact at this size"
708    )]
709    #[expect(
710        clippy::many_single_char_names,
711        reason = "w, g are standard geometry names; q, k, v are standard attention names"
712    )]
713    #[expect(
714        clippy::too_many_lines,
715        reason = "forward pass is a single logical unit"
716    )]
717    fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
718        let w = &self.weights;
719        let batch = encodings.len();
720        let hidden = w.hidden_dim;
721
722        let inputs = driver.prepare_batch_unpadded(encodings)?;
723        let max_seq = inputs.max_seq;
724        let total_tokens = inputs.total_tokens;
725
726        // Enter batched mode: all GPU ops encode into ONE command buffer.
727        driver.begin_batch()?;
728
729        // Embedding (FP32): tok_embeddings + LayerNorm.
730        let mut hidden_states =
731            driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
732        let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
733        driver.layer_norm(
734            &mut hidden_states,
735            &emb_input,
736            &w.emb_norm_weight,
737            &w.zero_bias,
738            total_tokens,
739            hidden,
740            w.layer_norm_eps,
741        )?;
742
743        let g = EncoderGeometry {
744            batch,
745            max_seq,
746            total_tokens,
747            padded_tokens: batch * max_seq,
748            seq_lengths: inputs.seq_lengths.clone(),
749            hidden,
750            num_heads: w.num_heads,
751            head_dim: w.head_dim,
752            intermediate: w.intermediate_dim,
753            local_window: w.local_window,
754            scale: 1.0 / (w.head_dim as f32).sqrt(),
755            eps: w.layer_norm_eps,
756        };
757
758        // FP16 path: f32_to_f16 ONCE → all layers in FP16 → f16_to_f32 ONCE.
759        // Falls back to FP32 if the driver doesn't support FP16 ops.
760        //
761        // MPS FP16 GEMM uses Apple's proprietary AMX coprocessor (72/s).
762        // RIPVEC_NO_MPS=1: force FP32 activations + compute GEMM path.
763        // The gemm_f16w_f32a_kernel uses native simdgroup ops with FP16 weights
764        // and FP32 activations — no MFA wrapper, no type conversion at store.
765        let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
766            || std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
767        let use_f16 = if force_fp32 {
768            false
769        } else {
770            driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
771        };
772
773        if use_f16 {
774            // === FP16 PATH: zero F32↔F16 conversions in layer loop ===
775
776            // ONLY conversion #1: F32 → F16 after embedding LN.
777            let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
778            driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
779
780            // 22 layers — ALL in FP16.
781            for layer in &w.layers {
782                let saved = driver.save_pool_cursor();
783
784                let rope = if layer.is_global {
785                    &self.global_rope
786                } else {
787                    &self.local_rope
788                };
789
790                let (q, k, v) =
791                    attn_prenorm_qkv_f16(driver, &hidden_f16, layer, &g, &w.zero_bias, rope)?;
792                let attn_output =
793                    attn_scores_residual_f16(driver, &q, &k, &v, &hidden_f16, layer, &inputs, &g)?;
794                hidden_f16 = ffn_sublayer_f16(driver, &attn_output, layer, &g, &w.zero_bias)?;
795                driver.restore_pool_cursor(saved);
796            }
797
798            // ONLY conversion #2: F16 → F32 before final LN + pooling.
799            let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
800            driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
801            hidden_states = hidden_f32;
802        } else {
803            // === FP32 PATH ===
804            for layer in &w.layers {
805                let saved = driver.save_pool_cursor();
806
807                let rope = if layer.is_global {
808                    &self.global_rope
809                } else {
810                    &self.local_rope
811                };
812
813                let (q, k, v) =
814                    attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
815                let attn_output =
816                    attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
817                hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
818
819                driver.restore_pool_cursor(saved);
820            }
821        }
822
823        // Final LayerNorm (FP32) before pooling.
824        let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
825        driver.layer_norm(
826            &mut hidden_states,
827            &final_input,
828            &w.final_norm_weight,
829            &w.zero_bias,
830            total_tokens,
831            hidden,
832            w.layer_norm_eps,
833        )?;
834
835        // Pad back to [batch, max_seq, hidden] for mean_pool kernel.
836        let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
837        driver.pad_to_batch(
838            &hidden_states,
839            &mut padded_for_pool,
840            &inputs.seq_lengths,
841            max_seq,
842            hidden,
843        )?;
844
845        // Mean pooling + L2 normalize (FP32).
846        let mut pooled = driver.alloc_zeros(batch * hidden)?;
847        driver.mean_pool(
848            &mut pooled,
849            &padded_for_pool,
850            &inputs.pooling_mask,
851            batch,
852            max_seq,
853            hidden,
854        )?;
855        driver.l2_normalize(&mut pooled, batch, hidden)?;
856
857        // End batched mode — commit all GPU work, wait for completion.
858        driver.end_batch()?;
859
860        driver.to_host(&pooled, batch, hidden)
861    }
862}