burn_nn/modules/conv/
conv1d.rs

1use alloc::format;
2
3use burn_core as burn;
4
5use crate::{PaddingConfig1d, conv::checks};
6use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::ConvOptions};
7use burn::{
8    config::Config,
9    module::{Content, DisplaySettings, Ignored, Initializer, Module, ModuleDisplay, Param},
10};
11
12/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
13#[derive(Config, Debug)]
14pub struct Conv1dConfig {
15    /// The number of input channels.
16    pub channels_in: usize,
17    /// The number of output channels.
18    pub channels_out: usize,
19    /// The size of the kernel.
20    pub kernel_size: usize,
21    /// The stride of the convolution.
22    #[config(default = "1")]
23    pub stride: usize,
24    /// Spacing between kernel elements.
25    #[config(default = "1")]
26    pub dilation: usize,
27    /// Controls the connections between input and output channels.
28    #[config(default = "1")]
29    pub groups: usize,
30    /// The padding configuration.
31    ///
32    /// ### Warning
33    /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
34    /// size is not supported as it will not produce the same output size.
35    #[config(default = "PaddingConfig1d::Valid")]
36    pub padding: PaddingConfig1d,
37    /// If bias should be added to the output.
38    #[config(default = true)]
39    pub bias: bool,
40    /// The type of function used to initialize neural network parameters
41    #[config(
42        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
43    )]
44    pub initializer: Initializer,
45}
46
47/// Applies a 1D convolution over input tensors.
48///
49/// Should be created with [Conv1dConfig].
50#[derive(Module, Debug)]
51#[module(custom_display)]
52pub struct Conv1d<B: Backend> {
53    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
54    pub weight: Param<Tensor<B, 3>>,
55    /// Tensor of shape `[channels_out]`
56    pub bias: Option<Param<Tensor<B, 1>>>,
57    /// Stride of the convolution.
58    pub stride: usize,
59    /// Size of the kernel.
60    pub kernel_size: usize,
61    /// Spacing between kernel elements.
62    pub dilation: usize,
63    /// Controls the connections between input and output channels.
64    pub groups: usize,
65    /// Padding configuration.
66    pub padding: Ignored<PaddingConfig1d>,
67}
68
69impl<B: Backend> ModuleDisplay for Conv1d<B> {
70    fn custom_settings(&self) -> Option<DisplaySettings> {
71        DisplaySettings::new()
72            .with_new_line_after_attribute(false)
73            .optional()
74    }
75
76    fn custom_content(&self, content: Content) -> Option<Content> {
77        // Format padding
78        let padding_formatted = format!("{}", &self.padding);
79
80        // Format stride/dilation as strings
81        let stride = format!("{:?}", self.stride);
82        let kernel_size = format!("{:?}", self.kernel_size);
83        let dilation = format!("{:?}", self.dilation);
84
85        // Extract channels in/out from weight dims
86        let [channels_out, group_channels_in, _] = self.weight.dims();
87        let channels_in = group_channels_in * self.groups;
88        let ch_out = format!("{:?}", channels_out);
89        let ch_in = format!("{:?}", channels_in);
90
91        content
92            .add("ch_in", &ch_in)
93            .add("ch_out", &ch_out)
94            .add("stride", &stride)
95            .add("kernel_size", &kernel_size)
96            .add("dilation", &dilation)
97            .add("groups", &self.groups)
98            .add("padding", &padding_formatted)
99            .optional()
100    }
101}
102impl Conv1dConfig {
103    /// Initialize a new [conv1d](Conv1d) module.
104    pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
105        checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
106        if self.padding == PaddingConfig1d::Same {
107            checks::check_same_padding_support(&[self.kernel_size]);
108        }
109
110        let shape = [
111            self.channels_out,
112            self.channels_in / self.groups,
113            self.kernel_size,
114        ];
115
116        let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
117        let weight = self
118            .initializer
119            .init_with(shape, Some(fan_in), None, device);
120        let mut bias = None;
121
122        if self.bias {
123            bias =
124                Some(
125                    self.initializer
126                        .init_with([self.channels_out], Some(fan_in), None, device),
127                );
128        }
129
130        Conv1d {
131            weight,
132            bias,
133            stride: self.stride,
134            kernel_size: self.kernel_size,
135            padding: Ignored(self.padding.clone()),
136            dilation: self.dilation,
137            groups: self.groups,
138        }
139    }
140}
141
142impl<B: Backend> Conv1d<B> {
143    /// Applies the forward pass on the input tensor.
144    ///
145    /// See [conv1d](burn::tensor::module::conv1d) for more information.
146    ///
147    /// # Shapes
148    ///
149    /// - input: `[batch_size, channels_in, length_in]`
150    /// - output: `[batch_size, channels_out, length_out]`
151    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
152        let length = input.dims()[2];
153        let padding = self
154            .padding
155            .calculate_padding_1d(length, self.kernel_size, self.stride);
156
157        conv1d(
158            input,
159            self.weight.val(),
160            self.bias.as_ref().map(|bias| bias.val()),
161            ConvOptions::new([self.stride], [padding], [self.dilation], self.groups),
162        )
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use burn::tensor::{ElementConversion, ops::FloatElem};
169    type FT = FloatElem<TestBackend>;
170
171    use super::*;
172    use crate::TestBackend;
173    use burn::tensor::TensorData;
174
175    #[test]
176    fn initializer_default() {
177        let device = Default::default();
178        TestBackend::seed(&device, 0);
179
180        let config = Conv1dConfig::new(5, 5, 5);
181        let k = (config.channels_in * config.kernel_size) as f64;
182        let k = (config.groups as f64 / k).sqrt().elem::<FT>();
183        let conv = config.init::<TestBackend>(&device);
184
185        conv.weight.to_data().assert_within_range(-k..k);
186    }
187
188    #[test]
189    fn initializer_zeros() {
190        let device = Default::default();
191        TestBackend::seed(&device, 0);
192
193        let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
194        let conv = config.init::<TestBackend>(&Default::default());
195
196        assert_eq!(config.initializer, Initializer::Zeros);
197        conv.weight
198            .to_data()
199            .assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
200    }
201
202    #[test]
203    #[should_panic = "Same padding with an even kernel size is not supported"]
204    fn same_with_even_kernel_is_invalid() {
205        let device = Default::default();
206        let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same);
207        let _ = config.init::<TestBackend>(&device);
208    }
209
210    #[test]
211    fn display() {
212        let config = Conv1dConfig::new(5, 5, 5);
213        let conv = config.init::<TestBackend>(&Default::default());
214
215        assert_eq!(
216            alloc::format!("{conv}"),
217            "Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
218        );
219    }
220
221    #[test]
222    #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
223    fn input_channels_mismatch() {
224        let config = Conv1dConfig::new(5, 3, 3);
225        let conv = config.init::<TestBackend>(&Default::default());
226
227        let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
228        let _ = conv.forward(input);
229    }
230}