1use crate as burn;
2use crate::nn::Initializer;
3
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::Param;
7use crate::module::{Content, DisplaySettings, ModuleDisplay};
8use crate::tensor::backend::Backend;
9use crate::tensor::Tensor;
10
11#[derive(Debug, Config)]
13pub struct GroupNormConfig {
14 pub num_groups: usize,
16 pub num_channels: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21 #[config(default = true)]
25 pub affine: bool,
26}
27
28#[derive(Module, Debug)]
40#[module(custom_display)]
41pub struct GroupNorm<B: Backend> {
42 pub gamma: Option<Param<Tensor<B, 1>>>,
44 pub beta: Option<Param<Tensor<B, 1>>>,
46 pub num_groups: usize,
48 pub num_channels: usize,
50 pub epsilon: f64,
52 pub affine: bool,
54}
55
56impl<B: Backend> ModuleDisplay for GroupNorm<B> {
57 fn custom_settings(&self) -> Option<DisplaySettings> {
58 DisplaySettings::new()
59 .with_new_line_after_attribute(false)
60 .optional()
61 }
62
63 fn custom_content(&self, content: Content) -> Option<Content> {
64 content
65 .add("num_groups", &self.num_groups)
66 .add("num_channels", &self.num_channels)
67 .add("epsilon", &self.epsilon)
68 .add("affine", &self.affine)
69 .optional()
70 }
71}
72
73impl GroupNormConfig {
74 pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
76 assert_eq!(
77 self.num_channels % self.num_groups,
78 0,
79 "The number of channels must be divisible by the number of groups"
80 );
81
82 let (gamma, beta) = if self.affine {
83 let gamma = Initializer::Ones.init([self.num_channels], device);
84 let beta = Initializer::Zeros.init([self.num_channels], device);
85
86 (Some(gamma), Some(beta))
87 } else {
88 (None, None)
89 };
90
91 GroupNorm {
92 num_groups: self.num_groups,
93 num_channels: self.num_channels,
94 gamma,
95 beta,
96 epsilon: self.epsilon,
97 affine: self.affine,
98 }
99 }
100}
101
102impl<B: Backend> GroupNorm<B> {
103 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
112 if input.shape().dims[1] != self.num_channels {
113 panic!(
114 "The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
115 self.num_channels,
116 input.shape().dims[1]
117 );
118 }
119
120 let gamma = self.gamma.as_ref().map(|x| x.val());
121 let beta = self.beta.as_ref().map(|x| x.val());
122
123 group_norm(
124 input,
125 gamma,
126 beta,
127 self.num_groups,
128 self.epsilon,
129 self.affine,
130 )
131 }
132}
133
134pub(crate) fn group_norm<B: Backend, const D: usize>(
145 input: Tensor<B, D>,
146 gamma: Option<Tensor<B, 1>>,
147 beta: Option<Tensor<B, 1>>,
148 num_groups: usize,
149 epsilon: f64,
150 affine: bool,
151) -> Tensor<B, D> {
152 if (beta.is_none() || gamma.is_none()) && affine {
153 panic!("Affine is set to true, but gamma or beta is None");
154 }
155
156 let shape = input.shape();
157 if shape.num_elements() <= 2 {
158 panic!(
159 "input rank for GroupNorm should be at least 3, but got {}",
160 shape.num_elements()
161 );
162 }
163
164 let batch_size = shape.dims[0];
165 let num_channels = shape.dims[1];
166
167 let hidden_size = shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
168 let input = input.reshape([batch_size, num_groups, hidden_size]);
169
170 let mean = input.clone().sum_dim(2) / hidden_size as f64;
171 let input = input.sub(mean);
172
173 let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
174 let input_normalized = input.div(var.sqrt().add_scalar(epsilon));
175
176 if affine {
177 let mut affine_shape = [1; D];
178 affine_shape[1] = num_channels;
179
180 input_normalized
181 .reshape(shape)
182 .mul(gamma.clone().unwrap().reshape(affine_shape))
183 .add(beta.clone().unwrap().reshape(affine_shape))
184 } else {
185 input_normalized.reshape(shape)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::tensor::TensorData;
193 use crate::TestBackend;
194 use alloc::format;
195
196 #[test]
197 fn group_norm_forward_affine_false() {
198 let device = Default::default();
199 let module = GroupNormConfig::new(2, 6)
200 .with_affine(false)
201 .init::<TestBackend>(&device);
202
203 assert!(module.gamma.is_none());
204 assert!(module.beta.is_none());
205
206 let input = Tensor::<TestBackend, 3>::from_data(
207 TensorData::from([
208 [
209 [-0.3034, 0.2726, -0.9659],
210 [-1.1845, -1.3236, 0.0172],
211 [1.9507, 1.2554, -0.8625],
212 [1.0682, 0.3604, 0.3985],
213 [-0.4957, -0.4461, -0.9721],
214 [1.5157, -0.1546, -0.5596],
215 ],
216 [
217 [-1.6698, -0.4040, -0.7927],
218 [0.3736, -0.0975, -0.1351],
219 [-0.9461, 0.5461, -0.6334],
220 [-1.0919, -0.1158, 0.1213],
221 [-0.9535, 0.1281, 0.4372],
222 [-0.2845, 0.3488, 0.5641],
223 ],
224 ]),
225 &device,
226 );
227
228 let output = module.forward(input);
229
230 let expected = TensorData::from([
231 [
232 [-0.1653, 0.3748, -0.7866],
233 [-0.9916, -1.1220, 0.1353],
234 [1.9485, 1.2965, -0.6896],
235 [1.2769, 0.3628, 0.4120],
236 [-0.7427, -0.6786, -1.3578],
237 [1.8547, -0.3022, -0.8252],
238 ],
239 [
240 [-1.9342, 0.0211, -0.5793],
241 [1.2223, 0.4945, 0.4365],
242 [-0.8163, 1.4887, -0.3333],
243 [-1.7960, -0.0392, 0.3875],
244 [-1.5469, 0.3998, 0.9561],
245 [-0.3428, 0.7970, 1.1845],
246 ],
247 ]);
248 output.to_data().assert_approx_eq(&expected, 3);
249 }
250
251 #[test]
252 fn group_norm_forward_affine_true() {
253 let device = Default::default();
254 let module = GroupNormConfig::new(3, 6)
255 .with_affine(true)
256 .init::<TestBackend>(&device);
257
258 module
259 .gamma
260 .as_ref()
261 .expect("gamma should not be None")
262 .val()
263 .to_data()
264 .assert_approx_eq(&TensorData::ones::<f32, _>([6]), 3);
265
266 module
267 .beta
268 .as_ref()
269 .expect("beta should not be None")
270 .val()
271 .to_data()
272 .assert_approx_eq(&TensorData::zeros::<f32, _>([6]), 3);
273
274 let input = Tensor::<TestBackend, 3>::from_data(
275 TensorData::from([
276 [
277 [0.3345, 0.4429, 0.6639],
278 [0.5041, 0.4175, 0.8437],
279 [0.6159, 0.3758, 0.4071],
280 [0.5417, 0.5785, 0.7671],
281 [0.3837, 0.9883, 0.0420],
282 [0.4808, 0.8989, 0.6144],
283 ],
284 [
285 [0.3930, 0.2098, 0.0602],
286 [0.2298, 0.9425, 0.0333],
287 [0.7409, 0.8172, 0.8879],
288 [0.4846, 0.0486, 0.2029],
289 [0.6741, 0.9765, 0.6864],
290 [0.2827, 0.5534, 0.2125],
291 ],
292 ]),
293 &device,
294 );
295
296 let output = module.forward(input);
297
298 let expected = TensorData::from([
299 [
300 [-1.1694, -0.5353, 0.7572],
301 [-0.1775, -0.6838, 1.8087],
302 [0.5205, -1.3107, -1.0723],
303 [-0.0459, 0.2351, 1.6734],
304 [-0.5796, 1.3218, -1.6544],
305 [-0.2744, 1.0406, 0.1459],
306 ],
307 [
308 [0.2665, -0.3320, -0.8205],
309 [-0.2667, 2.0612, -0.9085],
310 [0.6681, 0.9102, 1.1345],
311 [-0.1453, -1.5287, -1.0389],
312 [0.4253, 1.5962, 0.4731],
313 [-1.0903, -0.0419, -1.3623],
314 ],
315 ]);
316 output.to_data().assert_approx_eq(&expected, 3);
317 }
318
319 #[test]
320 fn display() {
321 let config = GroupNormConfig::new(3, 6);
322 let group_norm = config.init::<TestBackend>(&Default::default());
323
324 assert_eq!(
325 format!("{}", group_norm),
326 "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
327 );
328 }
329}