burn_core/nn/norm/
group.rs

1use crate as burn;
2use crate::nn::Initializer;
3
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, ModuleDisplay};
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11/// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init).
12#[derive(Debug, Config)]
13pub struct GroupNormConfig {
14    /// The number of groups to separate the channels into
15    pub num_groups: usize,
16    /// The number of channels expected in the input
17    pub num_channels: usize,
18    /// A value required for numerical stability. Default: 1e-5
19    #[config(default = 1e-5)]
20    pub epsilon: f64,
21    /// A boolean value that when set to `true`, this module has learnable
22    /// per-channel affine parameters initialized to ones (for weights)
23    /// and zeros (for biases). Default: `true`
24    #[config(default = true)]
25    pub affine: bool,
26}
27
28/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
29///
30/// `Y = groupnorm(X) * γ + β`
31///
32/// Where:
33/// - `X` is the input tensor
34/// - `Y` is the output tensor
35/// - `γ` is the learnable weight
36/// - `β` is the learnable bias
37///
38/// Should be created using [GroupNormConfig](GroupNormConfig).
39#[derive(Module, Debug)]
40#[module(custom_display)]
41pub struct GroupNorm<B: Backend> {
42    /// The learnable weight
43    pub gamma: Option<Param<Tensor<B, 1>>>,
44    /// The learnable bias
45    pub beta: Option<Param<Tensor<B, 1>>>,
46    /// The number of groups to separate the channels into
47    pub num_groups: usize,
48    /// The number of channels expected in the input
49    pub num_channels: usize,
50    /// A value required for numerical stability
51    pub epsilon: f64,
52    /// A boolean value that when set to `true`, this module has learnable
53    pub affine: bool,
54}
55
56impl<B: Backend> ModuleDisplay for GroupNorm<B> {
57    fn custom_settings(&self) -> Option<DisplaySettings> {
58        DisplaySettings::new()
59            .with_new_line_after_attribute(false)
60            .optional()
61    }
62
63    fn custom_content(&self, content: Content) -> Option<Content> {
64        content
65            .add("num_groups", &self.num_groups)
66            .add("num_channels", &self.num_channels)
67            .add("epsilon", &self.epsilon)
68            .add("affine", &self.affine)
69            .optional()
70    }
71}
72
73impl GroupNormConfig {
74    /// Initialize a new [group norm](GroupNorm) module.
75    pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
76        assert_eq!(
77            self.num_channels % self.num_groups,
78            0,
79            "The number of channels must be divisible by the number of groups"
80        );
81
82        let (gamma, beta) = if self.affine {
83            let gamma = Initializer::Ones.init([self.num_channels], device);
84            let beta = Initializer::Zeros.init([self.num_channels], device);
85
86            (Some(gamma), Some(beta))
87        } else {
88            (None, None)
89        };
90
91        GroupNorm {
92            num_groups: self.num_groups,
93            num_channels: self.num_channels,
94            gamma,
95            beta,
96            epsilon: self.epsilon,
97            affine: self.affine,
98        }
99    }
100}
101
102impl<B: Backend> GroupNorm<B> {
103    /// Applies the forward pass on the input tensor.
104    ///
105    /// See [GroupNorm](GroupNorm) for more information.
106    ///
107    /// # Shapes
108    ///
109    /// - input: `[batch_size, num_channels, *]`
110    /// - output: `[batch_size, num_channels, *]`
111    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
112        if input.shape().dims[1] != self.num_channels {
113            panic!(
114                "The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
115                self.num_channels,
116                input.shape().dims[1]
117            );
118        }
119
120        let gamma = self.gamma.as_ref().map(|x| x.val());
121        let beta = self.beta.as_ref().map(|x| x.val());
122
123        group_norm(
124            input,
125            gamma,
126            beta,
127            self.num_groups,
128            self.epsilon,
129            self.affine,
130        )
131    }
132}
133
134/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
135///
136/// `Y = groupnorm(X) * γ + β`
137///
138/// Where:
139/// - `X` is the input tensor
140/// - `Y` is the output tensor
141/// - `γ` is the learnable weight
142/// - `β` is the learnable bias
143///
144pub(crate) fn group_norm<B: Backend, const D: usize>(
145    input: Tensor<B, D>,
146    gamma: Option<Tensor<B, 1>>,
147    beta: Option<Tensor<B, 1>>,
148    num_groups: usize,
149    epsilon: f64,
150    affine: bool,
151) -> Tensor<B, D> {
152    if (beta.is_none() || gamma.is_none()) && affine {
153        panic!("Affine is set to true, but gamma or beta is None");
154    }
155
156    let shape = input.shape();
157    if shape.num_elements() <= 2 {
158        panic!(
159            "input rank for GroupNorm should be at least 3, but got {}",
160            shape.num_elements()
161        );
162    }
163
164    let batch_size = shape.dims[0];
165    let num_channels = shape.dims[1];
166
167    let hidden_size = shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
168    let input = input.reshape([batch_size, num_groups, hidden_size]);
169
170    let mean = input.clone().sum_dim(2) / hidden_size as f64;
171    let input = input.sub(mean);
172
173    let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
174    let input_normalized = input.div(var.add_scalar(epsilon).sqrt());
175
176    if affine {
177        let mut affine_shape = [1; D];
178        affine_shape[1] = num_channels;
179
180        input_normalized
181            .reshape(shape)
182            .mul(gamma.clone().unwrap().reshape(affine_shape))
183            .add(beta.clone().unwrap().reshape(affine_shape))
184    } else {
185        input_normalized.reshape(shape)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::TestBackend;
193    use crate::tensor::TensorData;
194    use alloc::format;
195    use burn_tensor::{Tolerance, ops::FloatElem};
196    type FT = FloatElem<TestBackend>;
197
198    #[test]
199    fn group_norm_forward_affine_false() {
200        let device = Default::default();
201        let module = GroupNormConfig::new(2, 6)
202            .with_affine(false)
203            .init::<TestBackend>(&device);
204
205        assert!(module.gamma.is_none());
206        assert!(module.beta.is_none());
207
208        let input = Tensor::<TestBackend, 3>::from_data(
209            TensorData::from([
210                [
211                    [-0.3034, 0.2726, -0.9659],
212                    [-1.1845, -1.3236, 0.0172],
213                    [1.9507, 1.2554, -0.8625],
214                    [1.0682, 0.3604, 0.3985],
215                    [-0.4957, -0.4461, -0.9721],
216                    [1.5157, -0.1546, -0.5596],
217                ],
218                [
219                    [-1.6698, -0.4040, -0.7927],
220                    [0.3736, -0.0975, -0.1351],
221                    [-0.9461, 0.5461, -0.6334],
222                    [-1.0919, -0.1158, 0.1213],
223                    [-0.9535, 0.1281, 0.4372],
224                    [-0.2845, 0.3488, 0.5641],
225                ],
226            ]),
227            &device,
228        );
229
230        let output = module.forward(input);
231
232        let expected = TensorData::from([
233            [
234                [-0.1653, 0.3748, -0.7866],
235                [-0.9916, -1.1220, 0.1353],
236                [1.9485, 1.2965, -0.6896],
237                [1.2769, 0.3628, 0.4120],
238                [-0.7427, -0.6786, -1.3578],
239                [1.8547, -0.3022, -0.8252],
240            ],
241            [
242                [-1.9342, 0.0211, -0.5793],
243                [1.2223, 0.4945, 0.4365],
244                [-0.8163, 1.4887, -0.3333],
245                [-1.7960, -0.0392, 0.3875],
246                [-1.5469, 0.3998, 0.9561],
247                [-0.3428, 0.7970, 1.1845],
248            ],
249        ]);
250        output
251            .to_data()
252            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-4, 1e-4));
253    }
254
255    #[test]
256    fn group_norm_forward_affine_true() {
257        let device = Default::default();
258        let module = GroupNormConfig::new(3, 6)
259            .with_affine(true)
260            .init::<TestBackend>(&device);
261
262        let tolerance = Tolerance::rel_abs(1e-4, 3e-4);
263        module
264            .gamma
265            .as_ref()
266            .expect("gamma should not be None")
267            .val()
268            .to_data()
269            .assert_approx_eq::<FT>(&TensorData::ones::<f32, _>([6]), tolerance);
270
271        module
272            .beta
273            .as_ref()
274            .expect("beta should not be None")
275            .val()
276            .to_data()
277            .assert_approx_eq::<FT>(&TensorData::zeros::<f32, _>([6]), tolerance);
278
279        let input = Tensor::<TestBackend, 3>::from_data(
280            TensorData::from([
281                [
282                    [0.3345, 0.4429, 0.6639],
283                    [0.5041, 0.4175, 0.8437],
284                    [0.6159, 0.3758, 0.4071],
285                    [0.5417, 0.5785, 0.7671],
286                    [0.3837, 0.9883, 0.0420],
287                    [0.4808, 0.8989, 0.6144],
288                ],
289                [
290                    [0.3930, 0.2098, 0.0602],
291                    [0.2298, 0.9425, 0.0333],
292                    [0.7409, 0.8172, 0.8879],
293                    [0.4846, 0.0486, 0.2029],
294                    [0.6741, 0.9765, 0.6864],
295                    [0.2827, 0.5534, 0.2125],
296                ],
297            ]),
298            &device,
299        );
300
301        let output = module.forward(input);
302
303        let expected = TensorData::from([
304            [
305                [-1.1694, -0.5353, 0.7572],
306                [-0.1775, -0.6838, 1.8087],
307                [0.5205, -1.3107, -1.0723],
308                [-0.0459, 0.2351, 1.6734],
309                [-0.5796, 1.3218, -1.6544],
310                [-0.2744, 1.0406, 0.1459],
311            ],
312            [
313                [0.2665, -0.3320, -0.8205],
314                [-0.2667, 2.0612, -0.9085],
315                [0.6681, 0.9102, 1.1345],
316                [-0.1453, -1.5287, -1.0389],
317                [0.4253, 1.5962, 0.4731],
318                [-1.0903, -0.0419, -1.3623],
319            ],
320        ]);
321        output
322            .to_data()
323            .assert_approx_eq::<FT>(&expected, tolerance);
324    }
325
326    #[test]
327    fn display() {
328        let config = GroupNormConfig::new(3, 6);
329        let group_norm = config.init::<TestBackend>(&Default::default());
330
331        assert_eq!(
332            format!("{}", group_norm),
333            "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
334        );
335    }
336}