burn_nn/modules/transformer/
decoder.rs

1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::{Bool, Tensor, backend::Backend};
8
9use crate::cache::TensorCache;
10use crate::{
11    Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
12    attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
13};
14
15use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
16
17/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).
18#[derive(Config, Debug)]
19pub struct TransformerDecoderConfig {
20    /// The size of the model.
21    pub d_model: usize,
22    /// The size of the position-wise feed-forward network.
23    pub d_ff: usize,
24    /// The number of attention heads.
25    pub n_heads: usize,
26    /// The number of layers.
27    pub n_layers: usize,
28    /// The dropout rate. Default: 0.1
29    #[config(default = 0.1)]
30    pub dropout: f64,
31    /// Layer norm will be applied first instead of after the other modules.
32    #[config(default = false)]
33    pub norm_first: bool,
34    /// Use "quiet softmax" instead of regular softmax.
35    ///
36    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
37    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
38    ///
39    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
40    #[config(default = false)]
41    pub quiet_softmax: bool,
42    /// The type of function used to initialize neural network parameters
43    #[config(
44        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
45    )]
46    pub initializer: Initializer,
47}
48
49/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
50///
51/// # Params
52///
53/// - layers: transformer decoder layers with `d_model` input and output features.
54///
55/// Should be created using [TransformerDecoderConfig]
56#[derive(Module, Debug)]
57#[module(custom_display)]
58pub struct TransformerDecoder<B: Backend> {
59    /// Transformer decoder layers.
60    pub layers: Vec<TransformerDecoderLayer<B>>,
61
62    /// The size of the model.
63    pub d_model: usize,
64
65    /// The size of the position-wise feed-forward network.
66    pub d_ff: usize,
67
68    /// The number of attention heads.
69    pub n_heads: usize,
70
71    /// The number of layers.
72    pub n_layers: usize,
73
74    /// The dropout rate. Default: 0.1
75    pub dropout: f64,
76
77    /// Layer norm will be applied first instead of after the other modules.
78    pub norm_first: bool,
79
80    /// Use "quiet softmax" instead of regular softmax.
81    pub quiet_softmax: bool,
82}
83
84impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
85    fn custom_settings(&self) -> Option<DisplaySettings> {
86        DisplaySettings::new()
87            .with_new_line_after_attribute(false)
88            .optional()
89    }
90
91    fn custom_content(&self, content: Content) -> Option<Content> {
92        content
93            .add("d_model", &self.d_model)
94            .add("d_ff", &self.d_ff)
95            .add("n_heads", &self.n_heads)
96            .add("n_layers", &self.n_layers)
97            .add("dropout", &self.dropout)
98            .add("norm_first", &self.norm_first)
99            .add("quiet_softmax", &self.quiet_softmax)
100            .optional()
101    }
102}
103
104impl TransformerDecoderConfig {
105    /// Initialize a new [Transformer Decoder](TransformerDecoder) module.
106    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
107        let layers = (0..self.n_layers)
108            .map(|_| TransformerDecoderLayer::new(self, device))
109            .collect::<Vec<_>>();
110
111        TransformerDecoder {
112            layers,
113            d_model: self.d_model,
114            d_ff: self.d_ff,
115            n_heads: self.n_heads,
116            n_layers: self.n_layers,
117            dropout: self.dropout,
118            norm_first: self.norm_first,
119            quiet_softmax: self.quiet_softmax,
120        }
121    }
122}
123
124/// [Transformer Decoder](TransformerDecoder) forward pass input argument.
125#[derive(Debug)]
126pub struct TransformerDecoderInput<B: Backend> {
127    target: Tensor<B, 3>,
128    target_mask_pad: Option<Tensor<B, 2, Bool>>,
129    target_mask_attn: Option<Tensor<B, 3, Bool>>,
130    memory: Tensor<B, 3>,
131    memory_mask_pad: Option<Tensor<B, 2, Bool>>,
132    memory_mask_attn: Option<Tensor<B, 3, Bool>>,
133}
134
135impl<B: Backend> TransformerDecoderInput<B> {
136    /// Create a [transformer decoder](TransformerDecoder) input argument.
137    pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
138        Self {
139            target,
140            target_mask_pad: None,
141            target_mask_attn: None,
142            memory,
143            memory_mask_pad: None,
144            memory_mask_attn: None,
145        }
146    }
147
148    /// Register the memory padding mask.
149    pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
150        self.memory_mask_pad = Some(mask_pad);
151        self
152    }
153
154    /// Register the memory attention mask.
155    pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
156        self.memory_mask_attn = Some(mask_attn);
157        self
158    }
159
160    /// Register the target padding mask.
161    pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
162        self.target_mask_pad = Some(mask_pad);
163        self
164    }
165
166    /// Register the target attention mask.
167    pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
168        self.target_mask_attn = Some(mask_attn);
169        self
170    }
171}
172
173/// [Transformer Decoder](TransformerDecoder) layer module.
174#[derive(Module, Debug)]
175pub struct TransformerDecoderLayer<B: Backend> {
176    cross_attn: MultiHeadAttention<B>,
177    self_attn: MultiHeadAttention<B>,
178    pwff: PositionWiseFeedForward<B>,
179    norm_1: LayerNorm<B>,
180    norm_2: LayerNorm<B>,
181    norm_3: LayerNorm<B>,
182    dropout: Dropout,
183    norm_first: bool,
184}
185
186struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
187    cross_attn: MhaCache<B>,
188    self_attn: MhaCache<B>,
189    pwff: TensorCache<B, 3>,
190    norm_1: TensorCache<B, 3>,
191    norm_2: TensorCache<B, 3>,
192    norm_3: TensorCache<B, 3>,
193}
194
195impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
196    fn empty() -> Self {
197        Self {
198            cross_attn: MhaCache::autoregressive_cross_attention(),
199            self_attn: MhaCache::autoregressive(),
200            pwff: TensorCache::empty(),
201            norm_1: TensorCache::empty(),
202            norm_2: TensorCache::empty(),
203            norm_3: TensorCache::empty(),
204        }
205    }
206}
207
208/// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer.
209///
210/// To be used during inference when decoding tokens.
211pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
212    layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
213}
214
215impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
216    fn empty(num_layers: usize) -> Self {
217        Self {
218            layers: (0..num_layers)
219                .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
220                .collect(),
221        }
222    }
223}
224
225impl<B: Backend> TransformerDecoderLayer<B> {
226    fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
227        let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
228            .with_initializer(config.initializer.clone())
229            .with_dropout(config.dropout)
230            .with_quiet_softmax(config.quiet_softmax)
231            .init(device);
232
233        let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
234            .with_initializer(config.initializer.clone())
235            .with_dropout(config.dropout)
236            .with_quiet_softmax(config.quiet_softmax)
237            .init(device);
238        let norm_1 = LayerNormConfig::new(config.d_model).init(device);
239        let norm_2 = LayerNormConfig::new(config.d_model).init(device);
240        let norm_3 = LayerNormConfig::new(config.d_model).init(device);
241        let dropout = DropoutConfig::new(config.dropout).init();
242        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
243            .with_dropout(config.dropout)
244            .init(device);
245
246        Self {
247            cross_attn,
248            self_attn,
249            norm_1,
250            norm_2,
251            norm_3,
252            pwff,
253            dropout,
254            norm_first: config.norm_first,
255        }
256    }
257
258    /// Applies the TransformerDecoder forward pass to the input tensor.
259    fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
260        // Self attention residual path.
261        let x = input.target;
262        let mut residual_path = x.clone();
263
264        // Normalize.
265        if self.norm_first {
266            residual_path = self.norm_3.forward(residual_path);
267        }
268
269        // Self attention.
270        let mut self_attn_input = MhaInput::self_attn(residual_path);
271        if let Some(mask_pad) = &input.target_mask_pad {
272            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
273        }
274        if let Some(mask_attn) = &input.target_mask_attn {
275            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
276        }
277        let residual_path = self.self_attn.forward(self_attn_input).context;
278
279        let residual_path = self.dropout.forward(residual_path);
280        let mut x = x + residual_path;
281
282        // Cross attention residual path.
283        // Normalize.
284        let residual_path = if self.norm_first {
285            self.norm_1.forward(x.clone())
286        } else {
287            x = self.norm_1.forward(x);
288            x.clone()
289        };
290
291        // Cross attention.
292        let mut cross_attn_input =
293            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
294        if let Some(mask_pad) = &input.memory_mask_pad {
295            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
296        }
297        if let Some(mask_attn) = &input.memory_mask_attn {
298            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
299        }
300        let residual_path = self.cross_attn.forward(cross_attn_input).context;
301
302        let residual_path = self.dropout.forward(residual_path);
303        let mut x = x + residual_path;
304
305        // Feed forward residual path.
306        // Normalize.
307        let residual_path = if self.norm_first {
308            self.norm_2.forward(x.clone())
309        } else {
310            x = self.norm_2.forward(x);
311            x.clone()
312        };
313
314        let residual_path = self.pwff.forward(residual_path);
315        let residual_path = self.dropout.forward(residual_path);
316        let mut x = x + residual_path;
317
318        // Main path.
319        // Normalize.
320        if !self.norm_first {
321            x = self.norm_3.forward(x)
322        }
323
324        input.target = x;
325        input
326    }
327
328    fn forward_autoregressive_inference(
329        &self,
330        mut input: TransformerDecoderInput<B>,
331        cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
332    ) -> TransformerDecoderInput<B> {
333        // Self attention residual path.
334        let x = input.target;
335        let mut residual_path = x.clone();
336
337        // Normalize.
338        if self.norm_first {
339            residual_path = cache
340                .norm_3
341                .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
342        }
343
344        // Self attention.
345        let mut self_attn_input = MhaInput::self_attn(residual_path);
346        if let Some(mask_pad) = &input.target_mask_pad {
347            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
348        }
349        if let Some(mask_attn) = &input.target_mask_attn {
350            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
351        }
352        let residual_path = self
353            .self_attn
354            .forward_cache(self_attn_input, &mut cache.self_attn)
355            .context;
356
357        let residual_path = self.dropout.forward(residual_path);
358        let mut x = x + residual_path;
359
360        // Cross attention residual path.
361        // Normalize.
362        let residual_path = if self.norm_first {
363            cache
364                .norm_1
365                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
366        } else {
367            x = cache
368                .norm_1
369                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
370            x.clone()
371        };
372
373        // Cross attention.
374        let mut cross_attn_input =
375            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
376        if let Some(mask_pad) = &input.memory_mask_pad {
377            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
378        }
379        if let Some(mask_attn) = &input.memory_mask_attn {
380            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
381        }
382        let residual_path = self
383            .cross_attn
384            .forward_cache(cross_attn_input, &mut cache.cross_attn)
385            .context;
386
387        let residual_path = self.dropout.forward(residual_path);
388        let mut x = x + residual_path;
389
390        // Feed forward residual path.
391        // Normalize.
392        let residual_path = if self.norm_first {
393            cache
394                .norm_2
395                .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
396        } else {
397            x = cache
398                .norm_2
399                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
400            x.clone()
401        };
402
403        let residual_path = cache
404            .pwff
405            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
406        let residual_path = self.dropout.forward(residual_path);
407        let mut x = x + residual_path;
408
409        // Main path.
410        // Normalize.
411        if !self.norm_first {
412            x = cache
413                .norm_3
414                .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
415        }
416
417        input.target = x;
418        input
419    }
420}
421
422impl<B: Backend> TransformerDecoder<B> {
423    /// Applies the forward pass.
424    pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
425        for layer in self.layers.iter() {
426            input = layer.forward(input);
427        }
428
429        input.target
430    }
431
432    /// Applies the forward pass on the input using autoregressive cache.
433    pub fn forward_autoregressive_inference(
434        &self,
435        mut input: TransformerDecoderInput<B>,
436        cache: &mut TransformerDecoderAutoregressiveCache<B>,
437    ) -> Tensor<B, 3> {
438        for i in 0..self.layers.len() {
439            let layer = self.layers.get(i).unwrap();
440            let cache = cache.layers.get_mut(i).unwrap();
441
442            input = layer.forward_autoregressive_inference(input, cache);
443        }
444
445        input.target
446    }
447    /// Create an empty autoregressive cache.
448    pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
449        TransformerDecoderAutoregressiveCache::empty(self.layers.len())
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use burn::tensor::Device;
456
457    use super::*;
458    use crate::{TestBackend, attention::generate_autoregressive_mask};
459
460    use burn::tensor::{Tolerance, ops::FloatElem};
461    type FT = FloatElem<TestBackend>;
462
463    #[test]
464    fn test_autoregressive_norm_last() {
465        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
466        let device = Default::default();
467        TestBackend::seed(&device, 0);
468
469        test_autoregressive(
470            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
471                .with_norm_first(false),
472        )
473    }
474
475    #[test]
476    fn test_autoregressive_norm_first() {
477        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
478        let device = Default::default();
479        TestBackend::seed(&device, 0);
480
481        test_autoregressive(
482            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
483        )
484    }
485
486    fn test_autoregressive(config: TransformerDecoderConfig) {
487        let device: Device<TestBackend> = Default::default();
488        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
489        let transformer = config.init::<TestBackend>(&device);
490
491        let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
492            .float()
493            .reshape([batch_size, seq_length, d_model]);
494        let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
495            .float()
496            .reshape([batch_size, seq_length, d_model]);
497        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
498        let input = TransformerDecoderInput::new(target.clone(), memory.clone())
499            .target_mask_attn(mask_attn);
500
501        // Normal forward using masking.
502        let output_1 = transformer.forward(input);
503
504        // Forward using the autoregressive cache.
505        let mut output_2 = Vec::new();
506        let mut cache = transformer.new_autoregressive_cache();
507
508        for i in 1..seq_length + 1 {
509            let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
510
511            let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
512            let input = TransformerDecoderInput::new(target.clone(), memory.clone())
513                .target_mask_attn(mask_attn);
514            let next_tok = transformer // Greedy sampling
515                .forward_autoregressive_inference(input, &mut cache)
516                .slice([0..batch_size, i - 1..i, 0..d_model]);
517            output_2.push(next_tok);
518        }
519
520        let output_2 = Tensor::cat(output_2, 1);
521
522        // Should produce the same tokens.
523        let tolerance = Tolerance::rel_abs(5e-3, 1e-4);
524        output_1
525            .into_data()
526            .assert_approx_eq::<FT>(&output_2.into_data(), tolerance);
527    }
528
529    #[test]
530    fn display() {
531        let config = TransformerDecoderConfig::new(2, 4, 2, 3);
532        let transformer = config.init::<TestBackend>(&Default::default());
533
534        assert_eq!(
535            alloc::format!("{transformer}"),
536            "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
537            dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
538        );
539    }
540}