use super::super::Encoding;
use super::super::driver::{BatchInputs, Driver};
use super::ModelArch;
pub struct ModernBertLayerWeights<T> {
pub qkv_weight: T,
pub output_weight: T,
pub attn_norm_weight: Option<T>,
pub mlp_wi_weight: T,
pub mlp_wo_weight: T,
pub mlp_norm_weight: T,
pub is_global: bool,
}
pub struct ModernBertWeights<T> {
pub tok_embeddings: T,
pub emb_norm_weight: T,
pub final_norm_weight: T,
pub zero_bias: T,
pub layers: Vec<ModernBertLayerWeights<T>>,
pub num_heads: usize,
pub head_dim: usize,
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub layer_norm_eps: f32,
pub local_window: usize,
}
pub struct RopeCache<T> {
pub cos: T,
pub sin: T,
}
pub struct ModernBertArch<T> {
pub weights: ModernBertWeights<T>,
pub global_rope: RopeCache<T>,
pub local_rope: RopeCache<T>,
}
struct EncoderGeometry {
batch: usize,
max_seq: usize,
total_tokens: usize,
padded_tokens: usize,
seq_lengths: Vec<usize>,
hidden: usize,
num_heads: usize,
head_dim: usize,
intermediate: usize,
local_window: usize,
scale: f32,
eps: f32,
}
fn attn_prenorm_qkv<D: Driver>(
driver: &D,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
rope: &RopeCache<D::Tensor>,
) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.layer_norm(
&mut n,
hidden_states,
norm_w,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
n
} else {
driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
};
let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
driver.gemm(
&normed,
&layer.qkv_weight,
&mut qkv,
g.total_tokens,
3 * g.hidden,
g.hidden,
true,
)?;
let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
driver.pad_to_batch(
&qkv,
&mut qkv_padded,
&g.seq_lengths,
g.max_seq,
3 * g.hidden,
)?;
let padded = g.padded_tokens;
let mut q = driver.alloc_zeros(padded * g.hidden)?;
let mut k = driver.alloc_zeros(padded * g.hidden)?;
let mut v = driver.alloc_zeros(padded * g.hidden)?;
driver.qkv_split(
&mut q,
&mut k,
&mut v,
&qkv_padded,
g.batch,
g.max_seq,
g.hidden,
g.num_heads,
g.head_dim,
)?;
let num_rows = g.batch * g.num_heads * g.max_seq;
driver.apply_rope(
&mut q,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
driver.apply_rope(
&mut k,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
Ok((q, k, v))
}
#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
fn attn_scores_residual<D: Driver>(
driver: &D,
q: &D::Tensor,
k: &D::Tensor,
v: &D::Tensor,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
inputs: &BatchInputs<D::Tensor>,
g: &EncoderGeometry,
) -> crate::Result<D::Tensor> {
let batch_heads = g.batch * g.num_heads;
let stride_qk = g.max_seq * g.head_dim;
let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
driver.gemm_batched(
q,
k,
&mut scores,
g.max_seq,
g.max_seq,
g.head_dim,
true,
stride_qk,
stride_qk,
g.max_seq * g.max_seq,
batch_heads,
)?;
if layer.is_global {
driver.fused_scale_mask_softmax(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
)?;
} else {
driver.fused_scale_mask_softmax_windowed(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
g.local_window,
)?;
}
let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
driver.gemm_batched(
&scores,
v,
&mut attn_out,
g.max_seq,
g.head_dim,
g.max_seq,
false,
g.max_seq * g.max_seq,
stride_qk,
stride_qk,
batch_heads,
)?;
let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
driver.attn_reshape(
&mut context,
&attn_out,
g.batch,
g.max_seq,
g.num_heads,
g.head_dim,
)?;
let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.unpad_from_batch(
&context,
&mut context_unpacked,
&g.seq_lengths,
g.max_seq,
g.hidden,
)?;
let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.gemm(
&context_unpacked,
&layer.output_weight,
&mut projected,
g.total_tokens,
g.hidden,
g.hidden,
true,
)?;
let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.residual_add(
&mut output,
&projected,
hidden_states,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
fn ffn_sublayer<D: Driver>(
driver: &D,
attn_output: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
) -> crate::Result<D::Tensor> {
let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.layer_norm(
&mut mlp_normed,
attn_output,
&layer.mlp_norm_weight,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
let double_inter = 2 * g.intermediate;
let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
driver.gemm(
&mlp_normed,
&layer.mlp_wi_weight,
&mut wi_out,
g.total_tokens,
double_inter,
g.hidden,
true,
)?;
let n_elements = g.total_tokens * g.intermediate;
let mut value = driver.alloc_zeros(n_elements)?;
let mut gate = driver.alloc_zeros(n_elements)?;
driver.split_gate_value(
&mut value,
&mut gate,
&wi_out,
g.total_tokens,
g.intermediate,
)?;
let mut activated = driver.alloc_zeros(n_elements)?;
driver.geglu(&value, &gate, &mut activated, n_elements)?;
let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.gemm(
&activated,
&layer.mlp_wo_weight,
&mut mlp_out,
g.total_tokens,
g.hidden,
g.intermediate,
true,
)?;
let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.residual_add(
&mut output,
&mlp_out,
attn_output,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
fn debug_f16_tensor<D: Driver>(
driver: &D,
label: &str,
tensor: &D::Tensor,
rows: usize,
cols: usize,
) -> crate::Result<()> {
let mut probe = driver.alloc_zeros(rows * cols)?;
driver.f16_to_f32(&mut probe, tensor, rows * cols)?;
driver.debug_tensor(label, &probe, rows, cols)
}
#[expect(
clippy::too_many_lines,
reason = "attention diagnostics intentionally keep stage probes adjacent to the operations"
)]
fn attn_prenorm_qkv_f16<D: Driver>(
driver: &D,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
rope: &RopeCache<D::Tensor>,
layer_index: usize,
debug_tensors: bool,
) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
let normed: Option<D::Tensor>;
let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.layer_norm_f16(
&mut n,
hidden_states,
norm_w,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
normed = Some(n);
normed.as_ref().unwrap()
} else {
hidden_states
};
let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
driver.gemm_f16(
normed_ref,
&layer.qkv_weight,
&mut qkv,
g.total_tokens,
3 * g.hidden,
g.hidden,
true,
)?;
if debug_tensors && layer_index == 0 {
debug_f16_tensor(
driver,
"modernbert.layer_0.qkv_f16_as_f32",
&qkv,
g.total_tokens,
3 * g.hidden,
)?;
}
let padded = g.padded_tokens;
let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
driver.fused_pad_qkv_split_f16(
&mut q,
&mut k,
&mut v,
&qkv,
&g.seq_lengths,
g.max_seq,
g.batch,
g.hidden,
g.num_heads,
g.head_dim,
)?;
if debug_tensors && layer_index == 0 {
let rows = g.batch * g.num_heads * g.max_seq;
debug_f16_tensor(
driver,
"modernbert.layer_0.q_after_split_f16_as_f32",
&q,
rows,
g.head_dim,
)?;
debug_f16_tensor(
driver,
"modernbert.layer_0.k_after_split_f16_as_f32",
&k,
rows,
g.head_dim,
)?;
debug_f16_tensor(
driver,
"modernbert.layer_0.v_after_split_f16_as_f32",
&v,
rows,
g.head_dim,
)?;
}
let num_rows = g.batch * g.num_heads * g.max_seq;
driver.rope_encode_f16(
&mut q,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
driver.rope_encode_f16(
&mut k,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
if debug_tensors && layer_index == 0 {
let rows = g.batch * g.num_heads * g.max_seq;
debug_f16_tensor(
driver,
"modernbert.layer_0.q_after_rope_f16_as_f32",
&q,
rows,
g.head_dim,
)?;
debug_f16_tensor(
driver,
"modernbert.layer_0.k_after_rope_f16_as_f32",
&k,
rows,
g.head_dim,
)?;
}
Ok((q, k, v))
}
#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
#[expect(
clippy::too_many_lines,
reason = "attention diagnostics intentionally keep stage probes adjacent to the operations"
)]
fn attn_scores_residual_f16<D: Driver>(
driver: &D,
q: &D::Tensor,
k: &D::Tensor,
v: &D::Tensor,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
inputs: &BatchInputs<D::Tensor>,
g: &EncoderGeometry,
layer_index: usize,
debug_tensors: bool,
) -> crate::Result<D::Tensor> {
let batch_heads = g.batch * g.num_heads;
let stride_qk = g.max_seq * g.head_dim;
let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
driver.gemm_batched_f16(
q,
k,
&mut scores,
g.max_seq,
g.max_seq,
g.head_dim,
true,
stride_qk,
stride_qk,
g.max_seq * g.max_seq,
batch_heads,
)?;
if debug_tensors && layer_index == 0 {
debug_f16_tensor(
driver,
"modernbert.layer_0.attn_scores_before_softmax_f16_as_f32",
&scores,
batch_heads * g.max_seq,
g.max_seq,
)?;
}
if layer.is_global {
driver.fused_scale_mask_softmax_f16(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
)?;
} else {
driver.fused_scale_mask_softmax_windowed_f16(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
g.local_window,
)?;
}
if debug_tensors && layer_index == 0 {
debug_f16_tensor(
driver,
"modernbert.layer_0.attn_scores_after_softmax_f16_as_f32",
&scores,
batch_heads * g.max_seq,
g.max_seq,
)?;
}
let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
driver.gemm_batched_f16(
&scores,
v,
&mut attn_out,
g.max_seq,
g.head_dim,
g.max_seq,
false,
g.max_seq * g.max_seq,
stride_qk,
stride_qk,
batch_heads,
)?;
if debug_tensors && layer_index == 0 {
debug_f16_tensor(
driver,
"modernbert.layer_0.attn_heads_f16_as_f32",
&attn_out,
batch_heads * g.max_seq,
g.head_dim,
)?;
}
let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.fused_reshape_unpad_f16(
&mut context_unpacked,
&attn_out,
&g.seq_lengths,
g.max_seq,
g.batch,
g.num_heads,
g.head_dim,
)?;
if debug_tensors && layer_index == 0 {
debug_f16_tensor(
driver,
"modernbert.layer_0.context_unpacked_f16_as_f32",
&context_unpacked,
g.total_tokens,
g.hidden,
)?;
}
let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.gemm_f16(
&context_unpacked,
&layer.output_weight,
&mut projected,
g.total_tokens,
g.hidden,
g.hidden,
true,
)?;
if debug_tensors && layer_index == 0 {
debug_f16_tensor(
driver,
"modernbert.layer_0.attn_projected_f16_as_f32",
&projected,
g.total_tokens,
g.hidden,
)?;
}
let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.residual_add_f16(
&mut output,
&projected,
hidden_states,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
fn ffn_sublayer_f16<D: Driver>(
driver: &D,
attn_output: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
layer_index: usize,
debug_tensors: bool,
) -> crate::Result<D::Tensor> {
let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.layer_norm_f16(
&mut mlp_normed,
attn_output,
&layer.mlp_norm_weight,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
if debug_tensors && layer_index == 0 {
let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.f16_to_f32(&mut probe, &mlp_normed, g.total_tokens * g.hidden)?;
driver.debug_tensor(
"modernbert.layer_0.ffn_mlp_normed_f16_as_f32",
&probe,
g.total_tokens,
g.hidden,
)?;
}
let double_inter = 2 * g.intermediate;
let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
driver.gemm_f16(
&mlp_normed,
&layer.mlp_wi_weight,
&mut wi_out,
g.total_tokens,
double_inter,
g.hidden,
true,
)?;
if debug_tensors && layer_index == 0 {
let mut probe = driver.alloc_zeros(g.total_tokens * double_inter)?;
driver.f16_to_f32(&mut probe, &wi_out, g.total_tokens * double_inter)?;
driver.debug_tensor(
"modernbert.layer_0.ffn_wi_out_f16_as_f32",
&probe,
g.total_tokens,
double_inter,
)?;
}
let n_elements = g.total_tokens * g.intermediate;
let mut activated = driver.alloc_zeros_f16(n_elements)?;
driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
if debug_tensors && layer_index == 0 {
let mut probe = driver.alloc_zeros(n_elements)?;
driver.f16_to_f32(&mut probe, &activated, n_elements)?;
driver.debug_tensor(
"modernbert.layer_0.ffn_activated_f16_as_f32",
&probe,
g.total_tokens,
g.intermediate,
)?;
}
let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.gemm_f16(
&activated,
&layer.mlp_wo_weight,
&mut mlp_out,
g.total_tokens,
g.hidden,
g.intermediate,
true,
)?;
if debug_tensors && layer_index == 0 {
let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.f16_to_f32(&mut probe, &mlp_out, g.total_tokens * g.hidden)?;
driver.debug_tensor(
"modernbert.layer_0.ffn_mlp_out_f16_as_f32",
&probe,
g.total_tokens,
g.hidden,
)?;
}
let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.residual_add_f16(
&mut output,
&mlp_out,
attn_output,
g.total_tokens * g.hidden,
)?;
if debug_tensors && layer_index == 0 {
let mut probe = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.f16_to_f32(&mut probe, &output, g.total_tokens * g.hidden)?;
driver.debug_tensor(
"modernbert.layer_0.ffn_output_f16_as_f32",
&probe,
g.total_tokens,
g.hidden,
)?;
}
Ok(output)
}
impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
#[expect(
clippy::cast_precision_loss,
reason = "head_dim is small (64); sqrt is exact at this size"
)]
#[expect(
clippy::many_single_char_names,
reason = "w, g are standard geometry names; q, k, v are standard attention names"
)]
#[expect(
clippy::too_many_lines,
reason = "forward pass is a single logical unit"
)]
fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
let w = &self.weights;
let batch = encodings.len();
let hidden = w.hidden_dim;
let inputs = driver.prepare_batch_unpadded(encodings)?;
let max_seq = inputs.max_seq;
let total_tokens = inputs.total_tokens;
driver.begin_batch()?;
let mut hidden_states =
driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
driver.layer_norm(
&mut hidden_states,
&emb_input,
&w.emb_norm_weight,
&w.zero_bias,
total_tokens,
hidden,
w.layer_norm_eps,
)?;
driver.debug_tensor(
"modernbert.embedding_layer_norm",
&hidden_states,
total_tokens,
hidden,
)?;
let g = EncoderGeometry {
batch,
max_seq,
total_tokens,
padded_tokens: batch * max_seq,
seq_lengths: inputs.seq_lengths.clone(),
hidden,
num_heads: w.num_heads,
head_dim: w.head_dim,
intermediate: w.intermediate_dim,
local_window: w.local_window,
scale: 1.0 / (w.head_dim as f32).sqrt(),
eps: w.layer_norm_eps,
};
let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
|| std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
let use_f16 = if force_fp32 {
false
} else {
driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
};
if use_f16 {
let debug_tensors = driver.debug_tensors_enabled();
let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
if debug_tensors {
let mut initial_probe = driver.alloc_zeros(total_tokens * hidden)?;
driver.f16_to_f32(&mut initial_probe, &hidden_f16, total_tokens * hidden)?;
driver.debug_tensor(
"modernbert.after_initial_f32_to_f16",
&initial_probe,
total_tokens,
hidden,
)?;
}
for (layer_index, layer) in w.layers.iter().enumerate() {
let saved = driver.save_pool_cursor();
let rope = if layer.is_global {
&self.global_rope
} else {
&self.local_rope
};
let (q, k, v) = attn_prenorm_qkv_f16(
driver,
&hidden_f16,
layer,
&g,
&w.zero_bias,
rope,
layer_index,
debug_tensors,
)?;
let attn_output = attn_scores_residual_f16(
driver,
&q,
&k,
&v,
&hidden_f16,
layer,
&inputs,
&g,
layer_index,
debug_tensors,
)?;
if debug_tensors && layer_index == 0 {
let mut probe = driver.alloc_zeros(total_tokens * hidden)?;
driver.f16_to_f32(&mut probe, &attn_output, total_tokens * hidden)?;
driver.debug_tensor(
"modernbert.layer_0.attn_output_f16_as_f32",
&probe,
total_tokens,
hidden,
)?;
}
hidden_f16 = ffn_sublayer_f16(
driver,
&attn_output,
layer,
&g,
&w.zero_bias,
layer_index,
debug_tensors,
)?;
driver.restore_pool_cursor(saved);
if debug_tensors && (layer_index == 0 || layer_index + 1 == w.layers.len()) {
let mut probe = driver.alloc_zeros(total_tokens * hidden)?;
driver.f16_to_f32(&mut probe, &hidden_f16, total_tokens * hidden)?;
driver.debug_tensor(
&format!("modernbert.layer_{layer_index}.hidden_f16_as_f32"),
&probe,
total_tokens,
hidden,
)?;
}
}
let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
hidden_states = hidden_f32;
driver.debug_tensor(
"modernbert.after_f16_to_f32",
&hidden_states,
total_tokens,
hidden,
)?;
} else {
for (layer_index, layer) in w.layers.iter().enumerate() {
let saved = driver.save_pool_cursor();
let rope = if layer.is_global {
&self.global_rope
} else {
&self.local_rope
};
let (q, k, v) =
attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
let attn_output =
attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
driver.restore_pool_cursor(saved);
if layer_index == 0 || layer_index + 1 == w.layers.len() {
driver.debug_tensor(
&format!("modernbert.layer_{layer_index}.hidden_fp32"),
&hidden_states,
total_tokens,
hidden,
)?;
}
}
}
let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
driver.layer_norm(
&mut hidden_states,
&final_input,
&w.final_norm_weight,
&w.zero_bias,
total_tokens,
hidden,
w.layer_norm_eps,
)?;
driver.debug_tensor(
"modernbert.final_layer_norm",
&hidden_states,
total_tokens,
hidden,
)?;
let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
driver.pad_to_batch(
&hidden_states,
&mut padded_for_pool,
&inputs.seq_lengths,
max_seq,
hidden,
)?;
let mut pooled = driver.alloc_zeros(batch * hidden)?;
driver.mean_pool(
&mut pooled,
&padded_for_pool,
&inputs.pooling_mask,
batch,
max_seq,
hidden,
)?;
driver.debug_tensor("modernbert.mean_pool", &pooled, batch, hidden)?;
driver.l2_normalize(&mut pooled, batch, hidden)?;
driver.debug_tensor("modernbert.l2_normalize", &pooled, batch, hidden)?;
driver.end_batch()?;
driver.to_host(&pooled, batch, hidden)
}
}