burn_core/nn/norm/
group.rs

1use crate as burn;
2use crate::nn::Initializer;
3
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, ModuleDisplay};
8use crate::tensor::backend::Backend;
9use crate::tensor::Tensor;
10
11/// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init).
12#[derive(Debug, Config)]
13pub struct GroupNormConfig {
14    /// The number of groups to separate the channels into
15    pub num_groups: usize,
16    /// The number of channels expected in the input
17    pub num_channels: usize,
18    /// A value required for numerical stability. Default: 1e-5
19    #[config(default = 1e-5)]
20    pub epsilon: f64,
21    /// A boolean value that when set to `true`, this module has learnable
22    /// per-channel affine parameters initialized to ones (for weights)
23    /// and zeros (for biases). Default: `true`
24    #[config(default = true)]
25    pub affine: bool,
26}
27
28/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
29///
30/// `Y = groupnorm(X) * γ + β`
31///
32/// Where:
33/// - `X` is the input tensor
34/// - `Y` is the output tensor
35/// - `γ` is the learnable weight
36/// - `β` is the learnable bias
37///
38/// Should be created using [GroupNormConfig](GroupNormConfig).
39#[derive(Module, Debug)]
40#[module(custom_display)]
41pub struct GroupNorm<B: Backend> {
42    /// The learnable weight
43    pub gamma: Option<Param<Tensor<B, 1>>>,
44    /// The learnable bias
45    pub beta: Option<Param<Tensor<B, 1>>>,
46    /// The number of groups to separate the channels into
47    pub num_groups: usize,
48    /// The number of channels expected in the input
49    pub num_channels: usize,
50    /// A value required for numerical stability
51    pub epsilon: f64,
52    /// A boolean value that when set to `true`, this module has learnable
53    pub affine: bool,
54}
55
56impl<B: Backend> ModuleDisplay for GroupNorm<B> {
57    fn custom_settings(&self) -> Option<DisplaySettings> {
58        DisplaySettings::new()
59            .with_new_line_after_attribute(false)
60            .optional()
61    }
62
63    fn custom_content(&self, content: Content) -> Option<Content> {
64        content
65            .add("num_groups", &self.num_groups)
66            .add("num_channels", &self.num_channels)
67            .add("epsilon", &self.epsilon)
68            .add("affine", &self.affine)
69            .optional()
70    }
71}
72
73impl GroupNormConfig {
74    /// Initialize a new [group norm](GroupNorm) module.
75    pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
76        assert_eq!(
77            self.num_channels % self.num_groups,
78            0,
79            "The number of channels must be divisible by the number of groups"
80        );
81
82        let (gamma, beta) = if self.affine {
83            let gamma = Initializer::Ones.init([self.num_channels], device);
84            let beta = Initializer::Zeros.init([self.num_channels], device);
85
86            (Some(gamma), Some(beta))
87        } else {
88            (None, None)
89        };
90
91        GroupNorm {
92            num_groups: self.num_groups,
93            num_channels: self.num_channels,
94            gamma,
95            beta,
96            epsilon: self.epsilon,
97            affine: self.affine,
98        }
99    }
100}
101
102impl<B: Backend> GroupNorm<B> {
103    /// Applies the forward pass on the input tensor.
104    ///
105    /// See [GroupNorm](GroupNorm) for more information.
106    ///
107    /// # Shapes
108    ///
109    /// - input: `[batch_size, num_channels, *]`
110    /// - output: `[batch_size, num_channels, *]`
111    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
112        if input.shape().dims[1] != self.num_channels {
113            panic!(
114                "The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
115                self.num_channels,
116                input.shape().dims[1]
117            );
118        }
119
120        let gamma = self.gamma.as_ref().map(|x| x.val());
121        let beta = self.beta.as_ref().map(|x| x.val());
122
123        group_norm(
124            input,
125            gamma,
126            beta,
127            self.num_groups,
128            self.epsilon,
129            self.affine,
130        )
131    }
132}
133
134/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
135///
136/// `Y = groupnorm(X) * γ + β`
137///
138/// Where:
139/// - `X` is the input tensor
140/// - `Y` is the output tensor
141/// - `γ` is the learnable weight
142/// - `β` is the learnable bias
143///
144pub(crate) fn group_norm<B: Backend, const D: usize>(
145    input: Tensor<B, D>,
146    gamma: Option<Tensor<B, 1>>,
147    beta: Option<Tensor<B, 1>>,
148    num_groups: usize,
149    epsilon: f64,
150    affine: bool,
151) -> Tensor<B, D> {
152    if (beta.is_none() || gamma.is_none()) && affine {
153        panic!("Affine is set to true, but gamma or beta is None");
154    }
155
156    let shape = input.shape();
157    if shape.num_elements() <= 2 {
158        panic!(
159            "input rank for GroupNorm should be at least 3, but got {}",
160            shape.num_elements()
161        );
162    }
163
164    let batch_size = shape.dims[0];
165    let num_channels = shape.dims[1];
166
167    let hidden_size = shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
168    let input = input.reshape([batch_size, num_groups, hidden_size]);
169
170    let mean = input.clone().sum_dim(2) / hidden_size as f64;
171    let input = input.sub(mean);
172
173    let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
174    let input_normalized = input.div(var.sqrt().add_scalar(epsilon));
175
176    if affine {
177        let mut affine_shape = [1; D];
178        affine_shape[1] = num_channels;
179
180        input_normalized
181            .reshape(shape)
182            .mul(gamma.clone().unwrap().reshape(affine_shape))
183            .add(beta.clone().unwrap().reshape(affine_shape))
184    } else {
185        input_normalized.reshape(shape)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::tensor::TensorData;
193    use crate::TestBackend;
194    use alloc::format;
195
196    #[test]
197    fn group_norm_forward_affine_false() {
198        let device = Default::default();
199        let module = GroupNormConfig::new(2, 6)
200            .with_affine(false)
201            .init::<TestBackend>(&device);
202
203        assert!(module.gamma.is_none());
204        assert!(module.beta.is_none());
205
206        let input = Tensor::<TestBackend, 3>::from_data(
207            TensorData::from([
208                [
209                    [-0.3034, 0.2726, -0.9659],
210                    [-1.1845, -1.3236, 0.0172],
211                    [1.9507, 1.2554, -0.8625],
212                    [1.0682, 0.3604, 0.3985],
213                    [-0.4957, -0.4461, -0.9721],
214                    [1.5157, -0.1546, -0.5596],
215                ],
216                [
217                    [-1.6698, -0.4040, -0.7927],
218                    [0.3736, -0.0975, -0.1351],
219                    [-0.9461, 0.5461, -0.6334],
220                    [-1.0919, -0.1158, 0.1213],
221                    [-0.9535, 0.1281, 0.4372],
222                    [-0.2845, 0.3488, 0.5641],
223                ],
224            ]),
225            &device,
226        );
227
228        let output = module.forward(input);
229
230        let expected = TensorData::from([
231            [
232                [-0.1653, 0.3748, -0.7866],
233                [-0.9916, -1.1220, 0.1353],
234                [1.9485, 1.2965, -0.6896],
235                [1.2769, 0.3628, 0.4120],
236                [-0.7427, -0.6786, -1.3578],
237                [1.8547, -0.3022, -0.8252],
238            ],
239            [
240                [-1.9342, 0.0211, -0.5793],
241                [1.2223, 0.4945, 0.4365],
242                [-0.8163, 1.4887, -0.3333],
243                [-1.7960, -0.0392, 0.3875],
244                [-1.5469, 0.3998, 0.9561],
245                [-0.3428, 0.7970, 1.1845],
246            ],
247        ]);
248        output.to_data().assert_approx_eq(&expected, 3);
249    }
250
251    #[test]
252    fn group_norm_forward_affine_true() {
253        let device = Default::default();
254        let module = GroupNormConfig::new(3, 6)
255            .with_affine(true)
256            .init::<TestBackend>(&device);
257
258        module
259            .gamma
260            .as_ref()
261            .expect("gamma should not be None")
262            .val()
263            .to_data()
264            .assert_approx_eq(&TensorData::ones::<f32, _>([6]), 3);
265
266        module
267            .beta
268            .as_ref()
269            .expect("beta should not be None")
270            .val()
271            .to_data()
272            .assert_approx_eq(&TensorData::zeros::<f32, _>([6]), 3);
273
274        let input = Tensor::<TestBackend, 3>::from_data(
275            TensorData::from([
276                [
277                    [0.3345, 0.4429, 0.6639],
278                    [0.5041, 0.4175, 0.8437],
279                    [0.6159, 0.3758, 0.4071],
280                    [0.5417, 0.5785, 0.7671],
281                    [0.3837, 0.9883, 0.0420],
282                    [0.4808, 0.8989, 0.6144],
283                ],
284                [
285                    [0.3930, 0.2098, 0.0602],
286                    [0.2298, 0.9425, 0.0333],
287                    [0.7409, 0.8172, 0.8879],
288                    [0.4846, 0.0486, 0.2029],
289                    [0.6741, 0.9765, 0.6864],
290                    [0.2827, 0.5534, 0.2125],
291                ],
292            ]),
293            &device,
294        );
295
296        let output = module.forward(input);
297
298        let expected = TensorData::from([
299            [
300                [-1.1694, -0.5353, 0.7572],
301                [-0.1775, -0.6838, 1.8087],
302                [0.5205, -1.3107, -1.0723],
303                [-0.0459, 0.2351, 1.6734],
304                [-0.5796, 1.3218, -1.6544],
305                [-0.2744, 1.0406, 0.1459],
306            ],
307            [
308                [0.2665, -0.3320, -0.8205],
309                [-0.2667, 2.0612, -0.9085],
310                [0.6681, 0.9102, 1.1345],
311                [-0.1453, -1.5287, -1.0389],
312                [0.4253, 1.5962, 0.4731],
313                [-1.0903, -0.0419, -1.3623],
314            ],
315        ]);
316        output.to_data().assert_approx_eq(&expected, 3);
317    }
318
319    #[test]
320    fn display() {
321        let config = GroupNormConfig::new(3, 6);
322        let group_norm = config.init::<TestBackend>(&Default::default());
323
324        assert_eq!(
325            format!("{}", group_norm),
326            "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
327        );
328    }
329}