1use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17pub struct ClassicBertLayerWeights<T> {
22 pub qkv_weight: T,
24 pub qkv_bias: T,
26 pub output_weight: T,
28 pub output_bias: T,
30 pub output_ln_weight: T,
32 pub output_ln_bias: T,
34 pub ffn_inter_weight: T,
36 pub ffn_inter_bias: T,
38 pub ffn_out_weight: T,
40 pub ffn_out_bias: T,
42 pub ffn_ln_weight: T,
44 pub ffn_ln_bias: T,
46}
47
48pub struct ClassicBertWeights<T> {
54 pub word_embeddings: T,
56 pub position_embeddings: T,
58 pub token_type_embeddings: T,
60 pub emb_ln_weight: T,
62 pub emb_ln_bias: T,
64 pub layers: Vec<ClassicBertLayerWeights<T>>,
66 pub num_heads: usize,
68 pub head_dim: usize,
70 pub hidden_dim: usize,
72 pub intermediate_dim: usize,
74 pub layer_norm_eps: f32,
76}
77
78pub struct ClassicBertArch<T> {
83 pub weights: ClassicBertWeights<T>,
85}
86
87struct EncoderGeometry {
89 batch: usize,
90 max_seq: usize,
91 total_tokens: usize,
93 padded_tokens: usize,
95 seq_lengths: Vec<usize>,
97 hidden: usize,
98 num_heads: usize,
99 head_dim: usize,
100 intermediate: usize,
101 scale: f32,
102 eps: f32,
103}
104
105fn attn_qkv<D: Driver>(
109 driver: &D,
110 hidden_states: &D::Tensor,
111 layer: &ClassicBertLayerWeights<D::Tensor>,
112 g: &EncoderGeometry,
113) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
114 let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
117 driver.gemm(
118 hidden_states,
119 &layer.qkv_weight,
120 &mut qkv,
121 g.total_tokens,
122 3 * g.hidden,
123 g.hidden,
124 true,
125 )?;
126 driver.add_bias(&mut qkv, &layer.qkv_bias, g.total_tokens, 3 * g.hidden)?;
127
128 let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
131 driver.pad_to_batch(
132 &qkv,
133 &mut qkv_padded,
134 &g.seq_lengths,
135 g.max_seq,
136 3 * g.hidden,
137 )?;
138
139 let padded = g.padded_tokens;
141 let mut q = driver.alloc_zeros(padded * g.hidden)?;
142 let mut k = driver.alloc_zeros(padded * g.hidden)?;
143 let mut v = driver.alloc_zeros(padded * g.hidden)?;
144 driver.qkv_split(
145 &mut q,
146 &mut k,
147 &mut v,
148 &qkv_padded,
149 g.batch,
150 g.max_seq,
151 g.hidden,
152 g.num_heads,
153 g.head_dim,
154 )?;
155
156 Ok((q, k, v))
157}
158
159#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
161fn attn_scores_residual<D: Driver>(
162 driver: &D,
163 q: &D::Tensor,
164 k: &D::Tensor,
165 v: &D::Tensor,
166 hidden_states: &D::Tensor,
167 layer: &ClassicBertLayerWeights<D::Tensor>,
168 inputs: &BatchInputs<D::Tensor>,
169 g: &EncoderGeometry,
170) -> crate::Result<D::Tensor> {
171 let padded = g.padded_tokens;
172
173 let mut scores = driver.alloc_zeros(g.batch * g.num_heads * g.max_seq * g.max_seq)?;
175 driver.gemm_batched(
176 q,
177 k,
178 &mut scores,
179 g.max_seq,
180 g.max_seq,
181 g.head_dim,
182 true,
183 g.max_seq * g.head_dim,
184 g.max_seq * g.head_dim,
185 g.max_seq * g.max_seq,
186 g.batch * g.num_heads,
187 )?;
188 driver.fused_scale_mask_softmax(
189 &mut scores,
190 &inputs.float_mask,
191 g.batch,
192 g.num_heads,
193 g.max_seq,
194 g.scale,
195 )?;
196
197 let mut attn_out = driver.alloc_zeros(padded * g.hidden)?;
199 driver.gemm_batched(
200 &scores,
201 v,
202 &mut attn_out,
203 g.max_seq,
204 g.head_dim,
205 g.max_seq,
206 false,
207 g.max_seq * g.max_seq,
208 g.max_seq * g.head_dim,
209 g.max_seq * g.head_dim,
210 g.batch * g.num_heads,
211 )?;
212
213 let mut context = driver.alloc_zeros(padded * g.hidden)?;
215 driver.attn_reshape(
216 &mut context,
217 &attn_out,
218 g.batch,
219 g.max_seq,
220 g.num_heads,
221 g.head_dim,
222 )?;
223
224 let mut projected_padded = driver.alloc_zeros(padded * g.hidden)?;
226 driver.gemm(
227 &context,
228 &layer.output_weight,
229 &mut projected_padded,
230 padded,
231 g.hidden,
232 g.hidden,
233 true,
234 )?;
235
236 let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
238 driver.unpad_from_batch(
239 &projected_padded,
240 &mut projected,
241 &g.seq_lengths,
242 g.max_seq,
243 g.hidden,
244 )?;
245
246 driver.add_bias(&mut projected, &layer.output_bias, g.total_tokens, g.hidden)?;
247
248 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
249 driver.fused_residual_layernorm(
250 &mut output,
251 &projected,
252 hidden_states,
253 &layer.output_ln_weight,
254 &layer.output_ln_bias,
255 g.total_tokens,
256 g.hidden,
257 g.eps,
258 )?;
259 Ok(output)
260}
261
262fn ffn_sublayer<D: Driver>(
266 driver: &D,
267 attn_output: &D::Tensor,
268 layer: &ClassicBertLayerWeights<D::Tensor>,
269 g: &EncoderGeometry,
270) -> crate::Result<D::Tensor> {
271 let mut intermediate = driver.alloc_zeros(g.total_tokens * g.intermediate)?;
273 driver.gemm(
274 attn_output,
275 &layer.ffn_inter_weight,
276 &mut intermediate,
277 g.total_tokens,
278 g.intermediate,
279 g.hidden,
280 true,
281 )?;
282 driver.fused_bias_gelu(
283 &mut intermediate,
284 &layer.ffn_inter_bias,
285 g.total_tokens,
286 g.intermediate,
287 )?;
288
289 let mut ffn_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
291 driver.gemm(
292 &intermediate,
293 &layer.ffn_out_weight,
294 &mut ffn_out,
295 g.total_tokens,
296 g.hidden,
297 g.intermediate,
298 true,
299 )?;
300 driver.add_bias(&mut ffn_out, &layer.ffn_out_bias, g.total_tokens, g.hidden)?;
301
302 let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
303 driver.fused_residual_layernorm(
304 &mut output,
305 &ffn_out,
306 attn_output,
307 &layer.ffn_ln_weight,
308 &layer.ffn_ln_bias,
309 g.total_tokens,
310 g.hidden,
311 g.eps,
312 )?;
313 Ok(output)
314}
315
316impl<D: Driver> ModelArch<D> for ClassicBertArch<D::Tensor> {
317 #[expect(
318 clippy::cast_precision_loss,
319 reason = "head_dim is small (32-64); sqrt is exact at these sizes"
320 )]
321 fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
322 let w = &self.weights;
323 let batch = encodings.len();
324 let hidden = w.hidden_dim;
325
326 let inputs = driver.prepare_batch_unpadded(encodings)?;
330 let total_tokens = inputs.total_tokens;
331 let max_seq = inputs.max_seq;
332
333 driver.begin_batch()?;
335
336 let mut hidden_states =
338 driver.embedding_lookup(&inputs.input_ids, &w.word_embeddings, total_tokens, hidden)?;
339 driver.add_embeddings(
340 &mut hidden_states,
341 &w.position_embeddings,
342 &inputs.position_ids,
343 total_tokens,
344 hidden,
345 )?;
346 driver.add_embeddings(
347 &mut hidden_states,
348 &w.token_type_embeddings,
349 &inputs.token_type_ids,
350 total_tokens,
351 hidden,
352 )?;
353 let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
354 driver.layer_norm(
355 &mut hidden_states,
356 &emb_input,
357 &w.emb_ln_weight,
358 &w.emb_ln_bias,
359 total_tokens,
360 hidden,
361 w.layer_norm_eps,
362 )?;
363
364 let g = EncoderGeometry {
365 batch,
366 max_seq,
367 total_tokens,
368 padded_tokens: batch * max_seq,
369 seq_lengths: inputs.seq_lengths.clone(),
370 hidden,
371 num_heads: w.num_heads,
372 head_dim: w.head_dim,
373 intermediate: w.intermediate_dim,
374 scale: 1.0 / (w.head_dim as f32).sqrt(),
375 eps: w.layer_norm_eps,
376 };
377
378 for layer in &w.layers {
380 let saved = driver.save_pool_cursor();
381 let (q, k, v) = attn_qkv(driver, &hidden_states, layer, &g)?;
382 let attn_output =
383 attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
384 hidden_states = ffn_sublayer(driver, &attn_output, layer, &g)?;
385 driver.restore_pool_cursor(saved);
388 }
389
390 let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
392 driver.pad_to_batch(
393 &hidden_states,
394 &mut padded_for_pool,
395 &inputs.seq_lengths,
396 max_seq,
397 hidden,
398 )?;
399
400 let mut pooled = driver.alloc_zeros(batch * hidden)?;
402 driver.cls_pool(&mut pooled, &padded_for_pool, batch, max_seq, hidden)?;
403 driver.l2_normalize(&mut pooled, batch, hidden)?;
404
405 driver.end_batch()?;
407
408 driver.to_host(&pooled, batch, hidden)
409 }
410}