Skip to main content

burn_nn/modules/norm/
layer.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Content;
5use burn::module::DisplaySettings;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::ModuleDisplay;
9use burn::module::Param;
10use burn::tensor::Tensor;
11use burn::tensor::TensorPrimitive;
12use burn::tensor::backend::Backend;
13
14/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).
15#[derive(Debug, Config)]
16pub struct LayerNormConfig {
17    /// The size of the input features.
18    pub d_model: usize,
19    /// A value required for numerical stability. Default: 1e-5
20    #[config(default = 1e-5)]
21    pub epsilon: f64,
22    /// If a bias (beta) should be applied during the normalization. Default: true
23    #[config(default = true)]
24    pub bias: bool,
25}
26
27/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).
28///
29/// `Y = norm(X) * γ + β`
30///
31/// Where:
32/// - `X` is the input tensor
33/// - `Y` is the output tensor
34/// - `γ` is the learnable weight (scale)
35/// - `β` is the learnable bias (optional)
36///
37/// Should be created using [LayerNormConfig](LayerNormConfig).
38#[derive(Module, Debug)]
39#[module(custom_display)]
40pub struct LayerNorm<B: Backend> {
41    /// The learnable weight (scale).
42    pub gamma: Param<Tensor<B, 1>>,
43    /// The learnable bias (optional).
44    pub beta: Option<Param<Tensor<B, 1>>>,
45    /// A value required for numerical stability.
46    epsilon: f64,
47}
48
49impl LayerNormConfig {
50    /// Initialize a new [layer norm](LayerNorm) module.
51    pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
52        let gamma = Initializer::Ones.init([self.d_model], device);
53        let beta = if self.bias {
54            Some(Initializer::Zeros.init([self.d_model], device))
55        } else {
56            None
57        };
58
59        LayerNorm {
60            gamma,
61            beta,
62            epsilon: self.epsilon,
63        }
64    }
65}
66
67impl<B: Backend> LayerNorm<B> {
68    /// Applies the forward pass on the input tensor.
69    ///
70    /// See the [LayerNorm](LayerNorm) documentation for more information.
71    ///
72    /// # Shapes
73    ///
74    /// - input: `[..., any, d_model]`
75    /// - output: `[..., any, d_model]`
76    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
77        let gamma = self.gamma.val().into_primitive().tensor();
78        let beta = self
79            .beta
80            .as_ref()
81            .map(|b| b.val().into_primitive().tensor());
82
83        Tensor::from_primitive(TensorPrimitive::Float(B::layer_norm(
84            input.into_primitive().tensor(),
85            gamma,
86            beta,
87            self.epsilon,
88        )))
89    }
90}
91
92impl<B: Backend> ModuleDisplay for LayerNorm<B> {
93    fn custom_settings(&self) -> Option<DisplaySettings> {
94        DisplaySettings::new()
95            .with_new_line_after_attribute(false)
96            .optional()
97    }
98
99    fn custom_content(&self, content: Content) -> Option<Content> {
100        let [d_model] = self.gamma.shape().dims();
101        content
102            .add("d_model", &d_model)
103            .add("epsilon", &self.epsilon)
104            .add("bias", &self.beta.is_some())
105            .optional()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use alloc::format;
113    use burn::tensor::TensorData;
114    use burn::tensor::{Tolerance, ops::FloatElem};
115    type FT = FloatElem<TestBackend>;
116
117    #[cfg(feature = "std")]
118    use crate::{TestAutodiffBackend, TestBackend};
119
120    #[cfg(not(feature = "std"))]
121    use crate::TestBackend;
122
123    #[test]
124    fn layer_norm_forward() {
125        let device = Default::default();
126        let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
127        let input = Tensor::<TestBackend, 2>::from_data(
128            TensorData::from([[
129                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
130            ]]),
131            &device,
132        );
133
134        let output = module.forward(input);
135
136        let expected = TensorData::from([[
137            -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
138        ]]);
139        output
140            .to_data()
141            .assert_approx_eq::<FT>(&expected, Tolerance::default());
142    }
143
144    #[test]
145    fn layer_norm_forward_large_epsilon() {
146        let device = Default::default();
147        let module = LayerNormConfig::new(10)
148            .with_epsilon(1e-1)
149            .init::<TestBackend>(&device);
150        let input = Tensor::<TestBackend, 2>::from_data(
151            TensorData::from([[
152                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
153            ]]),
154            &device,
155        );
156
157        let output = module.forward(input);
158
159        let expected = TensorData::from([[
160            -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
161        ]]);
162        output
163            .to_data()
164            .assert_approx_eq::<FT>(&expected, Tolerance::default());
165    }
166
167    #[test]
168    fn layer_norm_forward_no_bias() {
169        let device = Default::default();
170        let module = LayerNormConfig::new(10)
171            .with_bias(false)
172            .init::<TestBackend>(&device);
173        let input = Tensor::<TestBackend, 2>::from_data(
174            TensorData::from([[
175                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
176            ]]),
177            &device,
178        );
179
180        let output = module.forward(input);
181
182        // With bias=false, output matches the bias=true case (beta is zero-initialized
183        // by default), confirming the `None` branch in the backend hook produces the
184        // pre-beta result.
185        let expected = TensorData::from([[
186            -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
187        ]]);
188        output
189            .to_data()
190            .assert_approx_eq::<FT>(&expected, Tolerance::default());
191    }
192
193    #[cfg(feature = "std")]
194    #[test]
195    fn layer_norm_backward() {
196        let device = Default::default();
197        let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
198        let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
199            TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
200            &device,
201        )
202        .require_grad();
203        let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
204            TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
205            &device,
206        )
207        .require_grad();
208
209        let x = tensor_1.clone().matmul(tensor_2.clone());
210
211        let output = module.forward(x);
212        let grads = output.backward();
213
214        let tensor_1_grad = tensor_1.grad(&grads).unwrap();
215        let tensor_2_grad = tensor_2.grad(&grads).unwrap();
216        let gamma_grad = module.gamma.grad(&grads).unwrap();
217        let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();
218
219        let expected = TensorData::from([-2.0, 2.0]);
220        gamma_grad
221            .to_data()
222            .assert_approx_eq::<FT>(&expected, Tolerance::default());
223
224        let expected = TensorData::from([2.0, 2.0]);
225        beta_grad
226            .to_data()
227            .assert_approx_eq::<FT>(&expected, Tolerance::default());
228
229        let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
230        tensor_1_grad
231            .to_data()
232            .assert_approx_eq::<FT>(&expected, Tolerance::default());
233
234        let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
235        tensor_2_grad
236            .to_data()
237            .assert_approx_eq::<FT>(&expected, Tolerance::default());
238    }
239
240    #[test]
241    fn display() {
242        let config = LayerNormConfig::new(6);
243        let layer_norm = config.init::<TestBackend>(&Default::default());
244
245        assert_eq!(
246            format!("{layer_norm}"),
247            "LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}"
248        );
249    }
250
251    #[test]
252    fn display_no_bias() {
253        let config = LayerNormConfig::new(6).with_bias(false);
254        let layer_norm = config.init::<TestBackend>(&Default::default());
255
256        assert_eq!(
257            format!("{layer_norm}"),
258            "LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}"
259        );
260    }
261}