1use alloc::format;
2
3use burn_core as burn;
4
5use crate::{PaddingConfig1d, conv::checks};
6use burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::ConvOptions};
7use burn::{
8 config::Config,
9 module::{Content, DisplaySettings, Ignored, Initializer, Module, ModuleDisplay, Param},
10};
11
12#[derive(Config, Debug)]
14pub struct Conv1dConfig {
15 pub channels_in: usize,
17 pub channels_out: usize,
19 pub kernel_size: usize,
21 #[config(default = "1")]
23 pub stride: usize,
24 #[config(default = "1")]
26 pub dilation: usize,
27 #[config(default = "1")]
29 pub groups: usize,
30 #[config(default = "PaddingConfig1d::Valid")]
36 pub padding: PaddingConfig1d,
37 #[config(default = true)]
39 pub bias: bool,
40 #[config(
42 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
43 )]
44 pub initializer: Initializer,
45}
46
47#[derive(Module, Debug)]
51#[module(custom_display)]
52pub struct Conv1d<B: Backend> {
53 pub weight: Param<Tensor<B, 3>>,
55 pub bias: Option<Param<Tensor<B, 1>>>,
57 pub stride: usize,
59 pub kernel_size: usize,
61 pub dilation: usize,
63 pub groups: usize,
65 pub padding: Ignored<PaddingConfig1d>,
67}
68
69impl<B: Backend> ModuleDisplay for Conv1d<B> {
70 fn custom_settings(&self) -> Option<DisplaySettings> {
71 DisplaySettings::new()
72 .with_new_line_after_attribute(false)
73 .optional()
74 }
75
76 fn custom_content(&self, content: Content) -> Option<Content> {
77 let padding_formatted = format!("{}", &self.padding);
79
80 let stride = format!("{:?}", self.stride);
82 let kernel_size = format!("{:?}", self.kernel_size);
83 let dilation = format!("{:?}", self.dilation);
84
85 let [channels_out, group_channels_in, _] = self.weight.dims();
87 let channels_in = group_channels_in * self.groups;
88 let ch_out = format!("{:?}", channels_out);
89 let ch_in = format!("{:?}", channels_in);
90
91 content
92 .add("ch_in", &ch_in)
93 .add("ch_out", &ch_out)
94 .add("stride", &stride)
95 .add("kernel_size", &kernel_size)
96 .add("dilation", &dilation)
97 .add("groups", &self.groups)
98 .add("padding", &padding_formatted)
99 .optional()
100 }
101}
102impl Conv1dConfig {
103 pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {
105 checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);
106 if self.padding == PaddingConfig1d::Same {
107 checks::check_same_padding_support(&[self.kernel_size]);
108 }
109
110 let shape = [
111 self.channels_out,
112 self.channels_in / self.groups,
113 self.kernel_size,
114 ];
115
116 let fan_in: usize = self.channels_in / self.groups * self.kernel_size;
117 let weight = self
118 .initializer
119 .init_with(shape, Some(fan_in), None, device);
120 let mut bias = None;
121
122 if self.bias {
123 bias =
124 Some(
125 self.initializer
126 .init_with([self.channels_out], Some(fan_in), None, device),
127 );
128 }
129
130 Conv1d {
131 weight,
132 bias,
133 stride: self.stride,
134 kernel_size: self.kernel_size,
135 padding: Ignored(self.padding.clone()),
136 dilation: self.dilation,
137 groups: self.groups,
138 }
139 }
140}
141
142impl<B: Backend> Conv1d<B> {
143 pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
152 let length = input.dims()[2];
153 let padding = self
154 .padding
155 .calculate_padding_1d(length, self.kernel_size, self.stride);
156
157 conv1d(
158 input,
159 self.weight.val(),
160 self.bias.as_ref().map(|bias| bias.val()),
161 ConvOptions::new([self.stride], [padding], [self.dilation], self.groups),
162 )
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use burn::tensor::{ElementConversion, ops::FloatElem};
169 type FT = FloatElem<TestBackend>;
170
171 use super::*;
172 use crate::TestBackend;
173 use burn::tensor::TensorData;
174
175 #[test]
176 fn initializer_default() {
177 let device = Default::default();
178 TestBackend::seed(&device, 0);
179
180 let config = Conv1dConfig::new(5, 5, 5);
181 let k = (config.channels_in * config.kernel_size) as f64;
182 let k = (config.groups as f64 / k).sqrt().elem::<FT>();
183 let conv = config.init::<TestBackend>(&device);
184
185 conv.weight.to_data().assert_within_range(-k..k);
186 }
187
188 #[test]
189 fn initializer_zeros() {
190 let device = Default::default();
191 TestBackend::seed(&device, 0);
192
193 let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);
194 let conv = config.init::<TestBackend>(&Default::default());
195
196 assert_eq!(config.initializer, Initializer::Zeros);
197 conv.weight
198 .to_data()
199 .assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);
200 }
201
202 #[test]
203 #[should_panic = "Same padding with an even kernel size is not supported"]
204 fn same_with_even_kernel_is_invalid() {
205 let device = Default::default();
206 let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same);
207 let _ = config.init::<TestBackend>(&device);
208 }
209
210 #[test]
211 fn display() {
212 let config = Conv1dConfig::new(5, 5, 5);
213 let conv = config.init::<TestBackend>(&Default::default());
214
215 assert_eq!(
216 alloc::format!("{conv}"),
217 "Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
218 );
219 }
220
221 #[test]
222 #[should_panic = "Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5"]
223 fn input_channels_mismatch() {
224 let config = Conv1dConfig::new(5, 3, 3);
225 let conv = config.init::<TestBackend>(&Default::default());
226
227 let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());
228 let _ = conv.forward(input);
229 }
230}