1use alloc::format;
2
3use burn_core as burn;
4
5use crate::PaddingConfig2d;
6use burn::config::Config;
7use burn::module::Initializer;
8use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11use burn::tensor::module::conv2d;
12use burn::tensor::ops::PaddedConvOptions;
13
14use crate::conv::checks;
15
16#[derive(Config, Debug)]
18pub struct Conv2dConfig {
19 pub channels: [usize; 2],
21 pub kernel_size: [usize; 2],
23 #[config(default = "[1, 1]")]
25 pub stride: [usize; 2],
26 #[config(default = "[1, 1]")]
28 pub dilation: [usize; 2],
29 #[config(default = "1")]
31 pub groups: usize,
32 #[config(default = "PaddingConfig2d::Valid")]
37 pub padding: PaddingConfig2d,
38 #[config(default = true)]
40 pub bias: bool,
41 #[config(
43 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
44 )]
45 pub initializer: Initializer,
46}
47
48#[derive(Module, Debug)]
52#[module(custom_display)]
53pub struct Conv2d<B: Backend> {
54 pub weight: Param<Tensor<B, 4>>,
56 pub bias: Option<Param<Tensor<B, 1>>>,
58 pub stride: [usize; 2],
60 pub kernel_size: [usize; 2],
62 pub dilation: [usize; 2],
64 pub groups: usize,
66 pub padding: PaddingConfig2d,
68}
69
70impl Conv2dConfig {
71 pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
73 checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
74
75 let shape = [
76 self.channels[1],
77 self.channels[0] / self.groups,
78 self.kernel_size[0],
79 self.kernel_size[1],
80 ];
81
82 let k = self.kernel_size.iter().product::<usize>();
83 let fan_in = self.channels[0] / self.groups * k;
84 let fan_out = self.channels[1] / self.groups * k;
85
86 let weight = self
87 .initializer
88 .init_with(shape, Some(fan_in), Some(fan_out), device);
89 let mut bias = None;
90
91 if self.bias {
92 bias = Some(self.initializer.init_with(
93 [self.channels[1]],
94 Some(fan_in),
95 Some(fan_out),
96 device,
97 ));
98 }
99
100 Conv2d {
101 weight,
102 bias,
103 stride: self.stride,
104 kernel_size: self.kernel_size,
105 dilation: self.dilation,
106 padding: self.padding.clone(),
107 groups: self.groups,
108 }
109 }
110}
111
112impl<B: Backend> ModuleDisplay for Conv2d<B> {
113 fn custom_settings(&self) -> Option<DisplaySettings> {
114 DisplaySettings::new()
115 .with_new_line_after_attribute(false)
116 .optional()
117 }
118
119 fn custom_content(&self, content: Content) -> Option<Content> {
120 let stride = format!("{:?}", self.stride);
122 let kernel_size = format!("{:?}", self.kernel_size);
123 let dilation = format!("{:?}", self.dilation);
124 let [channels_out, group_channels_in, _, _] = self.weight.dims();
125 let channels_in = group_channels_in * self.groups;
126 let ch_out = format!("{:?}", channels_out);
127 let ch_in = format!("{:?}", channels_in);
128 content
129 .add("ch_in", &ch_in)
130 .add("ch_out", &ch_out)
131 .add("stride", &stride)
132 .add("kernel_size", &kernel_size)
133 .add("dilation", &dilation)
134 .add("groups", &self.groups)
135 .add_debug_attribute("padding", &self.padding)
136 .optional()
137 }
138}
139
140impl<B: Backend> Conv2d<B> {
141 pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
164 let [_batch_size, _channels_in, height_in, width_in] = input.dims();
165
166 let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
168 height_in,
169 width_in,
170 &self.kernel_size,
171 &self.stride,
172 );
173
174 let options = PaddedConvOptions::asymmetric(
175 self.stride,
176 [top, left],
177 [bottom, right],
178 self.dilation,
179 self.groups,
180 );
181
182 conv2d(
183 input,
184 self.weight.val(),
185 self.bias.as_ref().map(|bias| bias.val()),
186 options,
187 )
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use burn::tensor::ops::FloatElem;
194 use burn::tensor::{ElementConversion, Tolerance};
195
196 use super::*;
197 use crate::TestBackend;
198 use burn::tensor::TensorData;
199 type FT = FloatElem<TestBackend>; #[test]
202 fn initializer_default() {
203 let device = Default::default();
204 TestBackend::seed(&device, 0);
205
206 let config = Conv2dConfig::new([5, 1], [5, 5]);
207 let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
208 let k = (config.groups as f64 / k).sqrt().elem::<FT>();
209 let conv = config.init::<TestBackend>(&device);
210
211 conv.weight.to_data().assert_within_range(-k..k);
212 }
213
214 #[test]
215 fn initializer_zeros() {
216 let device = Default::default();
217 TestBackend::seed(&device, 0);
218
219 let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
220 let conv = config.init::<TestBackend>(&device);
221
222 assert_eq!(config.initializer, Initializer::Zeros);
223 conv.weight.to_data().assert_approx_eq::<FT>(
224 &TensorData::zeros::<FT, _>(conv.weight.shape()),
225 Tolerance::default(),
226 );
227 }
228
229 #[test]
230 fn initializer_fan_out() {
231 let device = Default::default();
232 TestBackend::seed(&device, 0);
233
234 let init = Initializer::KaimingUniform {
235 gain: 1.0 / 3.0f64.sqrt(),
236 fan_out_only: true, };
238
239 let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
240 let _ = config.init::<TestBackend>(&device);
241
242 assert_eq!(config.initializer, init);
243 }
244
245 #[test]
246 fn initializer_fan_with_groups_is_valid() {
247 let device = Default::default();
248 TestBackend::seed(&device, 0);
249
250 let init = Initializer::KaimingUniform {
251 gain: 1.0 / 3.0f64.sqrt(),
252 fan_out_only: true,
253 };
254
255 let config = Conv2dConfig::new([4, 4], [1, 1])
256 .with_initializer(init.clone())
257 .with_groups(4);
258 let _ = config.init::<TestBackend>(&device);
259
260 assert_eq!(config.initializer, init);
261 }
262
263 #[test]
264 #[should_panic = "Both channels must be divisible by the number of groups."]
265 fn channels_with_groups_is_invalid() {
266 let device = Default::default();
267 let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
268 let _ = config.init::<TestBackend>(&device);
269 }
270
271 #[test]
272 fn same_with_even_kernel_uses_asymmetric_padding() {
273 let device = Default::default();
274 let config = Conv2dConfig::new([4, 4], [2, 2])
275 .with_padding(PaddingConfig2d::Same)
276 .with_initializer(Initializer::Constant { value: 1.0 })
277 .with_bias(false);
278 let conv = config.init::<TestBackend>(&device);
279
280 let input = Tensor::<TestBackend, 4>::ones([1, 4, 5, 5], &device);
282 let output = conv.forward(input);
283
284 assert_eq!(output.dims(), [1, 4, 5, 5]);
286 }
287
288 #[test]
289 fn display() {
290 let config = Conv2dConfig::new([5, 1], [5, 5]);
291 let conv = config.init::<TestBackend>(&Default::default());
292
293 assert_eq!(
294 alloc::format!("{conv}"),
295 "Conv2d {ch_in: 5, ch_out: 1, stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
296 );
297 }
298
299 #[test]
300 #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
301 fn input_channels_mismatch() {
302 let config = Conv2dConfig::new([5, 3], [3, 3]);
303 let conv = config.init::<TestBackend>(&Default::default());
304
305 let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
306 let _ = conv.forward(input);
307 }
308
309 #[test]
310 fn asymmetric_padding_forward() {
311 let device = Default::default();
312 let config = Conv2dConfig::new([2, 3], [3, 3])
314 .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4))
315 .with_initializer(Initializer::Constant { value: 1.0 })
316 .with_bias(false);
317 let conv = config.init::<TestBackend>(&device);
318
319 let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
321 let output = conv.forward(input);
322
323 assert_eq!(output.dims(), [1, 3, 6, 9]);
326 }
327
328 #[test]
329 fn symmetric_explicit_padding_forward() {
330 let device = Default::default();
331 let config = Conv2dConfig::new([2, 3], [3, 3])
333 .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))
334 .with_initializer(Initializer::Constant { value: 1.0 })
335 .with_bias(false);
336 let conv = config.init::<TestBackend>(&device);
337
338 let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
340 let output = conv.forward(input);
341
342 assert_eq!(output.dims(), [1, 3, 6, 7]);
345 }
346}