Skip to main content

burn_nn/modules/conv/
conv2d.rs

1use alloc::format;
2
3use burn_core as burn;
4
5use crate::PaddingConfig2d;
6use burn::config::Config;
7use burn::module::Initializer;
8use burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11use burn::tensor::module::conv2d;
12use burn::tensor::ops::PaddedConvOptions;
13
14use crate::conv::checks;
15
16/// Configuration to create a [2D convolution](Conv2d) layer, using the [init function](Conv2dConfig::init).
17#[derive(Config, Debug)]
18pub struct Conv2dConfig {
19    /// The number of channels.
20    pub channels: [usize; 2],
21    /// The size of the kernel.
22    pub kernel_size: [usize; 2],
23    /// The stride of the convolution.
24    #[config(default = "[1, 1]")]
25    pub stride: [usize; 2],
26    /// Spacing between kernel elements.
27    #[config(default = "[1, 1]")]
28    pub dilation: [usize; 2],
29    /// Controls the connections between input and output channels.
30    #[config(default = "1")]
31    pub groups: usize,
32    /// The padding configuration.
33    ///
34    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
35    /// will automatically use asymmetric padding to preserve input dimensions.
36    #[config(default = "PaddingConfig2d::Valid")]
37    pub padding: PaddingConfig2d,
38    /// If bias should be added to the output.
39    #[config(default = true)]
40    pub bias: bool,
41    /// The type of function used to initialize neural network parameters
42    #[config(
43        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
44    )]
45    pub initializer: Initializer,
46}
47
48/// Applies a 2D convolution over input tensors.
49///
50/// Should be created with [Conv2dConfig].
51#[derive(Module, Debug)]
52#[module(custom_display)]
53pub struct Conv2d<B: Backend> {
54    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
55    pub weight: Param<Tensor<B, 4>>,
56    /// Tensor of shape `[channels_out]`
57    pub bias: Option<Param<Tensor<B, 1>>>,
58    /// Stride of the convolution.
59    pub stride: [usize; 2],
60    /// Size of the kernel.
61    pub kernel_size: [usize; 2],
62    /// Spacing between kernel elements.
63    pub dilation: [usize; 2],
64    /// Controls the connections between input and output channels.
65    pub groups: usize,
66    /// The padding configuration.
67    pub padding: PaddingConfig2d,
68}
69
70impl Conv2dConfig {
71    /// Initialize a new [conv2d](Conv2d) module.
72    pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {
73        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);
74
75        let shape = [
76            self.channels[1],
77            self.channels[0] / self.groups,
78            self.kernel_size[0],
79            self.kernel_size[1],
80        ];
81
82        let k = self.kernel_size.iter().product::<usize>();
83        let fan_in = self.channels[0] / self.groups * k;
84        let fan_out = self.channels[1] / self.groups * k;
85
86        let weight = self
87            .initializer
88            .init_with(shape, Some(fan_in), Some(fan_out), device);
89        let mut bias = None;
90
91        if self.bias {
92            bias = Some(self.initializer.init_with(
93                [self.channels[1]],
94                Some(fan_in),
95                Some(fan_out),
96                device,
97            ));
98        }
99
100        Conv2d {
101            weight,
102            bias,
103            stride: self.stride,
104            kernel_size: self.kernel_size,
105            dilation: self.dilation,
106            padding: self.padding.clone(),
107            groups: self.groups,
108        }
109    }
110}
111
112impl<B: Backend> ModuleDisplay for Conv2d<B> {
113    fn custom_settings(&self) -> Option<DisplaySettings> {
114        DisplaySettings::new()
115            .with_new_line_after_attribute(false)
116            .optional()
117    }
118
119    fn custom_content(&self, content: Content) -> Option<Content> {
120        // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
121        let stride = format!("{:?}", self.stride);
122        let kernel_size = format!("{:?}", self.kernel_size);
123        let dilation = format!("{:?}", self.dilation);
124        let [channels_out, group_channels_in, _, _] = self.weight.dims();
125        let channels_in = group_channels_in * self.groups;
126        let ch_out = format!("{:?}", channels_out);
127        let ch_in = format!("{:?}", channels_in);
128        content
129            .add("ch_in", &ch_in)
130            .add("ch_out", &ch_out)
131            .add("stride", &stride)
132            .add("kernel_size", &kernel_size)
133            .add("dilation", &dilation)
134            .add("groups", &self.groups)
135            .add_debug_attribute("padding", &self.padding)
136            .optional()
137    }
138}
139
140impl<B: Backend> Conv2d<B> {
141    /// Applies the forward pass on the input tensor.
142    ///
143    /// See [conv2d](burn::tensor::module::conv2d) for more information.
144    ///
145    /// # Shapes
146    /// - `input`: `[batch_size, channels_in, height_in, width_in]`
147    /// - `output`: `[batch_size, channels_out, height_out, width_out]`
148    ///
149    /// # Example
150    /// ```rust,ignore
151    /// use burn::nn::conv::Conv2dConfig;
152    /// use burn::tensor::Tensor;
153    ///
154    /// // Assuming backend type alias `B`
155    /// let device = Default::default();
156    /// let conv = Conv2dConfig::new([3, 8], [3, 3]).init::<B>(&device);
157    ///
158    /// let x = Tensor::<B, 4>::zeros([1, 3, 28, 28], &device);
159    /// let y = conv.forward(x);
160    ///
161    /// println!("{:?}", y.dims()); // [1, 8, 26, 26]
162    /// ```
163    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
164        let [_batch_size, _channels_in, height_in, width_in] = input.dims();
165
166        // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
167        let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
168            height_in,
169            width_in,
170            &self.kernel_size,
171            &self.stride,
172        );
173
174        let options = PaddedConvOptions::asymmetric(
175            self.stride,
176            [top, left],
177            [bottom, right],
178            self.dilation,
179            self.groups,
180        );
181
182        conv2d(
183            input,
184            self.weight.val(),
185            self.bias.as_ref().map(|bias| bias.val()),
186            options,
187        )
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use burn::tensor::ops::FloatElem;
194    use burn::tensor::{ElementConversion, Tolerance};
195
196    use super::*;
197    use crate::TestBackend;
198    use burn::tensor::TensorData;
199    type FT = FloatElem<TestBackend>; // Float test
200
201    #[test]
202    fn initializer_default() {
203        let device = Default::default();
204        TestBackend::seed(&device, 0);
205
206        let config = Conv2dConfig::new([5, 1], [5, 5]);
207        let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
208        let k = (config.groups as f64 / k).sqrt().elem::<FT>();
209        let conv = config.init::<TestBackend>(&device);
210
211        conv.weight.to_data().assert_within_range(-k..k);
212    }
213
214    #[test]
215    fn initializer_zeros() {
216        let device = Default::default();
217        TestBackend::seed(&device, 0);
218
219        let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
220        let conv = config.init::<TestBackend>(&device);
221
222        assert_eq!(config.initializer, Initializer::Zeros);
223        conv.weight.to_data().assert_approx_eq::<FT>(
224            &TensorData::zeros::<FT, _>(conv.weight.shape()),
225            Tolerance::default(),
226        );
227    }
228
229    #[test]
230    fn initializer_fan_out() {
231        let device = Default::default();
232        TestBackend::seed(&device, 0);
233
234        let init = Initializer::KaimingUniform {
235            gain: 1.0 / 3.0f64.sqrt(),
236            fan_out_only: true, // test that fan_out is passed to `init_with()`
237        };
238
239        let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
240        let _ = config.init::<TestBackend>(&device);
241
242        assert_eq!(config.initializer, init);
243    }
244
245    #[test]
246    fn initializer_fan_with_groups_is_valid() {
247        let device = Default::default();
248        TestBackend::seed(&device, 0);
249
250        let init = Initializer::KaimingUniform {
251            gain: 1.0 / 3.0f64.sqrt(),
252            fan_out_only: true,
253        };
254
255        let config = Conv2dConfig::new([4, 4], [1, 1])
256            .with_initializer(init.clone())
257            .with_groups(4);
258        let _ = config.init::<TestBackend>(&device);
259
260        assert_eq!(config.initializer, init);
261    }
262
263    #[test]
264    #[should_panic = "Both channels must be divisible by the number of groups."]
265    fn channels_with_groups_is_invalid() {
266        let device = Default::default();
267        let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);
268        let _ = config.init::<TestBackend>(&device);
269    }
270
271    #[test]
272    fn same_with_even_kernel_uses_asymmetric_padding() {
273        let device = Default::default();
274        let config = Conv2dConfig::new([4, 4], [2, 2])
275            .with_padding(PaddingConfig2d::Same)
276            .with_initializer(Initializer::Constant { value: 1.0 })
277            .with_bias(false);
278        let conv = config.init::<TestBackend>(&device);
279
280        // Input: [batch=1, channels=4, height=5, width=5]
281        let input = Tensor::<TestBackend, 4>::ones([1, 4, 5, 5], &device);
282        let output = conv.forward(input);
283
284        // Same padding should preserve spatial dimensions
285        assert_eq!(output.dims(), [1, 4, 5, 5]);
286    }
287
288    #[test]
289    fn display() {
290        let config = Conv2dConfig::new([5, 1], [5, 5]);
291        let conv = config.init::<TestBackend>(&Default::default());
292
293        assert_eq!(
294            alloc::format!("{conv}"),
295            "Conv2d {ch_in: 5, ch_out: 1, stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
296        );
297    }
298
299    #[test]
300    #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
301    fn input_channels_mismatch() {
302        let config = Conv2dConfig::new([5, 3], [3, 3]);
303        let conv = config.init::<TestBackend>(&Default::default());
304
305        let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());
306        let _ = conv.forward(input);
307    }
308
309    #[test]
310    fn asymmetric_padding_forward() {
311        let device = Default::default();
312        // Create conv with asymmetric padding: top=1, left=2, bottom=3, right=4
313        let config = Conv2dConfig::new([2, 3], [3, 3])
314            .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4))
315            .with_initializer(Initializer::Constant { value: 1.0 })
316            .with_bias(false);
317        let conv = config.init::<TestBackend>(&device);
318
319        // Input: [batch=1, channels=2, height=4, width=5]
320        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
321        let output = conv.forward(input);
322
323        // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
324        // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
325        assert_eq!(output.dims(), [1, 3, 6, 9]);
326    }
327
328    #[test]
329    fn symmetric_explicit_padding_forward() {
330        let device = Default::default();
331        // Create conv with symmetric explicit padding: top=2, left=2, bottom=2, right=2
332        let config = Conv2dConfig::new([2, 3], [3, 3])
333            .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))
334            .with_initializer(Initializer::Constant { value: 1.0 })
335            .with_bias(false);
336        let conv = config.init::<TestBackend>(&device);
337
338        // Input: [batch=1, channels=2, height=4, width=5]
339        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
340        let output = conv.forward(input);
341
342        // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
343        // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
344        assert_eq!(output.dims(), [1, 3, 6, 7]);
345    }
346}