1use crate::backend::Backend;
7use crate::tensor::{DType, Tensor};
8
9use super::config::ModelConfig;
10use super::error::{ModelError, ModelResult};
11use super::layers::{Linear, NormLayer};
12use super::{Architecture, InferenceContext, Model};
13
14pub struct BertModel {
16 config: ModelConfig,
17 token_embedding: Tensor,
18 position_embedding: Option<Tensor>,
19 token_type_embedding: Option<Tensor>,
20 embed_norm: Option<NormLayer>,
21 layers: Vec<BertLayer>,
22 architecture: Architecture,
23}
24
25pub struct BertLayer {
27 pub attn_norm: NormLayer,
28 pub wq: Linear,
29 pub wk: Linear,
30 pub wv: Linear,
31 pub wo: Linear,
32 pub num_heads: usize,
33 pub head_dim: usize,
34 pub ffn_norm: NormLayer,
35 pub ffn_up: Linear,
36 pub ffn_down: Linear,
37}
38
39impl BertLayer {
40 pub fn forward(
42 &self,
43 hiddens: &[Tensor],
44 backend: &dyn Backend,
45 ) -> ModelResult<Vec<Tensor>> {
46 let seq_len = hiddens.len();
47 let hidden_size = hiddens[0].shape()[0];
48
49 let mut all_q = Vec::with_capacity(seq_len);
51 let mut all_k = Vec::with_capacity(seq_len);
52 let mut all_v = Vec::with_capacity(seq_len);
53
54 for h in hiddens {
55 let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
56 self.attn_norm.forward(h, &mut normed, backend)?;
57
58 let q_size = self.num_heads * self.head_dim;
59 let k_size = self.num_heads * self.head_dim;
60 let v_size = self.num_heads * self.head_dim;
61
62 let mut q = Tensor::zeros(vec![q_size], DType::F32);
63 let mut k = Tensor::zeros(vec![k_size], DType::F32);
64 let mut v = Tensor::zeros(vec![v_size], DType::F32);
65
66 self.wq.forward(&normed, &mut q, backend)?;
67 self.wk.forward(&normed, &mut k, backend)?;
68 self.wv.forward(&normed, &mut v, backend)?;
69
70 all_q.push(q);
71 all_k.push(k);
72 all_v.push(v);
73 }
74
75 let scale = 1.0 / (self.head_dim as f32).sqrt();
76
77 let mut attn_outputs = Vec::with_capacity(seq_len);
79
80 for i in 0..seq_len {
81 let q_data = all_q[i].as_f32()?;
82 let mut out = vec![0.0f32; self.num_heads * self.head_dim];
83
84 for head in 0..self.num_heads {
85 let q_offset = head * self.head_dim;
86 let q_head = &q_data[q_offset..q_offset + self.head_dim];
87
88 let mut scores = vec![0.0f32; seq_len];
90 for j in 0..seq_len {
91 let k_data = all_k[j].as_f32()?;
92 let k_head = &k_data[q_offset..q_offset + self.head_dim];
93 let dot: f32 = q_head.iter().zip(k_head.iter()).map(|(a, b)| a * b).sum();
94 scores[j] = dot * scale;
95 }
96
97 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
99 let mut sum = 0.0f32;
100 for s in &mut scores {
101 *s = (*s - max_score).exp();
102 sum += *s;
103 }
104 for s in &mut scores {
105 *s /= sum;
106 }
107
108 for j in 0..seq_len {
110 let v_data = all_v[j].as_f32()?;
111 let v_head = &v_data[q_offset..q_offset + self.head_dim];
112 for d in 0..self.head_dim {
113 out[q_offset + d] += scores[j] * v_head[d];
114 }
115 }
116 }
117
118 let attn_flat = Tensor::from_f32(&out, vec![self.num_heads * self.head_dim])?;
119 let mut projected = Tensor::zeros(vec![hidden_size], DType::F32);
120 self.wo.forward(&attn_flat, &mut projected, backend)?;
121
122 let proj_data = projected.as_f32_mut()?;
124 let h_data = hiddens[i].as_f32()?;
125 for (p, &h) in proj_data.iter_mut().zip(h_data.iter()) {
126 *p += h;
127 }
128
129 attn_outputs.push(projected);
130 }
131
132 let mut outputs = Vec::with_capacity(seq_len);
134 for h in &attn_outputs {
135 let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
136 self.ffn_norm.forward(h, &mut normed, backend)?;
137
138 let intermediate_size = self.ffn_up.out_features;
139 let mut up = Tensor::zeros(vec![intermediate_size], DType::F32);
140 self.ffn_up.forward(&normed, &mut up, backend)?;
141
142 {
144 let data = up.as_f32_mut()?;
145 for v in data.iter_mut() {
146 let x = *v;
147 *v = 0.5 * x * (1.0 + (0.797_884_6 * (x + 0.044715 * x * x * x)).tanh());
148 }
149 }
150
151 let mut down = Tensor::zeros(vec![hidden_size], DType::F32);
152 self.ffn_down.forward(&up, &mut down, backend)?;
153
154 let d = down.as_f32_mut()?;
156 let h_data = h.as_f32()?;
157 for (dd, &hv) in d.iter_mut().zip(h_data.iter()) {
158 *dd += hv;
159 }
160
161 outputs.push(down);
162 }
163
164 Ok(outputs)
165 }
166}
167
168impl BertModel {
169 pub fn new(
170 config: ModelConfig,
171 token_embedding: Tensor,
172 position_embedding: Option<Tensor>,
173 token_type_embedding: Option<Tensor>,
174 embed_norm: Option<NormLayer>,
175 layers: Vec<BertLayer>,
176 architecture: Architecture,
177 ) -> ModelResult<Self> {
178 Ok(Self {
179 config,
180 token_embedding,
181 position_embedding,
182 token_type_embedding,
183 embed_norm,
184 layers,
185 architecture,
186 })
187 }
188
189 pub fn encode(&self, tokens: &[u32], backend: &dyn Backend) -> ModelResult<Vec<Tensor>> {
192 let hidden_size = self.config.hidden_size;
193 let vocab_size = self.config.vocab_size;
194
195 let emb_data = if self.token_embedding.dtype() == DType::F32 {
197 std::borrow::Cow::Borrowed(self.token_embedding.as_f32()?)
198 } else {
199 let numel = self.token_embedding.numel();
200 let mut dequant = Tensor::zeros(vec![numel], DType::F32);
201 backend.dequantize(&self.token_embedding, &mut dequant)?;
202 std::borrow::Cow::Owned(dequant.as_f32()?.to_vec())
203 };
204
205 let pos_data = self
206 .position_embedding
207 .as_ref()
208 .map(|p| p.as_f32())
209 .transpose()?;
210 let type_data = self
211 .token_type_embedding
212 .as_ref()
213 .map(|t| t.as_f32())
214 .transpose()?;
215
216 let mut hiddens: Vec<Tensor> = Vec::with_capacity(tokens.len());
217 for (i, &token) in tokens.iter().enumerate() {
218 let idx = token as usize;
219 if idx >= vocab_size {
220 return Err(ModelError::InvalidMetadata {
221 key: "token".into(),
222 message: format!("Token {} >= vocab_size {}", token, vocab_size),
223 });
224 }
225
226 let src = idx * hidden_size;
227 let mut h = emb_data[src..src + hidden_size].to_vec();
228
229 if let Some(ref pos) = pos_data {
231 let pos_src = i * hidden_size;
232 if pos_src + hidden_size <= pos.len() {
233 for (j, val) in h.iter_mut().enumerate() {
234 *val += pos[pos_src + j];
235 }
236 }
237 }
238
239 if let Some(ref type_emb) = type_data {
241 for (j, val) in h.iter_mut().enumerate().take(hidden_size) {
242 if j < type_emb.len() {
243 *val += type_emb[j]; }
245 }
246 }
247
248 let mut t = Tensor::from_f32(&h, vec![hidden_size])?;
249
250 if let Some(ref norm) = self.embed_norm {
252 let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
253 norm.forward(&t, &mut normed, backend)?;
254 t = normed;
255 }
256
257 hiddens.push(t);
258 }
259
260 for layer in &self.layers {
262 hiddens = layer.forward(&hiddens, backend)?;
263 }
264
265 Ok(hiddens)
266 }
267}
268
269impl Model for BertModel {
270 fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor> {
271 let backend = ctx.backend.as_ref();
272 let hiddens = self.encode(tokens, backend)?;
273
274 if hiddens.is_empty() {
277 return Err(ModelError::ConfigError("Empty input".into()));
278 }
279 Ok(hiddens[0].clone())
280 }
281
282 fn config(&self) -> &ModelConfig {
283 &self.config
284 }
285
286 fn architecture(&self) -> Architecture {
287 self.architecture
288 }
289}
290
291impl std::fmt::Debug for BertModel {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("BertModel")
294 .field("architecture", &self.architecture)
295 .field("layers", &self.layers.len())
296 .finish()
297 }
298}
299
300impl std::fmt::Debug for BertLayer {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("BertLayer")
303 .field("num_heads", &self.num_heads)
304 .field("head_dim", &self.head_dim)
305 .finish()
306 }
307}