use super::super::Encoding;
use super::super::driver::{BatchInputs, Driver};
use super::ModelArch;
pub struct ModernBertLayerWeights<T> {
pub qkv_weight: T,
pub output_weight: T,
pub attn_norm_weight: Option<T>,
pub mlp_wi_weight: T,
pub mlp_wo_weight: T,
pub mlp_norm_weight: T,
pub is_global: bool,
}
pub struct ModernBertWeights<T> {
pub tok_embeddings: T,
pub emb_norm_weight: T,
pub final_norm_weight: T,
pub zero_bias: T,
pub layers: Vec<ModernBertLayerWeights<T>>,
pub num_heads: usize,
pub head_dim: usize,
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub layer_norm_eps: f32,
pub local_window: usize,
}
pub struct RopeCache<T> {
pub cos: T,
pub sin: T,
}
pub struct ModernBertArch<T> {
pub weights: ModernBertWeights<T>,
pub global_rope: RopeCache<T>,
pub local_rope: RopeCache<T>,
}
struct EncoderGeometry {
batch: usize,
max_seq: usize,
total_tokens: usize,
padded_tokens: usize,
seq_lengths: Vec<usize>,
hidden: usize,
num_heads: usize,
head_dim: usize,
intermediate: usize,
local_window: usize,
scale: f32,
eps: f32,
}
fn attn_prenorm_qkv<D: Driver>(
driver: &D,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
rope: &RopeCache<D::Tensor>,
) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
let normed = if let Some(ref norm_w) = layer.attn_norm_weight {
let mut n = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.layer_norm(
&mut n,
hidden_states,
norm_w,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
n
} else {
driver.clone_tensor(hidden_states, g.total_tokens * g.hidden)?
};
let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
driver.gemm(
&normed,
&layer.qkv_weight,
&mut qkv,
g.total_tokens,
3 * g.hidden,
g.hidden,
true,
)?;
let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
driver.pad_to_batch(
&qkv,
&mut qkv_padded,
&g.seq_lengths,
g.max_seq,
3 * g.hidden,
)?;
let padded = g.padded_tokens;
let mut q = driver.alloc_zeros(padded * g.hidden)?;
let mut k = driver.alloc_zeros(padded * g.hidden)?;
let mut v = driver.alloc_zeros(padded * g.hidden)?;
driver.qkv_split(
&mut q,
&mut k,
&mut v,
&qkv_padded,
g.batch,
g.max_seq,
g.hidden,
g.num_heads,
g.head_dim,
)?;
let num_rows = g.batch * g.num_heads * g.max_seq;
driver.apply_rope(
&mut q,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
driver.apply_rope(
&mut k,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
Ok((q, k, v))
}
#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
fn attn_scores_residual<D: Driver>(
driver: &D,
q: &D::Tensor,
k: &D::Tensor,
v: &D::Tensor,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
inputs: &BatchInputs<D::Tensor>,
g: &EncoderGeometry,
) -> crate::Result<D::Tensor> {
let batch_heads = g.batch * g.num_heads;
let stride_qk = g.max_seq * g.head_dim;
let mut scores = driver.alloc_zeros(batch_heads * g.max_seq * g.max_seq)?;
driver.gemm_batched(
q,
k,
&mut scores,
g.max_seq,
g.max_seq,
g.head_dim,
true,
stride_qk,
stride_qk,
g.max_seq * g.max_seq,
batch_heads,
)?;
if layer.is_global {
driver.fused_scale_mask_softmax(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
)?;
} else {
driver.fused_scale_mask_softmax_windowed(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
g.local_window,
)?;
}
let mut attn_out = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
driver.gemm_batched(
&scores,
v,
&mut attn_out,
g.max_seq,
g.head_dim,
g.max_seq,
false,
g.max_seq * g.max_seq,
stride_qk,
stride_qk,
batch_heads,
)?;
let mut context = driver.alloc_zeros(g.padded_tokens * g.hidden)?;
driver.attn_reshape(
&mut context,
&attn_out,
g.batch,
g.max_seq,
g.num_heads,
g.head_dim,
)?;
let mut context_unpacked = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.unpad_from_batch(
&context,
&mut context_unpacked,
&g.seq_lengths,
g.max_seq,
g.hidden,
)?;
let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.gemm(
&context_unpacked,
&layer.output_weight,
&mut projected,
g.total_tokens,
g.hidden,
g.hidden,
true,
)?;
let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.residual_add(
&mut output,
&projected,
hidden_states,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
fn ffn_sublayer<D: Driver>(
driver: &D,
attn_output: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
) -> crate::Result<D::Tensor> {
let mut mlp_normed = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.layer_norm(
&mut mlp_normed,
attn_output,
&layer.mlp_norm_weight,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
let double_inter = 2 * g.intermediate;
let mut wi_out = driver.alloc_zeros(g.total_tokens * double_inter)?;
driver.gemm(
&mlp_normed,
&layer.mlp_wi_weight,
&mut wi_out,
g.total_tokens,
double_inter,
g.hidden,
true,
)?;
let n_elements = g.total_tokens * g.intermediate;
let mut value = driver.alloc_zeros(n_elements)?;
let mut gate = driver.alloc_zeros(n_elements)?;
driver.split_gate_value(
&mut value,
&mut gate,
&wi_out,
g.total_tokens,
g.intermediate,
)?;
let mut activated = driver.alloc_zeros(n_elements)?;
driver.geglu(&value, &gate, &mut activated, n_elements)?;
let mut mlp_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.gemm(
&activated,
&layer.mlp_wo_weight,
&mut mlp_out,
g.total_tokens,
g.hidden,
g.intermediate,
true,
)?;
let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.residual_add(
&mut output,
&mlp_out,
attn_output,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
fn attn_prenorm_qkv_f16<D: Driver>(
driver: &D,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
rope: &RopeCache<D::Tensor>,
) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
let normed: Option<D::Tensor>;
let normed_ref = if let Some(ref norm_w) = layer.attn_norm_weight {
let mut n = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.layer_norm_f16(
&mut n,
hidden_states,
norm_w,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
normed = Some(n);
normed.as_ref().unwrap()
} else {
hidden_states
};
let mut qkv = driver.alloc_zeros_f16(g.total_tokens * 3 * g.hidden)?;
driver.gemm_f16(
normed_ref,
&layer.qkv_weight,
&mut qkv,
g.total_tokens,
3 * g.hidden,
g.hidden,
true,
)?;
let padded = g.padded_tokens;
let mut q = driver.alloc_zeros_f16(padded * g.hidden)?;
let mut k = driver.alloc_zeros_f16(padded * g.hidden)?;
let mut v = driver.alloc_zeros_f16(padded * g.hidden)?;
driver.fused_pad_qkv_split_f16(
&mut q,
&mut k,
&mut v,
&qkv,
&g.seq_lengths,
g.max_seq,
g.batch,
g.hidden,
g.num_heads,
g.head_dim,
)?;
let num_rows = g.batch * g.num_heads * g.max_seq;
driver.rope_encode_f16(
&mut q,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
driver.rope_encode_f16(
&mut k,
&rope.cos,
&rope.sin,
num_rows,
g.max_seq,
g.head_dim,
g.num_heads,
)?;
Ok((q, k, v))
}
#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
fn attn_scores_residual_f16<D: Driver>(
driver: &D,
q: &D::Tensor,
k: &D::Tensor,
v: &D::Tensor,
hidden_states: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
inputs: &BatchInputs<D::Tensor>,
g: &EncoderGeometry,
) -> crate::Result<D::Tensor> {
let batch_heads = g.batch * g.num_heads;
let stride_qk = g.max_seq * g.head_dim;
let mut scores = driver.alloc_zeros_f16(batch_heads * g.max_seq * g.max_seq)?;
driver.gemm_batched_f16(
q,
k,
&mut scores,
g.max_seq,
g.max_seq,
g.head_dim,
true,
stride_qk,
stride_qk,
g.max_seq * g.max_seq,
batch_heads,
)?;
if layer.is_global {
driver.fused_scale_mask_softmax_f16(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
)?;
} else {
driver.fused_scale_mask_softmax_windowed_f16(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
g.local_window,
)?;
}
let mut attn_out = driver.alloc_zeros_f16(g.padded_tokens * g.hidden)?;
driver.gemm_batched_f16(
&scores,
v,
&mut attn_out,
g.max_seq,
g.head_dim,
g.max_seq,
false,
g.max_seq * g.max_seq,
stride_qk,
stride_qk,
batch_heads,
)?;
let mut context_unpacked = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.fused_reshape_unpad_f16(
&mut context_unpacked,
&attn_out,
&g.seq_lengths,
g.max_seq,
g.batch,
g.num_heads,
g.head_dim,
)?;
let mut projected = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.gemm_f16(
&context_unpacked,
&layer.output_weight,
&mut projected,
g.total_tokens,
g.hidden,
g.hidden,
true,
)?;
let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.residual_add_f16(
&mut output,
&projected,
hidden_states,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
fn ffn_sublayer_f16<D: Driver>(
driver: &D,
attn_output: &D::Tensor,
layer: &ModernBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
zero_bias: &D::Tensor,
) -> crate::Result<D::Tensor> {
let mut mlp_normed = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.layer_norm_f16(
&mut mlp_normed,
attn_output,
&layer.mlp_norm_weight,
zero_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
let double_inter = 2 * g.intermediate;
let mut wi_out = driver.alloc_zeros_f16(g.total_tokens * double_inter)?;
driver.gemm_f16(
&mlp_normed,
&layer.mlp_wi_weight,
&mut wi_out,
g.total_tokens,
double_inter,
g.hidden,
true,
)?;
let n_elements = g.total_tokens * g.intermediate;
let mut activated = driver.alloc_zeros_f16(n_elements)?;
driver.fused_split_geglu_f16(&mut activated, &wi_out, g.total_tokens, g.intermediate)?;
let mut mlp_out = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.gemm_f16(
&activated,
&layer.mlp_wo_weight,
&mut mlp_out,
g.total_tokens,
g.hidden,
g.intermediate,
true,
)?;
let mut output = driver.alloc_zeros_f16(g.total_tokens * g.hidden)?;
driver.residual_add_f16(
&mut output,
&mlp_out,
attn_output,
g.total_tokens * g.hidden,
)?;
Ok(output)
}
impl<D: Driver> ModelArch<D> for ModernBertArch<D::Tensor> {
#[expect(
clippy::cast_precision_loss,
reason = "head_dim is small (64); sqrt is exact at this size"
)]
#[expect(
clippy::many_single_char_names,
reason = "w, g are standard geometry names; q, k, v are standard attention names"
)]
#[expect(
clippy::too_many_lines,
reason = "forward pass is a single logical unit"
)]
fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
let w = &self.weights;
let batch = encodings.len();
let hidden = w.hidden_dim;
let inputs = driver.prepare_batch_unpadded(encodings)?;
let max_seq = inputs.max_seq;
let total_tokens = inputs.total_tokens;
driver.begin_batch()?;
let mut hidden_states =
driver.embedding_lookup(&inputs.input_ids, &w.tok_embeddings, total_tokens, hidden)?;
let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
driver.layer_norm(
&mut hidden_states,
&emb_input,
&w.emb_norm_weight,
&w.zero_bias,
total_tokens,
hidden,
w.layer_norm_eps,
)?;
let g = EncoderGeometry {
batch,
max_seq,
total_tokens,
padded_tokens: batch * max_seq,
seq_lengths: inputs.seq_lengths.clone(),
hidden,
num_heads: w.num_heads,
head_dim: w.head_dim,
intermediate: w.intermediate_dim,
local_window: w.local_window,
scale: 1.0 / (w.head_dim as f32).sqrt(),
eps: w.layer_norm_eps,
};
let force_fp32 = std::env::var("RIPVEC_NO_MPS").is_ok_and(|v| v == "1")
|| std::env::var("RIPVEC_FP32").is_ok_and(|v| v == "1");
let use_f16 = if force_fp32 {
false
} else {
driver.alloc_zeros_f16(1).map(|_| true).unwrap_or(false)
};
if use_f16 {
let mut hidden_f16 = driver.alloc_zeros_f16(total_tokens * hidden)?;
driver.f32_to_f16(&mut hidden_f16, &hidden_states, total_tokens * hidden)?;
for layer in &w.layers {
let saved = driver.save_pool_cursor();
let rope = if layer.is_global {
&self.global_rope
} else {
&self.local_rope
};
let (q, k, v) =
attn_prenorm_qkv_f16(driver, &hidden_f16, layer, &g, &w.zero_bias, rope)?;
let attn_output =
attn_scores_residual_f16(driver, &q, &k, &v, &hidden_f16, layer, &inputs, &g)?;
hidden_f16 = ffn_sublayer_f16(driver, &attn_output, layer, &g, &w.zero_bias)?;
driver.restore_pool_cursor(saved);
}
let mut hidden_f32 = driver.alloc_zeros(total_tokens * hidden)?;
driver.f16_to_f32(&mut hidden_f32, &hidden_f16, total_tokens * hidden)?;
hidden_states = hidden_f32;
} else {
for layer in &w.layers {
let saved = driver.save_pool_cursor();
let rope = if layer.is_global {
&self.global_rope
} else {
&self.local_rope
};
let (q, k, v) =
attn_prenorm_qkv(driver, &hidden_states, layer, &g, &w.zero_bias, rope)?;
let attn_output =
attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
hidden_states = ffn_sublayer(driver, &attn_output, layer, &g, &w.zero_bias)?;
driver.restore_pool_cursor(saved);
}
}
let final_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
driver.layer_norm(
&mut hidden_states,
&final_input,
&w.final_norm_weight,
&w.zero_bias,
total_tokens,
hidden,
w.layer_norm_eps,
)?;
let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
driver.pad_to_batch(
&hidden_states,
&mut padded_for_pool,
&inputs.seq_lengths,
max_seq,
hidden,
)?;
let mut pooled = driver.alloc_zeros(batch * hidden)?;
driver.mean_pool(
&mut pooled,
&padded_for_pool,
&inputs.pooling_mask,
batch,
max_seq,
hidden,
)?;
driver.l2_normalize(&mut pooled, batch, hidden)?;
driver.end_batch()?;
driver.to_host(&pooled, batch, hidden)
}
}