1use alloc::format;
2
3use crate as burn;
4
5use crate::{
6 config::Config,
7 module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param},
8 nn::{Initializer, PaddingConfig1d, conv::checks},
9 tensor::{Tensor, backend::Backend, module::conv1d, ops::ConvOptions},
10};
11
12#[derive(Config, Debug)]
14pub struct Conv1dConfig {
15 pub channels_in: usize,
17 pub channels_out: usize,
19 pub kernel_size: usize,
21 #[config(default = "1")]
23 pub stride: usize,
24 #[config(default = "1")]
26 pub dilation: usize,
27 #[config(default = "1")]
29 pub groups: usize,
30 #[config(default = "PaddingConfig1d::Valid")]
36 pub padding: PaddingConfig1d,
37 #[config(default = true)]
39 pub bias: bool,
40 #[config(
42 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
43 )]
44 pub initializer: Initializer,
45}
46
47#[derive(Module, Debug)]
51#[module(custom_display)]
52pub struct Conv1d<B: Backend> {
53 pub weight: Param<Tensor<B, 3>>,
55 pub bias: Option<Param<Tensor<B, 1>>>,
57 pub stride: usize,
59 pub kernel_size: usize,
61 pub dilation: usize,
63 pub groups: usize,
65 pub padding: Ignored<PaddingConfig1d>,
67}
68
69impl<B: Backend> ModuleDisplay for Conv1d<B> {
70 fn custom_settings(&self) -> Option<DisplaySettings> {
71 DisplaySettings::new()
72 .with_new_line_after_attribute(false)
73 .optional()
74 }
75
76 fn custom_content(&self, content: Content) -> Option<Content> {
77 let padding_formatted = format!("{}", &self.padding);
79
80 content
81 .add("stride", &self.stride)
82 .add("kernel_size", &self.kernel_size)
83 .add("dilation", &self.dilation)
84 .add("groups", &self.groups)
85 .add("padding", &padding_formatted)
86 .optional()
87 }
88}
89
90impl Conv1dConfig {
91 pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
93 checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
94 if self.padding == PaddingConfig1d::Same {
95 checks::check_same_padding_support(&[self.kernel_size]);
96 }
97
98 let shape = [
99 self.channels_out,
100 self.channels_in / self.groups,
101 self.kernel_size,
102 ];
103
104 let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
105 let weight = self
106 .initializer
107 .init_with(shape, Some(fan_in), None, device);
108 let mut bias = None;
109
110 if self.bias {
111 bias =
112 Some(
113 self.initializer
114 .init_with([self.channels_out], Some(fan_in), None, device),
115 );
116 }
117
118 Conv1d {
119 weight,
120 bias,
121 stride: self.stride,
122 kernel_size: self.kernel_size,
123 padding: Ignored(self.padding.clone()),
124 dilation: self.dilation,
125 groups: self.groups,
126 }
127 }
128}
129
130impl<B: Backend> Conv1d<B> {
131 pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
140 let length = input.dims()[2];
141 let padding = self
142 .padding
143 .calculate_padding_1d(length, self.kernel_size, self.stride);
144
145 conv1d(
146 input,
147 self.weight.val(),
148 self.bias.as_ref().map(|bias| bias.val()),
149 ConvOptions::new([self.stride], [padding], [self.dilation], self.groups),
150 )
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use burn_tensor::{ElementConversion, ops::FloatElem};
157 type FT = FloatElem<TestBackend>;
158
159 use super::*;
160 use crate::TestBackend;
161 use crate::tensor::TensorData;
162
163 #[test]
164 fn initializer_default() {
165 TestBackend::seed(0);
166
167 let config = Conv1dConfig::new(5, 5, 5);
168 let k = (config.channels_in * config.kernel_size) as f64;
169 let k = (config.groups as f64 / k).sqrt().elem::<FT>();
170 let conv = config.init::<TestBackend>(&Default::default());
171
172 conv.weight.to_data().assert_within_range(-k..k);
173 }
174
175 #[test]
176 fn initializer_zeros() {
177 TestBackend::seed(0);
178
179 let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
180 let conv = config.init::<TestBackend>(&Default::default());
181
182 assert_eq!(config.initializer, Initializer::Zeros);
183 conv.weight
184 .to_data()
185 .assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
186 }
187
188 #[test]
189 #[should_panic = "Same padding with an even kernel size is not supported"]
190 fn same_with_even_kernel_is_invalid() {
191 let device = Default::default();
192 let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same);
193 let _ = config.init::<TestBackend>(&device);
194 }
195
196 #[test]
197 fn display() {
198 let config = Conv1dConfig::new(5, 5, 5);
199 let conv = config.init::<TestBackend>(&Default::default());
200
201 assert_eq!(
202 alloc::format!("{conv}"),
203 "Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
204 );
205 }
206
207 #[test]
208 #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
209 fn input_channels_mismatch() {
210 let config = Conv1dConfig::new(5, 3, 3);
211 let conv = config.init::<TestBackend>(&Default::default());
212
213 let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
214 let _ = conv.forward(input);
215 }
216}