burn_nn/modules/transformer/
pwff.rs

1use burn_core as burn;
2
3use crate::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig};
4use burn::config::Config;
5use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
6use burn::tensor::{Tensor, backend::Backend};
7
8/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).
9#[derive(Config, Debug)]
10pub struct PositionWiseFeedForwardConfig {
11    /// The size of the input and output features.
12    pub d_model: usize,
13    /// The size of the hidden inner features.
14    pub d_ff: usize,
15    /// The dropout rate. Default: 0.1
16    #[config(default = 0.1)]
17    pub dropout: f64,
18    /// The type of function used to initialize neural network parameters
19    #[config(
20        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
21    )]
22    pub initializer: Initializer,
23}
24
25/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).
26///
27/// # Params
28///
29/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.
30/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
31///
32/// `FFN(x) = max(0, xW1 + b1)W2 + b2`
33///
34/// Should be created using [PositionWiseFeedForwardConfig]
35#[derive(Module, Debug)]
36#[module(custom_display)]
37pub struct PositionWiseFeedForward<B: Backend> {
38    /// Linear layer with `d_model` input features and `d_ff` output features.
39    pub linear_inner: Linear<B>,
40    /// Linear layer with `d_ff` input features and `d_model` output features.
41    pub linear_outer: Linear<B>,
42    /// Dropout layer.
43    pub dropout: Dropout,
44    /// GELU activation function.
45    pub gelu: Gelu,
46}
47
48impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
49    fn custom_settings(&self) -> Option<DisplaySettings> {
50        DisplaySettings::new()
51            .with_new_line_after_attribute(false)
52            .optional()
53    }
54
55    fn custom_content(&self, content: Content) -> Option<Content> {
56        let [d_model, dff] = self.linear_inner.weight.shape().dims();
57
58        content
59            .add("d_model", &d_model)
60            .add("d_ff", &dff)
61            .add("prob", &self.dropout.prob)
62            .optional()
63    }
64}
65
66impl PositionWiseFeedForwardConfig {
67    /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
68    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
69        PositionWiseFeedForward {
70            linear_inner: LinearConfig::new(self.d_model, self.d_ff)
71                .with_initializer(self.initializer.clone())
72                .init(device),
73            linear_outer: LinearConfig::new(self.d_ff, self.d_model)
74                .with_initializer(self.initializer.clone())
75                .init(device),
76            dropout: DropoutConfig::new(self.dropout).init(),
77            gelu: Gelu::new(),
78        }
79    }
80}
81
82impl<B: Backend> PositionWiseFeedForward<B> {
83    /// Applies the forward pass on the input tensor.
84    ///
85    /// # Shapes
86    ///
87    /// - tensor: `[batch_size, seq_length, d_model]`
88    /// - output: `[batch_size, seq_length, d_model]`
89    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
90        let x = self.linear_inner.forward(input);
91        let x = self.gelu.forward(x);
92        let x = self.dropout.forward(x);
93
94        self.linear_outer.forward(x)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::TestBackend;
102
103    #[test]
104    fn display() {
105        let config = PositionWiseFeedForwardConfig::new(2, 4);
106        let pwff = config.init::<TestBackend>(&Default::default());
107
108        assert_eq!(
109            alloc::format!("{pwff}"),
110            "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
111        );
112    }
113}