Skip to main content

burn_nn/modules/transformer/
decoder.rs

1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::{Bool, Tensor, backend::Backend};
8
9use crate::activation::ActivationConfig;
10use crate::cache::TensorCache;
11use crate::{
12    Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
13    attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
14};
15
16use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
17
18/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).
19#[derive(Config, Debug)]
20pub struct TransformerDecoderConfig {
21    /// The size of the model.
22    pub d_model: usize,
23    /// The size of the position-wise feed-forward network.
24    pub d_ff: usize,
25    /// The number of attention heads.
26    pub n_heads: usize,
27    /// The number of layers.
28    pub n_layers: usize,
29    /// The dropout rate. Default: 0.1
30    #[config(default = 0.1)]
31    pub dropout: f64,
32    /// Layer norm will be applied first instead of after the other modules.
33    #[config(default = false)]
34    pub norm_first: bool,
35    /// Use "quiet softmax" instead of regular softmax.
36    ///
37    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
38    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
39    ///
40    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
41    #[config(default = false)]
42    pub quiet_softmax: bool,
43    /// The type of function used to initialize neural network parameters
44    #[config(
45        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
46    )]
47    pub initializer: Initializer,
48    /// The activation function used in the position-wise feed-forward network. Default: Gelu
49    #[config(default = "ActivationConfig::Gelu")]
50    pub activation: ActivationConfig,
51    /// The epsilon value for layer normalization. Default: 1e-5
52    #[config(default = 1e-5)]
53    pub layer_norm_eps: f64,
54}
55
56/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
57///
58/// # Params
59///
60/// - layers: transformer decoder layers with `d_model` input and output features.
61///
62/// Should be created using [TransformerDecoderConfig]
63#[derive(Module, Debug)]
64#[module(custom_display)]
65pub struct TransformerDecoder<B: Backend> {
66    /// Transformer decoder layers.
67    pub layers: Vec<TransformerDecoderLayer<B>>,
68
69    /// The size of the model.
70    pub d_model: usize,
71
72    /// The size of the position-wise feed-forward network.
73    pub d_ff: usize,
74
75    /// The number of attention heads.
76    pub n_heads: usize,
77
78    /// The number of layers.
79    pub n_layers: usize,
80
81    /// The dropout rate. Default: 0.1
82    pub dropout: f64,
83
84    /// Layer norm will be applied first instead of after the other modules.
85    pub norm_first: bool,
86
87    /// Use "quiet softmax" instead of regular softmax.
88    pub quiet_softmax: bool,
89}
90
91impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
92    fn custom_settings(&self) -> Option<DisplaySettings> {
93        DisplaySettings::new()
94            .with_new_line_after_attribute(false)
95            .optional()
96    }
97
98    fn custom_content(&self, content: Content) -> Option<Content> {
99        content
100            .add("d_model", &self.d_model)
101            .add("d_ff", &self.d_ff)
102            .add("n_heads", &self.n_heads)
103            .add("n_layers", &self.n_layers)
104            .add("dropout", &self.dropout)
105            .add("norm_first", &self.norm_first)
106            .add("quiet_softmax", &self.quiet_softmax)
107            .optional()
108    }
109}
110
111impl TransformerDecoderConfig {
112    /// Initialize a new [Transformer Decoder](TransformerDecoder) module.
113    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
114        let layers = (0..self.n_layers)
115            .map(|_| TransformerDecoderLayer::new(self, device))
116            .collect::<Vec<_>>();
117
118        TransformerDecoder {
119            layers,
120            d_model: self.d_model,
121            d_ff: self.d_ff,
122            n_heads: self.n_heads,
123            n_layers: self.n_layers,
124            dropout: self.dropout,
125            norm_first: self.norm_first,
126            quiet_softmax: self.quiet_softmax,
127        }
128    }
129}
130
131/// [Transformer Decoder](TransformerDecoder) forward pass input argument.
132#[derive(Debug)]
133pub struct TransformerDecoderInput<B: Backend> {
134    target: Tensor<B, 3>,
135    target_mask_pad: Option<Tensor<B, 2, Bool>>,
136    target_mask_attn: Option<Tensor<B, 3, Bool>>,
137    memory: Tensor<B, 3>,
138    memory_mask_pad: Option<Tensor<B, 2, Bool>>,
139    memory_mask_attn: Option<Tensor<B, 3, Bool>>,
140}
141
142impl<B: Backend> TransformerDecoderInput<B> {
143    /// Create a [transformer decoder](TransformerDecoder) input argument.
144    pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
145        Self {
146            target,
147            target_mask_pad: None,
148            target_mask_attn: None,
149            memory,
150            memory_mask_pad: None,
151            memory_mask_attn: None,
152        }
153    }
154
155    /// Register the memory padding mask.
156    pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
157        self.memory_mask_pad = Some(mask_pad);
158        self
159    }
160
161    /// Register the memory attention mask.
162    pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
163        self.memory_mask_attn = Some(mask_attn);
164        self
165    }
166
167    /// Register the target padding mask.
168    pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
169        self.target_mask_pad = Some(mask_pad);
170        self
171    }
172
173    /// Register the target attention mask.
174    pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
175        self.target_mask_attn = Some(mask_attn);
176        self
177    }
178}
179
180/// [Transformer Decoder](TransformerDecoder) layer module.
181#[derive(Module, Debug)]
182pub struct TransformerDecoderLayer<B: Backend> {
183    /// Cross-attention module.
184    pub cross_attn: MultiHeadAttention<B>,
185    /// Self-attention module.
186    pub self_attn: MultiHeadAttention<B>,
187    /// Position-wise feed-forward module.
188    pub pwff: PositionWiseFeedForward<B>,
189    /// First layer norm.
190    pub norm_1: LayerNorm<B>,
191    /// Second layer norm.
192    pub norm_2: LayerNorm<B>,
193    /// Third layer norm.
194    pub norm_3: LayerNorm<B>,
195    /// Dropout.
196    pub dropout: Dropout,
197    /// Whether to apply norm first.
198    pub norm_first: bool,
199}
200
201/// Autoregressive cache for a single [Transformer Decoder Layer](TransformerDecoderLayer).
202pub struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
203    /// Cross-attention cache.
204    pub cross_attn: MhaCache<B>,
205    /// Self-attention cache.
206    pub self_attn: MhaCache<B>,
207    /// Position-wise feed-forward cache.
208    pub pwff: TensorCache<B, 3>,
209    /// First layer norm cache.
210    pub norm_1: TensorCache<B, 3>,
211    /// Second layer norm cache.
212    pub norm_2: TensorCache<B, 3>,
213    /// Third layer norm cache.
214    pub norm_3: TensorCache<B, 3>,
215}
216
217impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
218    /// Create an empty cache.
219    pub fn empty() -> Self {
220        Self {
221            cross_attn: MhaCache::autoregressive_cross_attention(),
222            self_attn: MhaCache::autoregressive(),
223            pwff: TensorCache::empty(),
224            norm_1: TensorCache::empty(),
225            norm_2: TensorCache::empty(),
226            norm_3: TensorCache::empty(),
227        }
228    }
229}
230
231/// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer.
232///
233/// To be used during inference when decoding tokens.
234pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
235    layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
236}
237
238impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
239    fn empty(num_layers: usize) -> Self {
240        Self {
241            layers: (0..num_layers)
242                .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
243                .collect(),
244        }
245    }
246}
247
248impl<B: Backend> TransformerDecoderLayer<B> {
249    /// Create a new [TransformerDecoderLayer](TransformerDecoderLayer).
250    pub fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
251        let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
252            .with_initializer(config.initializer.clone())
253            .with_dropout(config.dropout)
254            .with_quiet_softmax(config.quiet_softmax)
255            .init(device);
256
257        let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
258            .with_initializer(config.initializer.clone())
259            .with_dropout(config.dropout)
260            .with_quiet_softmax(config.quiet_softmax)
261            .init(device);
262        let norm_1 = LayerNormConfig::new(config.d_model)
263            .with_epsilon(config.layer_norm_eps)
264            .init(device);
265        let norm_2 = LayerNormConfig::new(config.d_model)
266            .with_epsilon(config.layer_norm_eps)
267            .init(device);
268        let norm_3 = LayerNormConfig::new(config.d_model)
269            .with_epsilon(config.layer_norm_eps)
270            .init(device);
271        let dropout = DropoutConfig::new(config.dropout).init();
272        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
273            .with_initializer(config.initializer.clone())
274            .with_dropout(config.dropout)
275            .with_activation(config.activation.clone())
276            .init(device);
277
278        Self {
279            cross_attn,
280            self_attn,
281            norm_1,
282            norm_2,
283            norm_3,
284            pwff,
285            dropout,
286            norm_first: config.norm_first,
287        }
288    }
289
290    /// Applies the TransformerDecoder forward pass to the input tensor.
291    pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
292        // Self attention residual path.
293        let x = input.target;
294        let mut residual_path = x.clone();
295
296        // Normalize.
297        if self.norm_first {
298            residual_path = self.norm_3.forward(residual_path);
299        }
300
301        // Self attention.
302        let mut self_attn_input = MhaInput::self_attn(residual_path);
303        if let Some(mask_pad) = &input.target_mask_pad {
304            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
305        }
306        if let Some(mask_attn) = &input.target_mask_attn {
307            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
308        }
309        let residual_path = self.self_attn.forward(self_attn_input).context;
310
311        let residual_path = self.dropout.forward(residual_path);
312        let mut x = x + residual_path;
313
314        // Cross attention residual path.
315        // Normalize.
316        let residual_path = if self.norm_first {
317            self.norm_1.forward(x.clone())
318        } else {
319            x = self.norm_1.forward(x);
320            x.clone()
321        };
322
323        // Cross attention.
324        let mut cross_attn_input =
325            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
326        if let Some(mask_pad) = &input.memory_mask_pad {
327            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
328        }
329        if let Some(mask_attn) = &input.memory_mask_attn {
330            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
331        }
332        let residual_path = self.cross_attn.forward(cross_attn_input).context;
333
334        let residual_path = self.dropout.forward(residual_path);
335        let mut x = x + residual_path;
336
337        // Feed forward residual path.
338        // Normalize.
339        let residual_path = if self.norm_first {
340            self.norm_2.forward(x.clone())
341        } else {
342            x = self.norm_2.forward(x);
343            x.clone()
344        };
345
346        let residual_path = self.pwff.forward(residual_path);
347        let residual_path = self.dropout.forward(residual_path);
348        let mut x = x + residual_path;
349
350        // Main path.
351        // Normalize.
352        if !self.norm_first {
353            x = self.norm_3.forward(x)
354        }
355
356        input.target = x;
357        input
358    }
359
360    /// Applies the forward pass using an autoregressive cache.
361    pub fn forward_autoregressive_inference(
362        &self,
363        mut input: TransformerDecoderInput<B>,
364        cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
365    ) -> TransformerDecoderInput<B> {
366        // Self attention residual path.
367        let x = input.target;
368        let mut residual_path = x.clone();
369
370        // Normalize.
371        if self.norm_first {
372            residual_path = cache
373                .norm_3
374                .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
375        }
376
377        // Self attention.
378        let mut self_attn_input = MhaInput::self_attn(residual_path);
379        if let Some(mask_pad) = &input.target_mask_pad {
380            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
381        }
382        if let Some(mask_attn) = &input.target_mask_attn {
383            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
384        }
385        let residual_path = self
386            .self_attn
387            .forward_cache(self_attn_input, &mut cache.self_attn)
388            .context;
389
390        let residual_path = self.dropout.forward(residual_path);
391        let mut x = x + residual_path;
392
393        // Cross attention residual path.
394        // Normalize.
395        let residual_path = if self.norm_first {
396            cache
397                .norm_1
398                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
399        } else {
400            x = cache
401                .norm_1
402                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
403            x.clone()
404        };
405
406        // Cross attention.
407        let mut cross_attn_input =
408            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
409        if let Some(mask_pad) = &input.memory_mask_pad {
410            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
411        }
412        if let Some(mask_attn) = &input.memory_mask_attn {
413            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
414        }
415        let residual_path = self
416            .cross_attn
417            .forward_cache(cross_attn_input, &mut cache.cross_attn)
418            .context;
419
420        let residual_path = self.dropout.forward(residual_path);
421        let mut x = x + residual_path;
422
423        // Feed forward residual path.
424        // Normalize.
425        let residual_path = if self.norm_first {
426            cache
427                .norm_2
428                .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
429        } else {
430            x = cache
431                .norm_2
432                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
433            x.clone()
434        };
435
436        let residual_path = cache
437            .pwff
438            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
439        let residual_path = self.dropout.forward(residual_path);
440        let mut x = x + residual_path;
441
442        // Main path.
443        // Normalize.
444        if !self.norm_first {
445            x = cache
446                .norm_3
447                .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
448        }
449
450        input.target = x;
451        input
452    }
453}
454
455impl<B: Backend> TransformerDecoder<B> {
456    /// Applies the forward pass.
457    pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
458        for layer in self.layers.iter() {
459            input = layer.forward(input);
460        }
461
462        input.target
463    }
464
465    /// Applies the forward pass on the input using autoregressive cache.
466    pub fn forward_autoregressive_inference(
467        &self,
468        mut input: TransformerDecoderInput<B>,
469        cache: &mut TransformerDecoderAutoregressiveCache<B>,
470    ) -> Tensor<B, 3> {
471        for i in 0..self.layers.len() {
472            let layer = self.layers.get(i).unwrap();
473            let cache = cache.layers.get_mut(i).unwrap();
474
475            input = layer.forward_autoregressive_inference(input, cache);
476        }
477
478        input.target
479    }
480    /// Create an empty autoregressive cache.
481    pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
482        TransformerDecoderAutoregressiveCache::empty(self.layers.len())
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use burn::tensor::Device;
489
490    use super::*;
491    use crate::{TestBackend, attention::generate_autoregressive_mask};
492
493    use burn::tensor::{Tolerance, ops::FloatElem};
494    type FT = FloatElem<TestBackend>;
495
496    #[test]
497    fn test_autoregressive_norm_last() {
498        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
499        let device = Default::default();
500        TestBackend::seed(&device, 0);
501
502        test_autoregressive(
503            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
504                .with_norm_first(false),
505        )
506    }
507
508    #[test]
509    fn test_autoregressive_norm_first() {
510        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
511        let device = Default::default();
512        TestBackend::seed(&device, 0);
513
514        test_autoregressive(
515            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
516        )
517    }
518
519    fn test_autoregressive(config: TransformerDecoderConfig) {
520        let device: Device<TestBackend> = Default::default();
521        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
522        let transformer = config.init::<TestBackend>(&device);
523
524        let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
525            .float()
526            .reshape([batch_size, seq_length, d_model]);
527        let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
528            .float()
529            .reshape([batch_size, seq_length, d_model]);
530        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
531        let input = TransformerDecoderInput::new(target.clone(), memory.clone())
532            .target_mask_attn(mask_attn);
533
534        // Normal forward using masking.
535        let output_1 = transformer.forward(input);
536
537        // Forward using the autoregressive cache.
538        let mut output_2 = Vec::new();
539        let mut cache = transformer.new_autoregressive_cache();
540
541        for i in 1..seq_length + 1 {
542            let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
543
544            let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
545            let input = TransformerDecoderInput::new(target.clone(), memory.clone())
546                .target_mask_attn(mask_attn);
547            let next_tok = transformer // Greedy sampling
548                .forward_autoregressive_inference(input, &mut cache)
549                .slice([0..batch_size, i - 1..i, 0..d_model]);
550            output_2.push(next_tok);
551        }
552
553        let output_2 = Tensor::cat(output_2, 1);
554
555        // Should produce the same tokens.
556        let tolerance = Tolerance::rel_abs(5e-3, 1e-4);
557        output_1
558            .into_data()
559            .assert_approx_eq::<FT>(&output_2.into_data(), tolerance);
560    }
561
562    #[test]
563    fn display() {
564        let config = TransformerDecoderConfig::new(2, 4, 2, 3);
565        let transformer = config.init::<TestBackend>(&Default::default());
566
567        assert_eq!(
568            alloc::format!("{transformer}"),
569            "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
570            dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
571        );
572    }
573}