1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::nn::Initializer;
5use crate::nn::cache::TensorCache;
6use crate::{
7 config::Config,
8 nn,
9 tensor::{Bool, Tensor, activation, backend::Backend},
10};
11
12#[cfg(not(feature = "std"))]
13use num_traits::Float;
14
15#[derive(Config)]
17pub struct MultiHeadAttentionConfig {
18 pub d_model: usize,
20 pub n_heads: usize,
22 #[config(default = 0.1)]
24 pub dropout: f64,
25 #[config(default = -1.0e4)]
29 pub min_float: f64,
30 #[config(default = false)]
37 pub quiet_softmax: bool,
38 #[config(
40 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
41 )]
42 pub initializer: Initializer,
43}
44
45#[derive(Module, Debug)]
56#[module(custom_display)]
57pub struct MultiHeadAttention<B: Backend> {
58 pub query: nn::Linear<B>,
60 pub key: nn::Linear<B>,
62 pub value: nn::Linear<B>,
64 pub output: nn::Linear<B>,
66 pub dropout: nn::Dropout,
68 pub activation: nn::Gelu,
70 pub d_model: usize,
72 pub n_heads: usize,
74 pub d_k: usize,
76 pub min_float: f64,
78 pub quiet_softmax: bool,
80}
81
82impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
83 fn custom_settings(&self) -> Option<DisplaySettings> {
84 DisplaySettings::new()
85 .with_new_line_after_attribute(false)
86 .optional()
87 }
88
89 fn custom_content(&self, content: Content) -> Option<Content> {
90 content
91 .add("d_model", &self.d_model)
92 .add("n_heads", &self.n_heads)
93 .add("d_k", &self.d_k)
94 .add("dropout", &self.dropout.prob)
95 .add("min_float", &self.min_float)
96 .add("quiet_softmax", &self.quiet_softmax)
97 .optional()
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct MhaInput<B: Backend> {
104 query: Tensor<B, 3>,
106 key: Tensor<B, 3>,
108 value: Tensor<B, 3>,
110 mask_pad: Option<Tensor<B, 2, Bool>>,
111 mask_attn: Option<Tensor<B, 3, Bool>>,
112}
113
114impl MultiHeadAttentionConfig {
115 pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
117 let linear = |config: &Self| {
118 nn::LinearConfig::new(config.d_model, config.d_model)
119 .with_initializer(self.initializer.clone())
120 .init(device)
121 };
122
123 MultiHeadAttention {
124 query: linear(self),
125 key: linear(self),
126 value: linear(self),
127 output: linear(self),
128 dropout: nn::DropoutConfig::new(self.dropout).init(),
129 activation: nn::Gelu::new(),
130 n_heads: self.n_heads,
131 d_k: self.d_model / self.n_heads,
132 min_float: self.min_float,
133 quiet_softmax: self.quiet_softmax,
134 d_model: self.d_model,
135 }
136 }
137}
138
139impl<B: Backend> MhaInput<B> {
140 pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
146 Self {
147 query: tensor.clone(),
148 key: tensor.clone(),
149 value: tensor,
150 mask_pad: None,
151 mask_attn: None,
152 }
153 }
154
155 pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
157 Self {
158 query,
159 key,
160 value,
161 mask_pad: None,
162 mask_attn: None,
163 }
164 }
165
166 pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
168 self.mask_pad = Some(mask_pad);
169 self
170 }
171
172 pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
174 self.mask_attn = Some(mask_attn);
175 self
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct MhaOutput<B: Backend> {
182 pub weights: Tensor<B, 4>,
184 pub context: Tensor<B, 3>,
186}
187
188impl<B: Backend> MultiHeadAttention<B> {
189 pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
200 let [batch_size, seq_length_1, d_model] = input.query.dims();
201
202 let query = self.attention_linear(input.query, &self.query);
203 let key = self.attention_linear(input.key, &self.key);
204 let value = self.attention_linear(input.value, &self.value);
205
206 let attn_scores = self.attn_scores(query, key);
207 let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
208
209 let context = weights.clone().matmul(value);
210 let context = context
211 .swap_dims(1, 2)
212 .reshape([batch_size, seq_length_1, d_model]);
213 let context = self.output.forward(context);
214
215 MhaOutput { weights, context }
216 }
217
218 pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
227 let [batch_size, seq_length_1, d_model] = input.query.dims();
228
229 let query = cache
230 .query
231 .forward(input.query, |t| self.attention_linear(t, &self.query));
232 let key = cache
233 .key
234 .forward(input.key, |t| self.attention_linear(t, &self.key));
235 let value = cache
236 .value
237 .forward(input.value, |t| self.attention_linear(t, &self.value));
238
239 let attn_scores = self.attn_scores(query, key);
240 let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
241
242 let context = weights.clone().matmul(value);
243 let context = context
244 .swap_dims(1, 2)
245 .reshape([batch_size, seq_length_1, d_model]);
246
247 let context = cache.output.forward(context, |t| self.output.forward(t));
248
249 MhaOutput { weights, context }
250 }
251
252 fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
253 let attn_scores = query
254 .matmul(key.transpose())
255 .div_scalar((self.d_k as f32).sqrt());
256
257 self.dropout.forward(attn_scores)
258 }
259
260 fn attn_weights(
261 &self,
262 mut attn_scores: Tensor<B, 4>,
263 mask_pad: Option<Tensor<B, 2, Bool>>,
264 mask_attn: Option<Tensor<B, 3, Bool>>,
265 ) -> Tensor<B, 4> {
266 if let Some(mask_pad) = mask_pad {
267 let [batch_size, seq_length] = mask_pad.dims();
268
269 attn_scores = attn_scores.mask_fill(
270 mask_pad.reshape([batch_size, 1, 1, seq_length]),
271 self.min_float,
272 );
273 }
274
275 if let Some(mask_attn) = mask_attn {
276 let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
277
278 attn_scores = attn_scores.mask_fill(
279 mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
280 self.min_float,
281 );
282 }
283
284 if self.quiet_softmax {
285 activation::quiet_softmax(attn_scores, 3)
286 } else {
287 activation::softmax(attn_scores, 3)
288 }
289 }
290
291 fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
292 let [batch_size, seq_length, _d_model] = x.dims();
293 linear
294 .forward(x)
295 .reshape([batch_size, seq_length, self.n_heads, self.d_k])
296 .swap_dims(1, 2)
297 }
298}
299
300pub struct MhaCache<B: Backend> {
304 query: MhaLinearCache<B, 4>,
305 key: MhaLinearCache<B, 4>,
306 value: MhaLinearCache<B, 4>,
307 output: MhaLinearCache<B, 3>,
308}
309
310enum MhaLinearCache<B: Backend, const D: usize> {
311 Autoregressive(TensorCache<B, D>, usize),
312 Full(TensorCache<B, D>),
313}
314
315impl<B: Backend> MhaCache<B> {
316 pub fn autoregressive() -> Self {
318 Self {
319 query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
320 key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
321 value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
322 output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
323 }
324 }
325
326 pub fn autoregressive_cross_attention() -> Self {
329 Self {
330 query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
331 key: MhaLinearCache::Full(TensorCache::empty()),
332 value: MhaLinearCache::Full(TensorCache::empty()),
333 output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
334 }
335 }
336}
337
338impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
339 pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
340 &mut self,
341 tensor: Tensor<B, 3>,
342 func: F,
343 ) -> Tensor<B, D> {
344 match self {
345 MhaLinearCache::Autoregressive(cache, dim) => {
346 cache.forward_autoregressive(tensor, *dim, func)
347 }
348 MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::tensor::Int;
357 use crate::tensor::{Distribution, Shape};
358 use crate::{TestBackend, nn::attention::generate_autoregressive_mask};
359 use alloc::vec::Vec;
360 use burn_tensor::Tolerance;
361 use burn_tensor::ops::FloatElem;
362
363 #[test]
364 fn test_self_attention_shapes() {
365 let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
366 let device = Default::default();
367 let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
368 let input = MhaInput::self_attn(Tensor::random(
369 [batch_size, seq_length, d_model],
370 Distribution::Default,
371 &device,
372 ));
373
374 let output = mha.forward(input);
375
376 assert_eq!(
377 output.context.shape(),
378 Shape::new([batch_size, seq_length, d_model]),
379 "Context should have the correct shape",
380 );
381 assert_eq!(
382 output.weights.shape(),
383 Shape::new([batch_size, n_heads, seq_length, seq_length]),
384 "Weights should have the correct shape",
385 );
386 }
387
388 #[test]
389 fn test_generic_mha_shapes() {
390 let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
391 let mha = MultiHeadAttentionConfig::new(d_model, n_heads)
392 .init::<TestBackend>(&Default::default());
393 let device = Default::default();
394 let input = MhaInput::new(
395 Tensor::random(
396 [batch_size, seq_length_1, d_model],
397 Distribution::Default,
398 &device,
399 ),
400 Tensor::random(
401 [batch_size, seq_length_2, d_model],
402 Distribution::Default,
403 &device,
404 ),
405 Tensor::random(
406 [batch_size, seq_length_2, d_model],
407 Distribution::Default,
408 &device,
409 ),
410 );
411
412 let output = mha.forward(input);
413
414 assert_eq!(
415 output.context.shape(),
416 Shape::new([batch_size, seq_length_1, d_model]),
417 "Context should have the correct shape",
418 );
419 assert_eq!(
420 output.weights.shape(),
421 Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
422 "Weights should have the correct shape",
423 );
424 }
425
426 #[test]
427 fn test_self_attention_mask_pad() {
428 let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
429 let device = Default::default();
430 let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
431
432 let mask_pad: Tensor<TestBackend, 2, Int> =
434 Tensor::zeros([batch_size, seq_length], &device);
435 let mask_pad = mask_pad.slice_assign(
436 [0..batch_size, seq_length - num_padded..seq_length],
437 Tensor::ones([batch_size, num_padded], &device),
438 );
439 let mask_pad = mask_pad.equal_elem(1).to_device(&device);
440
441 let tensor_1 = Tensor::<TestBackend, 3>::random(
442 [batch_size, seq_length, d_model],
443 Distribution::Default,
444 &device,
445 );
446 let tensor_2 = tensor_1.clone().slice_assign(
448 [
449 0..batch_size,
450 seq_length - num_padded..seq_length,
451 0..d_model,
452 ],
453 Tensor::random(
454 [batch_size, num_padded, d_model],
455 Distribution::Default,
456 &device,
457 ),
458 );
459
460 let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
461 let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
462
463 let output_1 = mha.forward(input_1);
464 let output_2 = mha.forward(input_2);
465
466 output_1
468 .context
469 .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
470 .into_data()
471 .assert_approx_eq(
472 &output_2
473 .context
474 .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
475 .into_data(),
476 Tolerance::<f32>::default(),
477 );
478 }
479
480 #[test]
481 fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
482 let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
483 let device = Default::default();
484 let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
485
486 let tensor = Tensor::<TestBackend, 3>::random(
487 [batch_size, seq_length, d_model],
488 Distribution::Default,
489 &device,
490 );
491 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
492 let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
493
494 let output_1 = mha.forward(input);
495 let mut output_2 = Vec::new();
496 let mut cache = MhaCache::autoregressive();
497
498 for i in 1..seq_length + 1 {
499 let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
500 let input = MhaInput::self_attn(tensor);
501 let next_tok = mha.forward_cache(input, &mut cache).context.slice([
502 0..batch_size,
503 i - 1..i,
504 0..d_model,
505 ]);
506 output_2.push(next_tok);
507 }
508
509 let output_2 = Tensor::cat(output_2, 1);
510
511 output_1
512 .context
513 .into_data()
514 .assert_approx_eq::<FloatElem<TestBackend>>(
515 &output_2.into_data(),
516 Tolerance::default(),
517 );
518 }
519
520 #[test]
521 fn display() {
522 let config = MultiHeadAttentionConfig::new(2, 4);
523 let mha = config.init::<TestBackend>(&Default::default());
524
525 assert_eq!(
526 alloc::format!("{mha}"),
527 "MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
528 dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
529 );
530 }
531}