burn_core/nn/norm/
instance.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, ModuleDisplay};
5use crate::module::{Module, Param};
6use crate::nn::Initializer;
7use crate::nn::norm::group_norm;
8use crate::tensor::{Tensor, backend::Backend};
9
10/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).
11#[derive(Debug, Config)]
12pub struct InstanceNormConfig {
13    /// The number of channels expected in the input
14    pub num_channels: usize,
15    /// A value required for numerical stability. Default: 1e-5
16    #[config(default = 1e-5)]
17    pub epsilon: f64,
18    /// A boolean value that when set to `true`, this module has learnable
19    /// per-channel affine parameters initialized to ones (for weights)
20    /// and zeros (for biases). Default: `true`
21    #[config(default = true)]
22    pub affine: bool,
23}
24
25/// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022)
26///
27/// Should be created using [InstanceNormConfig](InstanceNormConfig).
28#[derive(Module, Debug)]
29#[module(custom_display)]
30pub struct InstanceNorm<B: Backend> {
31    /// The learnable weight
32    pub gamma: Option<Param<Tensor<B, 1>>>,
33    /// The learnable bias
34    pub beta: Option<Param<Tensor<B, 1>>>,
35    /// The number of channels expected in the input
36    pub num_channels: usize,
37    /// A value required for numerical stability
38    pub epsilon: f64,
39    /// A boolean value that when set to `true`, this module has learnable
40    pub affine: bool,
41}
42
43impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
44    fn custom_settings(&self) -> Option<DisplaySettings> {
45        DisplaySettings::new()
46            .with_new_line_after_attribute(false)
47            .optional()
48    }
49
50    fn custom_content(&self, content: Content) -> Option<Content> {
51        content
52            .add("num_channels", &self.num_channels)
53            .add("epsilon", &self.epsilon)
54            .add("affine", &self.affine)
55            .optional()
56    }
57}
58
59impl InstanceNormConfig {
60    /// Initialize a new [instance norm](InstanceNorm) module.
61    pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
62        let (gamma, beta) = if self.affine {
63            let gamma = Initializer::Ones.init([self.num_channels], device);
64            let beta = Initializer::Zeros.init([self.num_channels], device);
65
66            (Some(gamma), Some(beta))
67        } else {
68            (None, None)
69        };
70
71        InstanceNorm {
72            gamma,
73            beta,
74            num_channels: self.num_channels,
75            epsilon: self.epsilon,
76            affine: self.affine,
77        }
78    }
79}
80
81impl<B: Backend> InstanceNorm<B> {
82    /// Applies the forward pass on the input tensor.
83    ///
84    /// See also [InstanceNormConfig](InstanceNormConfig) for more information.
85    ///
86    /// # Shapes
87    ///
88    /// - input: `[batch_size, num_channels, *]`
89    /// - output: `[batch_size, num_channels, *]`
90    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
91        // Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.
92        let num_groups = self.num_channels;
93
94        let gamma = self.gamma.as_ref().map(|x| x.val());
95        let beta = self.beta.as_ref().map(|x| x.val());
96
97        group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::TestBackend;
105    use crate::tensor::TensorData;
106    use alloc::format;
107    use burn_tensor::{Tolerance, ops::FloatElem};
108    type FT = FloatElem<TestBackend>;
109
110    #[test]
111    fn instance_norm_forward_affine_false() {
112        let device = Default::default();
113        let module = InstanceNormConfig::new(6)
114            .with_affine(false)
115            .init::<TestBackend>(&device);
116
117        let input = Tensor::<TestBackend, 3>::from_data(
118            TensorData::from([
119                [
120                    [-0.3034, 0.2726, -0.9659],
121                    [-1.1845, 1.4078, 0.9774],
122                    [0.3963, -1.3738, 1.4125],
123                    [1.0682, 0.3604, 0.3985],
124                    [-0.4957, -0.4461, -0.9721],
125                    [1.5157, -0.1546, -0.5596],
126                ],
127                [
128                    [-1.6698, -0.4040, -0.7927],
129                    [0.3736, -0.0975, -0.1351],
130                    [-0.9461, 0.5461, -0.6334],
131                    [-1.0919, -0.1158, 0.1213],
132                    [-0.9535, 0.1281, 0.4372],
133                    [-0.2845, 0.3488, 0.5641],
134                ],
135            ]),
136            &device,
137        );
138
139        let output = module.forward(input);
140
141        let expected = TensorData::from([
142            [
143                [0.0569, 1.1952, -1.2522],
144                [-1.3971, 0.8883, 0.5088],
145                [0.2183, -1.3192, 1.1009],
146                [1.4126, -0.7649, -0.6477],
147                [0.5999, 0.8091, -1.409],
148                [1.39, -0.4696, -0.9205],
149            ],
150            [
151                [-1.3492, 1.0417, 0.3075],
152                [1.411, -0.6243, -0.7867],
153                [-0.9363, 1.386, -0.4497],
154                [-1.3899, 0.4692, 0.9208],
155                [-1.3822, 0.4319, 0.9503],
156                [-1.3714, 0.3868, 0.9846],
157            ],
158        ]);
159        output
160            .to_data()
161            .assert_approx_eq::<FT>(&expected, Tolerance::default());
162    }
163
164    #[test]
165    fn instance_norm_forward_affine_true() {
166        let device = Default::default();
167        let module = InstanceNormConfig::new(6)
168            .with_affine(true)
169            .init::<TestBackend>(&device);
170
171        let input = Tensor::<TestBackend, 3>::from_data(
172            TensorData::from([
173                [
174                    [0.3345, 0.4429, 0.6639],
175                    [0.5041, 0.4175, 0.8437],
176                    [0.6159, 0.3758, 0.4071],
177                    [0.5417, 0.5785, 0.7671],
178                    [0.3837, 0.9883, 0.0420],
179                    [0.4808, 0.8989, 0.6144],
180                ],
181                [
182                    [0.3930, 0.2098, 0.0602],
183                    [0.2298, 0.9425, 0.0333],
184                    [0.7409, 0.8172, 0.8879],
185                    [0.4846, 0.0486, 0.2029],
186                    [0.6741, 0.9765, 0.6864],
187                    [0.2827, 0.5534, 0.2125],
188                ],
189            ]),
190            &device,
191        );
192
193        let output = module.forward(input);
194
195        let expected = TensorData::from([
196            [
197                [-1.06458, -0.2738, 1.33838],
198                [-0.45848, -0.92929, 1.38777],
199                [1.40388, -0.84877, -0.55511],
200                [-0.88515, -0.51245, 1.3976],
201                [-0.22397, 1.32124, -1.09727],
202                [-1.05468, 1.34316, -0.28848],
203            ],
204            [
205                [1.26372, -0.08229, -1.18144],
206                [-0.44049, 1.38403, -0.94354],
207                [-1.23828, 0.03109, 1.2072],
208                [1.32524, -1.08999, -0.23524],
209                [-0.75061, 1.4132, -0.66259],
210                [-0.45469, 1.38697, -0.93228],
211            ],
212        ]);
213        output
214            .to_data()
215            .assert_approx_eq::<FT>(&expected, Tolerance::default());
216    }
217
218    #[test]
219    fn display() {
220        let config = InstanceNormConfig::new(6);
221        let instance_norm = config.init::<TestBackend>(&Default::default());
222
223        assert_eq!(
224            format!("{instance_norm}"),
225            "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
226        );
227    }
228}