1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, ModuleDisplay};
5use crate::module::{Module, Param};
6use crate::nn::norm::group_norm;
7use crate::nn::Initializer;
8use crate::tensor::{backend::Backend, Tensor};
9
10#[derive(Debug, Config)]
12pub struct InstanceNormConfig {
13 pub num_channels: usize,
15 #[config(default = 1e-5)]
17 pub epsilon: f64,
18 #[config(default = true)]
22 pub affine: bool,
23}
24
25#[derive(Module, Debug)]
29#[module(custom_display)]
30pub struct InstanceNorm<B: Backend> {
31 pub gamma: Option<Param<Tensor<B, 1>>>,
33 pub beta: Option<Param<Tensor<B, 1>>>,
35 pub num_channels: usize,
37 pub epsilon: f64,
39 pub affine: bool,
41}
42
43impl<B: Backend> ModuleDisplay for InstanceNorm<B> {
44 fn custom_settings(&self) -> Option<DisplaySettings> {
45 DisplaySettings::new()
46 .with_new_line_after_attribute(false)
47 .optional()
48 }
49
50 fn custom_content(&self, content: Content) -> Option<Content> {
51 content
52 .add("num_channels", &self.num_channels)
53 .add("epsilon", &self.epsilon)
54 .add("affine", &self.affine)
55 .optional()
56 }
57}
58
59impl InstanceNormConfig {
60 pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {
62 let (gamma, beta) = if self.affine {
63 let gamma = Initializer::Ones.init([self.num_channels], device);
64 let beta = Initializer::Zeros.init([self.num_channels], device);
65
66 (Some(gamma), Some(beta))
67 } else {
68 (None, None)
69 };
70
71 InstanceNorm {
72 gamma,
73 beta,
74 num_channels: self.num_channels,
75 epsilon: self.epsilon,
76 affine: self.affine,
77 }
78 }
79}
80
81impl<B: Backend> InstanceNorm<B> {
82 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
91 let num_groups = self.num_channels;
93
94 let gamma = self.gamma.as_ref().map(|x| x.val());
95 let beta = self.beta.as_ref().map(|x| x.val());
96
97 group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::tensor::TensorData;
105 use crate::TestBackend;
106 use alloc::format;
107
108 #[test]
109 fn instance_norm_forward_affine_false() {
110 let device = Default::default();
111 let module = InstanceNormConfig::new(6)
112 .with_affine(false)
113 .init::<TestBackend>(&device);
114
115 let input = Tensor::<TestBackend, 3>::from_data(
116 TensorData::from([
117 [
118 [-0.3034, 0.2726, -0.9659],
119 [-1.1845, 1.4078, 0.9774],
120 [0.3963, -1.3738, 1.4125],
121 [1.0682, 0.3604, 0.3985],
122 [-0.4957, -0.4461, -0.9721],
123 [1.5157, -0.1546, -0.5596],
124 ],
125 [
126 [-1.6698, -0.4040, -0.7927],
127 [0.3736, -0.0975, -0.1351],
128 [-0.9461, 0.5461, -0.6334],
129 [-1.0919, -0.1158, 0.1213],
130 [-0.9535, 0.1281, 0.4372],
131 [-0.2845, 0.3488, 0.5641],
132 ],
133 ]),
134 &device,
135 );
136
137 let output = module.forward(input);
138
139 let expected = TensorData::from([
140 [
141 [0.0569, 1.1952, -1.2522],
142 [-1.3971, 0.8883, 0.5088],
143 [0.2183, -1.3192, 1.1009],
144 [1.4126, -0.7649, -0.6477],
145 [0.5999, 0.8091, -1.409],
146 [1.39, -0.4696, -0.9205],
147 ],
148 [
149 [-1.3492, 1.0417, 0.3075],
150 [1.411, -0.6243, -0.7867],
151 [-0.9363, 1.386, -0.4497],
152 [-1.3899, 0.4692, 0.9208],
153 [-1.3822, 0.4319, 0.9503],
154 [-1.3714, 0.3868, 0.9846],
155 ],
156 ]);
157 output.to_data().assert_approx_eq(&expected, 3);
158 }
159
160 #[test]
161 fn instance_norm_forward_affine_true() {
162 let device = Default::default();
163 let module = InstanceNormConfig::new(6)
164 .with_affine(true)
165 .init::<TestBackend>(&device);
166
167 let input = Tensor::<TestBackend, 3>::from_data(
168 TensorData::from([
169 [
170 [0.3345, 0.4429, 0.6639],
171 [0.5041, 0.4175, 0.8437],
172 [0.6159, 0.3758, 0.4071],
173 [0.5417, 0.5785, 0.7671],
174 [0.3837, 0.9883, 0.0420],
175 [0.4808, 0.8989, 0.6144],
176 ],
177 [
178 [0.3930, 0.2098, 0.0602],
179 [0.2298, 0.9425, 0.0333],
180 [0.7409, 0.8172, 0.8879],
181 [0.4846, 0.0486, 0.2029],
182 [0.6741, 0.9765, 0.6864],
183 [0.2827, 0.5534, 0.2125],
184 ],
185 ]),
186 &device,
187 );
188
189 let output = module.forward(input);
190
191 let expected = TensorData::from([
192 [
193 [-1.06458, -0.2738, 1.33838],
194 [-0.45848, -0.92929, 1.38777],
195 [1.40388, -0.84877, -0.55511],
196 [-0.88515, -0.51245, 1.3976],
197 [-0.22397, 1.32124, -1.09727],
198 [-1.05468, 1.34316, -0.28848],
199 ],
200 [
201 [1.26372, -0.08229, -1.18144],
202 [-0.44049, 1.38403, -0.94354],
203 [-1.23979, 0.03109, 1.2087],
204 [1.32524, -1.08999, -0.23524],
205 [-0.75061, 1.4132, -0.66259],
206 [-0.45469, 1.38697, -0.93228],
207 ],
208 ]);
209 output.to_data().assert_approx_eq(&expected, 3);
210 }
211
212 #[test]
213 fn display() {
214 let config = InstanceNormConfig::new(6);
215 let instance_norm = config.init::<TestBackend>(&Default::default());
216
217 assert_eq!(
218 format!("{}", instance_norm),
219 "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
220 );
221 }
222}