burn_core/nn/transformer/
encoder.rs

1use crate::tensor::Bool;
2use alloc::vec::Vec;
3
4use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::{
7    self as burn,
8    nn::{Initializer, attention::MhaCache, cache::TensorCache},
9};
10use crate::{
11    config::Config,
12    nn::{
13        Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
14        attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
15    },
16    tensor::{Tensor, backend::Backend},
17};
18
19/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).
20#[derive(Config)]
21pub struct TransformerEncoderConfig {
22    /// The size of the model.
23    pub d_model: usize,
24    /// The size of the position-wise feed-forward network.
25    pub d_ff: usize,
26    /// The number of attention heads.
27    pub n_heads: usize,
28    /// The number of layers.
29    pub n_layers: usize,
30    /// The dropout rate. Default: 0.1
31    #[config(default = 0.1)]
32    pub dropout: f64,
33    /// Layer norm will be applied first instead of after the other modules.
34    #[config(default = false)]
35    pub norm_first: bool,
36    /// Use "quiet softmax" instead of regular softmax.
37    ///
38    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
39    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
40    ///
41    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
42    #[config(default = false)]
43    pub quiet_softmax: bool,
44    /// The type of function used to initialize neural network parameters
45    #[config(
46        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
47    )]
48    pub initializer: Initializer,
49}
50
51/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
52///
53/// # Params
54///
55/// - layers: transformer encoder layers with `d_model` input and output features.
56///
57/// Should be created using [TransformerEncoderConfig]
58#[derive(Module, Debug)]
59#[module(custom_display)]
60pub struct TransformerEncoder<B: Backend> {
61    /// The transformer encoder layers.
62    pub layers: Vec<TransformerEncoderLayer<B>>,
63
64    /// The size of the model.
65    pub d_model: usize,
66
67    /// The size of the position-wise feed-forward network.
68    pub d_ff: usize,
69
70    /// The number of attention heads.
71    pub n_heads: usize,
72
73    /// The number of layers.
74    pub n_layers: usize,
75
76    /// The dropout rate. Default: 0.1
77    pub dropout: f64,
78
79    /// Layer norm will be applied first instead of after the other modules.
80    pub norm_first: bool,
81
82    /// Use "quiet softmax" instead of regular softmax.
83    pub quiet_softmax: bool,
84}
85
86impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
87    fn custom_settings(&self) -> Option<DisplaySettings> {
88        DisplaySettings::new()
89            .with_new_line_after_attribute(false)
90            .optional()
91    }
92
93    fn custom_content(&self, content: Content) -> Option<Content> {
94        content
95            .add("d_model", &self.d_model)
96            .add("d_ff", &self.d_ff)
97            .add("n_heads", &self.n_heads)
98            .add("n_layers", &self.n_layers)
99            .add("dropout", &self.dropout)
100            .add("norm_first", &self.norm_first)
101            .add("quiet_softmax", &self.quiet_softmax)
102            .optional()
103    }
104}
105
106/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
107#[derive(Debug)]
108pub struct TransformerEncoderInput<B: Backend> {
109    tensor: Tensor<B, 3>,
110    mask_pad: Option<Tensor<B, 2, Bool>>,
111    mask_attn: Option<Tensor<B, 3, Bool>>,
112}
113
114impl<B: Backend> TransformerEncoderInput<B> {
115    /// Create a [transformer encoder](TransformerEncoder) input argument.
116    pub fn new(tensor: Tensor<B, 3>) -> Self {
117        Self {
118            tensor,
119            mask_pad: None,
120            mask_attn: None,
121        }
122    }
123
124    /// Register the padding mask.
125    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
126        self.mask_pad = Some(mask_pad);
127        self
128    }
129
130    /// Register the attention mask.
131    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
132        self.mask_attn = Some(mask_attn);
133        self
134    }
135}
136impl TransformerEncoderConfig {
137    /// Initialize a new [transformer encoder](TransformerEncoder) module.
138    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
139        let layers = (0..self.n_layers)
140            .map(|_| TransformerEncoderLayer::new(self, device))
141            .collect::<Vec<_>>();
142
143        TransformerEncoder {
144            layers,
145            d_model: self.d_model,
146            d_ff: self.d_ff,
147            n_heads: self.n_heads,
148            n_layers: self.n_layers,
149            dropout: self.dropout,
150            norm_first: self.norm_first,
151            quiet_softmax: self.quiet_softmax,
152        }
153    }
154}
155
156impl<B: Backend> TransformerEncoder<B> {
157    /// Applies the forward pass on the input tensor.
158    ///
159    /// # Shapes
160    ///
161    /// - tensor: `[batch_size, seq_length, d_model]`
162    /// - output: `[batch_size, seq_length, d_model]`
163    pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
164        let mut x = input.tensor;
165
166        for layer in self.layers.iter() {
167            x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
168        }
169
170        x
171    }
172    /// Applies the forward pass on the input tensor using autoregressive cache.
173    ///
174    /// # Shapes
175    ///
176    /// - tensor: `[batch_size, seq_length, d_model]`
177    /// - output: `[batch_size, seq_length, d_model]`
178    pub fn forward_autoregressive_inference(
179        &self,
180        input: TransformerEncoderInput<B>,
181        cache: &mut TransformerEncoderAutoregressiveCache<B>,
182    ) -> Tensor<B, 3> {
183        let mut x = input.tensor;
184
185        for i in 0..self.layers.len() {
186            let layer = self.layers.get(i).unwrap();
187            let cache = cache.layers.get_mut(i).unwrap();
188
189            x = layer.forward_autoregressive_inference(
190                x,
191                input.mask_pad.clone(),
192                input.mask_attn.clone(),
193                cache,
194            );
195        }
196
197        x
198    }
199
200    /// Create an empty autoregressive cache.
201    pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
202        TransformerEncoderAutoregressiveCache::empty(self.layers.len())
203    }
204}
205
206/// Transformer encoder layer module.
207#[derive(Module, Debug)]
208pub struct TransformerEncoderLayer<B: Backend> {
209    mha: MultiHeadAttention<B>,
210    pwff: PositionWiseFeedForward<B>,
211    norm_1: LayerNorm<B>,
212    norm_2: LayerNorm<B>,
213    dropout: Dropout,
214    norm_first: bool,
215}
216
217impl<B: Backend> TransformerEncoderLayer<B> {
218    fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
219        let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
220            .with_initializer(config.initializer.clone())
221            .with_dropout(config.dropout)
222            .with_quiet_softmax(config.quiet_softmax)
223            .init(device);
224        let norm_1 = LayerNormConfig::new(config.d_model).init(device);
225        let norm_2 = LayerNormConfig::new(config.d_model).init(device);
226        let dropout = DropoutConfig::new(config.dropout).init();
227        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
228            .with_initializer(config.initializer.clone())
229            .with_dropout(config.dropout)
230            .init(device);
231
232        Self {
233            mha,
234            norm_1,
235            norm_2,
236            pwff,
237            dropout,
238            norm_first: config.norm_first,
239        }
240    }
241
242    fn forward(
243        &self,
244        input: Tensor<B, 3>,
245        mask_pad: Option<Tensor<B, 2, Bool>>,
246        mask_attn: Option<Tensor<B, 3, Bool>>,
247    ) -> Tensor<B, 3> {
248        // Multi-head attention residual path.
249        let x = input;
250        let mut residual_path = x.clone();
251
252        // Normalize.
253        if self.norm_first {
254            residual_path = self.norm_2.forward(residual_path)
255        }
256
257        // Multi-head attention.
258        let mut input_mhs = MhaInput::self_attn(residual_path);
259        if let Some(mask_pad) = mask_pad {
260            input_mhs = input_mhs.mask_pad(mask_pad);
261        }
262        if let Some(mask_attn) = mask_attn {
263            input_mhs = input_mhs.mask_attn(mask_attn);
264        }
265        let residual_path = self.mha.forward(input_mhs).context;
266
267        let residual_path = self.dropout.forward(residual_path);
268        let mut x = x + residual_path;
269
270        // Feed forward residual path.
271        // Normalize.
272        let residual_path = if self.norm_first {
273            self.norm_1.forward(x.clone())
274        } else {
275            x = self.norm_1.forward(x);
276            x.clone()
277        };
278
279        // Feed forward.
280        let residual_path = self.pwff.forward(residual_path);
281        let residual_path = self.dropout.forward(residual_path);
282        let mut x = x + residual_path;
283
284        // Main path.
285        // Normalize.
286        if !self.norm_first {
287            x = self.norm_2.forward(x)
288        }
289
290        x
291    }
292
293    fn forward_autoregressive_inference(
294        &self,
295        input: Tensor<B, 3>,
296        mask_pad: Option<Tensor<B, 2, Bool>>,
297        mask_attn: Option<Tensor<B, 3, Bool>>,
298        cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
299    ) -> Tensor<B, 3> {
300        // Multi-head attention residual path.
301        let x = input;
302        let mut residual_path = x.clone();
303
304        // Normalize.
305        if self.norm_first {
306            residual_path = cache
307                .norm_2
308                .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
309        }
310
311        // Multi-head attention.
312        let mut input_mhs = MhaInput::self_attn(residual_path);
313        if let Some(mask_pad) = mask_pad {
314            input_mhs = input_mhs.mask_pad(mask_pad);
315        }
316        if let Some(mask_attn) = mask_attn {
317            input_mhs = input_mhs.mask_attn(mask_attn);
318        }
319        let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
320
321        let residual_path = self.dropout.forward(residual_path);
322        let mut x = x + residual_path;
323
324        // Feed forward residual path.
325        // Normalize.
326        let residual_path = if self.norm_first {
327            cache
328                .norm_1
329                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
330        } else {
331            x = cache
332                .norm_1
333                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
334            x.clone()
335        };
336
337        // Feed forward.
338        let residual_path = cache
339            .pwff
340            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
341        let residual_path = self.dropout.forward(residual_path);
342        let mut x = x + residual_path;
343
344        // Main path.
345        // Normalize.
346        if !self.norm_first {
347            x = cache
348                .norm_2
349                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
350        }
351
352        x
353    }
354}
355
356struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
357    mha: MhaCache<B>,
358    pwff: TensorCache<B, 3>,
359    norm_1: TensorCache<B, 3>,
360    norm_2: TensorCache<B, 3>,
361}
362
363impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
364    fn empty() -> Self {
365        Self {
366            mha: MhaCache::autoregressive(),
367            pwff: TensorCache::empty(),
368            norm_1: TensorCache::empty(),
369            norm_2: TensorCache::empty(),
370        }
371    }
372}
373
374/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
375///
376/// To be used during inference when decoding tokens.
377pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
378    layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
379}
380
381impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
382    fn empty(num_layers: usize) -> Self {
383        Self {
384            layers: (0..num_layers)
385                .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
386                .collect(),
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::tensor::Distribution;
395    use crate::{TestBackend, nn::attention::generate_autoregressive_mask};
396    use burn_tensor::{Tolerance, ops::FloatElem};
397    type FT = FloatElem<TestBackend>;
398
399    #[test]
400    fn test_autoregressive_norm_last() {
401        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
402        test_autoregressive(
403            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
404                .with_norm_first(false),
405        )
406    }
407
408    #[test]
409    fn test_autoregressive_norm_first() {
410        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
411        test_autoregressive(
412            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
413        )
414    }
415
416    fn test_autoregressive(config: TransformerEncoderConfig) {
417        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
418        let device = Default::default();
419        let transformer = config.init(&device);
420
421        let tensor = Tensor::<TestBackend, 3>::random(
422            [batch_size, seq_length, d_model],
423            Distribution::Default,
424            &device,
425        );
426        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
427        let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
428
429        let output_1 = transformer.forward(input);
430        let mut output_2 = Vec::new();
431        let mut cache = transformer.new_autoregressive_cache();
432
433        for i in 1..seq_length + 1 {
434            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
435            let input = TransformerEncoderInput::new(tensor.clone());
436            let next_tok = transformer
437                .forward_autoregressive_inference(input, &mut cache)
438                .slice([0..batch_size, i - 1..i, 0..d_model]);
439            output_2.push(next_tok);
440        }
441
442        let output_2 = Tensor::cat(output_2, 1);
443
444        output_1
445            .into_data()
446            .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::rel_abs(1e-4, 1e-4));
447    }
448
449    #[test]
450    fn display() {
451        let config = TransformerEncoderConfig::new(2, 4, 2, 3);
452        let transformer = config.init::<TestBackend>(&Default::default());
453
454        assert_eq!(
455            alloc::format!("{}", transformer),
456            "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
457            n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
458        );
459    }
460}