1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use burn::tensor::Shape;
6use burn::tensor::Tensor;
7use burn::tensor::backend::Backend;
8use burn::tensor::module::avg_pool1d;
9use burn::tensor::ops::PadMode;
10
11#[derive(Config, Debug)]
14pub struct LocalResponseNormConfig {
15 pub size: usize,
17 #[config(default = 0.0001)]
19 pub alpha: f64,
20 #[config(default = 0.75)]
22 pub beta: f64,
23 #[config(default = 1.0)]
25 pub k: f64,
26}
27
28impl LocalResponseNormConfig {
29 pub fn init(&self) -> LocalResponseNorm {
35 assert!(self.size > 0, "size must be greater than 0.");
36
37 LocalResponseNorm {
38 size: self.size,
39 alpha: self.alpha,
40 beta: self.beta,
41 k: self.k,
42 }
43 }
44}
45
46#[derive(Module, Clone, Debug)]
59#[module(custom_display)]
60pub struct LocalResponseNorm {
61 size: usize,
63 alpha: f64,
65 beta: f64,
67 k: f64,
69}
70
71impl LocalResponseNorm {
72 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
83 assert!(
84 D >= 3,
85 "LocalResponseNorm requires input rank >= 3, got {D}"
86 );
87
88 let shape = input.dims();
89 let n = shape[0];
90 let c = shape[1];
91 let d_flat: usize = shape[2..].iter().product();
92
93 let squared = input.clone().square();
95
96 let flat: Tensor<B, 3> = squared.reshape(Shape::new([n, c, d_flat]));
98
99 let transposed = flat.swap_dims(1, 2);
101
102 let batched: Tensor<B, 3> = transposed.reshape(Shape::new([n * d_flat, 1, c]));
104
105 let pad_left = (self.size - 1) / 2;
106 let pad_right = self.size / 2;
107 let square_avg = if pad_left != pad_right {
108 let padded = batched.pad((pad_left, pad_right, 0, 0), PadMode::Constant(0.0));
109 avg_pool1d(padded, self.size, 1, 0, true, false)
110 } else {
111 avg_pool1d(batched, self.size, 1, pad_left, true, false)
112 };
113
114 let unbatched: Tensor<B, 3> = square_avg.reshape(Shape::new([n, d_flat, c]));
116 let untransposed = unbatched.swap_dims(1, 2);
117 let square_avg_restored: Tensor<B, D> = untransposed.reshape(Shape::new(shape));
118
119 input
121 / square_avg_restored
122 .mul_scalar(self.alpha)
123 .add_scalar(self.k)
124 .powf_scalar(self.beta)
125 }
126}
127
128impl ModuleDisplay for LocalResponseNorm {
129 fn custom_settings(&self) -> Option<DisplaySettings> {
130 DisplaySettings::new()
131 .with_new_line_after_attribute(false)
132 .optional()
133 }
134
135 fn custom_content(&self, content: Content) -> Option<Content> {
136 content
137 .add("size", &self.size)
138 .add("alpha", &self.alpha)
139 .add("beta", &self.beta)
140 .add("k", &self.k)
141 .optional()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use alloc::format;
149 use burn::tensor::TensorData;
150 use burn::tensor::{Tolerance, ops::FloatElem};
151
152 #[cfg(feature = "std")]
153 use crate::{TestAutodiffBackend, TestBackend};
154
155 #[cfg(not(feature = "std"))]
156 use crate::TestBackend;
157
158 #[test]
161 fn forward_default_params() {
162 type FT = FloatElem<TestBackend>;
164 let device = Default::default();
165 let module = LocalResponseNormConfig::new(5).init();
166 let input = Tensor::<TestBackend, 4>::from_data(
167 TensorData::from([[
168 [
169 [1.9269, 1.4873, 0.9007, -2.1055],
170 [0.6784, -1.2345, -0.0431, -1.6047],
171 [-0.7521, 1.6487, -0.3925, -1.4036],
172 [-0.7279, -0.5594, -0.7688, 0.7624],
173 ],
174 [
175 [1.6423, -0.1596, -0.4974, 0.4396],
176 [-0.7581, 1.0783, 0.8008, 1.6806],
177 [1.2791, 1.2964, 0.6105, 1.3347],
178 [-0.2316, 0.0418, -0.2516, 0.8599],
179 ],
180 [
181 [-1.3847, -0.8712, -0.2234, 1.7174],
182 [0.3189, -0.4245, 0.3057, -0.7746],
183 [-1.5576, 0.9956, -0.8798, -0.6011],
184 [-1.2742, 2.1228, -1.2347, -0.4879],
185 ],
186 ]]),
187 &device,
188 );
189
190 let output = module.forward(input);
191
192 let expected = TensorData::from([[
193 [
194 [1.9267, 1.4872, 0.9007, -2.1053],
195 [0.6784, -1.2345, -0.0431, -1.6045],
196 [-0.7521, 1.6486, -0.3925, -1.4035],
197 [-0.7279, -0.5594, -0.7688, 0.7624],
198 ],
199 [
200 [1.6421, -0.1596, -0.4974, 0.4395],
201 [-0.7581, 1.0783, 0.8008, 1.6805],
202 [1.2790, 1.2963, 0.6105, 1.3347],
203 [-0.2316, 0.0418, -0.2516, 0.8598],
204 ],
205 [
206 [-1.3845, -0.8712, -0.2234, 1.7172],
207 [0.3189, -0.4245, 0.3057, -0.7745],
208 [-1.5575, 0.9956, -0.8798, -0.6011],
209 [-1.2741, 2.1226, -1.2346, -0.4879],
210 ],
211 ]]);
212 output
213 .to_data()
214 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
215 }
216
217 #[test]
218 fn forward_custom_params() {
219 type FT = FloatElem<TestBackend>;
221 let device = Default::default();
222 let module = LocalResponseNormConfig::new(3)
223 .with_alpha(0.001)
224 .with_beta(0.5)
225 .with_k(2.0)
226 .init();
227 let input = Tensor::<TestBackend, 4>::from_data(
228 TensorData::from([[
229 [
230 [1.9269, 1.4873, 0.9007],
231 [-2.1055, 0.6784, -1.2345],
232 [-0.0431, -1.6047, -0.7521],
233 ],
234 [
235 [1.6487, -0.3925, -1.4036],
236 [-0.7279, -0.5594, -0.7688],
237 [0.7624, 1.6423, -0.1596],
238 ],
239 [
240 [-0.4974, 0.4396, 0.3189],
241 [-0.4245, 0.3057, -0.7746],
242 [0.0349, 0.3211, 1.5736],
243 ],
244 [
245 [-0.8455, -1.2742, 2.1228],
246 [-1.2347, -0.4879, -1.4181],
247 [0.8963, 0.0499, 2.2667],
248 ],
249 ]]),
250 &device,
251 );
252
253 let output = module.forward(input);
254
255 let expected = TensorData::from([[
256 [
257 [1.3618, 1.0515, 0.6368],
258 [-1.4882, 0.4797, -0.8728],
259 [-0.0305, -1.1342, -0.5318],
260 ],
261 [
262 [1.1652, -0.2775, -0.9923],
263 [-0.5145, -0.3955, -0.5435],
264 [0.5391, 1.1608, -0.1128],
265 ],
266 [
267 [-0.3516, 0.3108, 0.2254],
268 [-0.3001, 0.2162, -0.5476],
269 [0.0247, 0.2270, 1.1120],
270 ],
271 [
272 [-0.5978, -0.9008, 1.5005],
273 [-0.8729, -0.3450, -1.0025],
274 [0.6337, 0.0353, 1.6018],
275 ],
276 ]]);
277 output
278 .to_data()
279 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
280 }
281
282 #[test]
283 fn forward_even_size() {
284 type FT = FloatElem<TestBackend>;
286 let device = Default::default();
287 let module = LocalResponseNormConfig::new(2).init();
288 let input = Tensor::<TestBackend, 4>::from_data(
289 TensorData::from([[
290 [[0.3367, 0.1288], [0.2345, 0.2303]],
291 [[-1.1229, -0.1863], [2.2082, -0.6380]],
292 [[0.4617, 0.2674], [0.5349, 0.8094]],
293 ]]),
294 &device,
295 );
296
297 let output = module.forward(input);
298
299 let expected = TensorData::from([[
300 [[0.3367, 0.1288], [0.2345, 0.2303]],
301 [[-1.1228, -0.1863], [2.2078, -0.6380]],
302 [[0.4616, 0.2673], [0.5348, 0.8093]],
303 ]]);
304 output
305 .to_data()
306 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
307 }
308
309 #[test]
310 fn forward_even_size_uses_asymmetric_positive_side_window() {
311 type FT = FloatElem<TestBackend>;
312 let device = Default::default();
313 let module = LocalResponseNormConfig::new(2)
314 .with_alpha(1.0)
315 .with_beta(1.0)
316 .with_k(0.0)
317 .init();
318 let input =
319 Tensor::<TestBackend, 3>::from_data(TensorData::from([[[1.0], [2.0], [4.0]]]), &device);
320
321 let output = module.forward(input);
322
323 let expected = TensorData::from([[[0.4], [0.2], [0.5]]]);
328 output
329 .to_data()
330 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-5, 1e-6));
331 }
332
333 #[test]
334 fn forward_3d() {
335 type FT = FloatElem<TestBackend>;
337 let device = Default::default();
338 let module = LocalResponseNormConfig::new(3).init();
339 let input = Tensor::<TestBackend, 3>::from_data(
340 TensorData::from([[
341 [1.9269, 1.4873, 0.9007, -2.1055, 0.6784, -1.2345],
342 [-0.0431, -1.6047, 0.3559, -0.6866, -0.4934, 0.2415],
343 [-1.1109, 0.0915, -2.3169, -0.2168, -0.3097, -0.3957],
344 [0.8034, -0.6216, -0.5920, -0.0631, -0.8286, 0.3309],
345 ]]),
346 &device,
347 );
348
349 let output = module.forward(input);
350
351 let expected = TensorData::from([[
352 [1.9267, 1.4871, 0.9007, -2.1053, 0.6784, -1.2345],
353 [-0.0431, -1.6045, 0.3558, -0.6865, -0.4933, 0.2415],
354 [-1.1109, 0.0915, -2.3166, -0.2168, -0.3097, -0.3957],
355 [0.8034, -0.6216, -0.5919, -0.0631, -0.8285, 0.3309],
356 ]]);
357 output
358 .to_data()
359 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
360 }
361
362 #[test]
363 fn forward_5d() {
364 type FT = FloatElem<TestBackend>;
366 let device = Default::default();
367 let module = LocalResponseNormConfig::new(3).init();
368 let input = Tensor::<TestBackend, 5>::from_data(
369 TensorData::from([[
370 [
371 [[1.9269, 1.4873], [0.9007, -2.1055]],
372 [[0.6784, -1.2345], [-0.0431, -1.6047]],
373 ],
374 [
375 [[0.3559, -0.6866], [-0.4934, 0.2415]],
376 [[-1.1109, 0.0915], [-2.3169, -0.2168]],
377 ],
378 [
379 [[-0.3097, -0.3957], [0.8034, -0.6216]],
380 [[-0.5920, -0.0631], [-0.8286, 0.3309]],
381 ],
382 ]]),
383 &device,
384 );
385
386 let output = module.forward(input);
387
388 let expected = TensorData::from([[
389 [
390 [[1.9267, 1.4872], [0.9007, -2.1053]],
391 [[0.6784, -1.2345], [-0.0431, -1.6046]],
392 ],
393 [
394 [[0.3558, -0.6866], [-0.4933, 0.2415]],
395 [[-1.1108, 0.0915], [-2.3166, -0.2168]],
396 ],
397 [
398 [[-0.3097, -0.3957], [0.8034, -0.6216]],
399 [[-0.5920, -0.0631], [-0.8284, 0.3309]],
400 ],
401 ]]);
402 output
403 .to_data()
404 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
405 }
406
407 #[test]
410 fn forward_size_1() {
411 type FT = FloatElem<TestBackend>;
413 let device = Default::default();
414 let module = LocalResponseNormConfig::new(1).init();
415 let input = Tensor::<TestBackend, 4>::from_data(
416 TensorData::from([[
417 [
418 [1.9269, 1.4873, 0.9007],
419 [-2.1055, 0.6784, -1.2345],
420 [-0.0431, -1.6047, -0.7521],
421 ],
422 [
423 [1.6487, -0.3925, 0.2415],
424 [-1.1109, 0.0915, -2.3169],
425 [-0.2168, -1.3847, -0.8712],
426 ],
427 [
428 [-0.2234, -0.6216, -0.5920],
429 [-0.0631, -0.8286, 0.3309],
430 [-1.5576, 0.9956, -0.8798],
431 ],
432 ]]),
433 &device,
434 );
435
436 let output = module.forward(input);
437
438 let expected = TensorData::from([[
439 [
440 [1.9264, 1.4870, 0.9007],
441 [-2.1048, 0.6784, -1.2344],
442 [-0.0431, -1.6044, -0.7521],
443 ],
444 [
445 [1.6484, -0.3925, 0.2415],
446 [-1.1108, 0.0915, -2.3160],
447 [-0.2168, -1.3845, -0.8712],
448 ],
449 [
450 [-0.2234, -0.6216, -0.5920],
451 [-0.0631, -0.8285, 0.3309],
452 [-1.5573, 0.9956, -0.8797],
453 ],
454 ]]);
455 output
456 .to_data()
457 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
458 }
459
460 #[test]
461 fn forward_c_less_than_size() {
462 type FT = FloatElem<TestBackend>;
464 let device = Default::default();
465 let module = LocalResponseNormConfig::new(5).init();
466 let input = Tensor::<TestBackend, 4>::from_data(
467 TensorData::from([[
468 [
469 [1.9269, 1.4873, -0.4974],
470 [0.4396, -0.7581, 1.0783],
471 [0.8008, 1.6806, 0.3559],
472 ],
473 [
474 [-0.6866, 0.6105, 1.3347],
475 [-0.2316, 0.0418, -0.2516],
476 [0.8599, -0.3097, -0.3957],
477 ],
478 ]]),
479 &device,
480 );
481
482 let output = module.forward(input);
483
484 let expected = TensorData::from([[
485 [
486 [1.9268, 1.4872, -0.4974],
487 [0.4396, -0.7581, 1.0783],
488 [0.8008, 1.6805, 0.3559],
489 ],
490 [
491 [-0.6866, 0.6104, 1.3347],
492 [-0.2316, 0.0418, -0.2516],
493 [0.8598, -0.3097, -0.3957],
494 ],
495 ]]);
496 output
497 .to_data()
498 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
499 }
500
501 #[test]
502 fn forward_multi_batch() {
503 type FT = FloatElem<TestBackend>;
505 let device = Default::default();
506 let module = LocalResponseNormConfig::new(5).init();
507 let input = Tensor::<TestBackend, 4>::from_data(
508 TensorData::from([
509 [
510 [
511 [1.9269, 1.4873, 0.9007, -2.1055],
512 [0.6784, -1.2345, -0.0431, -1.6047],
513 [-0.7521, 1.6487, -0.3925, -1.4036],
514 [-0.7279, -0.5594, -0.7688, 0.7624],
515 ],
516 [
517 [1.6423, -0.1596, -0.4974, 0.4396],
518 [-0.7581, 1.0783, 0.8008, 1.6806],
519 [1.2791, 1.2964, 0.6105, 1.3347],
520 [-0.2316, 0.0418, -0.2516, 0.8599],
521 ],
522 [
523 [-1.3847, -0.8712, -0.2234, 1.7174],
524 [0.3189, -0.4245, 0.3057, -0.7746],
525 [-1.5576, 0.9956, -0.8798, -0.6011],
526 [-1.2742, 2.1228, -1.2347, -0.4879],
527 ],
528 ],
529 [
530 [
531 [-0.9138, -0.6581, 0.0780, 0.5258],
532 [-0.4880, 1.1914, -0.8140, -0.7360],
533 [-1.4032, 0.0360, -0.0635, 0.6756],
534 [-0.0978, 1.8446, -1.1845, 1.3835],
535 ],
536 [
537 [1.4451, 0.8564, 2.2181, 0.5232],
538 [0.3466, -0.1973, -1.0546, 1.2780],
539 [-0.1722, 0.5238, 0.0566, 0.4263],
540 [0.5750, -0.6417, -2.2064, -0.7508],
541 ],
542 [
543 [0.0109, -0.3387, -1.3407, -0.5854],
544 [0.5362, 0.5246, 1.1412, 0.0516],
545 [0.7440, -0.4816, -1.0495, 0.6039],
546 [-1.7223, -0.8278, 1.3347, 0.4835],
547 ],
548 ],
549 ]),
550 &device,
551 );
552
553 let output = module.forward(input);
554
555 let out_data = output.to_data();
556 assert_eq!(out_data.shape, [2, 3, 4, 4].into());
557 let expected_full = TensorData::from([
558 [
559 [
560 [1.9267, 1.4872, 0.9007, -2.1053],
561 [0.6784, -1.2345, -0.0431, -1.6045],
562 [-0.7521, 1.6486, -0.3925, -1.4035],
563 [-0.7279, -0.5594, -0.7688, 0.7624],
564 ],
565 [
566 [1.6421, -0.1596, -0.4974, 0.4395],
567 [-0.7581, 1.0783, 0.8008, 1.6805],
568 [1.2790, 1.2963, 0.6105, 1.3347],
569 [-0.2316, 0.0418, -0.2516, 0.8598],
570 ],
571 [
572 [-1.3845, -0.8712, -0.2234, 1.7172],
573 [0.3189, -0.4245, 0.3057, -0.7745],
574 [-1.5575, 0.9956, -0.8798, -0.6011],
575 [-1.2741, 2.1226, -1.2346, -0.4879],
576 ],
577 ],
578 [
579 [
580 [-0.9138, -0.6581, 0.0780, 0.5258],
581 [-0.4880, 1.1913, -0.8140, -0.7360],
582 [-1.4032, 0.0360, -0.0635, 0.6756],
583 [-0.0978, 1.8445, -1.1844, 1.3835],
584 ],
585 [
586 [1.4451, 0.8564, 2.2179, 0.5232],
587 [0.3466, -0.1973, -1.0545, 1.2780],
588 [-0.1722, 0.5238, 0.0566, 0.4263],
589 [0.5750, -0.6417, -2.2061, -0.7508],
590 ],
591 [
592 [0.0109, -0.3387, -1.3405, -0.5854],
593 [0.5362, 0.5246, 1.1411, 0.0516],
594 [0.7439, -0.4816, -1.0494, 0.6039],
595 [-1.7222, -0.8277, 1.3345, 0.4835],
596 ],
597 ],
598 ]);
599 out_data.assert_approx_eq::<FT>(&expected_full, Tolerance::rel_abs(5e-3, 1e-4));
600 }
601
602 #[test]
605 #[should_panic(expected = "size must be greater than 0")]
606 fn config_size_zero_panics() {
607 LocalResponseNormConfig::new(0).init();
608 }
609
610 #[test]
611 #[should_panic(expected = "LocalResponseNorm requires input rank >= 3")]
612 fn forward_rank_2_panics() {
613 let module = LocalResponseNormConfig::new(3).init();
614 let input = Tensor::<TestBackend, 2>::zeros([2, 4], &Default::default());
615 module.forward(input);
616 }
617
618 #[cfg(feature = "std")]
621 #[test]
622 fn backward() {
623 type FT = FloatElem<TestAutodiffBackend>;
624 let device = Default::default();
625 let module = LocalResponseNormConfig::new(5).init();
626 let input = Tensor::<TestAutodiffBackend, 4>::from_data(
627 TensorData::from([[
628 [
629 [1.9269, 1.4873, 0.9007, -2.1055],
630 [0.6784, -1.2345, -0.0431, -1.6047],
631 [-0.7521, 1.6487, -0.3925, -1.4036],
632 [-0.7279, -0.5594, -0.7688, 0.7624],
633 ],
634 [
635 [1.6423, -0.1596, -0.4974, 0.4396],
636 [-0.7581, 1.0783, 0.8008, 1.6806],
637 [1.2791, 1.2964, 0.6105, 1.3347],
638 [-0.2316, 0.0418, -0.2516, 0.8599],
639 ],
640 [
641 [-1.3847, -0.8712, -0.2234, 1.7174],
642 [0.3189, -0.4245, 0.3057, -0.7746],
643 [-1.5576, 0.9956, -0.8798, -0.6011],
644 [-1.2742, 2.1228, -1.2347, -0.4879],
645 ],
646 ]]),
647 &device,
648 )
649 .require_grad();
650
651 let output = module.forward(input.clone());
652 let grads = output.sum().backward();
653 let input_grad = input.grad(&grads).unwrap();
654
655 assert_eq!(input_grad.dims(), [1, 3, 4, 4]);
656
657 let expected_grad = TensorData::from([[
658 [
659 [0.9997, 0.9999, 1.0000, 0.9999],
660 [1.0000, 0.9999, 1.0000, 0.9999],
661 [0.9999, 0.9997, 1.0000, 0.9999],
662 [0.9999, 1.0000, 0.9999, 1.0000],
663 ],
664 [
665 [0.9998, 1.0000, 1.0000, 0.9999],
666 [1.0000, 1.0000, 1.0000, 0.9999],
667 [1.0000, 0.9998, 1.0000, 1.0000],
668 [1.0000, 0.9999, 1.0000, 0.9999],
669 ],
670 [
671 [1.0000, 1.0000, 1.0000, 0.9999],
672 [1.0000, 1.0000, 1.0000, 0.9999],
673 [0.9999, 0.9998, 1.0000, 0.9999],
674 [0.9999, 0.9998, 0.9999, 1.0000],
675 ],
676 ]]);
677 input_grad
678 .to_data()
679 .assert_approx_eq::<FT>(&expected_grad, Tolerance::rel_abs(5e-3, 1e-4));
680 }
681
682 #[test]
685 fn display() {
686 let config = LocalResponseNormConfig::new(5);
687 let module = config.init();
688 assert_eq!(
689 format!("{module}"),
690 "LocalResponseNorm {size: 5, alpha: 0.0001, beta: 0.75, k: 1}"
691 );
692 }
693}