Skip to main content

llama_rs/model/
bert.rs

1//! BERT encoder-only model implementation
2//!
3//! Supports BERT, ModernBert, NomicBert, JinaBert, and similar
4//! encoder-only transformers for embedding generation.
5
6use crate::backend::Backend;
7use crate::tensor::{DType, Tensor};
8
9use super::config::ModelConfig;
10use super::error::{ModelError, ModelResult};
11use super::layers::{Linear, NormLayer};
12use super::{Architecture, InferenceContext, Model};
13
14/// BERT encoder-only transformer model
15pub struct BertModel {
16    config: ModelConfig,
17    token_embedding: Tensor,
18    position_embedding: Option<Tensor>,
19    token_type_embedding: Option<Tensor>,
20    embed_norm: Option<NormLayer>,
21    layers: Vec<BertLayer>,
22    architecture: Architecture,
23}
24
25/// A single BERT encoder layer
26pub struct BertLayer {
27    pub attn_norm: NormLayer,
28    pub wq: Linear,
29    pub wk: Linear,
30    pub wv: Linear,
31    pub wo: Linear,
32    pub num_heads: usize,
33    pub head_dim: usize,
34    pub ffn_norm: NormLayer,
35    pub ffn_up: Linear,
36    pub ffn_down: Linear,
37}
38
39impl BertLayer {
40    /// Bidirectional self-attention forward pass (no causal mask)
41    pub fn forward(
42        &self,
43        hiddens: &[Tensor],
44        backend: &dyn Backend,
45    ) -> ModelResult<Vec<Tensor>> {
46        let seq_len = hiddens.len();
47        let hidden_size = hiddens[0].shape()[0];
48
49        // Project Q, K, V for all positions
50        let mut all_q = Vec::with_capacity(seq_len);
51        let mut all_k = Vec::with_capacity(seq_len);
52        let mut all_v = Vec::with_capacity(seq_len);
53
54        for h in hiddens {
55            let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
56            self.attn_norm.forward(h, &mut normed, backend)?;
57
58            let q_size = self.num_heads * self.head_dim;
59            let k_size = self.num_heads * self.head_dim;
60            let v_size = self.num_heads * self.head_dim;
61
62            let mut q = Tensor::zeros(vec![q_size], DType::F32);
63            let mut k = Tensor::zeros(vec![k_size], DType::F32);
64            let mut v = Tensor::zeros(vec![v_size], DType::F32);
65
66            self.wq.forward(&normed, &mut q, backend)?;
67            self.wk.forward(&normed, &mut k, backend)?;
68            self.wv.forward(&normed, &mut v, backend)?;
69
70            all_q.push(q);
71            all_k.push(k);
72            all_v.push(v);
73        }
74
75        let scale = 1.0 / (self.head_dim as f32).sqrt();
76
77        // Compute bidirectional attention for each position
78        let mut attn_outputs = Vec::with_capacity(seq_len);
79
80        for i in 0..seq_len {
81            let q_data = all_q[i].as_f32()?;
82            let mut out = vec![0.0f32; self.num_heads * self.head_dim];
83
84            for head in 0..self.num_heads {
85                let q_offset = head * self.head_dim;
86                let q_head = &q_data[q_offset..q_offset + self.head_dim];
87
88                // Compute attention scores against ALL positions (bidirectional)
89                let mut scores = vec![0.0f32; seq_len];
90                for j in 0..seq_len {
91                    let k_data = all_k[j].as_f32()?;
92                    let k_head = &k_data[q_offset..q_offset + self.head_dim];
93                    let dot: f32 = q_head.iter().zip(k_head.iter()).map(|(a, b)| a * b).sum();
94                    scores[j] = dot * scale;
95                }
96
97                // Softmax
98                let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
99                let mut sum = 0.0f32;
100                for s in &mut scores {
101                    *s = (*s - max_score).exp();
102                    sum += *s;
103                }
104                for s in &mut scores {
105                    *s /= sum;
106                }
107
108                // Weighted sum of values
109                for j in 0..seq_len {
110                    let v_data = all_v[j].as_f32()?;
111                    let v_head = &v_data[q_offset..q_offset + self.head_dim];
112                    for d in 0..self.head_dim {
113                        out[q_offset + d] += scores[j] * v_head[d];
114                    }
115                }
116            }
117
118            let attn_flat = Tensor::from_f32(&out, vec![self.num_heads * self.head_dim])?;
119            let mut projected = Tensor::zeros(vec![hidden_size], DType::F32);
120            self.wo.forward(&attn_flat, &mut projected, backend)?;
121
122            // Residual
123            let proj_data = projected.as_f32_mut()?;
124            let h_data = hiddens[i].as_f32()?;
125            for (p, &h) in proj_data.iter_mut().zip(h_data.iter()) {
126                *p += h;
127            }
128
129            attn_outputs.push(projected);
130        }
131
132        // FFN with residual for each position
133        let mut outputs = Vec::with_capacity(seq_len);
134        for h in &attn_outputs {
135            let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
136            self.ffn_norm.forward(h, &mut normed, backend)?;
137
138            let intermediate_size = self.ffn_up.out_features;
139            let mut up = Tensor::zeros(vec![intermediate_size], DType::F32);
140            self.ffn_up.forward(&normed, &mut up, backend)?;
141
142            // GELU activation
143            {
144                let data = up.as_f32_mut()?;
145                for v in data.iter_mut() {
146                    let x = *v;
147                    *v = 0.5 * x * (1.0 + (0.797_884_6 * (x + 0.044715 * x * x * x)).tanh());
148                }
149            }
150
151            let mut down = Tensor::zeros(vec![hidden_size], DType::F32);
152            self.ffn_down.forward(&up, &mut down, backend)?;
153
154            // Residual
155            let d = down.as_f32_mut()?;
156            let h_data = h.as_f32()?;
157            for (dd, &hv) in d.iter_mut().zip(h_data.iter()) {
158                *dd += hv;
159            }
160
161            outputs.push(down);
162        }
163
164        Ok(outputs)
165    }
166}
167
168impl BertModel {
169    pub fn new(
170        config: ModelConfig,
171        token_embedding: Tensor,
172        position_embedding: Option<Tensor>,
173        token_type_embedding: Option<Tensor>,
174        embed_norm: Option<NormLayer>,
175        layers: Vec<BertLayer>,
176        architecture: Architecture,
177    ) -> ModelResult<Self> {
178        Ok(Self {
179            config,
180            token_embedding,
181            position_embedding,
182            token_type_embedding,
183            embed_norm,
184            layers,
185            architecture,
186        })
187    }
188
189    /// Encode tokens and return hidden states for all positions.
190    /// Returns the hidden state of the [CLS] token (first position) by default.
191    pub fn encode(&self, tokens: &[u32], backend: &dyn Backend) -> ModelResult<Vec<Tensor>> {
192        let hidden_size = self.config.hidden_size;
193        let vocab_size = self.config.vocab_size;
194
195        // Build embeddings
196        let emb_data = if self.token_embedding.dtype() == DType::F32 {
197            std::borrow::Cow::Borrowed(self.token_embedding.as_f32()?)
198        } else {
199            let numel = self.token_embedding.numel();
200            let mut dequant = Tensor::zeros(vec![numel], DType::F32);
201            backend.dequantize(&self.token_embedding, &mut dequant)?;
202            std::borrow::Cow::Owned(dequant.as_f32()?.to_vec())
203        };
204
205        let pos_data = self
206            .position_embedding
207            .as_ref()
208            .map(|p| p.as_f32())
209            .transpose()?;
210        let type_data = self
211            .token_type_embedding
212            .as_ref()
213            .map(|t| t.as_f32())
214            .transpose()?;
215
216        let mut hiddens: Vec<Tensor> = Vec::with_capacity(tokens.len());
217        for (i, &token) in tokens.iter().enumerate() {
218            let idx = token as usize;
219            if idx >= vocab_size {
220                return Err(ModelError::InvalidMetadata {
221                    key: "token".into(),
222                    message: format!("Token {} >= vocab_size {}", token, vocab_size),
223                });
224            }
225
226            let src = idx * hidden_size;
227            let mut h = emb_data[src..src + hidden_size].to_vec();
228
229            // Add position embedding
230            if let Some(ref pos) = pos_data {
231                let pos_src = i * hidden_size;
232                if pos_src + hidden_size <= pos.len() {
233                    for (j, val) in h.iter_mut().enumerate() {
234                        *val += pos[pos_src + j];
235                    }
236                }
237            }
238
239            // Add token type embedding (segment 0 for now)
240            if let Some(ref type_emb) = type_data {
241                for (j, val) in h.iter_mut().enumerate().take(hidden_size) {
242                    if j < type_emb.len() {
243                        *val += type_emb[j]; // segment 0
244                    }
245                }
246            }
247
248            let mut t = Tensor::from_f32(&h, vec![hidden_size])?;
249
250            // Apply embedding normalization if present
251            if let Some(ref norm) = self.embed_norm {
252                let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
253                norm.forward(&t, &mut normed, backend)?;
254                t = normed;
255            }
256
257            hiddens.push(t);
258        }
259
260        // Run through encoder layers
261        for layer in &self.layers {
262            hiddens = layer.forward(&hiddens, backend)?;
263        }
264
265        Ok(hiddens)
266    }
267}
268
269impl Model for BertModel {
270    fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor> {
271        let backend = ctx.backend.as_ref();
272        let hiddens = self.encode(tokens, backend)?;
273
274        // Return the CLS token (position 0) hidden state
275        // For actual embedding use, callers should use encode() directly
276        if hiddens.is_empty() {
277            return Err(ModelError::ConfigError("Empty input".into()));
278        }
279        Ok(hiddens[0].clone())
280    }
281
282    fn config(&self) -> &ModelConfig {
283        &self.config
284    }
285
286    fn architecture(&self) -> Architecture {
287        self.architecture
288    }
289}
290
291impl std::fmt::Debug for BertModel {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.debug_struct("BertModel")
294            .field("architecture", &self.architecture)
295            .field("layers", &self.layers.len())
296            .finish()
297    }
298}
299
300impl std::fmt::Debug for BertLayer {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("BertLayer")
303            .field("num_heads", &self.num_heads)
304            .field("head_dim", &self.head_dim)
305            .finish()
306    }
307}