1use alloc::format;
2
3use burn_core as burn;
4
5use crate::{PaddingConfig1d, conv::checks};
6use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::PaddedConvOptions};
7use burn::{
8 config::Config,
9 module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param},
10};
11
12#[derive(Config, Debug)]
14pub struct Conv1dConfig {
15 pub channels_in: usize,
17 pub channels_out: usize,
19 pub kernel_size: usize,
21 #[config(default = "1")]
23 pub stride: usize,
24 #[config(default = "1")]
26 pub dilation: usize,
27 #[config(default = "1")]
29 pub groups: usize,
30 #[config(default = "PaddingConfig1d::Valid")]
35 pub padding: PaddingConfig1d,
36 #[config(default = true)]
38 pub bias: bool,
39 #[config(
41 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
42 )]
43 pub initializer: Initializer,
44}
45
46#[derive(Module, Debug)]
50#[module(custom_display)]
51pub struct Conv1d<B: Backend> {
52 pub weight: Param<Tensor<B, 3>>,
54 pub bias: Option<Param<Tensor<B, 1>>>,
56 pub stride: usize,
58 pub kernel_size: usize,
60 pub dilation: usize,
62 pub groups: usize,
64 pub padding: PaddingConfig1d,
66}
67
68impl<B: Backend> ModuleDisplay for Conv1d<B> {
69 fn custom_settings(&self) -> Option<DisplaySettings> {
70 DisplaySettings::new()
71 .with_new_line_after_attribute(false)
72 .optional()
73 }
74
75 fn custom_content(&self, content: Content) -> Option<Content> {
76 let stride = format!("{:?}", self.stride);
78 let kernel_size = format!("{:?}", self.kernel_size);
79 let dilation = format!("{:?}", self.dilation);
80
81 let [channels_out, group_channels_in, _] = self.weight.dims();
83 let channels_in = group_channels_in * self.groups;
84 let ch_out = format!("{:?}", channels_out);
85 let ch_in = format!("{:?}", channels_in);
86
87 content
88 .add("ch_in", &ch_in)
89 .add("ch_out", &ch_out)
90 .add("stride", &stride)
91 .add("kernel_size", &kernel_size)
92 .add("dilation", &dilation)
93 .add("groups", &self.groups)
94 .add_debug_attribute("padding", &self.padding)
95 .optional()
96 }
97}
98impl Conv1dConfig {
99 pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
101 checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
102
103 let shape = [
104 self.channels_out,
105 self.channels_in / self.groups,
106 self.kernel_size,
107 ];
108
109 let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
110 let weight = self
111 .initializer
112 .init_with(shape, Some(fan_in), None, device);
113 let mut bias = None;
114
115 if self.bias {
116 bias =
117 Some(
118 self.initializer
119 .init_with([self.channels_out], Some(fan_in), None, device),
120 );
121 }
122
123 Conv1d {
124 weight,
125 bias,
126 stride: self.stride,
127 kernel_size: self.kernel_size,
128 padding: self.padding.clone(),
129 dilation: self.dilation,
130 groups: self.groups,
131 }
132 }
133}
134
135impl<B: Backend> Conv1d<B> {
136 pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
145 let length = input.dims()[2];
146
147 let (left, right) =
149 self.padding
150 .calculate_padding_1d_pair(length, self.kernel_size, self.stride);
151
152 let options = PaddedConvOptions::asymmetric(
153 [self.stride],
154 [left],
155 [right],
156 [self.dilation],
157 self.groups,
158 );
159
160 conv1d(
161 input,
162 self.weight.val(),
163 self.bias.as_ref().map(|bias| bias.val()),
164 options,
165 )
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use burn::tensor::{ElementConversion, ops::FloatElem};
172 type FT = FloatElem<TestBackend>;
173
174 use super::*;
175 use crate::TestBackend;
176 use burn::tensor::TensorData;
177
178 #[test]
179 fn initializer_default() {
180 let device = Default::default();
181 TestBackend::seed(&device, 0);
182
183 let config = Conv1dConfig::new(5, 5, 5);
184 let k = (config.channels_in * config.kernel_size) as f64;
185 let k = (config.groups as f64 / k).sqrt().elem::<FT>();
186 let conv = config.init::<TestBackend>(&device);
187
188 conv.weight.to_data().assert_within_range(-k..k);
189 }
190
191 #[test]
192 fn initializer_zeros() {
193 let device = Default::default();
194 TestBackend::seed(&device, 0);
195
196 let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
197 let conv = config.init::<TestBackend>(&Default::default());
198
199 assert_eq!(config.initializer, Initializer::Zeros);
200 conv.weight
201 .to_data()
202 .assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
203 }
204
205 #[test]
206 fn same_with_even_kernel_uses_asymmetric_padding() {
207 let device = Default::default();
208 let config = Conv1dConfig::new(4, 4, 2)
209 .with_padding(PaddingConfig1d::Same)
210 .with_initializer(Initializer::Constant { value: 1.0 })
211 .with_bias(false);
212 let conv = config.init::<TestBackend>(&device);
213
214 let input = Tensor::<TestBackend, 3>::ones([1, 4, 5], &device);
216 let output = conv.forward(input);
217
218 assert_eq!(output.dims(), [1, 4, 5]);
220 }
221
222 #[test]
223 fn display() {
224 let config = Conv1dConfig::new(5, 5, 5);
225 let conv = config.init::<TestBackend>(&Default::default());
226
227 assert_eq!(
228 alloc::format!("{conv}"),
229 "Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
230 );
231 }
232
233 #[test]
234 #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
235 fn input_channels_mismatch() {
236 let config = Conv1dConfig::new(5, 3, 3);
237 let conv = config.init::<TestBackend>(&Default::default());
238
239 let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
240 let _ = conv.forward(input);
241 }
242
243 #[test]
244 fn asymmetric_padding_forward() {
245 let device = Default::default();
246 let config = Conv1dConfig::new(2, 3, 3)
248 .with_padding(PaddingConfig1d::Explicit(1, 2))
249 .with_initializer(Initializer::Constant { value: 1.0 })
250 .with_bias(false);
251 let conv = config.init::<TestBackend>(&device);
252
253 let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
255 let output = conv.forward(input);
256
257 assert_eq!(output.dims(), [1, 3, 5]);
260 }
261
262 #[test]
263 fn symmetric_explicit_padding_forward() {
264 let device = Default::default();
265 let config = Conv1dConfig::new(2, 3, 3)
267 .with_padding(PaddingConfig1d::Explicit(2, 2))
268 .with_initializer(Initializer::Constant { value: 1.0 })
269 .with_bias(false);
270 let conv = config.init::<TestBackend>(&device);
271
272 let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
274 let output = conv.forward(input);
275
276 assert_eq!(output.dims(), [1, 3, 6]);
279 }
280}