Skip to main content

burn_nn/modules/transformer/
pwff.rs

1use burn_core as burn;
2
3use crate::activation::{Activation, ActivationConfig};
4use crate::{Dropout, DropoutConfig, Linear, LinearConfig};
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::{Tensor, backend::Backend};
8
9/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).
10#[derive(Config, Debug)]
11pub struct PositionWiseFeedForwardConfig {
12    /// The size of the input and output features.
13    pub d_model: usize,
14    /// The size of the hidden inner features.
15    pub d_ff: usize,
16    /// The dropout rate. Default: 0.1
17    #[config(default = 0.1)]
18    pub dropout: f64,
19    /// The type of function used to initialize neural network parameters
20    #[config(
21        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
22    )]
23    pub initializer: Initializer,
24    /// The activation function used between the two linear layers. Default: Gelu
25    #[config(default = "ActivationConfig::Gelu")]
26    pub activation: ActivationConfig,
27}
28
29/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).
30///
31/// # Params
32///
33/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.
34/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
35///
36/// `FFN(x) = max(0, xW1 + b1)W2 + b2`
37///
38/// Should be created using [PositionWiseFeedForwardConfig]
39///
40/// # Notes
41///
42/// The `activation` field is currently marked `#[module(skip)]` for backward
43/// compatibility with records saved before this field was introduced (when
44/// the activation was always `Gelu` and had no state). This means activation
45/// state is **not persisted** when saving or loading records.
46///
47/// For stateless activations (GELU, ReLU, etc.) this has no effect.
48/// **If you are using `SwiGLU`, its learnable parameters will not be saved or
49/// loaded correctly.**
50#[derive(Module, Debug)]
51#[module(custom_display)]
52pub struct PositionWiseFeedForward<B: Backend> {
53    /// Linear layer with `d_model` input features and `d_ff` output features.
54    pub linear_inner: Linear<B>,
55    /// Linear layer with `d_ff` input features and `d_model` output features.
56    pub linear_outer: Linear<B>,
57    /// Dropout layer.
58    pub dropout: Dropout,
59    /// Activation function.
60    #[module(skip)] // for backward compatibility with previous `gelu` field name
61    pub activation: Activation<B>,
62}
63
64impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
65    fn custom_settings(&self) -> Option<DisplaySettings> {
66        DisplaySettings::new()
67            .with_new_line_after_attribute(false)
68            .optional()
69    }
70
71    fn custom_content(&self, content: Content) -> Option<Content> {
72        let [d_model, dff] = self.linear_inner.weight.shape().dims();
73
74        content
75            .add("d_model", &d_model)
76            .add("d_ff", &dff)
77            .add("prob", &self.dropout.prob)
78            .optional()
79    }
80}
81
82impl PositionWiseFeedForwardConfig {
83    /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
84    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
85        PositionWiseFeedForward {
86            linear_inner: LinearConfig::new(self.d_model, self.d_ff)
87                .with_initializer(self.initializer.clone())
88                .init(device),
89            linear_outer: LinearConfig::new(self.d_ff, self.d_model)
90                .with_initializer(self.initializer.clone())
91                .init(device),
92            dropout: DropoutConfig::new(self.dropout).init(),
93            activation: self.activation.init(device),
94        }
95    }
96}
97
98impl<B: Backend> PositionWiseFeedForward<B> {
99    /// Applies the forward pass on the input tensor.
100    ///
101    /// # Shapes
102    ///
103    /// - tensor: `[batch_size, seq_length, d_model]`
104    /// - output: `[batch_size, seq_length, d_model]`
105    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
106        let x = self.linear_inner.forward(input);
107        let x = self.activation.forward(x);
108        let x = self.dropout.forward(x);
109
110        self.linear_outer.forward(x)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::TestBackend;
118
119    #[test]
120    fn display() {
121        let config = PositionWiseFeedForwardConfig::new(2, 4);
122        let pwff = config.init::<TestBackend>(&Default::default());
123
124        assert_eq!(
125            alloc::format!("{pwff}"),
126            "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
127        );
128    }
129}