Skip to main content

burn_nn/modules/norm/
local_response.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use burn::tensor::Shape;
6use burn::tensor::Tensor;
7use burn::tensor::backend::Backend;
8use burn::tensor::module::avg_pool1d;
9use burn::tensor::ops::PadMode;
10
11/// Configuration to create a [LocalResponseNorm](LocalResponseNorm) layer
12/// using the [init function](LocalResponseNormConfig::init).
13#[derive(Config, Debug)]
14pub struct LocalResponseNormConfig {
15    /// Number of channels in the sliding normalization window.
16    pub size: usize,
17    /// Scaling parameter. Default: 0.0001
18    #[config(default = 0.0001)]
19    pub alpha: f64,
20    /// Exponent. Default: 0.75
21    #[config(default = 0.75)]
22    pub beta: f64,
23    /// Bias constant (called `bias` in ONNX). Default: 1.0
24    #[config(default = 1.0)]
25    pub k: f64,
26}
27
28impl LocalResponseNormConfig {
29    /// Initialize a new [LocalResponseNorm](LocalResponseNorm) module.
30    ///
31    /// # Panics
32    ///
33    /// Panics if `size` is 0.
34    pub fn init(&self) -> LocalResponseNorm {
35        assert!(self.size > 0, "size must be greater than 0.");
36
37        LocalResponseNorm {
38            size: self.size,
39            alpha: self.alpha,
40            beta: self.beta,
41            k: self.k,
42        }
43    }
44}
45
46/// Applies Local Response Normalization as described in
47/// [ImageNet Classification with Deep Convolutional Neural Networks](https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html).
48///
49/// `Y = X / (k + (alpha / size) * sum(X^2))^beta`
50///
51/// Where the sum is computed over a sliding window of `size` channels.
52///
53/// For odd `size`, the window is centered on each channel position.
54/// For even `size`, the window uses asymmetric padding and includes the current
55/// channel plus one extra channel on the positive side.
56///
57/// Should be created using [LocalResponseNormConfig](LocalResponseNormConfig).
58#[derive(Module, Clone, Debug)]
59#[module(custom_display)]
60pub struct LocalResponseNorm {
61    /// Number of channels in the sliding window.
62    size: usize,
63    /// Scaling parameter.
64    alpha: f64,
65    /// Exponent.
66    beta: f64,
67    /// Bias constant.
68    k: f64,
69}
70
71impl LocalResponseNorm {
72    /// Applies Local Response Normalization on the input tensor.
73    ///
74    /// # Shapes
75    ///
76    /// - input: `[N, C, D1, D2, ..., Dk]` (rank >= 3)
77    /// - output: same shape as input
78    ///
79    /// # Panics
80    ///
81    /// Panics if the input tensor rank is less than 3.
82    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
83        assert!(
84            D >= 3,
85            "LocalResponseNorm requires input rank >= 3, got {D}"
86        );
87
88        let shape = input.dims();
89        let n = shape[0];
90        let c = shape[1];
91        let d_flat: usize = shape[2..].iter().product();
92
93        // Square the input
94        let squared = input.clone().square();
95
96        // Flatten spatial dims: [N, C, D1..Dk] -> [N, C, D_flat]
97        let flat: Tensor<B, 3> = squared.reshape(Shape::new([n, c, d_flat]));
98
99        // Move channel to last dim: [N, D_flat, C]
100        let transposed = flat.swap_dims(1, 2);
101
102        // Batch all spatial positions: [N*D_flat, 1, C]
103        let batched: Tensor<B, 3> = transposed.reshape(Shape::new([n * d_flat, 1, c]));
104
105        let pad_left = (self.size - 1) / 2;
106        let pad_right = self.size / 2;
107        let square_avg = if pad_left != pad_right {
108            let padded = batched.pad((pad_left, pad_right, 0, 0), PadMode::Constant(0.0));
109            avg_pool1d(padded, self.size, 1, 0, true, false)
110        } else {
111            avg_pool1d(batched, self.size, 1, pad_left, true, false)
112        };
113
114        // Restore shape: [N*D_flat, 1, C] -> [N, D_flat, C] -> [N, C, D_flat] -> original
115        let unbatched: Tensor<B, 3> = square_avg.reshape(Shape::new([n, d_flat, c]));
116        let untransposed = unbatched.swap_dims(1, 2);
117        let square_avg_restored: Tensor<B, D> = untransposed.reshape(Shape::new(shape));
118
119        // Apply LRN formula: output = input / (k + alpha * avg(x^2))^beta
120        input
121            / square_avg_restored
122                .mul_scalar(self.alpha)
123                .add_scalar(self.k)
124                .powf_scalar(self.beta)
125    }
126}
127
128impl ModuleDisplay for LocalResponseNorm {
129    fn custom_settings(&self) -> Option<DisplaySettings> {
130        DisplaySettings::new()
131            .with_new_line_after_attribute(false)
132            .optional()
133    }
134
135    fn custom_content(&self, content: Content) -> Option<Content> {
136        content
137            .add("size", &self.size)
138            .add("alpha", &self.alpha)
139            .add("beta", &self.beta)
140            .add("k", &self.k)
141            .optional()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use alloc::format;
149    use burn::tensor::TensorData;
150    use burn::tensor::{Tolerance, ops::FloatElem};
151
152    #[cfg(feature = "std")]
153    use crate::{TestAutodiffBackend, TestBackend};
154
155    #[cfg(not(feature = "std"))]
156    use crate::TestBackend;
157
158    // --- Correctness tests (values from PyTorch, torch.manual_seed(42)) ---
159
160    #[test]
161    fn forward_default_params() {
162        // size=5, alpha=0.0001, beta=0.75, k=1.0, input [1,3,4,4]
163        type FT = FloatElem<TestBackend>;
164        let device = Default::default();
165        let module = LocalResponseNormConfig::new(5).init();
166        let input = Tensor::<TestBackend, 4>::from_data(
167            TensorData::from([[
168                [
169                    [1.9269, 1.4873, 0.9007, -2.1055],
170                    [0.6784, -1.2345, -0.0431, -1.6047],
171                    [-0.7521, 1.6487, -0.3925, -1.4036],
172                    [-0.7279, -0.5594, -0.7688, 0.7624],
173                ],
174                [
175                    [1.6423, -0.1596, -0.4974, 0.4396],
176                    [-0.7581, 1.0783, 0.8008, 1.6806],
177                    [1.2791, 1.2964, 0.6105, 1.3347],
178                    [-0.2316, 0.0418, -0.2516, 0.8599],
179                ],
180                [
181                    [-1.3847, -0.8712, -0.2234, 1.7174],
182                    [0.3189, -0.4245, 0.3057, -0.7746],
183                    [-1.5576, 0.9956, -0.8798, -0.6011],
184                    [-1.2742, 2.1228, -1.2347, -0.4879],
185                ],
186            ]]),
187            &device,
188        );
189
190        let output = module.forward(input);
191
192        let expected = TensorData::from([[
193            [
194                [1.9267, 1.4872, 0.9007, -2.1053],
195                [0.6784, -1.2345, -0.0431, -1.6045],
196                [-0.7521, 1.6486, -0.3925, -1.4035],
197                [-0.7279, -0.5594, -0.7688, 0.7624],
198            ],
199            [
200                [1.6421, -0.1596, -0.4974, 0.4395],
201                [-0.7581, 1.0783, 0.8008, 1.6805],
202                [1.2790, 1.2963, 0.6105, 1.3347],
203                [-0.2316, 0.0418, -0.2516, 0.8598],
204            ],
205            [
206                [-1.3845, -0.8712, -0.2234, 1.7172],
207                [0.3189, -0.4245, 0.3057, -0.7745],
208                [-1.5575, 0.9956, -0.8798, -0.6011],
209                [-1.2741, 2.1226, -1.2346, -0.4879],
210            ],
211        ]]);
212        output
213            .to_data()
214            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
215    }
216
217    #[test]
218    fn forward_custom_params() {
219        // size=3, alpha=0.001, beta=0.5, k=2.0, input [1,4,3,3]
220        type FT = FloatElem<TestBackend>;
221        let device = Default::default();
222        let module = LocalResponseNormConfig::new(3)
223            .with_alpha(0.001)
224            .with_beta(0.5)
225            .with_k(2.0)
226            .init();
227        let input = Tensor::<TestBackend, 4>::from_data(
228            TensorData::from([[
229                [
230                    [1.9269, 1.4873, 0.9007],
231                    [-2.1055, 0.6784, -1.2345],
232                    [-0.0431, -1.6047, -0.7521],
233                ],
234                [
235                    [1.6487, -0.3925, -1.4036],
236                    [-0.7279, -0.5594, -0.7688],
237                    [0.7624, 1.6423, -0.1596],
238                ],
239                [
240                    [-0.4974, 0.4396, 0.3189],
241                    [-0.4245, 0.3057, -0.7746],
242                    [0.0349, 0.3211, 1.5736],
243                ],
244                [
245                    [-0.8455, -1.2742, 2.1228],
246                    [-1.2347, -0.4879, -1.4181],
247                    [0.8963, 0.0499, 2.2667],
248                ],
249            ]]),
250            &device,
251        );
252
253        let output = module.forward(input);
254
255        let expected = TensorData::from([[
256            [
257                [1.3618, 1.0515, 0.6368],
258                [-1.4882, 0.4797, -0.8728],
259                [-0.0305, -1.1342, -0.5318],
260            ],
261            [
262                [1.1652, -0.2775, -0.9923],
263                [-0.5145, -0.3955, -0.5435],
264                [0.5391, 1.1608, -0.1128],
265            ],
266            [
267                [-0.3516, 0.3108, 0.2254],
268                [-0.3001, 0.2162, -0.5476],
269                [0.0247, 0.2270, 1.1120],
270            ],
271            [
272                [-0.5978, -0.9008, 1.5005],
273                [-0.8729, -0.3450, -1.0025],
274                [0.6337, 0.0353, 1.6018],
275            ],
276        ]]);
277        output
278            .to_data()
279            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
280    }
281
282    #[test]
283    fn forward_even_size() {
284        // size=2, alpha=0.0001, beta=0.75, k=1.0, input [1,3,2,2]
285        type FT = FloatElem<TestBackend>;
286        let device = Default::default();
287        let module = LocalResponseNormConfig::new(2).init();
288        let input = Tensor::<TestBackend, 4>::from_data(
289            TensorData::from([[
290                [[0.3367, 0.1288], [0.2345, 0.2303]],
291                [[-1.1229, -0.1863], [2.2082, -0.6380]],
292                [[0.4617, 0.2674], [0.5349, 0.8094]],
293            ]]),
294            &device,
295        );
296
297        let output = module.forward(input);
298
299        let expected = TensorData::from([[
300            [[0.3367, 0.1288], [0.2345, 0.2303]],
301            [[-1.1228, -0.1863], [2.2078, -0.6380]],
302            [[0.4616, 0.2673], [0.5348, 0.8093]],
303        ]]);
304        output
305            .to_data()
306            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
307    }
308
309    #[test]
310    fn forward_even_size_uses_asymmetric_positive_side_window() {
311        type FT = FloatElem<TestBackend>;
312        let device = Default::default();
313        let module = LocalResponseNormConfig::new(2)
314            .with_alpha(1.0)
315            .with_beta(1.0)
316            .with_k(0.0)
317            .init();
318        let input =
319            Tensor::<TestBackend, 3>::from_data(TensorData::from([[[1.0], [2.0], [4.0]]]), &device);
320
321        let output = module.forward(input);
322
323        // For size=2, the implementation pads asymmetrically and uses:
324        // c0 -> avg([1^2, 2^2]) = 2.5
325        // c1 -> avg([2^2, 4^2]) = 10.0
326        // c2 -> avg([4^2, 0]) = 8.0
327        let expected = TensorData::from([[[0.4], [0.2], [0.5]]]);
328        output
329            .to_data()
330            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-5, 1e-6));
331    }
332
333    #[test]
334    fn forward_3d() {
335        // size=3, input [1,4,6]
336        type FT = FloatElem<TestBackend>;
337        let device = Default::default();
338        let module = LocalResponseNormConfig::new(3).init();
339        let input = Tensor::<TestBackend, 3>::from_data(
340            TensorData::from([[
341                [1.9269, 1.4873, 0.9007, -2.1055, 0.6784, -1.2345],
342                [-0.0431, -1.6047, 0.3559, -0.6866, -0.4934, 0.2415],
343                [-1.1109, 0.0915, -2.3169, -0.2168, -0.3097, -0.3957],
344                [0.8034, -0.6216, -0.5920, -0.0631, -0.8286, 0.3309],
345            ]]),
346            &device,
347        );
348
349        let output = module.forward(input);
350
351        let expected = TensorData::from([[
352            [1.9267, 1.4871, 0.9007, -2.1053, 0.6784, -1.2345],
353            [-0.0431, -1.6045, 0.3558, -0.6865, -0.4933, 0.2415],
354            [-1.1109, 0.0915, -2.3166, -0.2168, -0.3097, -0.3957],
355            [0.8034, -0.6216, -0.5919, -0.0631, -0.8285, 0.3309],
356        ]]);
357        output
358            .to_data()
359            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
360    }
361
362    #[test]
363    fn forward_5d() {
364        // size=3, input [1,3,2,2,2]
365        type FT = FloatElem<TestBackend>;
366        let device = Default::default();
367        let module = LocalResponseNormConfig::new(3).init();
368        let input = Tensor::<TestBackend, 5>::from_data(
369            TensorData::from([[
370                [
371                    [[1.9269, 1.4873], [0.9007, -2.1055]],
372                    [[0.6784, -1.2345], [-0.0431, -1.6047]],
373                ],
374                [
375                    [[0.3559, -0.6866], [-0.4934, 0.2415]],
376                    [[-1.1109, 0.0915], [-2.3169, -0.2168]],
377                ],
378                [
379                    [[-0.3097, -0.3957], [0.8034, -0.6216]],
380                    [[-0.5920, -0.0631], [-0.8286, 0.3309]],
381                ],
382            ]]),
383            &device,
384        );
385
386        let output = module.forward(input);
387
388        let expected = TensorData::from([[
389            [
390                [[1.9267, 1.4872], [0.9007, -2.1053]],
391                [[0.6784, -1.2345], [-0.0431, -1.6046]],
392            ],
393            [
394                [[0.3558, -0.6866], [-0.4933, 0.2415]],
395                [[-1.1108, 0.0915], [-2.3166, -0.2168]],
396            ],
397            [
398                [[-0.3097, -0.3957], [0.8034, -0.6216]],
399                [[-0.5920, -0.0631], [-0.8284, 0.3309]],
400            ],
401        ]]);
402        output
403            .to_data()
404            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
405    }
406
407    // --- Edge case tests ---
408
409    #[test]
410    fn forward_size_1() {
411        // size=1: window covers only self-channel, input [1,3,3,3]
412        type FT = FloatElem<TestBackend>;
413        let device = Default::default();
414        let module = LocalResponseNormConfig::new(1).init();
415        let input = Tensor::<TestBackend, 4>::from_data(
416            TensorData::from([[
417                [
418                    [1.9269, 1.4873, 0.9007],
419                    [-2.1055, 0.6784, -1.2345],
420                    [-0.0431, -1.6047, -0.7521],
421                ],
422                [
423                    [1.6487, -0.3925, 0.2415],
424                    [-1.1109, 0.0915, -2.3169],
425                    [-0.2168, -1.3847, -0.8712],
426                ],
427                [
428                    [-0.2234, -0.6216, -0.5920],
429                    [-0.0631, -0.8286, 0.3309],
430                    [-1.5576, 0.9956, -0.8798],
431                ],
432            ]]),
433            &device,
434        );
435
436        let output = module.forward(input);
437
438        let expected = TensorData::from([[
439            [
440                [1.9264, 1.4870, 0.9007],
441                [-2.1048, 0.6784, -1.2344],
442                [-0.0431, -1.6044, -0.7521],
443            ],
444            [
445                [1.6484, -0.3925, 0.2415],
446                [-1.1108, 0.0915, -2.3160],
447                [-0.2168, -1.3845, -0.8712],
448            ],
449            [
450                [-0.2234, -0.6216, -0.5920],
451                [-0.0631, -0.8285, 0.3309],
452                [-1.5573, 0.9956, -0.8797],
453            ],
454        ]]);
455        output
456            .to_data()
457            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
458    }
459
460    #[test]
461    fn forward_c_less_than_size() {
462        // C=2 < size=5, input [1,2,3,3]
463        type FT = FloatElem<TestBackend>;
464        let device = Default::default();
465        let module = LocalResponseNormConfig::new(5).init();
466        let input = Tensor::<TestBackend, 4>::from_data(
467            TensorData::from([[
468                [
469                    [1.9269, 1.4873, -0.4974],
470                    [0.4396, -0.7581, 1.0783],
471                    [0.8008, 1.6806, 0.3559],
472                ],
473                [
474                    [-0.6866, 0.6105, 1.3347],
475                    [-0.2316, 0.0418, -0.2516],
476                    [0.8599, -0.3097, -0.3957],
477                ],
478            ]]),
479            &device,
480        );
481
482        let output = module.forward(input);
483
484        let expected = TensorData::from([[
485            [
486                [1.9268, 1.4872, -0.4974],
487                [0.4396, -0.7581, 1.0783],
488                [0.8008, 1.6805, 0.3559],
489            ],
490            [
491                [-0.6866, 0.6104, 1.3347],
492                [-0.2316, 0.0418, -0.2516],
493                [0.8598, -0.3097, -0.3957],
494            ],
495        ]]);
496        output
497            .to_data()
498            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(5e-3, 1e-4));
499    }
500
501    #[test]
502    fn forward_multi_batch() {
503        // N=2, size=5, input [2,3,4,4]
504        type FT = FloatElem<TestBackend>;
505        let device = Default::default();
506        let module = LocalResponseNormConfig::new(5).init();
507        let input = Tensor::<TestBackend, 4>::from_data(
508            TensorData::from([
509                [
510                    [
511                        [1.9269, 1.4873, 0.9007, -2.1055],
512                        [0.6784, -1.2345, -0.0431, -1.6047],
513                        [-0.7521, 1.6487, -0.3925, -1.4036],
514                        [-0.7279, -0.5594, -0.7688, 0.7624],
515                    ],
516                    [
517                        [1.6423, -0.1596, -0.4974, 0.4396],
518                        [-0.7581, 1.0783, 0.8008, 1.6806],
519                        [1.2791, 1.2964, 0.6105, 1.3347],
520                        [-0.2316, 0.0418, -0.2516, 0.8599],
521                    ],
522                    [
523                        [-1.3847, -0.8712, -0.2234, 1.7174],
524                        [0.3189, -0.4245, 0.3057, -0.7746],
525                        [-1.5576, 0.9956, -0.8798, -0.6011],
526                        [-1.2742, 2.1228, -1.2347, -0.4879],
527                    ],
528                ],
529                [
530                    [
531                        [-0.9138, -0.6581, 0.0780, 0.5258],
532                        [-0.4880, 1.1914, -0.8140, -0.7360],
533                        [-1.4032, 0.0360, -0.0635, 0.6756],
534                        [-0.0978, 1.8446, -1.1845, 1.3835],
535                    ],
536                    [
537                        [1.4451, 0.8564, 2.2181, 0.5232],
538                        [0.3466, -0.1973, -1.0546, 1.2780],
539                        [-0.1722, 0.5238, 0.0566, 0.4263],
540                        [0.5750, -0.6417, -2.2064, -0.7508],
541                    ],
542                    [
543                        [0.0109, -0.3387, -1.3407, -0.5854],
544                        [0.5362, 0.5246, 1.1412, 0.0516],
545                        [0.7440, -0.4816, -1.0495, 0.6039],
546                        [-1.7223, -0.8278, 1.3347, 0.4835],
547                    ],
548                ],
549            ]),
550            &device,
551        );
552
553        let output = module.forward(input);
554
555        let out_data = output.to_data();
556        assert_eq!(out_data.shape, [2, 3, 4, 4].into());
557        let expected_full = TensorData::from([
558            [
559                [
560                    [1.9267, 1.4872, 0.9007, -2.1053],
561                    [0.6784, -1.2345, -0.0431, -1.6045],
562                    [-0.7521, 1.6486, -0.3925, -1.4035],
563                    [-0.7279, -0.5594, -0.7688, 0.7624],
564                ],
565                [
566                    [1.6421, -0.1596, -0.4974, 0.4395],
567                    [-0.7581, 1.0783, 0.8008, 1.6805],
568                    [1.2790, 1.2963, 0.6105, 1.3347],
569                    [-0.2316, 0.0418, -0.2516, 0.8598],
570                ],
571                [
572                    [-1.3845, -0.8712, -0.2234, 1.7172],
573                    [0.3189, -0.4245, 0.3057, -0.7745],
574                    [-1.5575, 0.9956, -0.8798, -0.6011],
575                    [-1.2741, 2.1226, -1.2346, -0.4879],
576                ],
577            ],
578            [
579                [
580                    [-0.9138, -0.6581, 0.0780, 0.5258],
581                    [-0.4880, 1.1913, -0.8140, -0.7360],
582                    [-1.4032, 0.0360, -0.0635, 0.6756],
583                    [-0.0978, 1.8445, -1.1844, 1.3835],
584                ],
585                [
586                    [1.4451, 0.8564, 2.2179, 0.5232],
587                    [0.3466, -0.1973, -1.0545, 1.2780],
588                    [-0.1722, 0.5238, 0.0566, 0.4263],
589                    [0.5750, -0.6417, -2.2061, -0.7508],
590                ],
591                [
592                    [0.0109, -0.3387, -1.3405, -0.5854],
593                    [0.5362, 0.5246, 1.1411, 0.0516],
594                    [0.7439, -0.4816, -1.0494, 0.6039],
595                    [-1.7222, -0.8277, 1.3345, 0.4835],
596                ],
597            ],
598        ]);
599        out_data.assert_approx_eq::<FT>(&expected_full, Tolerance::rel_abs(5e-3, 1e-4));
600    }
601
602    // --- Validation / panic tests ---
603
604    #[test]
605    #[should_panic(expected = "size must be greater than 0")]
606    fn config_size_zero_panics() {
607        LocalResponseNormConfig::new(0).init();
608    }
609
610    #[test]
611    #[should_panic(expected = "LocalResponseNorm requires input rank >= 3")]
612    fn forward_rank_2_panics() {
613        let module = LocalResponseNormConfig::new(3).init();
614        let input = Tensor::<TestBackend, 2>::zeros([2, 4], &Default::default());
615        module.forward(input);
616    }
617
618    // --- Autodiff ---
619
620    #[cfg(feature = "std")]
621    #[test]
622    fn backward() {
623        type FT = FloatElem<TestAutodiffBackend>;
624        let device = Default::default();
625        let module = LocalResponseNormConfig::new(5).init();
626        let input = Tensor::<TestAutodiffBackend, 4>::from_data(
627            TensorData::from([[
628                [
629                    [1.9269, 1.4873, 0.9007, -2.1055],
630                    [0.6784, -1.2345, -0.0431, -1.6047],
631                    [-0.7521, 1.6487, -0.3925, -1.4036],
632                    [-0.7279, -0.5594, -0.7688, 0.7624],
633                ],
634                [
635                    [1.6423, -0.1596, -0.4974, 0.4396],
636                    [-0.7581, 1.0783, 0.8008, 1.6806],
637                    [1.2791, 1.2964, 0.6105, 1.3347],
638                    [-0.2316, 0.0418, -0.2516, 0.8599],
639                ],
640                [
641                    [-1.3847, -0.8712, -0.2234, 1.7174],
642                    [0.3189, -0.4245, 0.3057, -0.7746],
643                    [-1.5576, 0.9956, -0.8798, -0.6011],
644                    [-1.2742, 2.1228, -1.2347, -0.4879],
645                ],
646            ]]),
647            &device,
648        )
649        .require_grad();
650
651        let output = module.forward(input.clone());
652        let grads = output.sum().backward();
653        let input_grad = input.grad(&grads).unwrap();
654
655        assert_eq!(input_grad.dims(), [1, 3, 4, 4]);
656
657        let expected_grad = TensorData::from([[
658            [
659                [0.9997, 0.9999, 1.0000, 0.9999],
660                [1.0000, 0.9999, 1.0000, 0.9999],
661                [0.9999, 0.9997, 1.0000, 0.9999],
662                [0.9999, 1.0000, 0.9999, 1.0000],
663            ],
664            [
665                [0.9998, 1.0000, 1.0000, 0.9999],
666                [1.0000, 1.0000, 1.0000, 0.9999],
667                [1.0000, 0.9998, 1.0000, 1.0000],
668                [1.0000, 0.9999, 1.0000, 0.9999],
669            ],
670            [
671                [1.0000, 1.0000, 1.0000, 0.9999],
672                [1.0000, 1.0000, 1.0000, 0.9999],
673                [0.9999, 0.9998, 1.0000, 0.9999],
674                [0.9999, 0.9998, 0.9999, 1.0000],
675            ],
676        ]]);
677        input_grad
678            .to_data()
679            .assert_approx_eq::<FT>(&expected_grad, Tolerance::rel_abs(5e-3, 1e-4));
680    }
681
682    // --- Display ---
683
684    #[test]
685    fn display() {
686        let config = LocalResponseNormConfig::new(5);
687        let module = config.init();
688        assert_eq!(
689            format!("{module}"),
690            "LocalResponseNorm {size: 5, alpha: 0.0001, beta: 0.75, k: 1}"
691        );
692    }
693}