1use crate as burn;
2use crate::config::Config;
3use crate::module::Content;
4use crate::module::DisplaySettings;
5use crate::module::Module;
6use crate::module::ModuleDisplay;
7use crate::module::Param;
8use crate::nn::Initializer;
9use crate::tensor::backend::Backend;
10use crate::tensor::Tensor;
11
12#[derive(Debug, Config)]
14pub struct LayerNormConfig {
15 pub d_model: usize,
17 #[config(default = 1e-5)]
19 pub epsilon: f64,
20}
21
22#[derive(Module, Debug)]
34#[module(custom_display)]
35pub struct LayerNorm<B: Backend> {
36 pub gamma: Param<Tensor<B, 1>>,
38 pub beta: Param<Tensor<B, 1>>,
40 epsilon: f64,
42}
43
44impl LayerNormConfig {
45 pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
47 let gamma = Initializer::Ones.init([self.d_model], device);
48 let beta = Initializer::Zeros.init([self.d_model], device);
49
50 LayerNorm {
51 gamma,
52 beta,
53 epsilon: self.epsilon,
54 }
55 }
56}
57
58impl<B: Backend> LayerNorm<B> {
59 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
68 let (var, mean) = input.clone().var_mean_bias(D - 1);
69
70 let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
71
72 input_normalized
73 .mul(self.gamma.val().unsqueeze())
74 .add(self.beta.val().unsqueeze())
75 }
76}
77
78impl<B: Backend> ModuleDisplay for LayerNorm<B> {
79 fn custom_settings(&self) -> Option<DisplaySettings> {
80 DisplaySettings::new()
81 .with_new_line_after_attribute(false)
82 .optional()
83 }
84
85 fn custom_content(&self, content: Content) -> Option<Content> {
86 let [d_model] = self.gamma.shape().dims();
87 content
88 .add("d_model", &d_model)
89 .add("epsilon", &self.epsilon)
90 .optional()
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::tensor::TensorData;
98 use alloc::format;
99
100 #[cfg(feature = "std")]
101 use crate::{TestAutodiffBackend, TestBackend};
102
103 #[cfg(not(feature = "std"))]
104 use crate::TestBackend;
105
106 #[test]
107 fn layer_norm_forward() {
108 let device = Default::default();
109 let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
110 let input = Tensor::<TestBackend, 2>::from_data(
111 TensorData::from([[
112 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
113 ]]),
114 &device,
115 );
116
117 let output = module.forward(input);
118
119 let expected = TensorData::from([[
120 -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
121 ]]);
122 output.to_data().assert_approx_eq(&expected, 3);
123 }
124
125 #[test]
126 fn layer_norm_forward_large_epsilon() {
127 let device = Default::default();
128 let module = LayerNormConfig::new(10)
129 .with_epsilon(1e-1)
130 .init::<TestBackend>(&device);
131 let input = Tensor::<TestBackend, 2>::from_data(
132 TensorData::from([[
133 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
134 ]]),
135 &device,
136 );
137
138 let output = module.forward(input);
139
140 let expected = TensorData::from([[
141 -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
142 ]]);
143 output.to_data().assert_approx_eq(&expected, 3);
144 }
145
146 #[cfg(feature = "std")]
147 #[test]
148 fn layer_norm_backward() {
149 let device = Default::default();
150 let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
151 let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
152 TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
153 &device,
154 )
155 .require_grad();
156 let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
157 TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
158 &device,
159 )
160 .require_grad();
161
162 let x = tensor_1.clone().matmul(tensor_2.clone());
163
164 let output = module.forward(x);
165 let grads = output.backward();
166
167 let tensor_1_grad = tensor_1.grad(&grads).unwrap();
168 let tensor_2_grad = tensor_2.grad(&grads).unwrap();
169 let gamma_grad = module.gamma.grad(&grads).unwrap();
170 let beta_grad = module.beta.grad(&grads).unwrap();
171
172 let expected = TensorData::from([-2.0, 2.0]);
173 gamma_grad.to_data().assert_approx_eq(&expected, 3);
174
175 let expected = TensorData::from([2.0, 2.0]);
176 beta_grad.to_data().assert_approx_eq(&expected, 3);
177
178 let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
179 tensor_1_grad.to_data().assert_approx_eq(&expected, 3);
180
181 let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
182 tensor_2_grad.to_data().assert_approx_eq(&expected, 3);
183 }
184
185 #[test]
186 fn display() {
187 let config = LayerNormConfig::new(6);
188 let layer_norm = config.init::<TestBackend>(&Default::default());
189
190 assert_eq!(
191 format!("{}", layer_norm),
192 "LayerNorm {d_model: 6, epsilon: 0.00001, params: 12}"
193 );
194 }
195}