Skip to main content

burn_nn/modules/transformer/
encoder.rs

1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
6use crate::{
7    Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
8    attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
9    cache::TensorCache,
10};
11use burn::config::Config;
12use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
13use burn::tensor::{Bool, Tensor, backend::Backend};
14
15/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).
16#[derive(Config, Debug)]
17pub struct TransformerEncoderConfig {
18    /// The size of the model.
19    pub d_model: usize,
20    /// The size of the position-wise feed-forward network.
21    pub d_ff: usize,
22    /// The number of attention heads.
23    pub n_heads: usize,
24    /// The number of layers.
25    pub n_layers: usize,
26    /// The dropout rate. Default: 0.1
27    #[config(default = 0.1)]
28    pub dropout: f64,
29    /// Layer norm will be applied first instead of after the other modules.
30    #[config(default = false)]
31    pub norm_first: bool,
32    /// Use "quiet softmax" instead of regular softmax.
33    ///
34    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
35    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
36    ///
37    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
38    #[config(default = false)]
39    pub quiet_softmax: bool,
40    /// The type of function used to initialize neural network parameters
41    #[config(
42        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
43    )]
44    pub initializer: Initializer,
45}
46
47/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
48///
49/// # Params
50///
51/// - layers: transformer encoder layers with `d_model` input and output features.
52///
53/// Should be created using [TransformerEncoderConfig]
54#[derive(Module, Debug)]
55#[module(custom_display)]
56pub struct TransformerEncoder<B: Backend> {
57    /// The transformer encoder layers.
58    pub layers: Vec<TransformerEncoderLayer<B>>,
59
60    /// The size of the model.
61    pub d_model: usize,
62
63    /// The size of the position-wise feed-forward network.
64    pub d_ff: usize,
65
66    /// The number of attention heads.
67    pub n_heads: usize,
68
69    /// The number of layers.
70    pub n_layers: usize,
71
72    /// The dropout rate. Default: 0.1
73    pub dropout: f64,
74
75    /// Layer norm will be applied first instead of after the other modules.
76    pub norm_first: bool,
77
78    /// Use "quiet softmax" instead of regular softmax.
79    pub quiet_softmax: bool,
80}
81
82impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
83    fn custom_settings(&self) -> Option<DisplaySettings> {
84        DisplaySettings::new()
85            .with_new_line_after_attribute(false)
86            .optional()
87    }
88
89    fn custom_content(&self, content: Content) -> Option<Content> {
90        content
91            .add("d_model", &self.d_model)
92            .add("d_ff", &self.d_ff)
93            .add("n_heads", &self.n_heads)
94            .add("n_layers", &self.n_layers)
95            .add("dropout", &self.dropout)
96            .add("norm_first", &self.norm_first)
97            .add("quiet_softmax", &self.quiet_softmax)
98            .optional()
99    }
100}
101
102/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
103#[derive(Debug)]
104pub struct TransformerEncoderInput<B: Backend> {
105    tensor: Tensor<B, 3>,
106    mask_pad: Option<Tensor<B, 2, Bool>>,
107    mask_attn: Option<Tensor<B, 3, Bool>>,
108}
109
110impl<B: Backend> TransformerEncoderInput<B> {
111    /// Create a [transformer encoder](TransformerEncoder) input argument.
112    pub fn new(tensor: Tensor<B, 3>) -> Self {
113        Self {
114            tensor,
115            mask_pad: None,
116            mask_attn: None,
117        }
118    }
119
120    /// Register the padding mask.
121    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
122        self.mask_pad = Some(mask_pad);
123        self
124    }
125
126    /// Register the attention mask.
127    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
128        self.mask_attn = Some(mask_attn);
129        self
130    }
131}
132impl TransformerEncoderConfig {
133    /// Initialize a new [transformer encoder](TransformerEncoder) module.
134    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
135        let layers = (0..self.n_layers)
136            .map(|_| TransformerEncoderLayer::new(self, device))
137            .collect::<Vec<_>>();
138
139        TransformerEncoder {
140            layers,
141            d_model: self.d_model,
142            d_ff: self.d_ff,
143            n_heads: self.n_heads,
144            n_layers: self.n_layers,
145            dropout: self.dropout,
146            norm_first: self.norm_first,
147            quiet_softmax: self.quiet_softmax,
148        }
149    }
150}
151
152impl<B: Backend> TransformerEncoder<B> {
153    /// Applies the forward pass on the input tensor.
154    ///
155    /// # Shapes
156    ///
157    /// - tensor: `[batch_size, seq_length, d_model]`
158    /// - output: `[batch_size, seq_length, d_model]`
159    pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
160        let mut x = input.tensor;
161
162        for layer in self.layers.iter() {
163            x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
164        }
165
166        x
167    }
168    /// Applies the forward pass on the input tensor using autoregressive cache.
169    ///
170    /// # Shapes
171    ///
172    /// - tensor: `[batch_size, seq_length, d_model]`
173    /// - output: `[batch_size, seq_length, d_model]`
174    pub fn forward_autoregressive_inference(
175        &self,
176        input: TransformerEncoderInput<B>,
177        cache: &mut TransformerEncoderAutoregressiveCache<B>,
178    ) -> Tensor<B, 3> {
179        let mut x = input.tensor;
180
181        for i in 0..self.layers.len() {
182            let layer = self.layers.get(i).unwrap();
183            let cache = cache.layers.get_mut(i).unwrap();
184
185            x = layer.forward_autoregressive_inference(
186                x,
187                input.mask_pad.clone(),
188                input.mask_attn.clone(),
189                cache,
190            );
191        }
192
193        x
194    }
195
196    /// Create an empty autoregressive cache.
197    pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
198        TransformerEncoderAutoregressiveCache::empty(self.layers.len())
199    }
200}
201
202/// Transformer encoder layer module.
203#[derive(Module, Debug)]
204pub struct TransformerEncoderLayer<B: Backend> {
205    /// Multi-head self-attention sub-layer.
206    pub mha: MultiHeadAttention<B>,
207    /// Position-wise feed-forward sub-layer.
208    pub pwff: PositionWiseFeedForward<B>,
209    /// Layer normalization applied around the feed-forward sub-layer.
210    pub norm_1: LayerNorm<B>,
211    /// Layer normalization applied around the attention sub-layer.
212    pub norm_2: LayerNorm<B>,
213    /// Dropout module applied to residual connections.
214    pub dropout: Dropout,
215    /// If `true`, apply layer normalization before sub-layers (pre-norm),
216    /// otherwise apply it after (post-norm).
217    pub norm_first: bool,
218}
219
220impl<B: Backend> TransformerEncoderLayer<B> {
221    fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
222        let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
223            .with_initializer(config.initializer.clone())
224            .with_dropout(config.dropout)
225            .with_quiet_softmax(config.quiet_softmax)
226            .init(device);
227        let norm_1 = LayerNormConfig::new(config.d_model).init(device);
228        let norm_2 = LayerNormConfig::new(config.d_model).init(device);
229        let dropout = DropoutConfig::new(config.dropout).init();
230        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
231            .with_initializer(config.initializer.clone())
232            .with_dropout(config.dropout)
233            .init(device);
234
235        Self {
236            mha,
237            norm_1,
238            norm_2,
239            pwff,
240            dropout,
241            norm_first: config.norm_first,
242        }
243    }
244
245    fn forward(
246        &self,
247        input: Tensor<B, 3>,
248        mask_pad: Option<Tensor<B, 2, Bool>>,
249        mask_attn: Option<Tensor<B, 3, Bool>>,
250    ) -> Tensor<B, 3> {
251        // Multi-head attention residual path.
252        let x = input;
253        let mut residual_path = x.clone();
254
255        // Normalize.
256        if self.norm_first {
257            residual_path = self.norm_2.forward(residual_path)
258        }
259
260        // Multi-head attention.
261        let mut input_mhs = MhaInput::self_attn(residual_path);
262        if let Some(mask_pad) = mask_pad {
263            input_mhs = input_mhs.mask_pad(mask_pad);
264        }
265        if let Some(mask_attn) = mask_attn {
266            input_mhs = input_mhs.mask_attn(mask_attn);
267        }
268        let residual_path = self.mha.forward(input_mhs).context;
269
270        let residual_path = self.dropout.forward(residual_path);
271        let mut x = x + residual_path;
272
273        // Feed forward residual path.
274        // Normalize.
275        let residual_path = if self.norm_first {
276            self.norm_1.forward(x.clone())
277        } else {
278            x = self.norm_1.forward(x);
279            x.clone()
280        };
281
282        // Feed forward.
283        let residual_path = self.pwff.forward(residual_path);
284        let residual_path = self.dropout.forward(residual_path);
285        let mut x = x + residual_path;
286
287        // Main path.
288        // Normalize.
289        if !self.norm_first {
290            x = self.norm_2.forward(x)
291        }
292
293        x
294    }
295
296    fn forward_autoregressive_inference(
297        &self,
298        input: Tensor<B, 3>,
299        mask_pad: Option<Tensor<B, 2, Bool>>,
300        mask_attn: Option<Tensor<B, 3, Bool>>,
301        cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
302    ) -> Tensor<B, 3> {
303        // Multi-head attention residual path.
304        let x = input;
305        let mut residual_path = x.clone();
306
307        // Normalize.
308        if self.norm_first {
309            residual_path = cache
310                .norm_2
311                .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
312        }
313
314        // Multi-head attention.
315        let mut input_mhs = MhaInput::self_attn(residual_path);
316        if let Some(mask_pad) = mask_pad {
317            input_mhs = input_mhs.mask_pad(mask_pad);
318        }
319        if let Some(mask_attn) = mask_attn {
320            input_mhs = input_mhs.mask_attn(mask_attn);
321        }
322        let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
323
324        let residual_path = self.dropout.forward(residual_path);
325        let mut x = x + residual_path;
326
327        // Feed forward residual path.
328        // Normalize.
329        let residual_path = if self.norm_first {
330            cache
331                .norm_1
332                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
333        } else {
334            x = cache
335                .norm_1
336                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
337            x.clone()
338        };
339
340        // Feed forward.
341        let residual_path = cache
342            .pwff
343            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
344        let residual_path = self.dropout.forward(residual_path);
345        let mut x = x + residual_path;
346
347        // Main path.
348        // Normalize.
349        if !self.norm_first {
350            x = cache
351                .norm_2
352                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
353        }
354
355        x
356    }
357}
358
359struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
360    mha: MhaCache<B>,
361    pwff: TensorCache<B, 3>,
362    norm_1: TensorCache<B, 3>,
363    norm_2: TensorCache<B, 3>,
364}
365
366impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
367    fn empty() -> Self {
368        Self {
369            mha: MhaCache::autoregressive(),
370            pwff: TensorCache::empty(),
371            norm_1: TensorCache::empty(),
372            norm_2: TensorCache::empty(),
373        }
374    }
375}
376
377/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
378///
379/// To be used during inference when decoding tokens.
380pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
381    layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
382}
383
384impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
385    fn empty(num_layers: usize) -> Self {
386        Self {
387            layers: (0..num_layers)
388                .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
389                .collect(),
390        }
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::{TestBackend, attention::generate_autoregressive_mask};
398    use burn::tensor::Distribution;
399    use burn::tensor::{Tolerance, ops::FloatElem};
400    type FT = FloatElem<TestBackend>;
401
402    #[test]
403    fn test_autoregressive_norm_last() {
404        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
405        test_autoregressive(
406            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
407                .with_norm_first(false),
408        )
409    }
410
411    #[test]
412    fn test_autoregressive_norm_first() {
413        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
414        test_autoregressive(
415            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
416        )
417    }
418
419    fn test_autoregressive(config: TransformerEncoderConfig) {
420        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
421        let device = Default::default();
422        let transformer = config.init(&device);
423
424        let tensor = Tensor::<TestBackend, 3>::random(
425            [batch_size, seq_length, d_model],
426            Distribution::Default,
427            &device,
428        );
429        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
430        let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
431
432        let output_1 = transformer.forward(input);
433        let mut output_2 = Vec::new();
434        let mut cache = transformer.new_autoregressive_cache();
435
436        for i in 1..seq_length + 1 {
437            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
438            let input = TransformerEncoderInput::new(tensor.clone());
439            let next_tok = transformer
440                .forward_autoregressive_inference(input, &mut cache)
441                .slice([0..batch_size, i - 1..i, 0..d_model]);
442            output_2.push(next_tok);
443        }
444
445        let output_2 = Tensor::cat(output_2, 1);
446
447        output_1
448            .into_data()
449            .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());
450    }
451
452    #[test]
453    fn display() {
454        let config = TransformerEncoderConfig::new(2, 4, 2, 3);
455        let transformer = config.init::<TestBackend>(&Default::default());
456
457        assert_eq!(
458            alloc::format!("{transformer}"),
459            "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
460            n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
461        );
462    }
463}