burn_core/nn/conv/
conv1d.rs

1use alloc::format;
2
3use crate as burn;
4
5use crate::{
6    config::Config,
7    module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param},
8    nn::{Initializer, PaddingConfig1d, conv::checks},
9    tensor::{Tensor, backend::Backend, module::conv1d, ops::ConvOptions},
10};
11
12/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).
13#[derive(Config, Debug)]
14pub struct Conv1dConfig {
15    /// The number of input channels.
16    pub channels_in: usize,
17    /// The number of output channels.
18    pub channels_out: usize,
19    /// The size of the kernel.
20    pub kernel_size: usize,
21    /// The stride of the convolution.
22    #[config(default = "1")]
23    pub stride: usize,
24    /// Spacing between kernel elements.
25    #[config(default = "1")]
26    pub dilation: usize,
27    /// Controls the connections between input and output channels.
28    #[config(default = "1")]
29    pub groups: usize,
30    /// The padding configuration.
31    ///
32    /// ### Warning
33    /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
34    /// size is not supported as it will not produce the same output size.
35    #[config(default = "PaddingConfig1d::Valid")]
36    pub padding: PaddingConfig1d,
37    /// If bias should be added to the output.
38    #[config(default = true)]
39    pub bias: bool,
40    /// The type of function used to initialize neural network parameters
41    #[config(
42        default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
43    )]
44    pub initializer: Initializer,
45}
46
47/// Applies a 1D convolution over input tensors.
48///
49/// Should be created with [Conv1dConfig].
50#[derive(Module, Debug)]
51#[module(custom_display)]
52pub struct Conv1d<B: Backend> {
53    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`
54    pub weight: Param<Tensor<B, 3>>,
55    /// Tensor of shape `[channels_out]`
56    pub bias: Option<Param<Tensor<B, 1>>>,
57    /// Stride of the convolution.
58    pub stride: usize,
59    /// Size of the kernel.
60    pub kernel_size: usize,
61    /// Spacing between kernel elements.
62    pub dilation: usize,
63    /// Controls the connections between input and output channels.
64    pub groups: usize,
65    /// Padding configuration.
66    pub padding: Ignored<PaddingConfig1d>,
67}
68
69impl<B: Backend> ModuleDisplay for Conv1d<B> {
70    fn custom_settings(&self) -> Option<DisplaySettings> {
71        DisplaySettings::new()
72            .with_new_line_after_attribute(false)
73            .optional()
74    }
75
76    fn custom_content(&self, content: Content) -> Option<Content> {
77        // Since padding does not implement ModuleDisplay, we need to format it manually.
78        let padding_formatted = format!("{}", &self.padding);
79
80        content
81            .add("stride", &self.stride)
82            .add("kernel_size", &self.kernel_size)
83            .add("dilation", &self.dilation)
84            .add("groups", &self.groups)
85            .add("padding", &padding_formatted)
86            .optional()
87    }
88}
89
90impl Conv1dConfig {
91    /// Initialize a new [conv1d](Conv1d) module.
92    pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
93        checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
94        if self.padding == PaddingConfig1d::Same {
95            checks::check_same_padding_support(&[self.kernel_size]);
96        }
97
98        let shape = [
99            self.channels_out,
100            self.channels_in / self.groups,
101            self.kernel_size,
102        ];
103
104        let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
105        let weight = self
106            .initializer
107            .init_with(shape, Some(fan_in), None, device);
108        let mut bias = None;
109
110        if self.bias {
111            bias =
112                Some(
113                    self.initializer
114                        .init_with([self.channels_out], Some(fan_in), None, device),
115                );
116        }
117
118        Conv1d {
119            weight,
120            bias,
121            stride: self.stride,
122            kernel_size: self.kernel_size,
123            padding: Ignored(self.padding.clone()),
124            dilation: self.dilation,
125            groups: self.groups,
126        }
127    }
128}
129
130impl<B: Backend> Conv1d<B> {
131    /// Applies the forward pass on the input tensor.
132    ///
133    /// See [conv1d](crate::tensor::module::conv1d) for more information.
134    ///
135    /// # Shapes
136    ///
137    /// - input: `[batch_size, channels_in, length_in]`
138    /// - output: `[batch_size, channels_out, length_out]`
139    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
140        let length = input.dims()[2];
141        let padding = self
142            .padding
143            .calculate_padding_1d(length, self.kernel_size, self.stride);
144
145        conv1d(
146            input,
147            self.weight.val(),
148            self.bias.as_ref().map(|bias| bias.val()),
149            ConvOptions::new([self.stride], [padding], [self.dilation], self.groups),
150        )
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use burn_tensor::{ElementConversion, ops::FloatElem};
157    type FT = FloatElem<TestBackend>;
158
159    use super::*;
160    use crate::TestBackend;
161    use crate::tensor::TensorData;
162
163    #[test]
164    fn initializer_default() {
165        TestBackend::seed(0);
166
167        let config = Conv1dConfig::new(5, 5, 5);
168        let k = (config.channels_in * config.kernel_size) as f64;
169        let k = (config.groups as f64 / k).sqrt().elem::<FT>();
170        let conv = config.init::<TestBackend>(&Default::default());
171
172        conv.weight.to_data().assert_within_range(-k..k);
173    }
174
175    #[test]
176    fn initializer_zeros() {
177        TestBackend::seed(0);
178
179        let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
180        let conv = config.init::<TestBackend>(&Default::default());
181
182        assert_eq!(config.initializer, Initializer::Zeros);
183        conv.weight
184            .to_data()
185            .assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
186    }
187
188    #[test]
189    #[should_panic = "Same padding with an even kernel size is not supported"]
190    fn same_with_even_kernel_is_invalid() {
191        let device = Default::default();
192        let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same);
193        let _ = config.init::<TestBackend>(&device);
194    }
195
196    #[test]
197    fn display() {
198        let config = Conv1dConfig::new(5, 5, 5);
199        let conv = config.init::<TestBackend>(&Default::default());
200
201        assert_eq!(
202            alloc::format!("{conv}"),
203            "Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
204        );
205    }
206
207    #[test]
208    #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
209    fn input_channels_mismatch() {
210        let config = Conv1dConfig::new(5, 3, 3);
211        let conv = config.init::<TestBackend>(&Default::default());
212
213        let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
214        let _ = conv.forward(input);
215    }
216}