burn_core/nn/transformer/
decoder.rs

1use alloc::vec::Vec;
2
3use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
4
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::Bool;
7use crate::{
8    self as burn,
9    nn::{Initializer, attention::MhaCache, cache::TensorCache},
10};
11use crate::{
12    config::Config,
13    nn::{
14        Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
15        attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
16    },
17    tensor::{Tensor, backend::Backend},
18};
19
20/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).
21#[derive(Config)]
22pub struct TransformerDecoderConfig {
23    /// The size of the model.
24    pub d_model: usize,
25    /// The size of the position-wise feed-forward network.
26    pub d_ff: usize,
27    /// The number of attention heads.
28    pub n_heads: usize,
29    /// The number of layers.
30    pub n_layers: usize,
31    /// The dropout rate. Default: 0.1
32    #[config(default = 0.1)]
33    pub dropout: f64,
34    /// Layer norm will be applied first instead of after the other modules.
35    #[config(default = false)]
36    pub norm_first: bool,
37    /// Use "quiet softmax" instead of regular softmax.
38    ///
39    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
40    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
41    ///
42    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
43    #[config(default = false)]
44    pub quiet_softmax: bool,
45    /// The type of function used to initialize neural network parameters
46    #[config(
47        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
48    )]
49    pub initializer: Initializer,
50}
51
52/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
53///
54/// # Params
55///
56/// - layers: transformer decoder layers with `d_model` input and output features.
57///
58/// Should be created using [TransformerDecoderConfig]
59#[derive(Module, Debug)]
60#[module(custom_display)]
61pub struct TransformerDecoder<B: Backend> {
62    /// Transformer decoder layers.
63    pub layers: Vec<TransformerDecoderLayer<B>>,
64
65    /// The size of the model.
66    pub d_model: usize,
67
68    /// The size of the position-wise feed-forward network.
69    pub d_ff: usize,
70
71    /// The number of attention heads.
72    pub n_heads: usize,
73
74    /// The number of layers.
75    pub n_layers: usize,
76
77    /// The dropout rate. Default: 0.1
78    pub dropout: f64,
79
80    /// Layer norm will be applied first instead of after the other modules.
81    pub norm_first: bool,
82
83    /// Use "quiet softmax" instead of regular softmax.
84    pub quiet_softmax: bool,
85}
86
87impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
88    fn custom_settings(&self) -> Option<DisplaySettings> {
89        DisplaySettings::new()
90            .with_new_line_after_attribute(false)
91            .optional()
92    }
93
94    fn custom_content(&self, content: Content) -> Option<Content> {
95        content
96            .add("d_model", &self.d_model)
97            .add("d_ff", &self.d_ff)
98            .add("n_heads", &self.n_heads)
99            .add("n_layers", &self.n_layers)
100            .add("dropout", &self.dropout)
101            .add("norm_first", &self.norm_first)
102            .add("quiet_softmax", &self.quiet_softmax)
103            .optional()
104    }
105}
106
107impl TransformerDecoderConfig {
108    /// Initialize a new [Transformer Decoder](TransformerDecoder) module.
109    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
110        let layers = (0..self.n_layers)
111            .map(|_| TransformerDecoderLayer::new(self, device))
112            .collect::<Vec<_>>();
113
114        TransformerDecoder {
115            layers,
116            d_model: self.d_model,
117            d_ff: self.d_ff,
118            n_heads: self.n_heads,
119            n_layers: self.n_layers,
120            dropout: self.dropout,
121            norm_first: self.norm_first,
122            quiet_softmax: self.quiet_softmax,
123        }
124    }
125}
126
127/// [Transformer Decoder](TransformerDecoder) forward pass input argument.
128#[derive(Debug)]
129pub struct TransformerDecoderInput<B: Backend> {
130    target: Tensor<B, 3>,
131    target_mask_pad: Option<Tensor<B, 2, Bool>>,
132    target_mask_attn: Option<Tensor<B, 3, Bool>>,
133    memory: Tensor<B, 3>,
134    memory_mask_pad: Option<Tensor<B, 2, Bool>>,
135    memory_mask_attn: Option<Tensor<B, 3, Bool>>,
136}
137
138impl<B: Backend> TransformerDecoderInput<B> {
139    /// Create a [transformer decoder](TransformerDecoder) input argument.
140    pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
141        Self {
142            target,
143            target_mask_pad: None,
144            target_mask_attn: None,
145            memory,
146            memory_mask_pad: None,
147            memory_mask_attn: None,
148        }
149    }
150
151    /// Register the memory padding mask.
152    pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
153        self.memory_mask_pad = Some(mask_pad);
154        self
155    }
156
157    /// Register the memory attention mask.
158    pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
159        self.memory_mask_attn = Some(mask_attn);
160        self
161    }
162
163    /// Register the target padding mask.
164    pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
165        self.target_mask_pad = Some(mask_pad);
166        self
167    }
168
169    /// Register the target attention mask.
170    pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
171        self.target_mask_attn = Some(mask_attn);
172        self
173    }
174}
175
176/// [Transformer Decoder](TransformerDecoder) layer module.
177#[derive(Module, Debug)]
178pub struct TransformerDecoderLayer<B: Backend> {
179    cross_attn: MultiHeadAttention<B>,
180    self_attn: MultiHeadAttention<B>,
181    pwff: PositionWiseFeedForward<B>,
182    norm_1: LayerNorm<B>,
183    norm_2: LayerNorm<B>,
184    norm_3: LayerNorm<B>,
185    dropout: Dropout,
186    norm_first: bool,
187}
188
189struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
190    cross_attn: MhaCache<B>,
191    self_attn: MhaCache<B>,
192    pwff: TensorCache<B, 3>,
193    norm_1: TensorCache<B, 3>,
194    norm_2: TensorCache<B, 3>,
195    norm_3: TensorCache<B, 3>,
196}
197
198impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
199    fn empty() -> Self {
200        Self {
201            cross_attn: MhaCache::autoregressive_cross_attention(),
202            self_attn: MhaCache::autoregressive(),
203            pwff: TensorCache::empty(),
204            norm_1: TensorCache::empty(),
205            norm_2: TensorCache::empty(),
206            norm_3: TensorCache::empty(),
207        }
208    }
209}
210
211/// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer.
212///
213/// To be used during inference when decoding tokens.
214pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
215    layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
216}
217
218impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
219    fn empty(num_layers: usize) -> Self {
220        Self {
221            layers: (0..num_layers)
222                .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
223                .collect(),
224        }
225    }
226}
227
228impl<B: Backend> TransformerDecoderLayer<B> {
229    fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
230        let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
231            .with_initializer(config.initializer.clone())
232            .with_dropout(config.dropout)
233            .with_quiet_softmax(config.quiet_softmax)
234            .init(device);
235
236        let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
237            .with_initializer(config.initializer.clone())
238            .with_dropout(config.dropout)
239            .with_quiet_softmax(config.quiet_softmax)
240            .init(device);
241        let norm_1 = LayerNormConfig::new(config.d_model).init(device);
242        let norm_2 = LayerNormConfig::new(config.d_model).init(device);
243        let norm_3 = LayerNormConfig::new(config.d_model).init(device);
244        let dropout = DropoutConfig::new(config.dropout).init();
245        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
246            .with_dropout(config.dropout)
247            .init(device);
248
249        Self {
250            cross_attn,
251            self_attn,
252            norm_1,
253            norm_2,
254            norm_3,
255            pwff,
256            dropout,
257            norm_first: config.norm_first,
258        }
259    }
260
261    /// Applies the TransformerDecoder forward pass to the input tensor.
262    fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
263        // Self attention residual path.
264        let x = input.target;
265        let mut residual_path = x.clone();
266
267        // Normalize.
268        if self.norm_first {
269            residual_path = self.norm_3.forward(residual_path);
270        }
271
272        // Self attention.
273        let mut self_attn_input = MhaInput::self_attn(residual_path);
274        if let Some(mask_pad) = &input.target_mask_pad {
275            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
276        }
277        if let Some(mask_attn) = &input.target_mask_attn {
278            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
279        }
280        let residual_path = self.self_attn.forward(self_attn_input).context;
281
282        let residual_path = self.dropout.forward(residual_path);
283        let mut x = x + residual_path;
284
285        // Cross attention residual path.
286        // Normalize.
287        let residual_path = if self.norm_first {
288            self.norm_1.forward(x.clone())
289        } else {
290            x = self.norm_1.forward(x);
291            x.clone()
292        };
293
294        // Cross attention.
295        let mut cross_attn_input =
296            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
297        if let Some(mask_pad) = &input.memory_mask_pad {
298            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
299        }
300        if let Some(mask_attn) = &input.memory_mask_attn {
301            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
302        }
303        let residual_path = self.cross_attn.forward(cross_attn_input).context;
304
305        let residual_path = self.dropout.forward(residual_path);
306        let mut x = x + residual_path;
307
308        // Feed forward residual path.
309        // Normalize.
310        let residual_path = if self.norm_first {
311            self.norm_2.forward(x.clone())
312        } else {
313            x = self.norm_2.forward(x);
314            x.clone()
315        };
316
317        let residual_path = self.pwff.forward(residual_path);
318        let residual_path = self.dropout.forward(residual_path);
319        let mut x = x + residual_path;
320
321        // Main path.
322        // Normalize.
323        if !self.norm_first {
324            x = self.norm_3.forward(x)
325        }
326
327        input.target = x;
328        input
329    }
330
331    fn forward_autoregressive_inference(
332        &self,
333        mut input: TransformerDecoderInput<B>,
334        cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
335    ) -> TransformerDecoderInput<B> {
336        // Self attention residual path.
337        let x = input.target;
338        let mut residual_path = x.clone();
339
340        // Normalize.
341        if self.norm_first {
342            residual_path = cache
343                .norm_3
344                .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
345        }
346
347        // Self attention.
348        let mut self_attn_input = MhaInput::self_attn(residual_path);
349        if let Some(mask_pad) = &input.target_mask_pad {
350            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
351        }
352        if let Some(mask_attn) = &input.target_mask_attn {
353            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
354        }
355        let residual_path = self
356            .self_attn
357            .forward_cache(self_attn_input, &mut cache.self_attn)
358            .context;
359
360        let residual_path = self.dropout.forward(residual_path);
361        let mut x = x + residual_path;
362
363        // Cross attention residual path.
364        // Normalize.
365        let residual_path = if self.norm_first {
366            cache
367                .norm_1
368                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
369        } else {
370            x = cache
371                .norm_1
372                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
373            x.clone()
374        };
375
376        // Cross attention.
377        let mut cross_attn_input =
378            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
379        if let Some(mask_pad) = &input.memory_mask_pad {
380            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
381        }
382        if let Some(mask_attn) = &input.memory_mask_attn {
383            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
384        }
385        let residual_path = self
386            .cross_attn
387            .forward_cache(cross_attn_input, &mut cache.cross_attn)
388            .context;
389
390        let residual_path = self.dropout.forward(residual_path);
391        let mut x = x + residual_path;
392
393        // Feed forward residual path.
394        // Normalize.
395        let residual_path = if self.norm_first {
396            cache
397                .norm_2
398                .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
399        } else {
400            x = cache
401                .norm_2
402                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
403            x.clone()
404        };
405
406        let residual_path = cache
407            .pwff
408            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
409        let residual_path = self.dropout.forward(residual_path);
410        let mut x = x + residual_path;
411
412        // Main path.
413        // Normalize.
414        if !self.norm_first {
415            x = cache
416                .norm_3
417                .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
418        }
419
420        input.target = x;
421        input
422    }
423}
424
425impl<B: Backend> TransformerDecoder<B> {
426    /// Applies the forward pass.
427    pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
428        for layer in self.layers.iter() {
429            input = layer.forward(input);
430        }
431
432        input.target
433    }
434
435    /// Applies the forward pass on the input using autoregressive cache.
436    pub fn forward_autoregressive_inference(
437        &self,
438        mut input: TransformerDecoderInput<B>,
439        cache: &mut TransformerDecoderAutoregressiveCache<B>,
440    ) -> Tensor<B, 3> {
441        for i in 0..self.layers.len() {
442            let layer = self.layers.get(i).unwrap();
443            let cache = cache.layers.get_mut(i).unwrap();
444
445            input = layer.forward_autoregressive_inference(input, cache);
446        }
447
448        input.target
449    }
450    /// Create an empty autoregressive cache.
451    pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
452        TransformerDecoderAutoregressiveCache::empty(self.layers.len())
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use burn_tensor::Device;
459
460    use super::*;
461    use crate::{TestBackend, nn::attention::generate_autoregressive_mask};
462
463    use burn_tensor::{Tolerance, ops::FloatElem};
464    type FT = FloatElem<TestBackend>;
465
466    #[test]
467    fn test_autoregressive_norm_last() {
468        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
469        TestBackend::seed(0);
470
471        test_autoregressive(
472            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
473                .with_norm_first(false),
474        )
475    }
476
477    #[test]
478    fn test_autoregressive_norm_first() {
479        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
480        TestBackend::seed(0);
481
482        test_autoregressive(
483            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
484        )
485    }
486
487    fn test_autoregressive(config: TransformerDecoderConfig) {
488        let device: Device<TestBackend> = Default::default();
489        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
490        let transformer = config.init::<TestBackend>(&device);
491
492        let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
493            .float()
494            .reshape([batch_size, seq_length, d_model]);
495        let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
496            .float()
497            .reshape([batch_size, seq_length, d_model]);
498        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
499        let input = TransformerDecoderInput::new(target.clone(), memory.clone())
500            .target_mask_attn(mask_attn);
501
502        // Normal forward using masking.
503        let output_1 = transformer.forward(input);
504
505        // Forward using the autoregressive cache.
506        let mut output_2 = Vec::new();
507        let mut cache = transformer.new_autoregressive_cache();
508
509        for i in 1..seq_length + 1 {
510            let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
511
512            let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
513            let input = TransformerDecoderInput::new(target.clone(), memory.clone())
514                .target_mask_attn(mask_attn);
515            let next_tok = transformer // Greedy sampling
516                .forward_autoregressive_inference(input, &mut cache)
517                .slice([0..batch_size, i - 1..i, 0..d_model]);
518            output_2.push(next_tok);
519        }
520
521        let output_2 = Tensor::cat(output_2, 1);
522
523        // Should produce the same tokens.
524        output_1
525            .into_data()
526            .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::default());
527    }
528
529    #[test]
530    fn display() {
531        let config = TransformerDecoderConfig::new(2, 4, 2, 3);
532        let transformer = config.init::<TestBackend>(&Default::default());
533
534        assert_eq!(
535            alloc::format!("{transformer}"),
536            "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
537            dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
538        );
539    }
540}