burn_nn/modules/norm/
normalization_wrapper.rs

1use burn_core as burn;
2
3use crate::{
4    BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig,
5    LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig,
6};
7use burn::prelude::{Config, Module};
8use burn::tensor::Tensor;
9use burn::tensor::backend::Backend;
10
11/// ['Normalization'] Configuration.
12///
13/// The enum is non-exhaustive to prepare for future additions.
14///
15/// Can be used as a generic configuration for normalization layers:
16/// * Construct a config with arbitrary input features (we suggest `0`).
17/// * Clone and match that config to the target input layer,
18///   using the [`NormalizationConfig::with_num_features()`] method.
19#[derive(Config, Debug)]
20#[non_exhaustive]
21pub enum NormalizationConfig {
22    /// ['BatchNorm'] Configuration.
23    Batch(BatchNormConfig),
24
25    /// ['GroupNorm'] Configuration.
26    Group(GroupNormConfig),
27
28    /// ['InstanceNorm'] Configuration.
29    Instance(InstanceNormConfig),
30
31    /// ['LayerNorm'] Configuration.
32    Layer(LayerNormConfig),
33
34    /// ['RmsNorm'] Configuration.
35    Rms(RmsNormConfig),
36}
37
38impl From<BatchNormConfig> for NormalizationConfig {
39    fn from(config: BatchNormConfig) -> Self {
40        Self::Batch(config)
41    }
42}
43
44impl From<GroupNormConfig> for NormalizationConfig {
45    fn from(config: GroupNormConfig) -> Self {
46        Self::Group(config)
47    }
48}
49
50impl From<InstanceNormConfig> for NormalizationConfig {
51    fn from(config: InstanceNormConfig) -> Self {
52        Self::Instance(config)
53    }
54}
55
56impl From<LayerNormConfig> for NormalizationConfig {
57    fn from(config: LayerNormConfig) -> Self {
58        Self::Layer(config)
59    }
60}
61
62impl From<RmsNormConfig> for NormalizationConfig {
63    fn from(config: RmsNormConfig) -> Self {
64        Self::Rms(config)
65    }
66}
67
68impl NormalizationConfig {
69    /// Initialize a ['Norm'] layer.
70    pub fn init<B: Backend>(&self, device: &B::Device) -> Normalization<B> {
71        match self {
72            NormalizationConfig::Batch(config) => config.init(device).into(),
73            NormalizationConfig::Group(config) => config.init(device).into(),
74            NormalizationConfig::Instance(config) => config.init(device).into(),
75            NormalizationConfig::Layer(config) => config.init(device).into(),
76            NormalizationConfig::Rms(config) => config.init(device).into(),
77        }
78    }
79
80    /// Set the number of features.
81    pub fn with_num_features(self, num_features: usize) -> Self {
82        match self {
83            NormalizationConfig::Batch(config) => BatchNormConfig {
84                num_features,
85                ..config
86            }
87            .into(),
88            NormalizationConfig::Group(config) => GroupNormConfig {
89                num_channels: num_features,
90                ..config
91            }
92            .into(),
93            NormalizationConfig::Instance(config) => InstanceNormConfig {
94                num_channels: num_features,
95                ..config
96            }
97            .into(),
98            NormalizationConfig::Layer(config) => LayerNormConfig {
99                d_model: num_features,
100                ..config
101            }
102            .into(),
103            NormalizationConfig::Rms(config) => RmsNormConfig {
104                d_model: num_features,
105                ..config
106            }
107            .into(),
108        }
109    }
110
111    /// Get the number of features.
112    pub fn num_features(&self) -> usize {
113        match self {
114            NormalizationConfig::Batch(config) => config.num_features,
115            NormalizationConfig::Group(config) => config.num_channels,
116            NormalizationConfig::Instance(config) => config.num_channels,
117            NormalizationConfig::Layer(config) => config.d_model,
118            NormalizationConfig::Rms(config) => config.d_model,
119        }
120    }
121}
122
123/// Normalization Layer Wrapper
124///
125/// Provides support for built-in ``burn::nn::norm`` norm layers:
126/// * [`Normalization::Batch`] - [`BatchNorm`]
127/// * [`Normalization::Group`] - [`GroupNorm`]
128/// * [`Normalization::Instance`] - [`InstanceNorm`]
129/// * [`Normalization::Layer`] - [`LayerNorm`]
130/// * [`Normalization::Rms`] - [`RmsNorm`]
131///
132/// The enum is non-exhaustive, to prepare for future additions.
133#[derive(Module, Debug)]
134#[non_exhaustive]
135pub enum Normalization<B: Backend> {
136    /// [`BatchNorm`] layer.
137    Batch(BatchNorm<B>),
138
139    /// [`GroupNorm`] layer.
140    Group(GroupNorm<B>),
141
142    /// ['InstanceNorm'] layer.
143    Instance(InstanceNorm<B>),
144
145    /// [`LayerNorm`] layer.
146    Layer(LayerNorm<B>),
147
148    /// ['RmsNorm'] layer.
149    Rms(RmsNorm<B>),
150}
151
152impl<B: Backend> From<BatchNorm<B>> for Normalization<B> {
153    fn from(layer: BatchNorm<B>) -> Self {
154        Self::Batch(layer)
155    }
156}
157
158impl<B: Backend> From<GroupNorm<B>> for Normalization<B> {
159    fn from(layer: GroupNorm<B>) -> Self {
160        Self::Group(layer)
161    }
162}
163
164impl<B: Backend> From<InstanceNorm<B>> for Normalization<B> {
165    fn from(layer: InstanceNorm<B>) -> Self {
166        Self::Instance(layer)
167    }
168}
169
170impl<B: Backend> From<LayerNorm<B>> for Normalization<B> {
171    fn from(layer: LayerNorm<B>) -> Self {
172        Self::Layer(layer)
173    }
174}
175
176impl<B: Backend> From<RmsNorm<B>> for Normalization<B> {
177    fn from(layer: RmsNorm<B>) -> Self {
178        Self::Rms(layer)
179    }
180}
181
182impl<B: Backend> Normalization<B> {
183    /// Applies normalization to a tensor.
184    ///
185    /// The normalization contract depends upon the wrapped norm layer;
186    /// but all norm layers assume an input of at least rank 2;
187    /// and produce an output of the same rank and shape.
188    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
189        match self {
190            Normalization::Batch(norm) => norm.forward(input),
191            Normalization::Group(norm) => norm.forward(input),
192            Normalization::Instance(norm) => norm.forward(input),
193            Normalization::Layer(norm) => norm.forward(input),
194            Normalization::Rms(norm) => norm.forward(input),
195        }
196    }
197
198    /// Get the number of features.
199    pub fn num_features(&self) -> usize {
200        match self {
201            Normalization::Batch(norm) => norm.gamma.shape().dims[0],
202            Normalization::Group(norm) => norm.num_channels,
203            Normalization::Instance(norm) => norm.num_channels,
204            Normalization::Layer(norm) => norm.gamma.shape().dims[0],
205            Normalization::Rms(norm) => norm.gamma.shape().dims[0],
206        }
207    }
208}
209
210#[cfg(feature = "std")]
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::TestAutodiffBackend;
215    use burn::tensor::{Tolerance, ops::FloatElem};
216    type FT = FloatElem<TestAutodiffBackend>;
217
218    #[test]
219    fn test_match_feature_size() {
220        let config: NormalizationConfig = BatchNormConfig::new(0).into();
221        assert_eq!(config.num_features(), 0);
222        let config = config.with_num_features(12);
223        assert_eq!(config.num_features(), 12);
224
225        let config: NormalizationConfig = GroupNormConfig::new(4, 0).into();
226        assert_eq!(config.num_features(), 0);
227        let config = config.with_num_features(12);
228        assert_eq!(config.num_features(), 12);
229
230        let config: NormalizationConfig = InstanceNormConfig::new(0).into();
231        assert_eq!(config.num_features(), 0);
232        let config = config.with_num_features(12);
233        assert_eq!(config.num_features(), 12);
234
235        let config: NormalizationConfig = LayerNormConfig::new(0).into();
236        assert_eq!(config.num_features(), 0);
237        let config = config.with_num_features(12);
238        assert_eq!(config.num_features(), 12);
239
240        let config: NormalizationConfig = RmsNormConfig::new(0).into();
241        assert_eq!(config.num_features(), 0);
242        let config = config.with_num_features(12);
243        assert_eq!(config.num_features(), 12);
244    }
245
246    #[test]
247    fn test_batch_norm() {
248        type B = TestAutodiffBackend;
249        let device = Default::default();
250
251        let num_features = 12;
252        let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
253
254        let config: NormalizationConfig = BatchNormConfig::new(12).into();
255
256        let layer: Normalization<B> = config.init(&device);
257        assert_eq!(layer.num_features(), 12);
258
259        let expected = match &layer {
260            Normalization::Batch(inner) => inner.forward(input.clone()),
261            _ => panic!("Unexpected layer type"),
262        };
263
264        let output = layer.forward(input);
265
266        output.to_data().assert_eq(&expected.to_data(), true);
267    }
268
269    #[test]
270    fn test_group_norm() {
271        type B = TestAutodiffBackend;
272        let device = Default::default();
273
274        let num_features = 12;
275        let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
276
277        let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into();
278
279        let layer: Normalization<B> = config.init(&device);
280        assert_eq!(layer.num_features(), 12);
281
282        let expected = match &layer {
283            Normalization::Group(inner) => inner.forward(input.clone()),
284            _ => panic!("Unexpected layer type"),
285        };
286
287        let output = layer.forward(input);
288
289        output
290            .to_data()
291            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
292    }
293
294    #[test]
295    fn test_instance_norm() {
296        type B = TestAutodiffBackend;
297        let device = Default::default();
298
299        let num_features = 12;
300        let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
301
302        let config: NormalizationConfig = InstanceNormConfig::new(num_features).into();
303
304        let layer: Normalization<B> = config.init(&device);
305        assert_eq!(layer.num_features(), 12);
306
307        let expected = match &layer {
308            Normalization::Instance(inner) => inner.forward(input.clone()),
309            _ => panic!("Unexpected layer type"),
310        };
311
312        let output = layer.forward(input);
313
314        output
315            .to_data()
316            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
317    }
318
319    #[test]
320    fn test_layer_norm() {
321        type B = TestAutodiffBackend;
322        let device = Default::default();
323
324        let num_features = 12;
325        let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
326
327        let config: NormalizationConfig = LayerNormConfig::new(num_features).into();
328
329        let layer: Normalization<B> = config.init(&device);
330        assert_eq!(layer.num_features(), 12);
331
332        let expected = match &layer {
333            Normalization::Layer(inner) => inner.forward(input.clone()),
334            _ => panic!("Unexpected layer type"),
335        };
336
337        let output = layer.forward(input);
338
339        output
340            .to_data()
341            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
342    }
343
344    #[test]
345    fn test_rms_norm() {
346        type B = TestAutodiffBackend;
347        let device = Default::default();
348
349        let num_features = 12;
350        let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
351
352        let config: NormalizationConfig = RmsNormConfig::new(num_features).into();
353
354        let layer: Normalization<B> = config.init(&device);
355        assert_eq!(layer.num_features(), 12);
356
357        let expected = match &layer {
358            Normalization::Rms(inner) => inner.forward(input.clone()),
359            _ => panic!("Unexpected layer type"),
360        };
361
362        let output = layer.forward(input);
363
364        output
365            .to_data()
366            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
367    }
368}