burn_core/nn/attention/
mha.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::nn::cache::TensorCache;
5use crate::nn::Initializer;
6use crate::{
7    config::Config,
8    nn,
9    tensor::{activation, backend::Backend, Bool, Tensor},
10};
11
12#[cfg(not(feature = "std"))]
13use num_traits::Float;
14
15/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init).
16#[derive(Config)]
17pub struct MultiHeadAttentionConfig {
18    /// The size of each linear layer.
19    pub d_model: usize,
20    /// The number of heads.
21    pub n_heads: usize,
22    /// The dropout rate. Default: 0.1
23    #[config(default = 0.1)]
24    pub dropout: f64,
25    /// The minimum value a float can take. Default: -1.0e4
26    /// This is used to mask attention scores before calculating attention weights.
27    /// A value too low might result in NaN.
28    #[config(default = -1.0e4)]
29    pub min_float: f64,
30    /// Use "quiet softmax" instead of regular softmax.
31    ///
32    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
33    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
34    ///
35    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
36    #[config(default = false)]
37    pub quiet_softmax: bool,
38    /// The type of function used to initialize neural network parameters
39    #[config(
40        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
41    )]
42    pub initializer: Initializer,
43}
44
45/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
46///
47/// # Params
48///
49/// - query: [Linear](nn::Linear) layer with `d_model` input and output features.
50/// - key: [Linear](nn::Linear) layer with `d_model` input and output features.
51/// - value: [Linear](nn::Linear) layer with `d_model` input and output features.
52/// - output: [Linear](nn::Linear) layer with `d_model` input and output features.
53///
54/// Should be created with [MultiHeadAttentionConfig].
55#[derive(Module, Debug)]
56#[module(custom_display)]
57pub struct MultiHeadAttention<B: Backend> {
58    /// Linear layer to transform the input features into the query space.
59    pub query: nn::Linear<B>,
60    /// Linear layer to transform the input features into the key space.
61    pub key: nn::Linear<B>,
62    /// Linear layer to transform the input features into the value space.
63    pub value: nn::Linear<B>,
64    /// Linear layer to transform the output features back to the original space.
65    pub output: nn::Linear<B>,
66    /// Dropout layer.
67    pub dropout: nn::Dropout,
68    /// Activation function.
69    pub activation: nn::Gelu,
70    /// The size of each linear layer.
71    pub d_model: usize,
72    /// The number of heads.
73    pub n_heads: usize,
74    /// Size of the key and query vectors.
75    pub d_k: usize,
76    /// Minimum value a float can take.
77    pub min_float: f64,
78    /// Use "quiet softmax" instead of regular softmax.
79    pub quiet_softmax: bool,
80}
81
82impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
83    fn custom_settings(&self) -> Option<DisplaySettings> {
84        DisplaySettings::new()
85            .with_new_line_after_attribute(false)
86            .optional()
87    }
88
89    fn custom_content(&self, content: Content) -> Option<Content> {
90        content
91            .add("d_model", &self.d_model)
92            .add("n_heads", &self.n_heads)
93            .add("d_k", &self.d_k)
94            .add("dropout", &self.dropout.prob)
95            .add("min_float", &self.min_float)
96            .add("quiet_softmax", &self.quiet_softmax)
97            .optional()
98    }
99}
100
101/// [Multihead attention](MultiHeadAttention) forward pass input argument.
102#[derive(Debug, Clone)]
103pub struct MhaInput<B: Backend> {
104    /// Shape `[batch_size, seq_length_1, d_model]`
105    query: Tensor<B, 3>,
106    /// Shape `[batch_size, seq_length_2, d_model]`
107    key: Tensor<B, 3>,
108    /// Shape `[batch_size, seq_length_2, d_model]`
109    value: Tensor<B, 3>,
110    mask_pad: Option<Tensor<B, 2, Bool>>,
111    mask_attn: Option<Tensor<B, 3, Bool>>,
112}
113
114impl MultiHeadAttentionConfig {
115    /// Initialize a new [multihead attention](MultiHeadAttention) module.
116    pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
117        let linear = |config: &Self| {
118            nn::LinearConfig::new(config.d_model, config.d_model)
119                .with_initializer(self.initializer.clone())
120                .init(device)
121        };
122
123        MultiHeadAttention {
124            query: linear(self),
125            key: linear(self),
126            value: linear(self),
127            output: linear(self),
128            dropout: nn::DropoutConfig::new(self.dropout).init(),
129            activation: nn::Gelu::new(),
130            n_heads: self.n_heads,
131            d_k: self.d_model / self.n_heads,
132            min_float: self.min_float,
133            quiet_softmax: self.quiet_softmax,
134            d_model: self.d_model,
135        }
136    }
137}
138
139impl<B: Backend> MhaInput<B> {
140    /// Create a [multihead attention](MultiHeadAttention) input argument
141    /// by setting the query, key and value to the given tensor.
142    ///
143    /// # Shape
144    /// - tensor: `[batch_size, seq_length, d_model]`
145    pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
146        Self {
147            query: tensor.clone(),
148            key: tensor.clone(),
149            value: tensor,
150            mask_pad: None,
151            mask_attn: None,
152        }
153    }
154
155    /// Create a [multihead attention](MultiHeadAttention) input argument.
156    pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
157        Self {
158            query,
159            key,
160            value,
161            mask_pad: None,
162            mask_attn: None,
163        }
164    }
165
166    /// Register the padding mask.
167    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
168        self.mask_pad = Some(mask_pad);
169        self
170    }
171
172    /// Register the attention mask.
173    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
174        self.mask_attn = Some(mask_attn);
175        self
176    }
177}
178
179/// [Multihead attention](MultiHeadAttention) outputs.
180#[derive(Debug, Clone)]
181pub struct MhaOutput<B: Backend> {
182    /// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`.
183    pub weights: Tensor<B, 4>,
184    /// The context tensor `[batch_size, seq_length_1, d_model]`.
185    pub context: Tensor<B, 3>,
186}
187
188impl<B: Backend> MultiHeadAttention<B> {
189    /// Applies the forward pass on the input tensors.
190    ///
191    /// See [MultiHeadAttention](MultiHeadAttention) for more information.
192    ///
193    /// # Shapes
194    ///
195    /// - query: `[batch_size, seq_length_1, d_model]`
196    /// - key: `[batch_size, seq_length_2, d_model]`
197    /// - value: `[batch_size, seq_length_2, d_model]`
198    /// - output: `[batch_size, seq_length_1, d_model]`
199    pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
200        let [batch_size, seq_length_1, d_model] = input.query.dims();
201
202        let query = self.attention_linear(input.query, &self.query);
203        let key = self.attention_linear(input.key, &self.key);
204        let value = self.attention_linear(input.value, &self.value);
205
206        let attn_scores = self.attn_scores(query, key);
207        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
208
209        let context = weights.clone().matmul(value);
210        let context = context
211            .swap_dims(1, 2)
212            .reshape([batch_size, seq_length_1, d_model]);
213        let context = self.output.forward(context);
214
215        MhaOutput { weights, context }
216    }
217
218    /// Applies the forward pass using a cache.
219    ///
220    /// # Shapes
221    ///
222    /// - query: `[batch_size, seq_length_1, d_model]`
223    /// - key: `[batch_size, seq_length_2, d_model]`
224    /// - value: `[batch_size, seq_length_2, d_model]`
225    /// - output: `[batch_size, seq_length_1, d_model]`
226    pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
227        let [batch_size, seq_length_1, d_model] = input.query.dims();
228
229        let query = cache
230            .query
231            .forward(input.query, |t| self.attention_linear(t, &self.query));
232        let key = cache
233            .key
234            .forward(input.key, |t| self.attention_linear(t, &self.key));
235        let value = cache
236            .value
237            .forward(input.value, |t| self.attention_linear(t, &self.value));
238
239        let attn_scores = self.attn_scores(query, key);
240        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
241
242        let context = weights.clone().matmul(value);
243        let context = context
244            .swap_dims(1, 2)
245            .reshape([batch_size, seq_length_1, d_model]);
246
247        let context = cache.output.forward(context, |t| self.output.forward(t));
248
249        MhaOutput { weights, context }
250    }
251
252    fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
253        let attn_scores = query
254            .matmul(key.transpose())
255            .div_scalar((self.d_k as f32).sqrt());
256
257        self.dropout.forward(attn_scores)
258    }
259
260    fn attn_weights(
261        &self,
262        mut attn_scores: Tensor<B, 4>,
263        mask_pad: Option<Tensor<B, 2, Bool>>,
264        mask_attn: Option<Tensor<B, 3, Bool>>,
265    ) -> Tensor<B, 4> {
266        if let Some(mask_pad) = mask_pad {
267            let [batch_size, seq_length] = mask_pad.dims();
268
269            attn_scores = attn_scores.mask_fill(
270                mask_pad.reshape([batch_size, 1, 1, seq_length]),
271                self.min_float,
272            );
273        }
274
275        if let Some(mask_attn) = mask_attn {
276            let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
277
278            attn_scores = attn_scores.mask_fill(
279                mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
280                self.min_float,
281            );
282        }
283
284        if self.quiet_softmax {
285            activation::quiet_softmax(attn_scores, 3)
286        } else {
287            activation::softmax(attn_scores, 3)
288        }
289    }
290
291    fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
292        let [batch_size, seq_length, _d_model] = x.dims();
293        linear
294            .forward(x)
295            .reshape([batch_size, seq_length, self.n_heads, self.d_k])
296            .swap_dims(1, 2)
297    }
298}
299
300/// Cache for the [Multi Head Attention](MultiHeadAttention) layer.
301///
302/// To be used during inference when decoding tokens.
303pub struct MhaCache<B: Backend> {
304    query: MhaLinearCache<B, 4>,
305    key: MhaLinearCache<B, 4>,
306    value: MhaLinearCache<B, 4>,
307    output: MhaLinearCache<B, 3>,
308}
309
310enum MhaLinearCache<B: Backend, const D: usize> {
311    Autoregressive(TensorCache<B, D>, usize),
312    Full(TensorCache<B, D>),
313}
314
315impl<B: Backend> MhaCache<B> {
316    /// Initialize a cache for autoregressive inference.
317    pub fn autoregressive() -> Self {
318        Self {
319            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
320            key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
321            value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
322            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
323        }
324    }
325
326    /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and
327    /// values (cross-attention).
328    pub fn autoregressive_cross_attention() -> Self {
329        Self {
330            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
331            key: MhaLinearCache::Full(TensorCache::empty()),
332            value: MhaLinearCache::Full(TensorCache::empty()),
333            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
334        }
335    }
336}
337
338impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
339    pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
340        &mut self,
341        tensor: Tensor<B, 3>,
342        func: F,
343    ) -> Tensor<B, D> {
344        match self {
345            MhaLinearCache::Autoregressive(cache, dim) => {
346                cache.forward_autoregressive(tensor, *dim, func)
347            }
348            MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::tensor::Int;
357    use crate::tensor::{Distribution, Shape};
358    use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
359    use alloc::vec::Vec;
360
361    #[test]
362    fn test_self_attention_shapes() {
363        let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
364        let device = Default::default();
365        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
366        let input = MhaInput::self_attn(Tensor::random(
367            [batch_size, seq_length, d_model],
368            Distribution::Default,
369            &device,
370        ));
371
372        let output = mha.forward(input);
373
374        assert_eq!(
375            output.context.shape(),
376            Shape::new([batch_size, seq_length, d_model]),
377            "Context should have the correct shape",
378        );
379        assert_eq!(
380            output.weights.shape(),
381            Shape::new([batch_size, n_heads, seq_length, seq_length]),
382            "Weights should have the correct shape",
383        );
384    }
385
386    #[test]
387    fn test_generic_mha_shapes() {
388        let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
389        let mha = MultiHeadAttentionConfig::new(d_model, n_heads)
390            .init::<TestBackend>(&Default::default());
391        let device = Default::default();
392        let input = MhaInput::new(
393            Tensor::random(
394                [batch_size, seq_length_1, d_model],
395                Distribution::Default,
396                &device,
397            ),
398            Tensor::random(
399                [batch_size, seq_length_2, d_model],
400                Distribution::Default,
401                &device,
402            ),
403            Tensor::random(
404                [batch_size, seq_length_2, d_model],
405                Distribution::Default,
406                &device,
407            ),
408        );
409
410        let output = mha.forward(input);
411
412        assert_eq!(
413            output.context.shape(),
414            Shape::new([batch_size, seq_length_1, d_model]),
415            "Context should have the correct shape",
416        );
417        assert_eq!(
418            output.weights.shape(),
419            Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
420            "Weights should have the correct shape",
421        );
422    }
423
424    #[test]
425    fn test_self_attention_mask_pad() {
426        let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
427        let device = Default::default();
428        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
429
430        // Create a padding mask
431        let mask_pad: Tensor<TestBackend, 2, Int> =
432            Tensor::zeros([batch_size, seq_length], &device);
433        let mask_pad = mask_pad.slice_assign(
434            [0..batch_size, seq_length - num_padded..seq_length],
435            Tensor::ones([batch_size, num_padded], &device),
436        );
437        let mask_pad = mask_pad.equal_elem(1).to_device(&device);
438
439        let tensor_1 = Tensor::<TestBackend, 3>::random(
440            [batch_size, seq_length, d_model],
441            Distribution::Default,
442            &device,
443        );
444        // Change the end of the tensor
445        let tensor_2 = tensor_1.clone().slice_assign(
446            [
447                0..batch_size,
448                seq_length - num_padded..seq_length,
449                0..d_model,
450            ],
451            Tensor::random(
452                [batch_size, num_padded, d_model],
453                Distribution::Default,
454                &device,
455            ),
456        );
457
458        let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
459        let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
460
461        let output_1 = mha.forward(input_1);
462        let output_2 = mha.forward(input_2);
463
464        // Check that the beginning of each tensor is the same
465        output_1
466            .context
467            .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
468            .into_data()
469            .assert_approx_eq(
470                &output_2
471                    .context
472                    .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
473                    .into_data(),
474                3,
475            );
476    }
477
478    #[test]
479    fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
480        let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
481        let device = Default::default();
482        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
483
484        let tensor = Tensor::<TestBackend, 3>::random(
485            [batch_size, seq_length, d_model],
486            Distribution::Default,
487            &device,
488        );
489        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
490        let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
491
492        let output_1 = mha.forward(input);
493        let mut output_2 = Vec::new();
494        let mut cache = MhaCache::autoregressive();
495
496        for i in 1..seq_length + 1 {
497            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
498            let input = MhaInput::self_attn(tensor);
499            let next_tok = mha.forward_cache(input, &mut cache).context.slice([
500                0..batch_size,
501                i - 1..i,
502                0..d_model,
503            ]);
504            output_2.push(next_tok);
505        }
506
507        let output_2 = Tensor::cat(output_2, 1);
508
509        output_1
510            .context
511            .into_data()
512            .assert_approx_eq(&output_2.into_data(), 3);
513    }
514
515    #[test]
516    fn display() {
517        let config = MultiHeadAttentionConfig::new(2, 4);
518        let mha = config.init::<TestBackend>(&Default::default());
519
520        assert_eq!(
521            alloc::format!("{}", mha),
522            "MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
523            dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
524        );
525    }
526}