Skip to main content

burn_nn/modules/transformer/
encoder.rs

1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
6use crate::{
7    Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
8    activation::ActivationConfig,
9    attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
10    cache::TensorCache,
11};
12use burn::config::Config;
13use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
14use burn::tensor::{Bool, Tensor, backend::Backend};
15
16/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).
17#[derive(Config, Debug)]
18pub struct TransformerEncoderConfig {
19    /// The size of the model.
20    pub d_model: usize,
21    /// The size of the position-wise feed-forward network.
22    pub d_ff: usize,
23    /// The number of attention heads.
24    pub n_heads: usize,
25    /// The number of layers.
26    pub n_layers: usize,
27    /// The dropout rate. Default: 0.1
28    #[config(default = 0.1)]
29    pub dropout: f64,
30    /// Layer norm will be applied first instead of after the other modules.
31    #[config(default = false)]
32    pub norm_first: bool,
33    /// Use "quiet softmax" instead of regular softmax.
34    ///
35    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
36    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
37    ///
38    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
39    #[config(default = false)]
40    pub quiet_softmax: bool,
41    /// The type of function used to initialize neural network parameters
42    #[config(
43        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
44    )]
45    pub initializer: Initializer,
46    /// The activation function used in the position-wise feed-forward network. Default: Gelu
47    #[config(default = "ActivationConfig::Gelu")]
48    pub activation: ActivationConfig,
49    /// The epsilon value for layer normalization. Default: 1e-5
50    #[config(default = 1e-5)]
51    pub layer_norm_eps: f64,
52}
53
54/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
55///
56/// # Params
57///
58/// - layers: transformer encoder layers with `d_model` input and output features.
59///
60/// Should be created using [TransformerEncoderConfig]
61#[derive(Module, Debug)]
62#[module(custom_display)]
63pub struct TransformerEncoder<B: Backend> {
64    /// The transformer encoder layers.
65    pub layers: Vec<TransformerEncoderLayer<B>>,
66
67    /// The size of the model.
68    pub d_model: usize,
69
70    /// The size of the position-wise feed-forward network.
71    pub d_ff: usize,
72
73    /// The number of attention heads.
74    pub n_heads: usize,
75
76    /// The number of layers.
77    pub n_layers: usize,
78
79    /// The dropout rate. Default: 0.1
80    pub dropout: f64,
81
82    /// Layer norm will be applied first instead of after the other modules.
83    pub norm_first: bool,
84
85    /// Use "quiet softmax" instead of regular softmax.
86    pub quiet_softmax: bool,
87}
88
89impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
90    fn custom_settings(&self) -> Option<DisplaySettings> {
91        DisplaySettings::new()
92            .with_new_line_after_attribute(false)
93            .optional()
94    }
95
96    fn custom_content(&self, content: Content) -> Option<Content> {
97        content
98            .add("d_model", &self.d_model)
99            .add("d_ff", &self.d_ff)
100            .add("n_heads", &self.n_heads)
101            .add("n_layers", &self.n_layers)
102            .add("dropout", &self.dropout)
103            .add("norm_first", &self.norm_first)
104            .add("quiet_softmax", &self.quiet_softmax)
105            .optional()
106    }
107}
108
109/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
110#[derive(Debug)]
111pub struct TransformerEncoderInput<B: Backend> {
112    tensor: Tensor<B, 3>,
113    mask_pad: Option<Tensor<B, 2, Bool>>,
114    mask_attn: Option<Tensor<B, 3, Bool>>,
115}
116
117impl<B: Backend> TransformerEncoderInput<B> {
118    /// Create a [transformer encoder](TransformerEncoder) input argument.
119    pub fn new(tensor: Tensor<B, 3>) -> Self {
120        Self {
121            tensor,
122            mask_pad: None,
123            mask_attn: None,
124        }
125    }
126
127    /// Register the padding mask.
128    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
129        self.mask_pad = Some(mask_pad);
130        self
131    }
132
133    /// Register the attention mask.
134    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
135        self.mask_attn = Some(mask_attn);
136        self
137    }
138}
139impl TransformerEncoderConfig {
140    /// Initialize a new [transformer encoder](TransformerEncoder) module.
141    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
142        let layers = (0..self.n_layers)
143            .map(|_| TransformerEncoderLayer::new(self, device))
144            .collect::<Vec<_>>();
145
146        TransformerEncoder {
147            layers,
148            d_model: self.d_model,
149            d_ff: self.d_ff,
150            n_heads: self.n_heads,
151            n_layers: self.n_layers,
152            dropout: self.dropout,
153            norm_first: self.norm_first,
154            quiet_softmax: self.quiet_softmax,
155        }
156    }
157}
158
159impl<B: Backend> TransformerEncoder<B> {
160    /// Applies the forward pass on the input tensor.
161    ///
162    /// # Shapes
163    ///
164    /// - tensor: `[batch_size, seq_length, d_model]`
165    /// - output: `[batch_size, seq_length, d_model]`
166    pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
167        let mut x = input.tensor;
168
169        for layer in self.layers.iter() {
170            x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
171        }
172
173        x
174    }
175    /// Applies the forward pass on the input tensor using autoregressive cache.
176    ///
177    /// # Shapes
178    ///
179    /// - tensor: `[batch_size, seq_length, d_model]`
180    /// - output: `[batch_size, seq_length, d_model]`
181    pub fn forward_autoregressive_inference(
182        &self,
183        input: TransformerEncoderInput<B>,
184        cache: &mut TransformerEncoderAutoregressiveCache<B>,
185    ) -> Tensor<B, 3> {
186        let mut x = input.tensor;
187
188        for i in 0..self.layers.len() {
189            let layer = self.layers.get(i).unwrap();
190            let cache = cache.layers.get_mut(i).unwrap();
191
192            x = layer.forward_autoregressive_inference(
193                x,
194                input.mask_pad.clone(),
195                input.mask_attn.clone(),
196                cache,
197            );
198        }
199
200        x
201    }
202
203    /// Create an empty autoregressive cache.
204    pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
205        TransformerEncoderAutoregressiveCache::empty(self.layers.len())
206    }
207}
208
209/// Transformer encoder layer module.
210#[derive(Module, Debug)]
211pub struct TransformerEncoderLayer<B: Backend> {
212    /// Multi-head self-attention sub-layer.
213    pub mha: MultiHeadAttention<B>,
214    /// Position-wise feed-forward sub-layer.
215    pub pwff: PositionWiseFeedForward<B>,
216    /// Layer normalization applied around the feed-forward sub-layer.
217    pub norm_1: LayerNorm<B>,
218    /// Layer normalization applied around the attention sub-layer.
219    pub norm_2: LayerNorm<B>,
220    /// Dropout module applied to residual connections.
221    pub dropout: Dropout,
222    /// If `true`, apply layer normalization before sub-layers (pre-norm),
223    /// otherwise apply it after (post-norm).
224    pub norm_first: bool,
225}
226
227impl<B: Backend> TransformerEncoderLayer<B> {
228    /// Create a new transformer encoder layer from the given configuration.
229    pub fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
230        let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
231            .with_initializer(config.initializer.clone())
232            .with_dropout(config.dropout)
233            .with_quiet_softmax(config.quiet_softmax)
234            .init(device);
235        let norm_1 = LayerNormConfig::new(config.d_model)
236            .with_epsilon(config.layer_norm_eps)
237            .init(device);
238        let norm_2 = LayerNormConfig::new(config.d_model)
239            .with_epsilon(config.layer_norm_eps)
240            .init(device);
241        let dropout = DropoutConfig::new(config.dropout).init();
242        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
243            .with_initializer(config.initializer.clone())
244            .with_dropout(config.dropout)
245            .with_activation(config.activation.clone())
246            .init(device);
247
248        Self {
249            mha,
250            norm_1,
251            norm_2,
252            pwff,
253            dropout,
254            norm_first: config.norm_first,
255        }
256    }
257
258    /// Applies the forward pass on the input tensor.
259    ///
260    /// # Shapes
261    ///
262    /// - input: `[batch_size, seq_length, d_model]`
263    /// - output: `[batch_size, seq_length, d_model]`
264    pub fn forward(
265        &self,
266        input: Tensor<B, 3>,
267        mask_pad: Option<Tensor<B, 2, Bool>>,
268        mask_attn: Option<Tensor<B, 3, Bool>>,
269    ) -> Tensor<B, 3> {
270        // Multi-head attention residual path.
271        let x = input;
272        let mut residual_path = x.clone();
273
274        // Normalize.
275        if self.norm_first {
276            residual_path = self.norm_2.forward(residual_path)
277        }
278
279        // Multi-head attention.
280        let mut input_mhs = MhaInput::self_attn(residual_path);
281        if let Some(mask_pad) = mask_pad {
282            input_mhs = input_mhs.mask_pad(mask_pad);
283        }
284        if let Some(mask_attn) = mask_attn {
285            input_mhs = input_mhs.mask_attn(mask_attn);
286        }
287        let residual_path = self.mha.forward(input_mhs).context;
288
289        let residual_path = self.dropout.forward(residual_path);
290        let mut x = x + residual_path;
291
292        // Feed forward residual path.
293        // Normalize.
294        let residual_path = if self.norm_first {
295            self.norm_1.forward(x.clone())
296        } else {
297            x = self.norm_1.forward(x);
298            x.clone()
299        };
300
301        // Feed forward.
302        let residual_path = self.pwff.forward(residual_path);
303        let residual_path = self.dropout.forward(residual_path);
304        let mut x = x + residual_path;
305
306        // Main path.
307        // Normalize.
308        if !self.norm_first {
309            x = self.norm_2.forward(x)
310        }
311
312        x
313    }
314
315    /// Applies the forward pass using an autoregressive cache.
316    pub fn forward_autoregressive_inference(
317        &self,
318        input: Tensor<B, 3>,
319        mask_pad: Option<Tensor<B, 2, Bool>>,
320        mask_attn: Option<Tensor<B, 3, Bool>>,
321        cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
322    ) -> Tensor<B, 3> {
323        // Multi-head attention residual path.
324        let x = input;
325        let mut residual_path = x.clone();
326
327        // Normalize.
328        if self.norm_first {
329            residual_path = cache
330                .norm_2
331                .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
332        }
333
334        // Multi-head attention.
335        let mut input_mhs = MhaInput::self_attn(residual_path);
336        if let Some(mask_pad) = mask_pad {
337            input_mhs = input_mhs.mask_pad(mask_pad);
338        }
339        if let Some(mask_attn) = mask_attn {
340            input_mhs = input_mhs.mask_attn(mask_attn);
341        }
342        let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
343
344        let residual_path = self.dropout.forward(residual_path);
345        let mut x = x + residual_path;
346
347        // Feed forward residual path.
348        // Normalize.
349        let residual_path = if self.norm_first {
350            cache
351                .norm_1
352                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
353        } else {
354            x = cache
355                .norm_1
356                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
357            x.clone()
358        };
359
360        // Feed forward.
361        let residual_path = cache
362            .pwff
363            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
364        let residual_path = self.dropout.forward(residual_path);
365        let mut x = x + residual_path;
366
367        // Main path.
368        // Normalize.
369        if !self.norm_first {
370            x = cache
371                .norm_2
372                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
373        }
374
375        x
376    }
377}
378
379/// Autoregressive cache for a single [Transformer Encoder Layer](TransformerEncoderLayer).
380pub struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
381    /// Multi-head attention cache.
382    pub mha: MhaCache<B>,
383    /// Position-wise feed-forward cache.
384    pub pwff: TensorCache<B, 3>,
385    /// First layer norm cache.
386    pub norm_1: TensorCache<B, 3>,
387    /// Second layer norm cache.
388    pub norm_2: TensorCache<B, 3>,
389}
390
391impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
392    /// Create an empty cache.
393    pub fn empty() -> Self {
394        Self {
395            mha: MhaCache::autoregressive(),
396            pwff: TensorCache::empty(),
397            norm_1: TensorCache::empty(),
398            norm_2: TensorCache::empty(),
399        }
400    }
401}
402
403/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
404///
405/// To be used during inference when decoding tokens.
406pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
407    layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
408}
409
410impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
411    fn empty(num_layers: usize) -> Self {
412        Self {
413            layers: (0..num_layers)
414                .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
415                .collect(),
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::{TestBackend, attention::generate_autoregressive_mask};
424    use burn::tensor::Distribution;
425    use burn::tensor::{Tolerance, ops::FloatElem};
426    type FT = FloatElem<TestBackend>;
427
428    #[test]
429    fn test_autoregressive_norm_last() {
430        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
431        test_autoregressive(
432            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
433                .with_norm_first(false),
434        )
435    }
436
437    #[test]
438    fn test_autoregressive_norm_first() {
439        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
440        test_autoregressive(
441            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
442        )
443    }
444
445    fn test_autoregressive(config: TransformerEncoderConfig) {
446        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
447        let device = Default::default();
448        let transformer = config.init(&device);
449
450        let tensor = Tensor::<TestBackend, 3>::random(
451            [batch_size, seq_length, d_model],
452            Distribution::Default,
453            &device,
454        );
455        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
456        let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
457
458        let output_1 = transformer.forward(input);
459        let mut output_2 = Vec::new();
460        let mut cache = transformer.new_autoregressive_cache();
461
462        for i in 1..seq_length + 1 {
463            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
464            let input = TransformerEncoderInput::new(tensor.clone());
465            let next_tok = transformer
466                .forward_autoregressive_inference(input, &mut cache)
467                .slice([0..batch_size, i - 1..i, 0..d_model]);
468            output_2.push(next_tok);
469        }
470
471        let output_2 = Tensor::cat(output_2, 1);
472
473        output_1
474            .into_data()
475            .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());
476    }
477
478    #[test]
479    fn display() {
480        let config = TransformerEncoderConfig::new(2, 4, 2, 3);
481        let transformer = config.init::<TestBackend>(&Default::default());
482
483        assert_eq!(
484            alloc::format!("{transformer}"),
485            "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
486            n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
487        );
488    }
489}