1use crate as burn;
2use crate::nn::Initializer;
3
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, ModuleDisplay};
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11#[derive(Debug, Config)]
13pub struct GroupNormConfig {
14 pub num_groups: usize,
16 pub num_channels: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21 #[config(default = true)]
25 pub affine: bool,
26}
27
28#[derive(Module, Debug)]
40#[module(custom_display)]
41pub struct GroupNorm<B: Backend> {
42 pub gamma: Option<Param<Tensor<B, 1>>>,
44 pub beta: Option<Param<Tensor<B, 1>>>,
46 pub num_groups: usize,
48 pub num_channels: usize,
50 pub epsilon: f64,
52 pub affine: bool,
54}
55
56impl<B: Backend> ModuleDisplay for GroupNorm<B> {
57 fn custom_settings(&self) -> Option<DisplaySettings> {
58 DisplaySettings::new()
59 .with_new_line_after_attribute(false)
60 .optional()
61 }
62
63 fn custom_content(&self, content: Content) -> Option<Content> {
64 content
65 .add("num_groups", &self.num_groups)
66 .add("num_channels", &self.num_channels)
67 .add("epsilon", &self.epsilon)
68 .add("affine", &self.affine)
69 .optional()
70 }
71}
72
73impl GroupNormConfig {
74 pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
76 assert_eq!(
77 self.num_channels % self.num_groups,
78 0,
79 "The number of channels must be divisible by the number of groups"
80 );
81
82 let (gamma, beta) = if self.affine {
83 let gamma = Initializer::Ones.init([self.num_channels], device);
84 let beta = Initializer::Zeros.init([self.num_channels], device);
85
86 (Some(gamma), Some(beta))
87 } else {
88 (None, None)
89 };
90
91 GroupNorm {
92 num_groups: self.num_groups,
93 num_channels: self.num_channels,
94 gamma,
95 beta,
96 epsilon: self.epsilon,
97 affine: self.affine,
98 }
99 }
100}
101
102impl<B: Backend> GroupNorm<B> {
103 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
112 if input.shape().dims[1] != self.num_channels {
113 panic!(
114 "The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
115 self.num_channels,
116 input.shape().dims[1]
117 );
118 }
119
120 let gamma = self.gamma.as_ref().map(|x| x.val());
121 let beta = self.beta.as_ref().map(|x| x.val());
122
123 group_norm(
124 input,
125 gamma,
126 beta,
127 self.num_groups,
128 self.epsilon,
129 self.affine,
130 )
131 }
132}
133
134pub(crate) fn group_norm<B: Backend, const D: usize>(
145 input: Tensor<B, D>,
146 gamma: Option<Tensor<B, 1>>,
147 beta: Option<Tensor<B, 1>>,
148 num_groups: usize,
149 epsilon: f64,
150 affine: bool,
151) -> Tensor<B, D> {
152 if (beta.is_none() || gamma.is_none()) && affine {
153 panic!("Affine is set to true, but gamma or beta is None");
154 }
155
156 let shape = input.shape();
157 if shape.num_elements() <= 2 {
158 panic!(
159 "input rank for GroupNorm should be at least 3, but got {}",
160 shape.num_elements()
161 );
162 }
163
164 let batch_size = shape.dims[0];
165 let num_channels = shape.dims[1];
166
167 let hidden_size = shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
168 let input = input.reshape([batch_size, num_groups, hidden_size]);
169
170 let mean = input.clone().sum_dim(2) / hidden_size as f64;
171 let input = input.sub(mean);
172
173 let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
174 let input_normalized = input.div(var.add_scalar(epsilon).sqrt());
175
176 if affine {
177 let mut affine_shape = [1; D];
178 affine_shape[1] = num_channels;
179
180 input_normalized
181 .reshape(shape)
182 .mul(gamma.clone().unwrap().reshape(affine_shape))
183 .add(beta.clone().unwrap().reshape(affine_shape))
184 } else {
185 input_normalized.reshape(shape)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::TestBackend;
193 use crate::tensor::TensorData;
194 use alloc::format;
195 use burn_tensor::{Tolerance, ops::FloatElem};
196 type FT = FloatElem<TestBackend>;
197
198 #[test]
199 fn group_norm_forward_affine_false() {
200 let device = Default::default();
201 let module = GroupNormConfig::new(2, 6)
202 .with_affine(false)
203 .init::<TestBackend>(&device);
204
205 assert!(module.gamma.is_none());
206 assert!(module.beta.is_none());
207
208 let input = Tensor::<TestBackend, 3>::from_data(
209 TensorData::from([
210 [
211 [-0.3034, 0.2726, -0.9659],
212 [-1.1845, -1.3236, 0.0172],
213 [1.9507, 1.2554, -0.8625],
214 [1.0682, 0.3604, 0.3985],
215 [-0.4957, -0.4461, -0.9721],
216 [1.5157, -0.1546, -0.5596],
217 ],
218 [
219 [-1.6698, -0.4040, -0.7927],
220 [0.3736, -0.0975, -0.1351],
221 [-0.9461, 0.5461, -0.6334],
222 [-1.0919, -0.1158, 0.1213],
223 [-0.9535, 0.1281, 0.4372],
224 [-0.2845, 0.3488, 0.5641],
225 ],
226 ]),
227 &device,
228 );
229
230 let output = module.forward(input);
231
232 let expected = TensorData::from([
233 [
234 [-0.1653, 0.3748, -0.7866],
235 [-0.9916, -1.1220, 0.1353],
236 [1.9485, 1.2965, -0.6896],
237 [1.2769, 0.3628, 0.4120],
238 [-0.7427, -0.6786, -1.3578],
239 [1.8547, -0.3022, -0.8252],
240 ],
241 [
242 [-1.9342, 0.0211, -0.5793],
243 [1.2223, 0.4945, 0.4365],
244 [-0.8163, 1.4887, -0.3333],
245 [-1.7960, -0.0392, 0.3875],
246 [-1.5469, 0.3998, 0.9561],
247 [-0.3428, 0.7970, 1.1845],
248 ],
249 ]);
250 output
251 .to_data()
252 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-4, 1e-4));
253 }
254
255 #[test]
256 fn group_norm_forward_affine_true() {
257 let device = Default::default();
258 let module = GroupNormConfig::new(3, 6)
259 .with_affine(true)
260 .init::<TestBackend>(&device);
261
262 let tolerance = Tolerance::rel_abs(1e-4, 3e-4);
263 module
264 .gamma
265 .as_ref()
266 .expect("gamma should not be None")
267 .val()
268 .to_data()
269 .assert_approx_eq::<FT>(&TensorData::ones::<f32, _>([6]), tolerance);
270
271 module
272 .beta
273 .as_ref()
274 .expect("beta should not be None")
275 .val()
276 .to_data()
277 .assert_approx_eq::<FT>(&TensorData::zeros::<f32, _>([6]), tolerance);
278
279 let input = Tensor::<TestBackend, 3>::from_data(
280 TensorData::from([
281 [
282 [0.3345, 0.4429, 0.6639],
283 [0.5041, 0.4175, 0.8437],
284 [0.6159, 0.3758, 0.4071],
285 [0.5417, 0.5785, 0.7671],
286 [0.3837, 0.9883, 0.0420],
287 [0.4808, 0.8989, 0.6144],
288 ],
289 [
290 [0.3930, 0.2098, 0.0602],
291 [0.2298, 0.9425, 0.0333],
292 [0.7409, 0.8172, 0.8879],
293 [0.4846, 0.0486, 0.2029],
294 [0.6741, 0.9765, 0.6864],
295 [0.2827, 0.5534, 0.2125],
296 ],
297 ]),
298 &device,
299 );
300
301 let output = module.forward(input);
302
303 let expected = TensorData::from([
304 [
305 [-1.1694, -0.5353, 0.7572],
306 [-0.1775, -0.6838, 1.8087],
307 [0.5205, -1.3107, -1.0723],
308 [-0.0459, 0.2351, 1.6734],
309 [-0.5796, 1.3218, -1.6544],
310 [-0.2744, 1.0406, 0.1459],
311 ],
312 [
313 [0.2665, -0.3320, -0.8205],
314 [-0.2667, 2.0612, -0.9085],
315 [0.6681, 0.9102, 1.1345],
316 [-0.1453, -1.5287, -1.0389],
317 [0.4253, 1.5962, 0.4731],
318 [-1.0903, -0.0419, -1.3623],
319 ],
320 ]);
321 output
322 .to_data()
323 .assert_approx_eq::<FT>(&expected, tolerance);
324 }
325
326 #[test]
327 fn display() {
328 let config = GroupNormConfig::new(3, 6);
329 let group_norm = config.init::<TestBackend>(&Default::default());
330
331 assert_eq!(
332 format!("{}", group_norm),
333 "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
334 );
335 }
336}