1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, ModuleDisplay};
5use crate::module::{Module, Param};
6use crate::nn::Initializer;
7use crate::nn::norm::group_norm;
8use crate::tensor::{Tensor, backend::Backend};
9
10#[derive(Debug, Config)]
12pub struct InstanceNormConfig {
13 pub num_channels: usize,
15 #[config(default = 1e-5)]
17 pub epsilon: f64,
18 #[config(default = true)]
22 pub affine: bool,
23}
24
25#[derive(Module, Debug)]
29#[module(custom_display)]
30pub struct InstanceNorm<B: Backend> {
31 pub gamma: Option<Param<Tensor<B, 1>>>,
33 pub beta: Option<Param<Tensor<B, 1>>>,
35 pub num_channels: usize,
37 pub epsilon: f64,
39 pub affine: bool,
41}
42
43impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
44 fn custom_settings(&self) -> Option<DisplaySettings> {
45 DisplaySettings::new()
46 .with_new_line_after_attribute(false)
47 .optional()
48 }
49
50 fn custom_content(&self, content: Content) -> Option<Content> {
51 content
52 .add("num_channels", &self.num_channels)
53 .add("epsilon", &self.epsilon)
54 .add("affine", &self.affine)
55 .optional()
56 }
57}
58
59impl InstanceNormConfig {
60 pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
62 let (gamma, beta) = if self.affine {
63 let gamma = Initializer::Ones.init([self.num_channels], device);
64 let beta = Initializer::Zeros.init([self.num_channels], device);
65
66 (Some(gamma), Some(beta))
67 } else {
68 (None, None)
69 };
70
71 InstanceNorm {
72 gamma,
73 beta,
74 num_channels: self.num_channels,
75 epsilon: self.epsilon,
76 affine: self.affine,
77 }
78 }
79}
80
81impl<B: Backend> InstanceNorm<B> {
82 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
91 let num_groups = self.num_channels;
93
94 let gamma = self.gamma.as_ref().map(|x| x.val());
95 let beta = self.beta.as_ref().map(|x| x.val());
96
97 group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::TestBackend;
105 use crate::tensor::TensorData;
106 use alloc::format;
107 use burn_tensor::{Tolerance, ops::FloatElem};
108 type FT = FloatElem<TestBackend>;
109
110 #[test]
111 fn instance_norm_forward_affine_false() {
112 let device = Default::default();
113 let module = InstanceNormConfig::new(6)
114 .with_affine(false)
115 .init::<TestBackend>(&device);
116
117 let input = Tensor::<TestBackend, 3>::from_data(
118 TensorData::from([
119 [
120 [-0.3034, 0.2726, -0.9659],
121 [-1.1845, 1.4078, 0.9774],
122 [0.3963, -1.3738, 1.4125],
123 [1.0682, 0.3604, 0.3985],
124 [-0.4957, -0.4461, -0.9721],
125 [1.5157, -0.1546, -0.5596],
126 ],
127 [
128 [-1.6698, -0.4040, -0.7927],
129 [0.3736, -0.0975, -0.1351],
130 [-0.9461, 0.5461, -0.6334],
131 [-1.0919, -0.1158, 0.1213],
132 [-0.9535, 0.1281, 0.4372],
133 [-0.2845, 0.3488, 0.5641],
134 ],
135 ]),
136 &device,
137 );
138
139 let output = module.forward(input);
140
141 let expected = TensorData::from([
142 [
143 [0.0569, 1.1952, -1.2522],
144 [-1.3971, 0.8883, 0.5088],
145 [0.2183, -1.3192, 1.1009],
146 [1.4126, -0.7649, -0.6477],
147 [0.5999, 0.8091, -1.409],
148 [1.39, -0.4696, -0.9205],
149 ],
150 [
151 [-1.3492, 1.0417, 0.3075],
152 [1.411, -0.6243, -0.7867],
153 [-0.9363, 1.386, -0.4497],
154 [-1.3899, 0.4692, 0.9208],
155 [-1.3822, 0.4319, 0.9503],
156 [-1.3714, 0.3868, 0.9846],
157 ],
158 ]);
159 output
160 .to_data()
161 .assert_approx_eq::<FT>(&expected, Tolerance::default());
162 }
163
164 #[test]
165 fn instance_norm_forward_affine_true() {
166 let device = Default::default();
167 let module = InstanceNormConfig::new(6)
168 .with_affine(true)
169 .init::<TestBackend>(&device);
170
171 let input = Tensor::<TestBackend, 3>::from_data(
172 TensorData::from([
173 [
174 [0.3345, 0.4429, 0.6639],
175 [0.5041, 0.4175, 0.8437],
176 [0.6159, 0.3758, 0.4071],
177 [0.5417, 0.5785, 0.7671],
178 [0.3837, 0.9883, 0.0420],
179 [0.4808, 0.8989, 0.6144],
180 ],
181 [
182 [0.3930, 0.2098, 0.0602],
183 [0.2298, 0.9425, 0.0333],
184 [0.7409, 0.8172, 0.8879],
185 [0.4846, 0.0486, 0.2029],
186 [0.6741, 0.9765, 0.6864],
187 [0.2827, 0.5534, 0.2125],
188 ],
189 ]),
190 &device,
191 );
192
193 let output = module.forward(input);
194
195 let expected = TensorData::from([
196 [
197 [-1.06458, -0.2738, 1.33838],
198 [-0.45848, -0.92929, 1.38777],
199 [1.40388, -0.84877, -0.55511],
200 [-0.88515, -0.51245, 1.3976],
201 [-0.22397, 1.32124, -1.09727],
202 [-1.05468, 1.34316, -0.28848],
203 ],
204 [
205 [1.26372, -0.08229, -1.18144],
206 [-0.44049, 1.38403, -0.94354],
207 [-1.23828, 0.03109, 1.2072],
208 [1.32524, -1.08999, -0.23524],
209 [-0.75061, 1.4132, -0.66259],
210 [-0.45469, 1.38697, -0.93228],
211 ],
212 ]);
213 output
214 .to_data()
215 .assert_approx_eq::<FT>(&expected, Tolerance::default());
216 }
217
218 #[test]
219 fn display() {
220 let config = InstanceNormConfig::new(6);
221 let instance_norm = config.init::<TestBackend>(&Default::default());
222
223 assert_eq!(
224 format!("{instance_norm}"),
225 "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
226 );
227 }
228}