burn_core/nn/transformer/
pwff.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::nn::Initializer;
5use crate::{
6    config::Config,
7    nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig},
8    tensor::{Tensor, backend::Backend},
9};
10
11/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).
12#[derive(Config)]
13pub struct PositionWiseFeedForwardConfig {
14    /// The size of the input and output features.
15    pub d_model: usize,
16    /// The size of the hidden inner features.
17    pub d_ff: usize,
18    /// The dropout rate. Default: 0.1
19    #[config(default = 0.1)]
20    pub dropout: f64,
21    /// The type of function used to initialize neural network parameters
22    #[config(
23        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
24    )]
25    pub initializer: Initializer,
26}
27
28/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).
29///
30/// # Params
31///
32/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.
33/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
34///
35/// `FFN(x) = max(0, xW1 + b1)W2 + b2`
36///
37/// Should be created using [PositionWiseFeedForwardConfig]
38#[derive(Module, Debug)]
39#[module(custom_display)]
40pub struct PositionWiseFeedForward<B: Backend> {
41    /// Linear layer with `d_model` input features and `d_ff` output features.
42    pub linear_inner: Linear<B>,
43    /// Linear layer with `d_ff` input features and `d_model` output features.
44    pub linear_outer: Linear<B>,
45    /// Dropout layer.
46    pub dropout: Dropout,
47    /// GELU activation function.
48    pub gelu: Gelu,
49}
50
51impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
52    fn custom_settings(&self) -> Option<DisplaySettings> {
53        DisplaySettings::new()
54            .with_new_line_after_attribute(false)
55            .optional()
56    }
57
58    fn custom_content(&self, content: Content) -> Option<Content> {
59        let [d_model, dff] = self.linear_inner.weight.shape().dims();
60
61        content
62            .add("d_model", &d_model)
63            .add("d_ff", &dff)
64            .add("prob", &self.dropout.prob)
65            .optional()
66    }
67}
68
69impl PositionWiseFeedForwardConfig {
70    /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
71    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
72        PositionWiseFeedForward {
73            linear_inner: LinearConfig::new(self.d_model, self.d_ff)
74                .with_initializer(self.initializer.clone())
75                .init(device),
76            linear_outer: LinearConfig::new(self.d_ff, self.d_model)
77                .with_initializer(self.initializer.clone())
78                .init(device),
79            dropout: DropoutConfig::new(self.dropout).init(),
80            gelu: Gelu::new(),
81        }
82    }
83}
84
85impl<B: Backend> PositionWiseFeedForward<B> {
86    /// Applies the forward pass on the input tensor.
87    ///
88    /// # Shapes
89    ///
90    /// - tensor: `[batch_size, seq_length, d_model]`
91    /// - output: `[batch_size, seq_length, d_model]`
92    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
93        let x = self.linear_inner.forward(input);
94        let x = self.gelu.forward(x);
95        let x = self.dropout.forward(x);
96
97        self.linear_outer.forward(x)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::TestBackend;
105
106    #[test]
107    fn display() {
108        let config = PositionWiseFeedForwardConfig::new(2, 4);
109        let pwff = config.init::<TestBackend>(&Default::default());
110
111        assert_eq!(
112            alloc::format!("{}", pwff),
113            "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
114        );
115    }
116}