burn_core/nn/attention/
mha.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::nn::Initializer;
5use crate::nn::cache::TensorCache;
6use crate::{
7    config::Config,
8    nn,
9    tensor::{Bool, Tensor, activation, backend::Backend},
10};
11
12#[cfg(not(feature = "std"))]
13use num_traits::Float;
14
15/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init).
16#[derive(Config)]
17pub struct MultiHeadAttentionConfig {
18    /// The size of each linear layer.
19    pub d_model: usize,
20    /// The number of heads.
21    pub n_heads: usize,
22    /// The dropout rate. Default: 0.1
23    #[config(default = 0.1)]
24    pub dropout: f64,
25    /// The minimum value a float can take. Default: -1.0e4
26    /// This is used to mask attention scores before calculating attention weights.
27    /// A value too low might result in NaN.
28    #[config(default = -1.0e4)]
29    pub min_float: f64,
30    /// Use "quiet softmax" instead of regular softmax.
31    ///
32    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
33    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
34    ///
35    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
36    #[config(default = false)]
37    pub quiet_softmax: bool,
38    /// The type of function used to initialize neural network parameters
39    #[config(
40        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
41    )]
42    pub initializer: Initializer,
43}
44
45/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
46///
47/// # Params
48///
49/// - query: [Linear](nn::Linear) layer with `d_model` input and output features.
50/// - key: [Linear](nn::Linear) layer with `d_model` input and output features.
51/// - value: [Linear](nn::Linear) layer with `d_model` input and output features.
52/// - output: [Linear](nn::Linear) layer with `d_model` input and output features.
53///
54/// Should be created with [MultiHeadAttentionConfig].
55#[derive(Module, Debug)]
56#[module(custom_display)]
57pub struct MultiHeadAttention<B: Backend> {
58    /// Linear layer to transform the input features into the query space.
59    pub query: nn::Linear<B>,
60    /// Linear layer to transform the input features into the key space.
61    pub key: nn::Linear<B>,
62    /// Linear layer to transform the input features into the value space.
63    pub value: nn::Linear<B>,
64    /// Linear layer to transform the output features back to the original space.
65    pub output: nn::Linear<B>,
66    /// Dropout layer.
67    pub dropout: nn::Dropout,
68    /// Activation function.
69    pub activation: nn::Gelu,
70    /// The size of each linear layer.
71    pub d_model: usize,
72    /// The number of heads.
73    pub n_heads: usize,
74    /// Size of the key and query vectors.
75    pub d_k: usize,
76    /// Minimum value a float can take.
77    pub min_float: f64,
78    /// Use "quiet softmax" instead of regular softmax.
79    pub quiet_softmax: bool,
80}
81
82impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
83    fn custom_settings(&self) -> Option<DisplaySettings> {
84        DisplaySettings::new()
85            .with_new_line_after_attribute(false)
86            .optional()
87    }
88
89    fn custom_content(&self, content: Content) -> Option<Content> {
90        content
91            .add("d_model", &self.d_model)
92            .add("n_heads", &self.n_heads)
93            .add("d_k", &self.d_k)
94            .add("dropout", &self.dropout.prob)
95            .add("min_float", &self.min_float)
96            .add("quiet_softmax", &self.quiet_softmax)
97            .optional()
98    }
99}
100
101/// [Multihead attention](MultiHeadAttention) forward pass input argument.
102#[derive(Debug, Clone)]
103pub struct MhaInput<B: Backend> {
104    /// Shape `[batch_size, seq_length_1, d_model]`
105    query: Tensor<B, 3>,
106    /// Shape `[batch_size, seq_length_2, d_model]`
107    key: Tensor<B, 3>,
108    /// Shape `[batch_size, seq_length_2, d_model]`
109    value: Tensor<B, 3>,
110    mask_pad: Option<Tensor<B, 2, Bool>>,
111    mask_attn: Option<Tensor<B, 3, Bool>>,
112}
113
114impl MultiHeadAttentionConfig {
115    /// Initialize a new [multihead attention](MultiHeadAttention) module.
116    pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
117        let linear = |config: &Self| {
118            nn::LinearConfig::new(config.d_model, config.d_model)
119                .with_initializer(self.initializer.clone())
120                .init(device)
121        };
122
123        MultiHeadAttention {
124            query: linear(self),
125            key: linear(self),
126            value: linear(self),
127            output: linear(self),
128            dropout: nn::DropoutConfig::new(self.dropout).init(),
129            activation: nn::Gelu::new(),
130            n_heads: self.n_heads,
131            d_k: self.d_model / self.n_heads,
132            min_float: self.min_float,
133            quiet_softmax: self.quiet_softmax,
134            d_model: self.d_model,
135        }
136    }
137}
138
139impl<B: Backend> MhaInput<B> {
140    /// Create a [multihead attention](MultiHeadAttention) input argument
141    /// by setting the query, key and value to the given tensor.
142    ///
143    /// # Shape
144    /// - tensor: `[batch_size, seq_length, d_model]`
145    pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
146        Self {
147            query: tensor.clone(),
148            key: tensor.clone(),
149            value: tensor,
150            mask_pad: None,
151            mask_attn: None,
152        }
153    }
154
155    /// Create a [multihead attention](MultiHeadAttention) input argument.
156    pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
157        Self {
158            query,
159            key,
160            value,
161            mask_pad: None,
162            mask_attn: None,
163        }
164    }
165
166    /// Register the padding mask.
167    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
168        self.mask_pad = Some(mask_pad);
169        self
170    }
171
172    /// Register the attention mask.
173    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
174        self.mask_attn = Some(mask_attn);
175        self
176    }
177}
178
179/// [Multihead attention](MultiHeadAttention) outputs.
180#[derive(Debug, Clone)]
181pub struct MhaOutput<B: Backend> {
182    /// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`.
183    pub weights: Tensor<B, 4>,
184    /// The context tensor `[batch_size, seq_length_1, d_model]`.
185    pub context: Tensor<B, 3>,
186}
187
188impl<B: Backend> MultiHeadAttention<B> {
189    /// Applies the forward pass on the input tensors.
190    ///
191    /// See [MultiHeadAttention](MultiHeadAttention) for more information.
192    ///
193    /// # Shapes
194    ///
195    /// - query: `[batch_size, seq_length_1, d_model]`
196    /// - key: `[batch_size, seq_length_2, d_model]`
197    /// - value: `[batch_size, seq_length_2, d_model]`
198    /// - output: `[batch_size, seq_length_1, d_model]`
199    pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
200        let [batch_size, seq_length_1, d_model] = input.query.dims();
201
202        let query = self.attention_linear(input.query, &self.query);
203        let key = self.attention_linear(input.key, &self.key);
204        let value = self.attention_linear(input.value, &self.value);
205
206        let attn_scores = self.attn_scores(query, key);
207        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
208
209        let context = weights.clone().matmul(value);
210        let context = context
211            .swap_dims(1, 2)
212            .reshape([batch_size, seq_length_1, d_model]);
213        let context = self.output.forward(context);
214
215        MhaOutput { weights, context }
216    }
217
218    /// Applies the forward pass using a cache.
219    ///
220    /// # Shapes
221    ///
222    /// - query: `[batch_size, seq_length_1, d_model]`
223    /// - key: `[batch_size, seq_length_2, d_model]`
224    /// - value: `[batch_size, seq_length_2, d_model]`
225    /// - output: `[batch_size, seq_length_1, d_model]`
226    pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
227        let [batch_size, seq_length_1, d_model] = input.query.dims();
228
229        let query = cache
230            .query
231            .forward(input.query, |t| self.attention_linear(t, &self.query));
232        let key = cache
233            .key
234            .forward(input.key, |t| self.attention_linear(t, &self.key));
235        let value = cache
236            .value
237            .forward(input.value, |t| self.attention_linear(t, &self.value));
238
239        let attn_scores = self.attn_scores(query, key);
240        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
241
242        let context = weights.clone().matmul(value);
243        let context = context
244            .swap_dims(1, 2)
245            .reshape([batch_size, seq_length_1, d_model]);
246
247        let context = cache.output.forward(context, |t| self.output.forward(t));
248
249        MhaOutput { weights, context }
250    }
251
252    fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
253        let attn_scores = query
254            .matmul(key.transpose())
255            .div_scalar((self.d_k as f32).sqrt());
256
257        self.dropout.forward(attn_scores)
258    }
259
260    fn attn_weights(
261        &self,
262        mut attn_scores: Tensor<B, 4>,
263        mask_pad: Option<Tensor<B, 2, Bool>>,
264        mask_attn: Option<Tensor<B, 3, Bool>>,
265    ) -> Tensor<B, 4> {
266        if let Some(mask_pad) = mask_pad {
267            let [batch_size, seq_length] = mask_pad.dims();
268
269            attn_scores = attn_scores.mask_fill(
270                mask_pad.reshape([batch_size, 1, 1, seq_length]),
271                self.min_float,
272            );
273        }
274
275        if let Some(mask_attn) = mask_attn {
276            let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
277
278            attn_scores = attn_scores.mask_fill(
279                mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
280                self.min_float,
281            );
282        }
283
284        if self.quiet_softmax {
285            activation::quiet_softmax(attn_scores, 3)
286        } else {
287            activation::softmax(attn_scores, 3)
288        }
289    }
290
291    fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
292        let [batch_size, seq_length, _d_model] = x.dims();
293        linear
294            .forward(x)
295            .reshape([batch_size, seq_length, self.n_heads, self.d_k])
296            .swap_dims(1, 2)
297    }
298}
299
300/// Cache for the [Multi Head Attention](MultiHeadAttention) layer.
301///
302/// To be used during inference when decoding tokens.
303pub struct MhaCache<B: Backend> {
304    query: MhaLinearCache<B, 4>,
305    key: MhaLinearCache<B, 4>,
306    value: MhaLinearCache<B, 4>,
307    output: MhaLinearCache<B, 3>,
308}
309
310enum MhaLinearCache<B: Backend, const D: usize> {
311    Autoregressive(TensorCache<B, D>, usize),
312    Full(TensorCache<B, D>),
313}
314
315impl<B: Backend> MhaCache<B> {
316    /// Initialize a cache for autoregressive inference.
317    pub fn autoregressive() -> Self {
318        Self {
319            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
320            key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
321            value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
322            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
323        }
324    }
325
326    /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and
327    /// values (cross-attention).
328    pub fn autoregressive_cross_attention() -> Self {
329        Self {
330            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
331            key: MhaLinearCache::Full(TensorCache::empty()),
332            value: MhaLinearCache::Full(TensorCache::empty()),
333            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
334        }
335    }
336}
337
338impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
339    pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
340        &mut self,
341        tensor: Tensor<B, 3>,
342        func: F,
343    ) -> Tensor<B, D> {
344        match self {
345            MhaLinearCache::Autoregressive(cache, dim) => {
346                cache.forward_autoregressive(tensor, *dim, func)
347            }
348            MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::tensor::Int;
357    use crate::tensor::{Distribution, Shape};
358    use crate::{TestBackend, nn::attention::generate_autoregressive_mask};
359    use alloc::vec::Vec;
360    use burn_tensor::Tolerance;
361    use burn_tensor::ops::FloatElem;
362
363    #[test]
364    fn test_self_attention_shapes() {
365        let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
366        let device = Default::default();
367        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
368        let input = MhaInput::self_attn(Tensor::random(
369            [batch_size, seq_length, d_model],
370            Distribution::Default,
371            &device,
372        ));
373
374        let output = mha.forward(input);
375
376        assert_eq!(
377            output.context.shape(),
378            Shape::new([batch_size, seq_length, d_model]),
379            "Context should have the correct shape",
380        );
381        assert_eq!(
382            output.weights.shape(),
383            Shape::new([batch_size, n_heads, seq_length, seq_length]),
384            "Weights should have the correct shape",
385        );
386    }
387
388    #[test]
389    fn test_generic_mha_shapes() {
390        let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
391        let mha = MultiHeadAttentionConfig::new(d_model, n_heads)
392            .init::<TestBackend>(&Default::default());
393        let device = Default::default();
394        let input = MhaInput::new(
395            Tensor::random(
396                [batch_size, seq_length_1, d_model],
397                Distribution::Default,
398                &device,
399            ),
400            Tensor::random(
401                [batch_size, seq_length_2, d_model],
402                Distribution::Default,
403                &device,
404            ),
405            Tensor::random(
406                [batch_size, seq_length_2, d_model],
407                Distribution::Default,
408                &device,
409            ),
410        );
411
412        let output = mha.forward(input);
413
414        assert_eq!(
415            output.context.shape(),
416            Shape::new([batch_size, seq_length_1, d_model]),
417            "Context should have the correct shape",
418        );
419        assert_eq!(
420            output.weights.shape(),
421            Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
422            "Weights should have the correct shape",
423        );
424    }
425
426    #[test]
427    fn test_self_attention_mask_pad() {
428        let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
429        let device = Default::default();
430        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
431
432        // Create a padding mask
433        let mask_pad: Tensor<TestBackend, 2, Int> =
434            Tensor::zeros([batch_size, seq_length], &device);
435        let mask_pad = mask_pad.slice_assign(
436            [0..batch_size, seq_length - num_padded..seq_length],
437            Tensor::ones([batch_size, num_padded], &device),
438        );
439        let mask_pad = mask_pad.equal_elem(1).to_device(&device);
440
441        let tensor_1 = Tensor::<TestBackend, 3>::random(
442            [batch_size, seq_length, d_model],
443            Distribution::Default,
444            &device,
445        );
446        // Change the end of the tensor
447        let tensor_2 = tensor_1.clone().slice_assign(
448            [
449                0..batch_size,
450                seq_length - num_padded..seq_length,
451                0..d_model,
452            ],
453            Tensor::random(
454                [batch_size, num_padded, d_model],
455                Distribution::Default,
456                &device,
457            ),
458        );
459
460        let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
461        let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
462
463        let output_1 = mha.forward(input_1);
464        let output_2 = mha.forward(input_2);
465
466        // Check that the beginning of each tensor is the same
467        output_1
468            .context
469            .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
470            .into_data()
471            .assert_approx_eq(
472                &output_2
473                    .context
474                    .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
475                    .into_data(),
476                Tolerance::<f32>::default(),
477            );
478    }
479
480    #[test]
481    fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
482        let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
483        let device = Default::default();
484        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
485
486        let tensor = Tensor::<TestBackend, 3>::random(
487            [batch_size, seq_length, d_model],
488            Distribution::Default,
489            &device,
490        );
491        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
492        let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
493
494        let output_1 = mha.forward(input);
495        let mut output_2 = Vec::new();
496        let mut cache = MhaCache::autoregressive();
497
498        for i in 1..seq_length + 1 {
499            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
500            let input = MhaInput::self_attn(tensor);
501            let next_tok = mha.forward_cache(input, &mut cache).context.slice([
502                0..batch_size,
503                i - 1..i,
504                0..d_model,
505            ]);
506            output_2.push(next_tok);
507        }
508
509        let output_2 = Tensor::cat(output_2, 1);
510
511        output_1
512            .context
513            .into_data()
514            .assert_approx_eq::<FloatElem<TestBackend>>(
515                &output_2.into_data(),
516                Tolerance::default(),
517            );
518    }
519
520    #[test]
521    fn display() {
522        let config = MultiHeadAttentionConfig::new(2, 4);
523        let mha = config.init::<TestBackend>(&Default::default());
524
525        assert_eq!(
526            alloc::format!("{mha}"),
527            "MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
528            dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
529        );
530    }
531}