1use crate::cpu::f32::{matmul, matmul_t, softmax, Tensor as F32Tensor};
2use crate::nn::layers::{Embedding, LayerNorm, Linear};
3use crate::traits::{Tensor, TensorOps};
4use crate::SmeltError;
5
6macro_rules! debug {
7 ($str: expr, $tensor: expr) => {
9 };
24}
25
26pub struct BertContext<T: Tensor> {
28 input_ids: Vec<usize>,
29 type_ids: Vec<usize>,
30 position_ids: Vec<usize>,
31 hidden_states: T,
32 hidden_states_copy: T,
36 hidden_states_attn_output: T,
38 q_cache: T,
39 k_cache: T,
41 v_cache: T,
43 qk: T,
45 qkv: T,
46 intermediate_states: T,
48 pool: T,
49 pool_output: T,
50 probs: T,
51}
52
53impl<T: Tensor> BertContext<T> {
54 pub fn probs(&self) -> &T {
56 &self.probs
57 }
58}
59
60fn split_heads(q: &F32Tensor, out_q: &mut F32Tensor) -> Result<(), SmeltError> {
61 let num_heads = out_q.shape()[0];
62 let sequence_length = out_q.shape()[1];
63 let head_dim = out_q.shape()[2];
64 let hidden_dim = head_dim * num_heads;
65
66 (0..num_heads).for_each(|i| {
67 (0..sequence_length).for_each(|j| {
68 (0..head_dim).for_each(|k| {
69 let index = j * hidden_dim + i * head_dim + k;
70 let out_index = i * sequence_length * head_dim + j * head_dim + k;
71 out_q.data_mut()[out_index] = q.data()[index];
72 });
73 });
74 });
75 Ok(())
76}
77
78fn attention<'data, 'ctx>(
79 q_weights: &Linear<F32Tensor<'data>>,
80 k_weights: &Linear<F32Tensor<'data>>,
81 v_weights: &Linear<F32Tensor<'data>>,
82 ctx: &mut BertContext<F32Tensor<'ctx>>,
83) -> Result<(), SmeltError>
84where
85 'data: 'ctx,
86{
87 q_weights.forward(&ctx.hidden_states, &mut ctx.hidden_states_copy)?;
88 split_heads(&ctx.hidden_states_copy, &mut ctx.q_cache)?;
89
90 debug!("Q head splitted", ctx.q_cache);
91
92 k_weights.forward(&ctx.hidden_states, &mut ctx.hidden_states_copy)?;
93 split_heads(&ctx.hidden_states_copy, &mut ctx.k_cache)?;
94
95 debug!("K head splitted", ctx.k_cache);
96
97 v_weights.forward(&ctx.hidden_states, &mut ctx.hidden_states_copy)?;
98 split_heads(&ctx.hidden_states_copy, &mut ctx.v_cache)?;
99
100 debug!("V head splitted", ctx.v_cache);
101
102 matmul_t(&ctx.q_cache, &ctx.k_cache, &mut ctx.qk).unwrap();
103
104 let num_heads = ctx.q_cache.shape()[0];
105 let sequence_length = ctx.q_cache.shape()[1];
106 let head_dim = ctx.q_cache.shape()[2];
107 let hidden_dim = head_dim * num_heads;
108 let scale = (head_dim as f32).sqrt();
109 ctx.qk.data_mut().iter_mut().for_each(|v| *v /= scale);
110
111 softmax(&mut ctx.qk).unwrap();
112 debug!("attention_probs", ctx.qk);
113 matmul(&ctx.qk, &ctx.v_cache, &mut ctx.qkv).unwrap();
114 debug!("qkv", ctx.qkv);
115
116 let new_out = &mut ctx.hidden_states_attn_output.data_mut();
117 (0..num_heads).for_each(|i| {
118 (0..sequence_length).for_each(|j| {
119 (0..head_dim).for_each(|k| {
120 let in_index = i * sequence_length * head_dim + j * head_dim + k;
121 let out_index = j * hidden_dim + i * head_dim + k;
122 new_out[out_index] = (ctx.qkv).data()[in_index];
123 });
124 });
125 });
126 debug!("qkv (reshaed)", ctx.hidden_states_attn_output);
127
128 Ok(())
129}
130
131pub trait TensorAttention<T: Tensor> {
133 fn attention(
135 query: &Linear<T>,
136 key: &Linear<T>,
137 value: &Linear<T>,
138 ctx: &mut BertContext<T>,
139 ) -> Result<(), SmeltError>;
140}
141
142impl<'a> TensorAttention<F32Tensor<'a>> for F32Tensor<'a> {
143 fn attention(
144 query: &Linear<F32Tensor<'a>>,
145 key: &Linear<F32Tensor<'a>>,
146 value: &Linear<F32Tensor<'a>>,
147 ctx: &mut BertContext<F32Tensor<'a>>,
148 ) -> Result<(), SmeltError> {
149 attention(query, key, value, ctx)?;
150 Ok(())
151 }
152}
153
154pub trait Debug<T: Tensor> {
156 fn data(&self) -> &[f32];
158}
159
160impl<'a> Debug<F32Tensor<'a>> for F32Tensor<'a> {
161 fn data(&self) -> &[f32] {
162 self.data()
163 }
164}
165
166pub trait BertOps<T: Tensor>: TensorOps<T> + TensorAttention<T> + Debug<T> {}
168
169impl<'a> BertOps<F32Tensor<'a>> for F32Tensor<'a> {}
170
171#[derive(Clone)]
173pub struct BertAttention<T: Tensor> {
174 query: Linear<T>,
175 key: Linear<T>,
176 value: Linear<T>,
177 output: Linear<T>,
178 output_ln: LayerNorm<T>,
179}
180
181impl<T: Tensor + BertOps<T>> BertAttention<T> {
182 pub fn new(
184 query: Linear<T>,
185 key: Linear<T>,
186 value: Linear<T>,
187 output: Linear<T>,
188 output_ln: LayerNorm<T>,
189 ) -> Self {
190 Self {
191 query,
192 key,
193 value,
194 output,
195 output_ln,
196 }
197 }
198
199 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
201 T::attention(&self.query, &self.key, &self.value, ctx)?;
202
203 self.output
204 .forward(&ctx.hidden_states_attn_output, &mut ctx.hidden_states_copy)?;
205 T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
206 self.output_ln.forward(&mut ctx.hidden_states)?;
207 Ok(())
208 }
209}
210
211#[derive(Clone)]
213pub struct Mlp<T: Tensor> {
214 intermediate: Linear<T>,
215 output: Linear<T>,
216 output_ln: LayerNorm<T>,
217}
218
219impl<T: Tensor + BertOps<T>> Mlp<T> {
220 pub fn new(intermediate: Linear<T>, output: Linear<T>, output_ln: LayerNorm<T>) -> Self {
222 Self {
223 intermediate,
224 output,
225 output_ln,
226 }
227 }
228
229 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
231 debug!("Before MLP", ctx.hidden_states);
233 self.intermediate
234 .forward(&ctx.hidden_states, &mut ctx.intermediate_states)?;
235 debug!("Intermediate ", ctx.intermediate_states);
236 T::gelu(&mut ctx.intermediate_states)?;
237 debug!("Intermediate (gelu)", ctx.intermediate_states);
238 self.output
239 .forward(&ctx.intermediate_states, &mut ctx.hidden_states_copy)?;
240 debug!("output", ctx.hidden_states_copy);
241 T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
242 debug!("output (skip)", ctx.hidden_states);
243 self.output_ln.forward(&mut ctx.hidden_states)?;
244 debug!("output ln", ctx.hidden_states);
245 Ok(())
246 }
247}
248
249#[derive(Clone)]
251pub struct BertLayer<T: Tensor> {
252 attention: BertAttention<T>,
253 mlp: Mlp<T>,
254}
255
256impl<T: Tensor + BertOps<T>> BertLayer<T> {
257 pub fn new(attention: BertAttention<T>, mlp: Mlp<T>) -> Self {
259 Self { attention, mlp }
260 }
261
262 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
264 debug!("Before attention", ctx.hidden_states);
265 self.attention.forward(ctx)?;
266 debug!("After attention", ctx.hidden_states);
267 self.mlp.forward(ctx)?;
268 debug!("After mlp", ctx.hidden_states);
269 Ok(())
271 }
272}
273
274#[derive(Clone)]
276pub struct BertEncoder<T: Tensor> {
277 layers: Vec<BertLayer<T>>,
278}
279
280impl<T: Tensor + BertOps<T>> BertEncoder<T> {
281 pub fn new(layers: Vec<BertLayer<T>>) -> Self {
283 Self { layers }
284 }
285
286 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
288 for layer in &self.layers {
289 layer.forward(ctx)?;
290 }
291 Ok(())
292 }
293}
294
295#[derive(Clone)]
297pub struct BertEmbeddings<T: Tensor> {
298 input_embeddings: Embedding<T>,
299 position_embeddings: Embedding<T>,
300 type_embeddings: Embedding<T>,
301 layer_norm: LayerNorm<T>,
302}
303
304impl<T: Tensor + BertOps<T>> BertEmbeddings<T> {
305 pub fn new(
307 input_embeddings: Embedding<T>,
308 position_embeddings: Embedding<T>,
309 type_embeddings: Embedding<T>,
310 layer_norm: LayerNorm<T>,
311 ) -> Self {
312 Self {
313 input_embeddings,
314 position_embeddings,
315 type_embeddings,
316 layer_norm,
317 }
318 }
319
320 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
322 let input_ids = &ctx.input_ids;
323 let position_ids = &ctx.position_ids;
324 let type_ids = &ctx.type_ids;
325
326 if input_ids.len() != position_ids.len() {
327 return Err(SmeltError::InvalidLength {
328 expected: input_ids.len(),
329 got: position_ids.len(),
330 });
331 }
332 if input_ids.len() != type_ids.len() {
333 return Err(SmeltError::InvalidLength {
334 expected: input_ids.len(),
335 got: type_ids.len(),
336 });
337 }
338
339 self.input_embeddings
340 .forward(input_ids, &mut ctx.hidden_states)?;
341
342 debug!("input embeddings", ctx.hidden_states);
343
344 self.type_embeddings
345 .forward(type_ids, &mut ctx.hidden_states_copy)?;
346 debug!("type embeddings", ctx.hidden_states_copy);
347 T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
348 debug!("After add type embeddings", ctx.hidden_states);
349
350 self.position_embeddings
351 .forward(position_ids, &mut ctx.hidden_states_copy)?;
352 debug!("position embeddings", ctx.hidden_states_copy);
353 T::add(&ctx.hidden_states_copy, &mut ctx.hidden_states)?;
354 debug!("After add position embeddings", ctx.hidden_states);
355
356 self.layer_norm.forward(&mut ctx.hidden_states)?;
357
358 debug!("After embeddings", ctx.hidden_states);
359 Ok(())
360 }
361}
362
363pub struct Bert<T: Tensor + BertOps<T>> {
365 embeddings: BertEmbeddings<T>,
366 encoder: BertEncoder<T>,
367}
368
369impl<T: Tensor + BertOps<T>> Bert<T> {
370 pub fn new(embeddings: BertEmbeddings<T>, encoder: BertEncoder<T>) -> Self {
372 Self {
373 embeddings,
374 encoder,
375 }
376 }
377 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
379 self.embeddings.forward(ctx)?;
380 self.encoder.forward(ctx)
381 }
382}
383
384#[derive(Clone)]
386pub struct BertPooler<T: Tensor> {
387 pooler: Linear<T>,
388}
389
390impl<T: Tensor + BertOps<T>> BertPooler<T> {
391 pub fn new(pooler: Linear<T>) -> Self {
393 Self { pooler }
394 }
395
396 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
398 T::select(&[0], &ctx.hidden_states, &mut ctx.pool)?;
399 self.pooler.forward(&ctx.pool, &mut ctx.pool_output)?;
400 T::tanh(&mut ctx.pool_output)?;
401 Ok(())
402 }
403}
404
405pub struct BertClassifier<T: Tensor + BertOps<T>> {
407 bert: Bert<T>,
408 pooler: BertPooler<T>,
409 classifier: Linear<T>,
410}
411
412impl<T: Tensor + BertOps<T> + TensorAttention<T>> BertClassifier<T> {
413 pub fn new(bert: Bert<T>, pooler: BertPooler<T>, classifier: Linear<T>) -> Self {
415 Self {
416 bert,
417 pooler,
418 classifier,
419 }
420 }
421
422 pub fn forward(&self, ctx: &mut BertContext<T>) -> Result<(), SmeltError> {
424 self.bert.forward(ctx)?;
425 self.pooler.forward(ctx)?;
426 self.classifier.forward(&ctx.pool_output, &mut ctx.probs)?;
427 T::softmax(&mut ctx.probs)?;
428 Ok(())
429 }
430
431 pub fn new_context(
433 &self,
434 input_ids: Vec<usize>,
435 position_ids: Vec<usize>,
436 type_ids: Vec<usize>,
437 num_heads: usize,
438 ) -> BertContext<T> {
439 let hidden_dim = self.bert.embeddings.input_embeddings.weight().shape()[1];
440 let intermediate_dim = self.bert.encoder.layers[0]
441 .mlp
442 .intermediate
443 .weight()
444 .shape()[0];
445 let num_classes = self.classifier.weight().shape()[0];
446 let head_dim = hidden_dim / num_heads;
447 let sequence_length = input_ids.len();
448
449 let hidden_states = T::zeros(vec![sequence_length, hidden_dim]);
450 let hidden_states_copy = T::zeros(vec![sequence_length, hidden_dim]);
451 let hidden_states_attn_output = T::zeros(vec![sequence_length, hidden_dim]);
452 let intermediate_states = T::zeros(vec![sequence_length, intermediate_dim]);
453 let q_cache = T::zeros(vec![num_heads, sequence_length, head_dim]);
454 let k_cache = T::zeros(vec![num_heads, sequence_length, head_dim]);
455 let v_cache = T::zeros(vec![num_heads, sequence_length, head_dim]);
456 let qk = T::zeros(vec![num_heads, sequence_length, sequence_length]);
457 let qkv = T::zeros(vec![num_heads, sequence_length, head_dim]);
458 let pool = T::zeros(vec![1, hidden_dim]);
459 let pool_output = T::zeros(vec![1, hidden_dim]);
460 let probs = T::zeros(vec![1, num_classes]);
461 BertContext {
462 input_ids,
463 position_ids,
464 type_ids,
465 hidden_states,
466 hidden_states_copy,
467 hidden_states_attn_output,
468 intermediate_states,
469 q_cache,
470 k_cache,
471 v_cache,
472 qk,
473 qkv,
474 pool,
475 pool_output,
476 probs,
477 }
478 }
479}