burn_core/nn/norm/
layer.rs

1use crate as burn;
2use crate::config::Config;
3use crate::module::Content;
4use crate::module::DisplaySettings;
5use crate::module::Module;
6use crate::module::ModuleDisplay;
7use crate::module::Param;
8use crate::nn::Initializer;
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11
12/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).
13#[derive(Debug, Config)]
14pub struct LayerNormConfig {
15    /// The size of the input features.
16    pub d_model: usize,
17    /// A value required for numerical stability. Default: 1e-5
18    #[config(default = 1e-5)]
19    pub epsilon: f64,
20}
21
22/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).
23///
24/// `Y = norm(X) * γ + β`
25///
26/// Where:
27/// - `X` is the input tensor
28/// - `Y` is the output tensor
29/// - `γ` is the learnable weight
30/// - `β` is the learnable bias
31///
32/// Should be created using [LayerNormConfig](LayerNormConfig).
33#[derive(Module, Debug)]
34#[module(custom_display)]
35pub struct LayerNorm<B: Backend> {
36    /// The learnable weight.
37    pub gamma: Param<Tensor<B, 1>>,
38    /// The learnable bias.
39    pub beta: Param<Tensor<B, 1>>,
40    /// A value required for numerical stability.
41    epsilon: f64,
42}
43
44impl LayerNormConfig {
45    /// Initialize a new [layer norm](LayerNorm) module.
46    pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
47        let gamma = Initializer::Ones.init([self.d_model], device);
48        let beta = Initializer::Zeros.init([self.d_model], device);
49
50        LayerNorm {
51            gamma,
52            beta,
53            epsilon: self.epsilon,
54        }
55    }
56}
57
58impl<B: Backend> LayerNorm<B> {
59    /// Applies the forward pass on the input tensor.
60    ///
61    /// See the [LayerNorm](LayerNorm) documentation for more information.
62    ///
63    /// # Shapes
64    ///
65    /// - input: `[..., any, d_model]`
66    /// - output: `[..., any, d_model]`
67    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
68        let (var, mean) = input.clone().var_mean_bias(D - 1);
69
70        let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
71
72        input_normalized
73            .mul(self.gamma.val().unsqueeze())
74            .add(self.beta.val().unsqueeze())
75    }
76}
77
78impl<B: Backend> ModuleDisplay for LayerNorm<B> {
79    fn custom_settings(&self) -> Option<DisplaySettings> {
80        DisplaySettings::new()
81            .with_new_line_after_attribute(false)
82            .optional()
83    }
84
85    fn custom_content(&self, content: Content) -> Option<Content> {
86        let [d_model] = self.gamma.shape().dims();
87        content
88            .add("d_model", &d_model)
89            .add("epsilon", &self.epsilon)
90            .optional()
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::tensor::TensorData;
98    use alloc::format;
99    use burn_tensor::{Tolerance, ops::FloatElem};
100    type FT = FloatElem<TestBackend>;
101
102    #[cfg(feature = "std")]
103    use crate::{TestAutodiffBackend, TestBackend};
104
105    #[cfg(not(feature = "std"))]
106    use crate::TestBackend;
107
108    #[test]
109    fn layer_norm_forward() {
110        let device = Default::default();
111        let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
112        let input = Tensor::<TestBackend, 2>::from_data(
113            TensorData::from([[
114                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
115            ]]),
116            &device,
117        );
118
119        let output = module.forward(input);
120
121        let expected = TensorData::from([[
122            -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
123        ]]);
124        output
125            .to_data()
126            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-4, 1e-4));
127    }
128
129    #[test]
130    fn layer_norm_forward_large_epsilon() {
131        let device = Default::default();
132        let module = LayerNormConfig::new(10)
133            .with_epsilon(1e-1)
134            .init::<TestBackend>(&device);
135        let input = Tensor::<TestBackend, 2>::from_data(
136            TensorData::from([[
137                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
138            ]]),
139            &device,
140        );
141
142        let output = module.forward(input);
143
144        let expected = TensorData::from([[
145            -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
146        ]]);
147        output
148            .to_data()
149            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-4, 1e-4));
150    }
151
152    #[cfg(feature = "std")]
153    #[test]
154    fn layer_norm_backward() {
155        let device = Default::default();
156        let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
157        let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
158            TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
159            &device,
160        )
161        .require_grad();
162        let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
163            TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
164            &device,
165        )
166        .require_grad();
167
168        let x = tensor_1.clone().matmul(tensor_2.clone());
169
170        let output = module.forward(x);
171        let grads = output.backward();
172
173        let tensor_1_grad = tensor_1.grad(&grads).unwrap();
174        let tensor_2_grad = tensor_2.grad(&grads).unwrap();
175        let gamma_grad = module.gamma.grad(&grads).unwrap();
176        let beta_grad = module.beta.grad(&grads).unwrap();
177
178        let expected = TensorData::from([-2.0, 2.0]);
179        gamma_grad
180            .to_data()
181            .assert_approx_eq::<FT>(&expected, Tolerance::default());
182
183        let expected = TensorData::from([2.0, 2.0]);
184        beta_grad
185            .to_data()
186            .assert_approx_eq::<FT>(&expected, Tolerance::default());
187
188        let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
189        tensor_1_grad
190            .to_data()
191            .assert_approx_eq::<FT>(&expected, Tolerance::default());
192
193        let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
194        tensor_2_grad
195            .to_data()
196            .assert_approx_eq::<FT>(&expected, Tolerance::default());
197    }
198
199    #[test]
200    fn display() {
201        let config = LayerNormConfig::new(6);
202        let layer_norm = config.init::<TestBackend>(&Default::default());
203
204        assert_eq!(
205            format!("{}", layer_norm),
206            "LayerNorm {d_model: 6, epsilon: 0.00001, params: 12}"
207        );
208    }
209}