1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Content;
5use burn::module::DisplaySettings;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::ModuleDisplay;
9use burn::module::Param;
10use burn::tensor::Tensor;
11use burn::tensor::TensorPrimitive;
12use burn::tensor::backend::Backend;
13
14#[derive(Debug, Config)]
16pub struct LayerNormConfig {
17 pub d_model: usize,
19 #[config(default = 1e-5)]
21 pub epsilon: f64,
22 #[config(default = true)]
24 pub bias: bool,
25}
26
27#[derive(Module, Debug)]
39#[module(custom_display)]
40pub struct LayerNorm<B: Backend> {
41 pub gamma: Param<Tensor<B, 1>>,
43 pub beta: Option<Param<Tensor<B, 1>>>,
45 epsilon: f64,
47}
48
49impl LayerNormConfig {
50 pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
52 let gamma = Initializer::Ones.init([self.d_model], device);
53 let beta = if self.bias {
54 Some(Initializer::Zeros.init([self.d_model], device))
55 } else {
56 None
57 };
58
59 LayerNorm {
60 gamma,
61 beta,
62 epsilon: self.epsilon,
63 }
64 }
65}
66
67impl<B: Backend> LayerNorm<B> {
68 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
77 let gamma = self.gamma.val().into_primitive().tensor();
78 let beta = self
79 .beta
80 .as_ref()
81 .map(|b| b.val().into_primitive().tensor());
82
83 Tensor::from_primitive(TensorPrimitive::Float(B::layer_norm(
84 input.into_primitive().tensor(),
85 gamma,
86 beta,
87 self.epsilon,
88 )))
89 }
90}
91
92impl<B: Backend> ModuleDisplay for LayerNorm<B> {
93 fn custom_settings(&self) -> Option<DisplaySettings> {
94 DisplaySettings::new()
95 .with_new_line_after_attribute(false)
96 .optional()
97 }
98
99 fn custom_content(&self, content: Content) -> Option<Content> {
100 let [d_model] = self.gamma.shape().dims();
101 content
102 .add("d_model", &d_model)
103 .add("epsilon", &self.epsilon)
104 .add("bias", &self.beta.is_some())
105 .optional()
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use alloc::format;
113 use burn::tensor::TensorData;
114 use burn::tensor::{Tolerance, ops::FloatElem};
115 type FT = FloatElem<TestBackend>;
116
117 #[cfg(feature = "std")]
118 use crate::{TestAutodiffBackend, TestBackend};
119
120 #[cfg(not(feature = "std"))]
121 use crate::TestBackend;
122
123 #[test]
124 fn layer_norm_forward() {
125 let device = Default::default();
126 let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
127 let input = Tensor::<TestBackend, 2>::from_data(
128 TensorData::from([[
129 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
130 ]]),
131 &device,
132 );
133
134 let output = module.forward(input);
135
136 let expected = TensorData::from([[
137 -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
138 ]]);
139 output
140 .to_data()
141 .assert_approx_eq::<FT>(&expected, Tolerance::default());
142 }
143
144 #[test]
145 fn layer_norm_forward_large_epsilon() {
146 let device = Default::default();
147 let module = LayerNormConfig::new(10)
148 .with_epsilon(1e-1)
149 .init::<TestBackend>(&device);
150 let input = Tensor::<TestBackend, 2>::from_data(
151 TensorData::from([[
152 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
153 ]]),
154 &device,
155 );
156
157 let output = module.forward(input);
158
159 let expected = TensorData::from([[
160 -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
161 ]]);
162 output
163 .to_data()
164 .assert_approx_eq::<FT>(&expected, Tolerance::default());
165 }
166
167 #[test]
168 fn layer_norm_forward_no_bias() {
169 let device = Default::default();
170 let module = LayerNormConfig::new(10)
171 .with_bias(false)
172 .init::<TestBackend>(&device);
173 let input = Tensor::<TestBackend, 2>::from_data(
174 TensorData::from([[
175 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
176 ]]),
177 &device,
178 );
179
180 let output = module.forward(input);
181
182 let expected = TensorData::from([[
186 -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
187 ]]);
188 output
189 .to_data()
190 .assert_approx_eq::<FT>(&expected, Tolerance::default());
191 }
192
193 #[cfg(feature = "std")]
194 #[test]
195 fn layer_norm_backward() {
196 let device = Default::default();
197 let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
198 let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
199 TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
200 &device,
201 )
202 .require_grad();
203 let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
204 TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
205 &device,
206 )
207 .require_grad();
208
209 let x = tensor_1.clone().matmul(tensor_2.clone());
210
211 let output = module.forward(x);
212 let grads = output.backward();
213
214 let tensor_1_grad = tensor_1.grad(&grads).unwrap();
215 let tensor_2_grad = tensor_2.grad(&grads).unwrap();
216 let gamma_grad = module.gamma.grad(&grads).unwrap();
217 let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();
218
219 let expected = TensorData::from([-2.0, 2.0]);
220 gamma_grad
221 .to_data()
222 .assert_approx_eq::<FT>(&expected, Tolerance::default());
223
224 let expected = TensorData::from([2.0, 2.0]);
225 beta_grad
226 .to_data()
227 .assert_approx_eq::<FT>(&expected, Tolerance::default());
228
229 let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
230 tensor_1_grad
231 .to_data()
232 .assert_approx_eq::<FT>(&expected, Tolerance::default());
233
234 let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
235 tensor_2_grad
236 .to_data()
237 .assert_approx_eq::<FT>(&expected, Tolerance::default());
238 }
239
240 #[test]
241 fn display() {
242 let config = LayerNormConfig::new(6);
243 let layer_norm = config.init::<TestBackend>(&Default::default());
244
245 assert_eq!(
246 format!("{layer_norm}"),
247 "LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}"
248 );
249 }
250
251 #[test]
252 fn display_no_bias() {
253 let config = LayerNormConfig::new(6).with_bias(false);
254 let layer_norm = config.init::<TestBackend>(&Default::default());
255
256 assert_eq!(
257 format!("{layer_norm}"),
258 "LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}"
259 );
260 }
261}