Skip to main content

ripvec_core/backend/arch/
modern_bert.rs

1//! `ModernBERT` architecture (`nomic-ai/modernbert-embed-base`).
2//!
3//! 22-layer transformer with alternating local/global attention, gated GELU
4//! (`GeGLU`) MLP, two `RoPE` frequency caches, and pre-norm layer structure.
5//! No biases anywhere, no position embeddings (`RoPE` only), mean pooling.
6//!
7//! Weight structures are generic over the tensor type `T`, which is
8//! [`Driver::Tensor`] when wired to a
9//! backend. The [`ModelArch`] implementation composes
10//! [`Driver`] primitives into the full forward
11//! pass.
12
13use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17// ---------------------------------------------------------------------------
18// Weight structures
19// ---------------------------------------------------------------------------
20
21/// Weights for one `ModernBERT` encoder layer.
22///
23/// All projections are bias-free. The QKV weight is a fused `[3*hidden, hidden]`
24/// matrix. The MLP `Wi` is `[2*intermediate, hidden]` and chunks into value +
25/// gate halves for the `GeGLU` activation.
26pub struct ModernBertLayerWeights<T> {
27    /// Fused Q+K+V projection weight `[3*hidden, hidden]` -- no bias.
28    pub qkv_weight: T,
29    /// Attention output projection weight `[hidden, hidden]` -- no bias.
30    pub output_weight: T,
31    /// Pre-attention `LayerNorm` weight `[hidden]` -- `None` for layer 0 (identity).
32    pub attn_norm_weight: Option<T>,
33    /// Gated MLP input weight `[2*intermediate, hidden]` -- chunks to
34    /// `value[intermediate]` + `gate[intermediate]` for `GeGLU`.
35    pub mlp_wi_weight: T,
36    /// MLP output projection weight `[hidden, intermediate]` -- no bias.
37    pub mlp_wo_weight: T,
38    /// Pre-MLP `LayerNorm` weight `[hidden]`.
39    pub mlp_norm_weight: T,
40    /// Whether this layer uses global (full) or local (sliding window) attention.
41    pub is_global: bool,
42}
43
44/// Full `ModernBERT` model weights, generic over tensor type.
45///
46/// Includes embedding table, per-layer encoder weights, final norm, and model
47/// geometry. The tensor type `T` becomes
48/// [`Driver::Tensor`] when loaded onto a
49/// specific backend.
50pub struct ModernBertWeights<T> {
51    /// Word embedding table `[vocab_size, hidden]`.
52    pub tok_embeddings: T,
53    /// Post-embedding `LayerNorm` weight `[hidden]` (no bias).
54    pub emb_norm_weight: T,
55    /// Final `LayerNorm` weight `[hidden]` applied before pooling (no bias).
56    pub final_norm_weight: T,
57    /// A zero-filled tensor `[hidden]` used as dummy bias for `LayerNorm` calls.
58    ///
59    /// The [`Driver::layer_norm`] API requires a bias tensor; `ModernBERT` has
60    /// none, so we pass this zero buffer instead.
61    pub zero_bias: T,
62    /// Per-layer encoder weights.
63    pub layers: Vec<ModernBertLayerWeights<T>>,
64    /// Number of attention heads (12 for modernbert-embed-base).
65    pub num_heads: usize,
66    /// Dimension per attention head (`hidden / num_heads`, 64).
67    pub head_dim: usize,
68    /// Hidden dimension (768 for modernbert-embed-base).
69    pub hidden_dim: usize,
70    /// MLP intermediate dimension (1152 for modernbert-embed-base).
71    pub intermediate_dim: usize,
72    /// Layer normalization epsilon (1e-5).
73    pub layer_norm_eps: f32,
74    /// Sliding window size for local attention layers (128).
75    pub local_window: usize,
76}
77
78/// Pre-computed `RoPE` cos/sin cache for one frequency base.
79///
80/// `ModernBERT` uses two caches: one for global layers (theta=160000) and one
81/// for local layers (theta=10000).
82pub struct RopeCache<T> {
83    /// Cosine table `[max_seq, head_dim/2]`.
84    pub cos: T,
85    /// Sine table `[max_seq, head_dim/2]`.
86    pub sin: T,
87}
88
89/// `ModernBERT` architecture: `nomic-ai/modernbert-embed-base`.
90///
91/// 22 layers, alternating local/global attention, `GeGLU` MLP, `RoPE` (two
92/// theta values), no biases, mean pooling. Composes [`Driver`] primitives into
93/// the full forward pass.
94pub struct ModernBertArch<T> {
95    /// Model weights on device.
96    pub weights: ModernBertWeights<T>,
97    /// `RoPE` cos/sin cache for global attention layers (theta=160000).
98    pub global_rope: RopeCache<T>,
99    /// `RoPE` cos/sin cache for local attention layers (theta=10000).
100    pub local_rope: RopeCache<T>,
101}
102
103// ---------------------------------------------------------------------------
104// Encoder geometry
105// ---------------------------------------------------------------------------
106
107/// Encoder geometry passed to sublayer helpers to avoid repeating fields.
108struct EncoderGeometry {
109    batch: usize,
110    max_seq: usize,
111    /// Actual tokens across all sequences (no padding). Used for linear ops.
112    total_tokens: usize,
113    /// Padded total: `batch * max_seq`. Used for attention layout.
114    padded_tokens: usize,
115    /// Per-sequence lengths for pad/unpad.
116    seq_lengths: Vec<usize>,
117    hidden: usize,
118    num_heads: usize,
119    head_dim: usize,
120    intermediate: usize,
121    local_window: usize,
122    scale: f32,
123    eps: f32,
124}
125
126// ---------------------------------------------------------------------------
127// Attention sublayer — pre-norm + QKV + RoPE
128// ---------------------------------------------------------------------------
129
130/// Pre-norm + QKV projection + split + `RoPE`.
131///
132/// Returns `(q, k, v)` each `[batch*num_heads, seq, head_dim]`.
133fn attn_prenorm_qkv<D: Driver>(
134    driver: &D,
135    hidden_states: &D::Tensor,
136    layer: &ModernBertLayerWeights<D::Tensor>,
137    g: &EncoderGeometry,
138    zero_bias: &D::Tensor,
139    rope: &RopeCache<D::Tensor>,
140) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
141    // Pre-attention norm (identity for layer 0).
142    let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
143        let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
144        driver.layer_norm(
145            &mut n,
146            hidden_states,
147            norm_w,
148            zero_bias,
149            g.total_tokens,
150            g.hidden,
151            g.eps,
152        )?;
153        n
154    } else {
155        // Layer 0: identity -- just clone the input.
156        driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
157    };
158
159    // QKV projection: [total_tokens, hidden] @ [3*hidden, hidden]^T
160    // Uses total_tokens (unpadded) — no wasted compute on padding.
161    let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
162    driver.gemm(
163        &normed,
164        &layer.qkv_weight,
165        &mut qkv,
166        g.total_tokens,
167        3 * g.hidden,
168        g.hidden,
169        true,
170    )?;
171
172    // Pad QKV from [total_tokens, 3H] to [batch*max_seq, 3H] for attention.
173    // qkv_split needs the padded batch×seq layout to reshape into per-head tensors.
174    let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
175    driver.pad_to_batch(
176        &qkv,
177        &mut qkv_padded,
178        &g.seq_lengths,
179        g.max_seq,
180        3 * g.hidden,
181    )?;
182
183    // Split into Q, K, V each [batch * num_heads, seq, head_dim].
184    let padded = g.padded_tokens;
185    let mut q = driver.alloc_zeros(padded * g.hidden)?;
186    let mut k = driver.alloc_zeros(padded * g.hidden)?;
187    let mut v = driver.alloc_zeros(padded * g.hidden)?;
188    driver.qkv_split(
189        &mut q,
190        &mut k,
191        &mut v,
192        &qkv_padded,
193        g.batch,
194        g.max_seq,
195        g.hidden,
196        g.num_heads,
197        g.head_dim,
198    )?;
199
200    // Apply RoPE to Q and K with appropriate theta.
201    let num_rows = g.batch * g.num_heads * g.max_seq;
202    driver.apply_rope(
203        &mut q,
204        &rope.cos,
205        &rope.sin,
206        num_rows,
207        g.max_seq,
208        g.head_dim,
209        g.num_heads,
210    )?;
211    driver.apply_rope(
212        &mut k,
213        &rope.cos,
214        &rope.sin,
215        num_rows,
216        g.max_seq,
217        g.head_dim,
218        g.num_heads,
219    )?;
220
221    Ok((q, k, v))
222}
223
224// ---------------------------------------------------------------------------
225// Attention sublayer — scores + output projection + residual
226// ---------------------------------------------------------------------------
227
228/// Attention scores + output projection + residual add.
229#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
230fn attn_scores_residual<D: Driver>(
231    driver: &D,
232    q: &D::Tensor,
233    k: &D::Tensor,
234    v: &D::Tensor,
235    hidden_states: &D::Tensor,
236    layer: &ModernBertLayerWeights<D::Tensor>,
237    inputs: &BatchInputs<D::Tensor>,
238    g: &EncoderGeometry,
239) -> crate::Result<D::Tensor> {
240    let batch_heads = g.batch * g.num_heads;
241    let stride_qk = g.max_seq * g.head_dim;
242
243    // Full Q@K^T for all layers. Local layers mask via windowed softmax.
244    //
245    // Banded attention kernels (banded_qk/sv) compute 4× fewer elements for
246    // local layers but are scalar — 35% slower than hardware simdgroup GEMM.
247    // The GEMM + windowed mask approach wastes compute on masked positions but
248    // Apple's hardware matmul throughput more than compensates at seq≤512.
249    // TODO: banded attention wins at seq>2048 where O(seq²) dominates.
250    let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
251    driver.gemm_batched(
252        q,
253        k,
254        &mut scores,
255        g.max_seq,
256        g.max_seq,
257        g.head_dim,
258        true,
259        stride_qk,
260        stride_qk,
261        g.max_seq * g.max_seq,
262        batch_heads,
263    )?;
264
265    if layer.is_global {
266        driver.fused_scale_mask_softmax(
267            &mut scores,
268            &inputs.float_mask,
269            g.batch,
270            g.num_heads,
271            g.max_seq,
272            g.scale,
273        )?;
274    } else {
275        driver.fused_scale_mask_softmax_windowed(
276            &mut scores,
277            &inputs.float_mask,
278            g.batch,
279            g.num_heads,
280            g.max_seq,
281            g.scale,
282            g.local_window,
283        )?;
284    }
285
286    let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
287    driver.gemm_batched(
288        &scores,
289        v,
290        &mut attn_out,
291        g.max_seq,
292        g.head_dim,
293        g.max_seq,
294        false,
295        g.max_seq * g.max_seq,
296        stride_qk,
297        stride_qk,
298        batch_heads,
299    )?;
300
301    // Reshape heads back to [padded_tokens, hidden] (still padded).
302    let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
303    driver.attn_reshape(
304        &mut context,
305        &attn_out,
306        g.batch,
307        g.max_seq,
308        g.num_heads,
309        g.head_dim,
310    )?;
311
312    // Unpad FIRST: [padded_tokens, H] → [total_tokens, H].
313    // Output projection is per-token — unpadding before GEMM is valid and
314    // avoids processing batch*max_seq rows when only total_tokens are real.
315    let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
316    driver.unpad_from_batch(
317        &context,
318        &mut context_unpacked,
319        &g.seq_lengths,
320        g.max_seq,
321        g.hidden,
322    )?;
323
324    // Output projection on unpadded layout: [total_tokens, H] × [H, H].
325    let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
326    driver.gemm(
327        &context_unpacked,
328        &layer.output_weight,
329        &mut projected,
330        g.total_tokens,
331        g.hidden,
332        g.hidden,
333        true,
334    )?;
335
336    // Residual add (no bias in ModernBERT). Both are [total_tokens, H].
337    let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
338    driver.residual_add(
339        &mut output,
340        &projected,
341        hidden_states,
342        g.total_tokens * g.hidden,
343    )?;
344    Ok(output)
345}
346
347// ---------------------------------------------------------------------------
348// Feed-forward (GeGLU MLP) sublayer
349// ---------------------------------------------------------------------------
350
351/// Run the gated GELU MLP sublayer for one `ModernBERT` encoder layer.
352///
353/// Pre-MLP norm -> Wi projection -> split into value+gate -> `GeGLU` ->
354/// Wo projection -> residual add.
355fn ffn_sublayer<D: Driver>(
356    driver: &D,
357    attn_output: &D::Tensor,
358    layer: &ModernBertLayerWeights<D::Tensor>,
359    g: &EncoderGeometry,
360    zero_bias: &D::Tensor,
361) -> crate::Result<D::Tensor> {
362    // Pre-MLP LayerNorm.
363    let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
364    driver.layer_norm(
365        &mut mlp_normed,
366        attn_output,
367        &layer.mlp_norm_weight,
368        zero_bias,
369        g.total_tokens,
370        g.hidden,
371        g.eps,
372    )?;
373
374    // Wi projection: [total_tokens, hidden] @ [2*inter, hidden]^T → [total_tokens, 2*inter].
375    let double_inter = 2 * g.intermediate;
376    let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
377    driver.gemm(
378        &mlp_normed,
379        &layer.mlp_wi_weight,
380        &mut wi_out,
381        g.total_tokens,
382        double_inter,
383        g.hidden,
384        true,
385    )?;
386
387    // Split Wi output into value [total_tokens, inter] and gate [total_tokens, inter].
388    let n_elements = g.total_tokens * g.intermediate;
389    let mut value = driver.alloc_zeros(n_elements)?;
390    let mut gate = driver.alloc_zeros(n_elements)?;
391    driver.split_gate_value(
392        &mut value,
393        &mut gate,
394        &wi_out,
395        g.total_tokens,
396        g.intermediate,
397    )?;
398
399    // GeGLU: output = gelu(value) * gate
400    let mut activated = driver.alloc_zeros(n_elements)?;
401    driver.geglu(&value, &gate, &mut activated, n_elements)?;
402
403    // Wo projection: [total_tokens, inter] @ [hidden, inter]^T => [total_tokens, hidden]
404    let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
405    driver.gemm(
406        &activated,
407        &layer.mlp_wo_weight,
408        &mut mlp_out,
409        g.total_tokens,
410        g.hidden,
411        g.intermediate,
412        true,
413    )?;
414
415    // Residual add (no bias).
416    let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
417    driver.residual_add(
418        &mut output,
419        &mlp_out,
420        attn_output,
421        g.total_tokens * g.hidden,
422    )?;
423    Ok(output)
424}
425
426// ---------------------------------------------------------------------------
427// FP16 attention sublayer — pre-norm + QKV + RoPE (all half precision)
428// ---------------------------------------------------------------------------
429
430fn debug_f16_tensor<D: Driver>(
431    driver: &D,
432    label: &str,
433    tensor: &D::Tensor,
434    rows: usize,
435    cols: usize,
436) -> crate::Result<()> {
437    let mut probe = driver.alloc_zeros(rows * cols)?;
438    driver.f16_to_f32(&mut probe, tensor, rows * cols)?;
439    driver.debug_tensor(label, &probe, rows, cols)
440}
441
442/// FP16 pre-norm + QKV projection + split + `RoPE`.
443///
444/// All tensors are half precision. RoPE cos/sin tables stay FP32 (the kernel
445/// reads half Q/K, does FP32 trig, writes half).
446/// Returns `(q, k, v)` each `[batch*num_heads, seq, head_dim]` in FP16.
447#[expect(
448    clippy::too_many_lines,
449    reason = "attention diagnostics intentionally keep stage probes adjacent to the operations"
450)]
451fn attn_prenorm_qkv_f16<D: Driver>(
452    driver: &D,
453    hidden_states: &D::Tensor,
454    layer: &ModernBertLayerWeights<D::Tensor>,
455    g: &EncoderGeometry,
456    zero_bias: &D::Tensor,
457    rope: &RopeCache<D::Tensor>,
458    layer_index: usize,
459    debug_tensors: bool,
460) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
461    // Pre-attention norm (identity for layer 0). FP16 in/out.
462    // Layer 0 uses hidden_states directly (GEMM is read-only, no clone needed).
463    let normed: Option<D::Tensor>;
464    let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
465        let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
466        driver.layer_norm_f16(
467            &mut n,
468            hidden_states,
469            norm_w,
470            zero_bias,
471            g.total_tokens,
472            g.hidden,
473            g.eps,
474        )?;
475        normed = Some(n);
476        normed.as_ref().unwrap()
477    } else {
478        // Layer 0: identity — pass through directly. GEMM reads, does not modify.
479        hidden_states
480    };
481
482    // QKV: [total_tokens, hidden] @ [3*hidden, hidden]^T — all FP16.
483    let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
484    driver.gemm_f16(
485        normed_ref,
486        &layer.qkv_weight,
487        &mut qkv,
488        g.total_tokens,
489        3 * g.hidden,
490        g.hidden,
491        true,
492    )?;
493    if debug_tensors && layer_index == 0 {
494        debug_f16_tensor(
495            driver,
496            "modernbert.layer_0.qkv_f16_as_f32",
497            &qkv,
498            g.total_tokens,
499            3 * g.hidden,
500        )?;
501    }
502
503    // Fused pad + QKV split: flat → Q, K, V in per-head layout directly.
504    // Eliminates the padded intermediate buffer and its 2 memory round-trips.
505    let padded = g.padded_tokens;
506    let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
507    let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
508    let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
509    driver.fused_pad_qkv_split_f16(
510        &mut q,
511        &mut k,
512        &mut v,
513        &qkv,
514        &g.seq_lengths,
515        g.max_seq,
516        g.batch,
517        g.hidden,
518        g.num_heads,
519        g.head_dim,
520    )?;
521    if debug_tensors && layer_index == 0 {
522        let rows = g.batch * g.num_heads * g.max_seq;
523        debug_f16_tensor(
524            driver,
525            "modernbert.layer_0.q_after_split_f16_as_f32",
526            &q,
527            rows,
528            g.head_dim,
529        )?;
530        debug_f16_tensor(
531            driver,
532            "modernbert.layer_0.k_after_split_f16_as_f32",
533            &k,
534            rows,
535            g.head_dim,
536        )?;
537        debug_f16_tensor(
538            driver,
539            "modernbert.layer_0.v_after_split_f16_as_f32",
540            &v,
541            rows,
542            g.head_dim,
543        )?;
544    }
545
546    // RoPE: half Q/K, float cos/sin tables.
547    let num_rows = g.batch * g.num_heads * g.max_seq;
548    driver.rope_encode_f16(
549        &mut q,
550        &rope.cos,
551        &rope.sin,
552        num_rows,
553        g.max_seq,
554        g.head_dim,
555        g.num_heads,
556    )?;
557    driver.rope_encode_f16(
558        &mut k,
559        &rope.cos,
560        &rope.sin,
561        num_rows,
562        g.max_seq,
563        g.head_dim,
564        g.num_heads,
565    )?;
566    if debug_tensors && layer_index == 0 {
567        let rows = g.batch * g.num_heads * g.max_seq;
568        debug_f16_tensor(
569            driver,
570            "modernbert.layer_0.q_after_rope_f16_as_f32",
571            &q,
572            rows,
573            g.head_dim,
574        )?;
575        debug_f16_tensor(
576            driver,
577            "modernbert.layer_0.k_after_rope_f16_as_f32",
578            &k,
579            rows,
580            g.head_dim,
581        )?;
582    }
583
584    Ok((q, k, v))
585}
586
587// ---------------------------------------------------------------------------
588// FP16 attention sublayer — scores + output projection + residual
589// ---------------------------------------------------------------------------
590
591/// FP16 attention scores + output projection + residual add.
592///
593/// All tensors FP16. The softmax kernel uses FP32 accumulators internally.
594/// The `float_mask` from `BatchInputs` stays FP32 (softmax kernel reads it).
595#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
596#[expect(
597    clippy::too_many_lines,
598    reason = "attention diagnostics intentionally keep stage probes adjacent to the operations"
599)]
600fn attn_scores_residual_f16<D: Driver>(
601    driver: &D,
602    q: &D::Tensor,
603    k: &D::Tensor,
604    v: &D::Tensor,
605    hidden_states: &D::Tensor,
606    layer: &ModernBertLayerWeights<D::Tensor>,
607    inputs: &BatchInputs<D::Tensor>,
608    g: &EncoderGeometry,
609    layer_index: usize,
610    debug_tensors: bool,
611) -> crate::Result<D::Tensor> {
612    let batch_heads = g.batch * g.num_heads;
613    let stride_qk = g.max_seq * g.head_dim;
614
615    // Q@K^T — FP16 batched GEMM.
616    let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
617    driver.gemm_batched_f16(
618        q,
619        k,
620        &mut scores,
621        g.max_seq,
622        g.max_seq,
623        g.head_dim,
624        true,
625        stride_qk,
626        stride_qk,
627        g.max_seq * g.max_seq,
628        batch_heads,
629    )?;
630    if debug_tensors && layer_index == 0 {
631        debug_f16_tensor(
632            driver,
633            "modernbert.layer_0.attn_scores_before_softmax_f16_as_f32",
634            &scores,
635            batch_heads * g.max_seq,
636            g.max_seq,
637        )?;
638    }
639
640    // Softmax — FP16 scores, FP32 mask, FP32 accumulators inside kernel.
641    if layer.is_global {
642        driver.fused_scale_mask_softmax_f16(
643            &mut scores,
644            &inputs.float_mask,
645            g.batch,
646            g.num_heads,
647            g.max_seq,
648            g.scale,
649        )?;
650    } else {
651        driver.fused_scale_mask_softmax_windowed_f16(
652            &mut scores,
653            &inputs.float_mask,
654            g.batch,
655            g.num_heads,
656            g.max_seq,
657            g.scale,
658            g.local_window,
659        )?;
660    }
661    if debug_tensors && layer_index == 0 {
662        debug_f16_tensor(
663            driver,
664            "modernbert.layer_0.attn_scores_after_softmax_f16_as_f32",
665            &scores,
666            batch_heads * g.max_seq,
667            g.max_seq,
668        )?;
669    }
670
671    // scores @ V — FP16 batched GEMM.
672    let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
673    driver.gemm_batched_f16(
674        &scores,
675        v,
676        &mut attn_out,
677        g.max_seq,
678        g.head_dim,
679        g.max_seq,
680        false,
681        g.max_seq * g.max_seq,
682        stride_qk,
683        stride_qk,
684        batch_heads,
685    )?;
686    if debug_tensors && layer_index == 0 {
687        debug_f16_tensor(
688            driver,
689            "modernbert.layer_0.attn_heads_f16_as_f32",
690            &attn_out,
691            batch_heads * g.max_seq,
692            g.head_dim,
693        )?;
694    }
695
696    // Fused reshape + unpad: [batch*heads, max_seq, head_dim] → [total_tokens, hidden].
697    // Eliminates the padded context intermediate buffer.
698    let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
699    driver.fused_reshape_unpad_f16(
700        &mut context_unpacked,
701        &attn_out,
702        &g.seq_lengths,
703        g.max_seq,
704        g.batch,
705        g.num_heads,
706        g.head_dim,
707    )?;
708    if debug_tensors && layer_index == 0 {
709        debug_f16_tensor(
710            driver,
711            "modernbert.layer_0.context_unpacked_f16_as_f32",
712            &context_unpacked,
713            g.total_tokens,
714            g.hidden,
715        )?;
716    }
717
718    // Output projection on unpadded — FP16: [total_tokens, H] × [H, H].
719    let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
720    driver.gemm_f16(
721        &context_unpacked,
722        &layer.output_weight,
723        &mut projected,
724        g.total_tokens,
725        g.hidden,
726        g.hidden,
727        true,
728    )?;
729    if debug_tensors && layer_index == 0 {
730        debug_f16_tensor(
731            driver,
732            "modernbert.layer_0.attn_projected_f16_as_f32",
733            &projected,
734            g.total_tokens,
735            g.hidden,
736        )?;
737    }
738
739    // Residual add — FP16.
740    let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
741    driver.residual_add_f16(
742        &mut output,
743        &projected,
744        hidden_states,
745        g.total_tokens * g.hidden,
746    )?;
747    Ok(output)
748}
749
750// ---------------------------------------------------------------------------
751// FP16 feed-forward (GeGLU MLP) sublayer
752// ---------------------------------------------------------------------------
753
754/// FP16 gated GELU MLP sublayer.
755///
756/// All tensors FP16. `GeGLU` kernel uses FP32 GELU compute internally.
757fn ffn_sublayer_f16<D: Driver>(
758    driver: &D,
759    attn_output: &D::Tensor,
760    layer: &ModernBertLayerWeights<D::Tensor>,
761    g: &EncoderGeometry,
762    zero_bias: &D::Tensor,
763    layer_index: usize,
764    debug_tensors: bool,
765) -> crate::Result<D::Tensor> {
766    // Pre-MLP LayerNorm — FP16.
767    let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
768    driver.layer_norm_f16(
769        &mut mlp_normed,
770        attn_output,
771        &layer.mlp_norm_weight,
772        zero_bias,
773        g.total_tokens,
774        g.hidden,
775        g.eps,
776    )?;
777    if debug_tensors && layer_index == 0 {
778        let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
779        driver.f16_to_f32(&mut probe, &mlp_normed, g.total_tokens * g.hidden)?;
780        driver.debug_tensor(
781            "modernbert.layer_0.ffn_mlp_normed_f16_as_f32",
782            &probe,
783            g.total_tokens,
784            g.hidden,
785        )?;
786    }
787
788    // Wi projection — FP16 GEMM.
789    let double_inter = 2 * g.intermediate;
790    let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
791    driver.gemm_f16(
792        &mlp_normed,
793        &layer.mlp_wi_weight,
794        &mut wi_out,
795        g.total_tokens,
796        double_inter,
797        g.hidden,
798        true,
799    )?;
800    if debug_tensors && layer_index == 0 {
801        let mut probe = driver.alloc_zeros(g.total_tokens * double_inter)?;
802        driver.f16_to_f32(&mut probe, &wi_out, g.total_tokens * double_inter)?;
803        driver.debug_tensor(
804            "modernbert.layer_0.ffn_wi_out_f16_as_f32",
805            &probe,
806            g.total_tokens,
807            double_inter,
808        )?;
809    }
810
811    // Fused split + GeGLU — FP16.
812    // Reads [total_tokens, 2*intermediate], writes [total_tokens, intermediate].
813    // Eliminates two intermediate buffers and halves HBM bandwidth.
814    let n_elements = g.total_tokens * g.intermediate;
815    let mut activated = driver.alloc_zeros_f16(n_elements)?;
816    driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
817    if debug_tensors && layer_index == 0 {
818        let mut probe = driver.alloc_zeros(n_elements)?;
819        driver.f16_to_f32(&mut probe, &activated, n_elements)?;
820        driver.debug_tensor(
821            "modernbert.layer_0.ffn_activated_f16_as_f32",
822            &probe,
823            g.total_tokens,
824            g.intermediate,
825        )?;
826    }
827
828    // Wo projection — FP16 GEMM.
829    let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
830    driver.gemm_f16(
831        &activated,
832        &layer.mlp_wo_weight,
833        &mut mlp_out,
834        g.total_tokens,
835        g.hidden,
836        g.intermediate,
837        true,
838    )?;
839    if debug_tensors && layer_index == 0 {
840        let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
841        driver.f16_to_f32(&mut probe, &mlp_out, g.total_tokens * g.hidden)?;
842        driver.debug_tensor(
843            "modernbert.layer_0.ffn_mlp_out_f16_as_f32",
844            &probe,
845            g.total_tokens,
846            g.hidden,
847        )?;
848    }
849
850    // Residual add — FP16.
851    let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
852    driver.residual_add_f16(
853        &mut output,
854        &mlp_out,
855        attn_output,
856        g.total_tokens * g.hidden,
857    )?;
858    if debug_tensors && layer_index == 0 {
859        let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
860        driver.f16_to_f32(&mut probe, &output, g.total_tokens * g.hidden)?;
861        driver.debug_tensor(
862            "modernbert.layer_0.ffn_output_f16_as_f32",
863            &probe,
864            g.total_tokens,
865            g.hidden,
866        )?;
867    }
868    Ok(output)
869}
870
871// ---------------------------------------------------------------------------
872// ModelArch implementation
873// ---------------------------------------------------------------------------
874
875impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
876    #[expect(
877        clippy::cast_precision_loss,
878        reason = "head_dim is small (64); sqrt is exact at this size"
879    )]
880    #[expect(
881        clippy::many_single_char_names,
882        reason = "w, g are standard geometry names; q, k, v are standard attention names"
883    )]
884    #[expect(
885        clippy::too_many_lines,
886        reason = "forward pass is a single logical unit"
887    )]
888    fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
889        let w = &self.weights;
890        let batch = encodings.len();
891        let hidden = w.hidden_dim;
892
893        let inputs = driver.prepare_batch_unpadded(encodings)?;
894        let max_seq = inputs.max_seq;
895        let total_tokens = inputs.total_tokens;
896
897        // Enter batched mode: all GPU ops encode into ONE command buffer.
898        driver.begin_batch()?;
899
900        // Embedding (FP32): tok_embeddings + LayerNorm.
901        let mut hidden_states =
902            driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
903        let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
904        driver.layer_norm(
905            &mut hidden_states,
906            &emb_input,
907            &w.emb_norm_weight,
908            &w.zero_bias,
909            total_tokens,
910            hidden,
911            w.layer_norm_eps,
912        )?;
913        driver.debug_tensor(
914            "modernbert.embedding_layer_norm",
915            &hidden_states,
916            total_tokens,
917            hidden,
918        )?;
919
920        let g = EncoderGeometry {
921            batch,
922            max_seq,
923            total_tokens,
924            padded_tokens: batch * max_seq,
925            seq_lengths: inputs.seq_lengths.clone(),
926            hidden,
927            num_heads: w.num_heads,
928            head_dim: w.head_dim,
929            intermediate: w.intermediate_dim,
930            local_window: w.local_window,
931            scale: 1.0 / (w.head_dim as f32).sqrt(),
932            eps: w.layer_norm_eps,
933        };
934
935        // FP16 path: f32_to_f16 ONCE → all layers in FP16 → f16_to_f32 ONCE.
936        // Falls back to FP32 if the driver doesn't support FP16 ops.
937        //
938        // MPS FP16 GEMM uses Apple's proprietary AMX coprocessor (72/s).
939        // RIPVEC_NO_MPS=1: force FP32 activations + compute GEMM path.
940        // The gemm_f16w_f32a_kernel uses native simdgroup ops with FP16 weights
941        // and FP32 activations — no MFA wrapper, no type conversion at store.
942        let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
943            || std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
944        let use_f16 = if force_fp32 {
945            false
946        } else {
947            driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
948        };
949
950        if use_f16 {
951            // === FP16 PATH: zero F32↔F16 conversions in layer loop ===
952            let debug_tensors = driver.debug_tensors_enabled();
953
954            // ONLY conversion #1: F32 → F16 after embedding LN.
955            let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
956            driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
957            if debug_tensors {
958                let mut initial_probe = driver.alloc_zeros(total_tokens * hidden)?;
959                driver.f16_to_f32(&mut initial_probe, &hidden_f16, total_tokens * hidden)?;
960                driver.debug_tensor(
961                    "modernbert.after_initial_f32_to_f16",
962                    &initial_probe,
963                    total_tokens,
964                    hidden,
965                )?;
966            }
967
968            // 22 layers — ALL in FP16.
969            for (layer_index, layer) in w.layers.iter().enumerate() {
970                let saved = driver.save_pool_cursor();
971
972                let rope = if layer.is_global {
973                    &self.global_rope
974                } else {
975                    &self.local_rope
976                };
977
978                let (q, k, v) = attn_prenorm_qkv_f16(
979                    driver,
980                    &hidden_f16,
981                    layer,
982                    &g,
983                    &w.zero_bias,
984                    rope,
985                    layer_index,
986                    debug_tensors,
987                )?;
988                let attn_output = attn_scores_residual_f16(
989                    driver,
990                    &q,
991                    &k,
992                    &v,
993                    &hidden_f16,
994                    layer,
995                    &inputs,
996                    &g,
997                    layer_index,
998                    debug_tensors,
999                )?;
1000                if debug_tensors && layer_index == 0 {
1001                    let mut probe = driver.alloc_zeros(total_tokens * hidden)?;
1002                    driver.f16_to_f32(&mut probe, &attn_output, total_tokens * hidden)?;
1003                    driver.debug_tensor(
1004                        "modernbert.layer_0.attn_output_f16_as_f32",
1005                        &probe,
1006                        total_tokens,
1007                        hidden,
1008                    )?;
1009                }
1010                hidden_f16 = ffn_sublayer_f16(
1011                    driver,
1012                    &attn_output,
1013                    layer,
1014                    &g,
1015                    &w.zero_bias,
1016                    layer_index,
1017                    debug_tensors,
1018                )?;
1019                driver.restore_pool_cursor(saved);
1020                if debug_tensors && (layer_index == 0 || layer_index + 1 == w.layers.len()) {
1021                    let mut probe = driver.alloc_zeros(total_tokens * hidden)?;
1022                    driver.f16_to_f32(&mut probe, &hidden_f16, total_tokens * hidden)?;
1023                    driver.debug_tensor(
1024                        &format!("modernbert.layer_{layer_index}.hidden_f16_as_f32"),
1025                        &probe,
1026                        total_tokens,
1027                        hidden,
1028                    )?;
1029                }
1030            }
1031
1032            // ONLY conversion #2: F16 → F32 before final LN + pooling.
1033            let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
1034            driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
1035            hidden_states = hidden_f32;
1036            driver.debug_tensor(
1037                "modernbert.after_f16_to_f32",
1038                &hidden_states,
1039                total_tokens,
1040                hidden,
1041            )?;
1042        } else {
1043            // === FP32 PATH ===
1044            for (layer_index, layer) in w.layers.iter().enumerate() {
1045                let saved = driver.save_pool_cursor();
1046
1047                let rope = if layer.is_global {
1048                    &self.global_rope
1049                } else {
1050                    &self.local_rope
1051                };
1052
1053                let (q, k, v) =
1054                    attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
1055                let attn_output =
1056                    attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
1057                hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
1058
1059                driver.restore_pool_cursor(saved);
1060                if layer_index == 0 || layer_index + 1 == w.layers.len() {
1061                    driver.debug_tensor(
1062                        &format!("modernbert.layer_{layer_index}.hidden_fp32"),
1063                        &hidden_states,
1064                        total_tokens,
1065                        hidden,
1066                    )?;
1067                }
1068            }
1069        }
1070
1071        // Final LayerNorm (FP32) before pooling.
1072        let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
1073        driver.layer_norm(
1074            &mut hidden_states,
1075            &final_input,
1076            &w.final_norm_weight,
1077            &w.zero_bias,
1078            total_tokens,
1079            hidden,
1080            w.layer_norm_eps,
1081        )?;
1082        driver.debug_tensor(
1083            "modernbert.final_layer_norm",
1084            &hidden_states,
1085            total_tokens,
1086            hidden,
1087        )?;
1088
1089        // Pad back to [batch, max_seq, hidden] for mean_pool kernel.
1090        let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
1091        driver.pad_to_batch(
1092            &hidden_states,
1093            &mut padded_for_pool,
1094            &inputs.seq_lengths,
1095            max_seq,
1096            hidden,
1097        )?;
1098
1099        // Mean pooling + L2 normalize (FP32).
1100        let mut pooled = driver.alloc_zeros(batch * hidden)?;
1101        driver.mean_pool(
1102            &mut pooled,
1103            &padded_for_pool,
1104            &inputs.pooling_mask,
1105            batch,
1106            max_seq,
1107            hidden,
1108        )?;
1109        driver.debug_tensor("modernbert.mean_pool", &pooled, batch, hidden)?;
1110        driver.l2_normalize(&mut pooled, batch, hidden)?;
1111        driver.debug_tensor("modernbert.l2_normalize", &pooled, batch, hidden)?;
1112
1113        // End batched mode — commit all GPU work, wait for completion.
1114        driver.end_batch()?;
1115
1116        driver.to_host(&pooled, batch, hidden)
1117    }
1118}