1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Content;
5use burn::module::DisplaySettings;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::ModuleDisplay;
9use burn::module::Param;
10use burn::tensor::Tensor;
11use burn::tensor::backend::Backend;
12
13#[derive(Debug, Config)]
15pub struct LayerNormConfig {
16 pub d_model: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21 #[config(default = true)]
23 pub bias: bool,
24}
25
26#[derive(Module, Debug)]
38#[module(custom_display)]
39pub struct LayerNorm<B: Backend> {
40 pub gamma: Param<Tensor<B, 1>>,
42 pub beta: Option<Param<Tensor<B, 1>>>,
44 epsilon: f64,
46}
47
48impl LayerNormConfig {
49 pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {
51 let gamma = Initializer::Ones.init([self.d_model], device);
52 let beta = if self.bias {
53 Some(Initializer::Zeros.init([self.d_model], device))
54 } else {
55 None
56 };
57
58 LayerNorm {
59 gamma,
60 beta,
61 epsilon: self.epsilon,
62 }
63 }
64}
65
66impl<B: Backend> LayerNorm<B> {
67 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
76 let (var, mean) = input.clone().var_mean_bias(D - 1);
77
78 let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
79
80 let output = input_normalized.mul(self.gamma.val().unsqueeze());
81
82 match &self.beta {
83 Some(beta) => output.add(beta.val().unsqueeze()),
84 None => output,
85 }
86 }
87}
88
89impl<B: Backend> ModuleDisplay for LayerNorm<B> {
90 fn custom_settings(&self) -> Option<DisplaySettings> {
91 DisplaySettings::new()
92 .with_new_line_after_attribute(false)
93 .optional()
94 }
95
96 fn custom_content(&self, content: Content) -> Option<Content> {
97 let [d_model] = self.gamma.shape().dims();
98 content
99 .add("d_model", &d_model)
100 .add("epsilon", &self.epsilon)
101 .add("bias", &self.beta.is_some())
102 .optional()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use alloc::format;
110 use burn::tensor::TensorData;
111 use burn::tensor::{Tolerance, ops::FloatElem};
112 type FT = FloatElem<TestBackend>;
113
114 #[cfg(feature = "std")]
115 use crate::{TestAutodiffBackend, TestBackend};
116
117 #[cfg(not(feature = "std"))]
118 use crate::TestBackend;
119
120 #[test]
121 fn layer_norm_forward() {
122 let device = Default::default();
123 let module = LayerNormConfig::new(10).init::<TestBackend>(&device);
124 let input = Tensor::<TestBackend, 2>::from_data(
125 TensorData::from([[
126 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
127 ]]),
128 &device,
129 );
130
131 let output = module.forward(input);
132
133 let expected = TensorData::from([[
134 -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,
135 ]]);
136 output
137 .to_data()
138 .assert_approx_eq::<FT>(&expected, Tolerance::default());
139 }
140
141 #[test]
142 fn layer_norm_forward_large_epsilon() {
143 let device = Default::default();
144 let module = LayerNormConfig::new(10)
145 .with_epsilon(1e-1)
146 .init::<TestBackend>(&device);
147 let input = Tensor::<TestBackend, 2>::from_data(
148 TensorData::from([[
149 -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
150 ]]),
151 &device,
152 );
153
154 let output = module.forward(input);
155
156 let expected = TensorData::from([[
157 -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
158 ]]);
159 output
160 .to_data()
161 .assert_approx_eq::<FT>(&expected, Tolerance::default());
162 }
163
164 #[cfg(feature = "std")]
165 #[test]
166 fn layer_norm_backward() {
167 let device = Default::default();
168 let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);
169 let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(
170 TensorData::from([[0.0, 1.0], [3.0, 4.0]]),
171 &device,
172 )
173 .require_grad();
174 let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(
175 TensorData::from([[6.0, 7.0], [9.0, 10.0]]),
176 &device,
177 )
178 .require_grad();
179
180 let x = tensor_1.clone().matmul(tensor_2.clone());
181
182 let output = module.forward(x);
183 let grads = output.backward();
184
185 let tensor_1_grad = tensor_1.grad(&grads).unwrap();
186 let tensor_2_grad = tensor_2.grad(&grads).unwrap();
187 let gamma_grad = module.gamma.grad(&grads).unwrap();
188 let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();
189
190 let expected = TensorData::from([-2.0, 2.0]);
191 gamma_grad
192 .to_data()
193 .assert_approx_eq::<FT>(&expected, Tolerance::default());
194
195 let expected = TensorData::from([2.0, 2.0]);
196 beta_grad
197 .to_data()
198 .assert_approx_eq::<FT>(&expected, Tolerance::default());
199
200 let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());
201 tensor_1_grad
202 .to_data()
203 .assert_approx_eq::<FT>(&expected, Tolerance::default());
204
205 let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());
206 tensor_2_grad
207 .to_data()
208 .assert_approx_eq::<FT>(&expected, Tolerance::default());
209 }
210
211 #[test]
212 fn display() {
213 let config = LayerNormConfig::new(6);
214 let layer_norm = config.init::<TestBackend>(&Default::default());
215
216 assert_eq!(
217 format!("{layer_norm}"),
218 "LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}"
219 );
220 }
221
222 #[test]
223 fn display_no_bias() {
224 let config = LayerNormConfig::new(6).with_bias(false);
225 let layer_norm = config.init::<TestBackend>(&Default::default());
226
227 assert_eq!(
228 format!("{layer_norm}"),
229 "LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}"
230 );
231 }
232}