use super::super::Encoding;
use super::super::driver::{BatchInputs, Driver};
use super::ModelArch;
pub struct ClassicBertLayerWeights<T> {
pub qkv_weight: T,
pub qkv_bias: T,
pub output_weight: T,
pub output_bias: T,
pub output_ln_weight: T,
pub output_ln_bias: T,
pub ffn_inter_weight: T,
pub ffn_inter_bias: T,
pub ffn_out_weight: T,
pub ffn_out_bias: T,
pub ffn_ln_weight: T,
pub ffn_ln_bias: T,
}
pub struct ClassicBertWeights<T> {
pub word_embeddings: T,
pub position_embeddings: T,
pub token_type_embeddings: T,
pub emb_ln_weight: T,
pub emb_ln_bias: T,
pub layers: Vec<ClassicBertLayerWeights<T>>,
pub num_heads: usize,
pub head_dim: usize,
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub layer_norm_eps: f32,
}
pub struct ClassicBertArch<T> {
pub weights: ClassicBertWeights<T>,
}
struct EncoderGeometry {
batch: usize,
max_seq: usize,
total_tokens: usize,
padded_tokens: usize,
seq_lengths: Vec<usize>,
hidden: usize,
num_heads: usize,
head_dim: usize,
intermediate: usize,
scale: f32,
eps: f32,
}
fn attn_qkv<D: Driver>(
driver: &D,
hidden_states: &D::Tensor,
layer: &ClassicBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
driver.gemm(
hidden_states,
&layer.qkv_weight,
&mut qkv,
g.total_tokens,
3 * g.hidden,
g.hidden,
true,
)?;
driver.add_bias(&mut qkv, &layer.qkv_bias, g.total_tokens, 3 * g.hidden)?;
let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
driver.pad_to_batch(
&qkv,
&mut qkv_padded,
&g.seq_lengths,
g.max_seq,
3 * g.hidden,
)?;
let padded = g.padded_tokens;
let mut q = driver.alloc_zeros(padded * g.hidden)?;
let mut k = driver.alloc_zeros(padded * g.hidden)?;
let mut v = driver.alloc_zeros(padded * g.hidden)?;
driver.qkv_split(
&mut q,
&mut k,
&mut v,
&qkv_padded,
g.batch,
g.max_seq,
g.hidden,
g.num_heads,
g.head_dim,
)?;
Ok((q, k, v))
}
#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
fn attn_scores_residual<D: Driver>(
driver: &D,
q: &D::Tensor,
k: &D::Tensor,
v: &D::Tensor,
hidden_states: &D::Tensor,
layer: &ClassicBertLayerWeights<D::Tensor>,
inputs: &BatchInputs<D::Tensor>,
g: &EncoderGeometry,
) -> crate::Result<D::Tensor> {
let padded = g.padded_tokens;
let mut scores = driver.alloc_zeros(g.batch * g.num_heads * g.max_seq * g.max_seq)?;
driver.gemm_batched(
q,
k,
&mut scores,
g.max_seq,
g.max_seq,
g.head_dim,
true,
g.max_seq * g.head_dim,
g.max_seq * g.head_dim,
g.max_seq * g.max_seq,
g.batch * g.num_heads,
)?;
driver.fused_scale_mask_softmax(
&mut scores,
&inputs.float_mask,
g.batch,
g.num_heads,
g.max_seq,
g.scale,
)?;
let mut attn_out = driver.alloc_zeros(padded * g.hidden)?;
driver.gemm_batched(
&scores,
v,
&mut attn_out,
g.max_seq,
g.head_dim,
g.max_seq,
false,
g.max_seq * g.max_seq,
g.max_seq * g.head_dim,
g.max_seq * g.head_dim,
g.batch * g.num_heads,
)?;
let mut context = driver.alloc_zeros(padded * g.hidden)?;
driver.attn_reshape(
&mut context,
&attn_out,
g.batch,
g.max_seq,
g.num_heads,
g.head_dim,
)?;
let mut projected_padded = driver.alloc_zeros(padded * g.hidden)?;
driver.gemm(
&context,
&layer.output_weight,
&mut projected_padded,
padded,
g.hidden,
g.hidden,
true,
)?;
let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.unpad_from_batch(
&projected_padded,
&mut projected,
&g.seq_lengths,
g.max_seq,
g.hidden,
)?;
driver.add_bias(&mut projected, &layer.output_bias, g.total_tokens, g.hidden)?;
let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.fused_residual_layernorm(
&mut output,
&projected,
hidden_states,
&layer.output_ln_weight,
&layer.output_ln_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
Ok(output)
}
fn ffn_sublayer<D: Driver>(
driver: &D,
attn_output: &D::Tensor,
layer: &ClassicBertLayerWeights<D::Tensor>,
g: &EncoderGeometry,
) -> crate::Result<D::Tensor> {
let mut intermediate = driver.alloc_zeros(g.total_tokens * g.intermediate)?;
driver.gemm(
attn_output,
&layer.ffn_inter_weight,
&mut intermediate,
g.total_tokens,
g.intermediate,
g.hidden,
true,
)?;
driver.fused_bias_gelu(
&mut intermediate,
&layer.ffn_inter_bias,
g.total_tokens,
g.intermediate,
)?;
let mut ffn_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.gemm(
&intermediate,
&layer.ffn_out_weight,
&mut ffn_out,
g.total_tokens,
g.hidden,
g.intermediate,
true,
)?;
driver.add_bias(&mut ffn_out, &layer.ffn_out_bias, g.total_tokens, g.hidden)?;
let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
driver.fused_residual_layernorm(
&mut output,
&ffn_out,
attn_output,
&layer.ffn_ln_weight,
&layer.ffn_ln_bias,
g.total_tokens,
g.hidden,
g.eps,
)?;
Ok(output)
}
impl<D: Driver> ModelArch<D> for ClassicBertArch<D::Tensor> {
#[expect(
clippy::cast_precision_loss,
reason = "head_dim is small (32-64); sqrt is exact at these sizes"
)]
fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
let w = &self.weights;
let batch = encodings.len();
let hidden = w.hidden_dim;
let inputs = driver.prepare_batch_unpadded(encodings)?;
let total_tokens = inputs.total_tokens;
let max_seq = inputs.max_seq;
driver.begin_batch()?;
let mut hidden_states =
driver.embedding_lookup(&inputs.input_ids, &w.word_embeddings, total_tokens, hidden)?;
driver.add_embeddings(
&mut hidden_states,
&w.position_embeddings,
&inputs.position_ids,
total_tokens,
hidden,
)?;
driver.add_embeddings(
&mut hidden_states,
&w.token_type_embeddings,
&inputs.token_type_ids,
total_tokens,
hidden,
)?;
let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
driver.layer_norm(
&mut hidden_states,
&emb_input,
&w.emb_ln_weight,
&w.emb_ln_bias,
total_tokens,
hidden,
w.layer_norm_eps,
)?;
let g = EncoderGeometry {
batch,
max_seq,
total_tokens,
padded_tokens: batch * max_seq,
seq_lengths: inputs.seq_lengths.clone(),
hidden,
num_heads: w.num_heads,
head_dim: w.head_dim,
intermediate: w.intermediate_dim,
scale: 1.0 / (w.head_dim as f32).sqrt(),
eps: w.layer_norm_eps,
};
for layer in &w.layers {
let saved = driver.save_pool_cursor();
let (q, k, v) = attn_qkv(driver, &hidden_states, layer, &g)?;
let attn_output =
attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
hidden_states = ffn_sublayer(driver, &attn_output, layer, &g)?;
driver.restore_pool_cursor(saved);
}
let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
driver.pad_to_batch(
&hidden_states,
&mut padded_for_pool,
&inputs.seq_lengths,
max_seq,
hidden,
)?;
let mut pooled = driver.alloc_zeros(batch * hidden)?;
driver.cls_pool(&mut pooled, &padded_for_pool, batch, max_seq, hidden)?;
driver.l2_normalize(&mut pooled, batch, hidden)?;
driver.end_batch()?;
driver.to_host(&pooled, batch, hidden)
}
}