Skip to main content

burn_nn/modules/norm/
layer.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Content;
5use burn::module::DisplaySettings;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::ModuleDisplay;
9use burn::module::Param;
10use burn::tensor::Tensor;
11use burn::tensor::backend::Backend;
12
13/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).
14#[derive(Debug, Config)]
15pub struct LayerNormConfig {
16    /// The size of the input features.
17    pub d_model: usize,
18    /// A value required for numerical stability. Default: 1e-5
19    #[config(default = 1e-5)]
20    pub epsilon: f64,
21    /// If a bias (beta) should be applied during the normalization. Default: true
22    #[config(default = true)]
23    pub bias: bool,
24}
25
26/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).
27///
28/// `Y = norm(X) * γ + β`
29///
30/// Where:
31/// - `X` is the input tensor
32/// - `Y` is the output tensor
33/// - `γ` is the learnable weight (scale)
34/// - `β` is the learnable bias (optional)
35///
36/// Should be created using [LayerNormConfig](LayerNormConfig).
37#[derive(Module, Debug)]
38#[module(custom_display)]
39pub struct LayerNorm<B: Backend> {
40    /// The learnable weight (scale).
41    pub gamma: Param<Tensor<B, 1>>,
42    /// The learnable bias (optional).
43    pub beta: Option<Param<Tensor<B, 1>>>,
44    /// A value required for numerical stability.
45    epsilon: f64,
46}
47
48impl LayerNormConfig {
49    /// Initialize a new [layer norm](LayerNorm) module.
50    pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
51        let gamma = Initializer::Ones.init([self.d_model], device);
52        let beta = if self.bias {
53            Some(Initializer::Zeros.init([self.d_model], device))
54        } else {
55            None
56        };
57
58        LayerNorm {
59            gamma,
60            beta,
61            epsilon: self.epsilon,
62        }
63    }
64}
65
66impl<B: Backend> LayerNorm<B> {
67    /// Applies the forward pass on the input tensor.
68    ///
69    /// See the [LayerNorm](LayerNorm) documentation for more information.
70    ///
71    /// # Shapes
72    ///
73    /// - input: `[..., any, d_model]`
74    /// - output: `[..., any, d_model]`
75    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
76        let (var, mean) = input.clone().var_mean_bias(D - 1);
77
78        let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
79
80        let output = input_normalized.mul(self.gamma.val().unsqueeze());
81
82        match &self.beta {
83            Some(beta) => output.add(beta.val().unsqueeze()),
84            None => output,
85        }
86    }
87}
88
89impl<B: Backend> ModuleDisplay for LayerNorm<B> {
90    fn custom_settings(&self) -> Option<DisplaySettings> {
91        DisplaySettings::new()
92            .with_new_line_after_attribute(false)
93            .optional()
94    }
95
96    fn custom_content(&self, content: Content) -> Option<Content> {
97        let [d_model] = self.gamma.shape().dims();
98        content
99            .add("d_model", &d_model)
100            .add("epsilon", &self.epsilon)
101            .add("bias", &self.beta.is_some())
102            .optional()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use alloc::format;
110    use burn::tensor::TensorData;
111    use burn::tensor::{Tolerance, ops::FloatElem};
112    type FT = FloatElem<TestBackend>;
113
114    #[cfg(feature = "std")]
115    use crate::{TestAutodiffBackend, TestBackend};
116
117    #[cfg(not(feature = "std"))]
118    use crate::TestBackend;
119
120    #[test]
121    fn layer_norm_forward() {
122        let device = Default::default();
123        let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
124        let input = Tensor::<TestBackend, 2>::from_data(
125            TensorData::from([[
126                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
127            ]]),
128            &device,
129        );
130
131        let output = module.forward(input);
132
133        let expected = TensorData::from([[
134            -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
135        ]]);
136        output
137            .to_data()
138            .assert_approx_eq::<FT>(&expected, Tolerance::default());
139    }
140
141    #[test]
142    fn layer_norm_forward_large_epsilon() {
143        let device = Default::default();
144        let module = LayerNormConfig::new(10)
145            .with_epsilon(1e-1)
146            .init::<TestBackend>(&device);
147        let input = Tensor::<TestBackend, 2>::from_data(
148            TensorData::from([[
149                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
150            ]]),
151            &device,
152        );
153
154        let output = module.forward(input);
155
156        let expected = TensorData::from([[
157            -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
158        ]]);
159        output
160            .to_data()
161            .assert_approx_eq::<FT>(&expected, Tolerance::default());
162    }
163
164    #[cfg(feature = "std")]
165    #[test]
166    fn layer_norm_backward() {
167        let device = Default::default();
168        let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
169        let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
170            TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
171            &device,
172        )
173        .require_grad();
174        let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
175            TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
176            &device,
177        )
178        .require_grad();
179
180        let x = tensor_1.clone().matmul(tensor_2.clone());
181
182        let output = module.forward(x);
183        let grads = output.backward();
184
185        let tensor_1_grad = tensor_1.grad(&grads).unwrap();
186        let tensor_2_grad = tensor_2.grad(&grads).unwrap();
187        let gamma_grad = module.gamma.grad(&grads).unwrap();
188        let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();
189
190        let expected = TensorData::from([-2.0, 2.0]);
191        gamma_grad
192            .to_data()
193            .assert_approx_eq::<FT>(&expected, Tolerance::default());
194
195        let expected = TensorData::from([2.0, 2.0]);
196        beta_grad
197            .to_data()
198            .assert_approx_eq::<FT>(&expected, Tolerance::default());
199
200        let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
201        tensor_1_grad
202            .to_data()
203            .assert_approx_eq::<FT>(&expected, Tolerance::default());
204
205        let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
206        tensor_2_grad
207            .to_data()
208            .assert_approx_eq::<FT>(&expected, Tolerance::default());
209    }
210
211    #[test]
212    fn display() {
213        let config = LayerNormConfig::new(6);
214        let layer_norm = config.init::<TestBackend>(&Default::default());
215
216        assert_eq!(
217            format!("{layer_norm}"),
218            "LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}"
219        );
220    }
221
222    #[test]
223    fn display_no_bias() {
224        let config = LayerNormConfig::new(6).with_bias(false);
225        let layer_norm = config.init::<TestBackend>(&Default::default());
226
227        assert_eq!(
228            format!("{layer_norm}"),
229            "LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}"
230        );
231    }
232}