Skip to main content

ripvec_core/backend/arch/
classic_bert.rs

1//! `ClassicBert` architecture (BGE-small-en-v1.5).
2//!
3//! 12-layer BERT with learned position embeddings, GELU activation, fused QKV
4//! projections, and CLS pooling. This is the original BERT architecture used
5//! by BGE-small.
6//!
7//! Weight structures are generic over the tensor type `T`, which is
8//! [`Driver::Tensor`](super::super::driver::Driver::Tensor) when wired to a
9//! backend. The [`ModelArch`](super::ModelArch) implementation composes
10//! [`Driver`](super::super::driver::Driver) primitives into the full forward
11//! pass.
12
13use super::super::Encoding;
14use super::super::driver::{BatchInputs, Driver};
15use super::ModelArch;
16
17/// Weights for one `ClassicBert` encoder layer.
18///
19/// All projections include bias (unlike `ModernBERT`). The QKV weight is a fused
20/// `[3*hidden, hidden]` matrix that produces Q, K, V in a single GEMM.
21pub struct ClassicBertLayerWeights<T> {
22    /// Fused Q+K+V projection weight `[3*hidden, hidden]`.
23    pub qkv_weight: T,
24    /// Fused Q+K+V projection bias `[3*hidden]`.
25    pub qkv_bias: T,
26    /// Attention output projection weight `[hidden, hidden]`.
27    pub output_weight: T,
28    /// Attention output projection bias `[hidden]`.
29    pub output_bias: T,
30    /// Post-attention `LayerNorm` weight `[hidden]`.
31    pub output_ln_weight: T,
32    /// Post-attention `LayerNorm` bias `[hidden]`.
33    pub output_ln_bias: T,
34    /// FFN intermediate projection weight `[intermediate, hidden]`.
35    pub ffn_inter_weight: T,
36    /// FFN intermediate projection bias `[intermediate]`.
37    pub ffn_inter_bias: T,
38    /// FFN output projection weight `[hidden, intermediate]`.
39    pub ffn_out_weight: T,
40    /// FFN output projection bias `[hidden]`.
41    pub ffn_out_bias: T,
42    /// Post-FFN `LayerNorm` weight `[hidden]`.
43    pub ffn_ln_weight: T,
44    /// Post-FFN `LayerNorm` bias `[hidden]`.
45    pub ffn_ln_bias: T,
46}
47
48/// Full `ClassicBert` model weights, generic over tensor type.
49///
50/// Includes embedding tables, per-layer encoder weights, and model geometry.
51/// The tensor type `T` becomes [`Driver::Tensor`](super::super::driver::Driver::Tensor)
52/// when loaded onto a specific backend.
53pub struct ClassicBertWeights<T> {
54    /// Word embedding table `[vocab_size, hidden]`.
55    pub word_embeddings: T,
56    /// Learned position embedding table `[max_position, hidden]`.
57    pub position_embeddings: T,
58    /// Token type embedding table `[2, hidden]`.
59    pub token_type_embeddings: T,
60    /// Post-embedding `LayerNorm` weight `[hidden]`.
61    pub emb_ln_weight: T,
62    /// Post-embedding `LayerNorm` bias `[hidden]`.
63    pub emb_ln_bias: T,
64    /// Per-layer encoder weights.
65    pub layers: Vec<ClassicBertLayerWeights<T>>,
66    /// Number of attention heads (e.g., 12 for BGE-small).
67    pub num_heads: usize,
68    /// Dimension per attention head (`hidden / num_heads`).
69    pub head_dim: usize,
70    /// Hidden dimension (e.g., 384 for BGE-small).
71    pub hidden_dim: usize,
72    /// FFN intermediate dimension (e.g., 1536 for BGE-small).
73    pub intermediate_dim: usize,
74    /// Layer normalization epsilon (typically 1e-12).
75    pub layer_norm_eps: f32,
76}
77
78/// `ClassicBert` architecture: BGE-small-en-v1.5.
79///
80/// 12 layers, learned position embeddings, GELU activation, CLS pooling.
81/// Composes [`Driver`] primitives into the full forward pass.
82pub struct ClassicBertArch<T> {
83    /// Model weights on device.
84    pub weights: ClassicBertWeights<T>,
85}
86
87/// Encoder geometry passed to sublayer helpers to avoid repeating fields.
88struct EncoderGeometry {
89    batch: usize,
90    max_seq: usize,
91    /// Actual tokens across all sequences (no padding). Used for linear ops.
92    total_tokens: usize,
93    /// Padded total: `batch * max_seq`. Used for attention layout.
94    padded_tokens: usize,
95    /// Per-sequence lengths for pad/unpad.
96    seq_lengths: Vec<usize>,
97    hidden: usize,
98    num_heads: usize,
99    head_dim: usize,
100    intermediate: usize,
101    scale: f32,
102    eps: f32,
103}
104
105/// QKV projection (unpadded) + bias + pad + split heads.
106///
107/// Returns `(q, k, v)` each `[batch*num_heads, seq, head_dim]` in padded layout.
108fn attn_qkv<D: Driver>(
109    driver: &D,
110    hidden_states: &D::Tensor,
111    layer: &ClassicBertLayerWeights<D::Tensor>,
112    g: &EncoderGeometry,
113) -> crate::Result<(D::Tensor, D::Tensor, D::Tensor)> {
114    // QKV projection: [total_tokens, hidden] @ [3*hidden, hidden]^T
115    // Uses total_tokens (unpadded) — no wasted compute on padding.
116    let mut qkv = driver.alloc_zeros(g.total_tokens * 3 * g.hidden)?;
117    driver.gemm(
118        hidden_states,
119        &layer.qkv_weight,
120        &mut qkv,
121        g.total_tokens,
122        3 * g.hidden,
123        g.hidden,
124        true,
125    )?;
126    driver.add_bias(&mut qkv, &layer.qkv_bias, g.total_tokens, 3 * g.hidden)?;
127
128    // Pad QKV from [total_tokens, 3H] to [batch*max_seq, 3H] for attention.
129    // qkv_split needs the padded batch×seq layout to reshape into per-head tensors.
130    let mut qkv_padded = driver.alloc_zeros(g.padded_tokens * 3 * g.hidden)?;
131    driver.pad_to_batch(
132        &qkv,
133        &mut qkv_padded,
134        &g.seq_lengths,
135        g.max_seq,
136        3 * g.hidden,
137    )?;
138
139    // Split into Q, K, V each [batch * num_heads, seq, head_dim].
140    let padded = g.padded_tokens;
141    let mut q = driver.alloc_zeros(padded * g.hidden)?;
142    let mut k = driver.alloc_zeros(padded * g.hidden)?;
143    let mut v = driver.alloc_zeros(padded * g.hidden)?;
144    driver.qkv_split(
145        &mut q,
146        &mut k,
147        &mut v,
148        &qkv_padded,
149        g.batch,
150        g.max_seq,
151        g.hidden,
152        g.num_heads,
153        g.head_dim,
154    )?;
155
156    Ok((q, k, v))
157}
158
159/// Attention scores + output projection (padded) + unpad + bias + residual + `LayerNorm`.
160#[expect(clippy::too_many_arguments, reason = "Q/K/V must be separate tensors")]
161fn attn_scores_residual<D: Driver>(
162    driver: &D,
163    q: &D::Tensor,
164    k: &D::Tensor,
165    v: &D::Tensor,
166    hidden_states: &D::Tensor,
167    layer: &ClassicBertLayerWeights<D::Tensor>,
168    inputs: &BatchInputs<D::Tensor>,
169    g: &EncoderGeometry,
170) -> crate::Result<D::Tensor> {
171    let padded = g.padded_tokens;
172
173    // Attention scores: Q @ K^T => [batch * num_heads, seq, seq]
174    let mut scores = driver.alloc_zeros(g.batch * g.num_heads * g.max_seq * g.max_seq)?;
175    driver.gemm_batched(
176        q,
177        k,
178        &mut scores,
179        g.max_seq,
180        g.max_seq,
181        g.head_dim,
182        true,
183        g.max_seq * g.head_dim,
184        g.max_seq * g.head_dim,
185        g.max_seq * g.max_seq,
186        g.batch * g.num_heads,
187    )?;
188    driver.fused_scale_mask_softmax(
189        &mut scores,
190        &inputs.float_mask,
191        g.batch,
192        g.num_heads,
193        g.max_seq,
194        g.scale,
195    )?;
196
197    // Weighted sum: scores @ V => [batch * num_heads, seq, head_dim]
198    let mut attn_out = driver.alloc_zeros(padded * g.hidden)?;
199    driver.gemm_batched(
200        &scores,
201        v,
202        &mut attn_out,
203        g.max_seq,
204        g.head_dim,
205        g.max_seq,
206        false,
207        g.max_seq * g.max_seq,
208        g.max_seq * g.head_dim,
209        g.max_seq * g.head_dim,
210        g.batch * g.num_heads,
211    )?;
212
213    // Reshape heads back to [padded_tokens, hidden] (still padded).
214    let mut context = driver.alloc_zeros(padded * g.hidden)?;
215    driver.attn_reshape(
216        &mut context,
217        &attn_out,
218        g.batch,
219        g.max_seq,
220        g.num_heads,
221        g.head_dim,
222    )?;
223
224    // Output projection on padded layout, then unpad.
225    let mut projected_padded = driver.alloc_zeros(padded * g.hidden)?;
226    driver.gemm(
227        &context,
228        &layer.output_weight,
229        &mut projected_padded,
230        padded,
231        g.hidden,
232        g.hidden,
233        true,
234    )?;
235
236    // Unpad: [padded_tokens, H] → [total_tokens, H]
237    let mut projected = driver.alloc_zeros(g.total_tokens * g.hidden)?;
238    driver.unpad_from_batch(
239        &projected_padded,
240        &mut projected,
241        &g.seq_lengths,
242        g.max_seq,
243        g.hidden,
244    )?;
245
246    driver.add_bias(&mut projected, &layer.output_bias, g.total_tokens, g.hidden)?;
247
248    let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
249    driver.fused_residual_layernorm(
250        &mut output,
251        &projected,
252        hidden_states,
253        &layer.output_ln_weight,
254        &layer.output_ln_bias,
255        g.total_tokens,
256        g.hidden,
257        g.eps,
258    )?;
259    Ok(output)
260}
261
262/// Run the feed-forward sublayer for one encoder layer.
263///
264/// Intermediate GEMM -> bias + GELU -> output GEMM -> bias + residual + `LayerNorm`.
265fn ffn_sublayer<D: Driver>(
266    driver: &D,
267    attn_output: &D::Tensor,
268    layer: &ClassicBertLayerWeights<D::Tensor>,
269    g: &EncoderGeometry,
270) -> crate::Result<D::Tensor> {
271    // Intermediate: [total_tokens, hidden] @ [inter, hidden]^T => [total_tokens, inter]
272    let mut intermediate = driver.alloc_zeros(g.total_tokens * g.intermediate)?;
273    driver.gemm(
274        attn_output,
275        &layer.ffn_inter_weight,
276        &mut intermediate,
277        g.total_tokens,
278        g.intermediate,
279        g.hidden,
280        true,
281    )?;
282    driver.fused_bias_gelu(
283        &mut intermediate,
284        &layer.ffn_inter_bias,
285        g.total_tokens,
286        g.intermediate,
287    )?;
288
289    // Output: [total_tokens, inter] @ [hidden, inter]^T => [total_tokens, hidden]
290    let mut ffn_out = driver.alloc_zeros(g.total_tokens * g.hidden)?;
291    driver.gemm(
292        &intermediate,
293        &layer.ffn_out_weight,
294        &mut ffn_out,
295        g.total_tokens,
296        g.hidden,
297        g.intermediate,
298        true,
299    )?;
300    driver.add_bias(&mut ffn_out, &layer.ffn_out_bias, g.total_tokens, g.hidden)?;
301
302    let mut output = driver.alloc_zeros(g.total_tokens * g.hidden)?;
303    driver.fused_residual_layernorm(
304        &mut output,
305        &ffn_out,
306        attn_output,
307        &layer.ffn_ln_weight,
308        &layer.ffn_ln_bias,
309        g.total_tokens,
310        g.hidden,
311        g.eps,
312    )?;
313    Ok(output)
314}
315
316impl<D: Driver> ModelArch<D> for ClassicBertArch<D::Tensor> {
317    #[expect(
318        clippy::cast_precision_loss,
319        reason = "head_dim is small (32-64); sqrt is exact at these sizes"
320    )]
321    fn forward(&self, driver: &D, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
322        let w = &self.weights;
323        let batch = encodings.len();
324        let hidden = w.hidden_dim;
325
326        // Unpadded mode: tokens concatenated without padding.
327        // Linear layers (GEMM, LN, GELU) process total_tokens rows — no wasted compute.
328        // Attention pads/unpads around per-head operations via pad_to_batch/unpad_from_batch.
329        let inputs = driver.prepare_batch_unpadded(encodings)?;
330        let total_tokens = inputs.total_tokens;
331        let max_seq = inputs.max_seq;
332
333        // Enter batched mode: all GPU ops encode into ONE command buffer.
334        driver.begin_batch()?;
335
336        // Embedding: word + position + token_type + LayerNorm.
337        let mut hidden_states =
338            driver.embedding_lookup(&inputs.input_ids, &w.word_embeddings, total_tokens, hidden)?;
339        driver.add_embeddings(
340            &mut hidden_states,
341            &w.position_embeddings,
342            &inputs.position_ids,
343            total_tokens,
344            hidden,
345        )?;
346        driver.add_embeddings(
347            &mut hidden_states,
348            &w.token_type_embeddings,
349            &inputs.token_type_ids,
350            total_tokens,
351            hidden,
352        )?;
353        let emb_input = driver.clone_tensor(&hidden_states, total_tokens * hidden)?;
354        driver.layer_norm(
355            &mut hidden_states,
356            &emb_input,
357            &w.emb_ln_weight,
358            &w.emb_ln_bias,
359            total_tokens,
360            hidden,
361            w.layer_norm_eps,
362        )?;
363
364        let g = EncoderGeometry {
365            batch,
366            max_seq,
367            total_tokens,
368            padded_tokens: batch * max_seq,
369            seq_lengths: inputs.seq_lengths.clone(),
370            hidden,
371            num_heads: w.num_heads,
372            head_dim: w.head_dim,
373            intermediate: w.intermediate_dim,
374            scale: 1.0 / (w.head_dim as f32).sqrt(),
375            eps: w.layer_norm_eps,
376        };
377
378        // Encoder layers.
379        for layer in &w.layers {
380            let saved = driver.save_pool_cursor();
381            let (q, k, v) = attn_qkv(driver, &hidden_states, layer, &g)?;
382            let attn_output =
383                attn_scores_residual(driver, &q, &k, &v, &hidden_states, layer, &inputs, &g)?;
384            hidden_states = ffn_sublayer(driver, &attn_output, layer, &g)?;
385            // All transient tensors (q, k, v, attn_output) dropped here.
386            // Only hidden_states survives. Restore cursor so next layer reuses slots.
387            driver.restore_pool_cursor(saved);
388        }
389
390        // Pad back to [batch, max_seq, hidden] for cls_pool kernel.
391        let mut padded_for_pool = driver.alloc_zeros(batch * max_seq * hidden)?;
392        driver.pad_to_batch(
393            &hidden_states,
394            &mut padded_for_pool,
395            &inputs.seq_lengths,
396            max_seq,
397            hidden,
398        )?;
399
400        // CLS pooling + L2 normalize.
401        let mut pooled = driver.alloc_zeros(batch * hidden)?;
402        driver.cls_pool(&mut pooled, &padded_for_pool, batch, max_seq, hidden)?;
403        driver.l2_normalize(&mut pooled, batch, hidden)?;
404
405        // End batched mode -- commit all GPU work, wait for completion.
406        driver.end_batch()?;
407
408        driver.to_host(&pooled, batch, hidden)
409    }
410}