1use burn_core as burn;
2
3use crate::{
4 BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig,
5 LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig,
6};
7use burn::prelude::{Config, Module};
8use burn::tensor::Tensor;
9use burn::tensor::backend::Backend;
10
11#[derive(Config, Debug)]
20#[non_exhaustive]
21pub enum NormalizationConfig {
22 Batch(BatchNormConfig),
24
25 Group(GroupNormConfig),
27
28 Instance(InstanceNormConfig),
30
31 Layer(LayerNormConfig),
33
34 Rms(RmsNormConfig),
36}
37
38impl From<BatchNormConfig> for NormalizationConfig {
39 fn from(config: BatchNormConfig) -> Self {
40 Self::Batch(config)
41 }
42}
43
44impl From<GroupNormConfig> for NormalizationConfig {
45 fn from(config: GroupNormConfig) -> Self {
46 Self::Group(config)
47 }
48}
49
50impl From<InstanceNormConfig> for NormalizationConfig {
51 fn from(config: InstanceNormConfig) -> Self {
52 Self::Instance(config)
53 }
54}
55
56impl From<LayerNormConfig> for NormalizationConfig {
57 fn from(config: LayerNormConfig) -> Self {
58 Self::Layer(config)
59 }
60}
61
62impl From<RmsNormConfig> for NormalizationConfig {
63 fn from(config: RmsNormConfig) -> Self {
64 Self::Rms(config)
65 }
66}
67
68impl NormalizationConfig {
69 pub fn init<B: Backend>(&self, device: &B::Device) -> Normalization<B> {
71 match self {
72 NormalizationConfig::Batch(config) => config.init(device).into(),
73 NormalizationConfig::Group(config) => config.init(device).into(),
74 NormalizationConfig::Instance(config) => config.init(device).into(),
75 NormalizationConfig::Layer(config) => config.init(device).into(),
76 NormalizationConfig::Rms(config) => config.init(device).into(),
77 }
78 }
79
80 pub fn with_num_features(self, num_features: usize) -> Self {
82 match self {
83 NormalizationConfig::Batch(config) => BatchNormConfig {
84 num_features,
85 ..config
86 }
87 .into(),
88 NormalizationConfig::Group(config) => GroupNormConfig {
89 num_channels: num_features,
90 ..config
91 }
92 .into(),
93 NormalizationConfig::Instance(config) => InstanceNormConfig {
94 num_channels: num_features,
95 ..config
96 }
97 .into(),
98 NormalizationConfig::Layer(config) => LayerNormConfig {
99 d_model: num_features,
100 ..config
101 }
102 .into(),
103 NormalizationConfig::Rms(config) => RmsNormConfig {
104 d_model: num_features,
105 ..config
106 }
107 .into(),
108 }
109 }
110
111 pub fn num_features(&self) -> usize {
113 match self {
114 NormalizationConfig::Batch(config) => config.num_features,
115 NormalizationConfig::Group(config) => config.num_channels,
116 NormalizationConfig::Instance(config) => config.num_channels,
117 NormalizationConfig::Layer(config) => config.d_model,
118 NormalizationConfig::Rms(config) => config.d_model,
119 }
120 }
121}
122
123#[derive(Module, Debug)]
134#[non_exhaustive]
135pub enum Normalization<B: Backend> {
136 Batch(BatchNorm<B>),
138
139 Group(GroupNorm<B>),
141
142 Instance(InstanceNorm<B>),
144
145 Layer(LayerNorm<B>),
147
148 Rms(RmsNorm<B>),
150}
151
152impl<B: Backend> From<BatchNorm<B>> for Normalization<B> {
153 fn from(layer: BatchNorm<B>) -> Self {
154 Self::Batch(layer)
155 }
156}
157
158impl<B: Backend> From<GroupNorm<B>> for Normalization<B> {
159 fn from(layer: GroupNorm<B>) -> Self {
160 Self::Group(layer)
161 }
162}
163
164impl<B: Backend> From<InstanceNorm<B>> for Normalization<B> {
165 fn from(layer: InstanceNorm<B>) -> Self {
166 Self::Instance(layer)
167 }
168}
169
170impl<B: Backend> From<LayerNorm<B>> for Normalization<B> {
171 fn from(layer: LayerNorm<B>) -> Self {
172 Self::Layer(layer)
173 }
174}
175
176impl<B: Backend> From<RmsNorm<B>> for Normalization<B> {
177 fn from(layer: RmsNorm<B>) -> Self {
178 Self::Rms(layer)
179 }
180}
181
182impl<B: Backend> Normalization<B> {
183 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
189 match self {
190 Normalization::Batch(norm) => norm.forward(input),
191 Normalization::Group(norm) => norm.forward(input),
192 Normalization::Instance(norm) => norm.forward(input),
193 Normalization::Layer(norm) => norm.forward(input),
194 Normalization::Rms(norm) => norm.forward(input),
195 }
196 }
197
198 pub fn num_features(&self) -> usize {
200 match self {
201 Normalization::Batch(norm) => norm.gamma.shape().dims[0],
202 Normalization::Group(norm) => norm.num_channels,
203 Normalization::Instance(norm) => norm.num_channels,
204 Normalization::Layer(norm) => norm.gamma.shape().dims[0],
205 Normalization::Rms(norm) => norm.gamma.shape().dims[0],
206 }
207 }
208}
209
210#[cfg(feature = "std")]
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::TestAutodiffBackend;
215 use burn::tensor::{Tolerance, ops::FloatElem};
216 type FT = FloatElem<TestAutodiffBackend>;
217
218 #[test]
219 fn test_match_feature_size() {
220 let config: NormalizationConfig = BatchNormConfig::new(0).into();
221 assert_eq!(config.num_features(), 0);
222 let config = config.with_num_features(12);
223 assert_eq!(config.num_features(), 12);
224
225 let config: NormalizationConfig = GroupNormConfig::new(4, 0).into();
226 assert_eq!(config.num_features(), 0);
227 let config = config.with_num_features(12);
228 assert_eq!(config.num_features(), 12);
229
230 let config: NormalizationConfig = InstanceNormConfig::new(0).into();
231 assert_eq!(config.num_features(), 0);
232 let config = config.with_num_features(12);
233 assert_eq!(config.num_features(), 12);
234
235 let config: NormalizationConfig = LayerNormConfig::new(0).into();
236 assert_eq!(config.num_features(), 0);
237 let config = config.with_num_features(12);
238 assert_eq!(config.num_features(), 12);
239
240 let config: NormalizationConfig = RmsNormConfig::new(0).into();
241 assert_eq!(config.num_features(), 0);
242 let config = config.with_num_features(12);
243 assert_eq!(config.num_features(), 12);
244 }
245
246 #[test]
247 fn test_batch_norm() {
248 type B = TestAutodiffBackend;
249 let device = Default::default();
250
251 let num_features = 12;
252 let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
253
254 let config: NormalizationConfig = BatchNormConfig::new(12).into();
255
256 let layer: Normalization<B> = config.init(&device);
257 assert_eq!(layer.num_features(), 12);
258
259 let expected = match &layer {
260 Normalization::Batch(inner) => inner.forward(input.clone()),
261 _ => panic!("Unexpected layer type"),
262 };
263
264 let output = layer.forward(input);
265
266 output.to_data().assert_eq(&expected.to_data(), true);
267 }
268
269 #[test]
270 fn test_group_norm() {
271 type B = TestAutodiffBackend;
272 let device = Default::default();
273
274 let num_features = 12;
275 let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
276
277 let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into();
278
279 let layer: Normalization<B> = config.init(&device);
280 assert_eq!(layer.num_features(), 12);
281
282 let expected = match &layer {
283 Normalization::Group(inner) => inner.forward(input.clone()),
284 _ => panic!("Unexpected layer type"),
285 };
286
287 let output = layer.forward(input);
288
289 output
290 .to_data()
291 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
292 }
293
294 #[test]
295 fn test_instance_norm() {
296 type B = TestAutodiffBackend;
297 let device = Default::default();
298
299 let num_features = 12;
300 let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
301
302 let config: NormalizationConfig = InstanceNormConfig::new(num_features).into();
303
304 let layer: Normalization<B> = config.init(&device);
305 assert_eq!(layer.num_features(), 12);
306
307 let expected = match &layer {
308 Normalization::Instance(inner) => inner.forward(input.clone()),
309 _ => panic!("Unexpected layer type"),
310 };
311
312 let output = layer.forward(input);
313
314 output
315 .to_data()
316 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
317 }
318
319 #[test]
320 fn test_layer_norm() {
321 type B = TestAutodiffBackend;
322 let device = Default::default();
323
324 let num_features = 12;
325 let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
326
327 let config: NormalizationConfig = LayerNormConfig::new(num_features).into();
328
329 let layer: Normalization<B> = config.init(&device);
330 assert_eq!(layer.num_features(), 12);
331
332 let expected = match &layer {
333 Normalization::Layer(inner) => inner.forward(input.clone()),
334 _ => panic!("Unexpected layer type"),
335 };
336
337 let output = layer.forward(input);
338
339 output
340 .to_data()
341 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
342 }
343
344 #[test]
345 fn test_rms_norm() {
346 type B = TestAutodiffBackend;
347 let device = Default::default();
348
349 let num_features = 12;
350 let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
351
352 let config: NormalizationConfig = RmsNormConfig::new(num_features).into();
353
354 let layer: Normalization<B> = config.init(&device);
355 assert_eq!(layer.num_features(), 12);
356
357 let expected = match &layer {
358 Normalization::Rms(inner) => inner.forward(input.clone()),
359 _ => panic!("Unexpected layer type"),
360 };
361
362 let output = layer.forward(input);
363
364 output
365 .to_data()
366 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
367 }
368}