burn_core/nn/norm/
instance.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, ModuleDisplay};
5use crate::module::{Module, Param};
6use crate::nn::norm::group_norm;
7use crate::nn::Initializer;
8use crate::tensor::{backend::Backend, Tensor};
9
10/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).
11#[derive(Debug, Config)]
12pub struct InstanceNormConfig {
13    /// The number of channels expected in the input
14    pub num_channels: usize,
15    /// A value required for numerical stability. Default: 1e-5
16    #[config(default = 1e-5)]
17    pub epsilon: f64,
18    /// A boolean value that when set to `true`, this module has learnable
19    /// per-channel affine parameters initialized to ones (for weights)
20    /// and zeros (for biases). Default: `true`
21    #[config(default = true)]
22    pub affine: bool,
23}
24
25/// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022)
26///
27/// Should be created using [InstanceNormConfig](InstanceNormConfig).
28#[derive(Module, Debug)]
29#[module(custom_display)]
30pub struct InstanceNorm<B: Backend> {
31    /// The learnable weight
32    pub gamma: Option<Param<Tensor<B, 1>>>,
33    /// The learnable bias
34    pub beta: Option<Param<Tensor<B, 1>>>,
35    /// The number of channels expected in the input
36    pub num_channels: usize,
37    /// A value required for numerical stability
38    pub epsilon: f64,
39    /// A boolean value that when set to `true`, this module has learnable
40    pub affine: bool,
41}
42
43impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
44    fn custom_settings(&self) -> Option<DisplaySettings> {
45        DisplaySettings::new()
46            .with_new_line_after_attribute(false)
47            .optional()
48    }
49
50    fn custom_content(&self, content: Content) -> Option<Content> {
51        content
52            .add("num_channels", &self.num_channels)
53            .add("epsilon", &self.epsilon)
54            .add("affine", &self.affine)
55            .optional()
56    }
57}
58
59impl InstanceNormConfig {
60    /// Initialize a new [instance norm](InstanceNorm) module.
61    pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
62        let (gamma, beta) = if self.affine {
63            let gamma = Initializer::Ones.init([self.num_channels], device);
64            let beta = Initializer::Zeros.init([self.num_channels], device);
65
66            (Some(gamma), Some(beta))
67        } else {
68            (None, None)
69        };
70
71        InstanceNorm {
72            gamma,
73            beta,
74            num_channels: self.num_channels,
75            epsilon: self.epsilon,
76            affine: self.affine,
77        }
78    }
79}
80
81impl<B: Backend> InstanceNorm<B> {
82    /// Applies the forward pass on the input tensor.
83    ///
84    /// See also [InstanceNormConfig](InstanceNormConfig) for more information.
85    ///
86    /// # Shapes
87    ///
88    /// - input: `[batch_size, num_channels, *]`
89    /// - output: `[batch_size, num_channels, *]`
90    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
91        // Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.
92        let num_groups = self.num_channels;
93
94        let gamma = self.gamma.as_ref().map(|x| x.val());
95        let beta = self.beta.as_ref().map(|x| x.val());
96
97        group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::tensor::TensorData;
105    use crate::TestBackend;
106    use alloc::format;
107
108    #[test]
109    fn instance_norm_forward_affine_false() {
110        let device = Default::default();
111        let module = InstanceNormConfig::new(6)
112            .with_affine(false)
113            .init::<TestBackend>(&device);
114
115        let input = Tensor::<TestBackend, 3>::from_data(
116            TensorData::from([
117                [
118                    [-0.3034, 0.2726, -0.9659],
119                    [-1.1845, 1.4078, 0.9774],
120                    [0.3963, -1.3738, 1.4125],
121                    [1.0682, 0.3604, 0.3985],
122                    [-0.4957, -0.4461, -0.9721],
123                    [1.5157, -0.1546, -0.5596],
124                ],
125                [
126                    [-1.6698, -0.4040, -0.7927],
127                    [0.3736, -0.0975, -0.1351],
128                    [-0.9461, 0.5461, -0.6334],
129                    [-1.0919, -0.1158, 0.1213],
130                    [-0.9535, 0.1281, 0.4372],
131                    [-0.2845, 0.3488, 0.5641],
132                ],
133            ]),
134            &device,
135        );
136
137        let output = module.forward(input);
138
139        let expected = TensorData::from([
140            [
141                [0.0569, 1.1952, -1.2522],
142                [-1.3971, 0.8883, 0.5088],
143                [0.2183, -1.3192, 1.1009],
144                [1.4126, -0.7649, -0.6477],
145                [0.5999, 0.8091, -1.409],
146                [1.39, -0.4696, -0.9205],
147            ],
148            [
149                [-1.3492, 1.0417, 0.3075],
150                [1.411, -0.6243, -0.7867],
151                [-0.9363, 1.386, -0.4497],
152                [-1.3899, 0.4692, 0.9208],
153                [-1.3822, 0.4319, 0.9503],
154                [-1.3714, 0.3868, 0.9846],
155            ],
156        ]);
157        output.to_data().assert_approx_eq(&expected, 3);
158    }
159
160    #[test]
161    fn instance_norm_forward_affine_true() {
162        let device = Default::default();
163        let module = InstanceNormConfig::new(6)
164            .with_affine(true)
165            .init::<TestBackend>(&device);
166
167        let input = Tensor::<TestBackend, 3>::from_data(
168            TensorData::from([
169                [
170                    [0.3345, 0.4429, 0.6639],
171                    [0.5041, 0.4175, 0.8437],
172                    [0.6159, 0.3758, 0.4071],
173                    [0.5417, 0.5785, 0.7671],
174                    [0.3837, 0.9883, 0.0420],
175                    [0.4808, 0.8989, 0.6144],
176                ],
177                [
178                    [0.3930, 0.2098, 0.0602],
179                    [0.2298, 0.9425, 0.0333],
180                    [0.7409, 0.8172, 0.8879],
181                    [0.4846, 0.0486, 0.2029],
182                    [0.6741, 0.9765, 0.6864],
183                    [0.2827, 0.5534, 0.2125],
184                ],
185            ]),
186            &device,
187        );
188
189        let output = module.forward(input);
190
191        let expected = TensorData::from([
192            [
193                [-1.06458, -0.2738, 1.33838],
194                [-0.45848, -0.92929, 1.38777],
195                [1.40388, -0.84877, -0.55511],
196                [-0.88515, -0.51245, 1.3976],
197                [-0.22397, 1.32124, -1.09727],
198                [-1.05468, 1.34316, -0.28848],
199            ],
200            [
201                [1.26372, -0.08229, -1.18144],
202                [-0.44049, 1.38403, -0.94354],
203                [-1.23979, 0.03109, 1.2087],
204                [1.32524, -1.08999, -0.23524],
205                [-0.75061, 1.4132, -0.66259],
206                [-0.45469, 1.38697, -0.93228],
207            ],
208        ]);
209        output.to_data().assert_approx_eq(&expected, 3);
210    }
211
212    #[test]
213    fn display() {
214        let config = InstanceNormConfig::new(6);
215        let instance_norm = config.init::<TestBackend>(&Default::default());
216
217        assert_eq!(
218            format!("{}", instance_norm),
219            "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
220        );
221    }
222}