1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{kaiming_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct Conv1d {
29 pub weight: Parameter,
31 pub bias: Option<Parameter>,
33 in_channels: usize,
35 out_channels: usize,
37 kernel_size: usize,
39 stride: usize,
41 padding: usize,
43}
44
45impl Conv1d {
46 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
48 Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
49 }
50
51 pub fn with_options(
53 in_channels: usize,
54 out_channels: usize,
55 kernel_size: usize,
56 stride: usize,
57 padding: usize,
58 bias: bool,
59 ) -> Self {
60 let fan_in = in_channels * kernel_size;
62 let weight_data = kaiming_uniform(out_channels, fan_in);
63 let weight_reshaped = weight_data
64 .reshape(&[
65 out_channels as isize,
66 in_channels as isize,
67 kernel_size as isize,
68 ])
69 .unwrap();
70 let weight = Parameter::named("weight", weight_reshaped, true);
71
72 let bias_param = if bias {
73 Some(Parameter::named("bias", zeros(&[out_channels]), true))
74 } else {
75 None
76 };
77
78 Self {
79 weight,
80 bias: bias_param,
81 in_channels,
82 out_channels,
83 kernel_size,
84 stride,
85 padding,
86 }
87 }
88}
89
90impl Module for Conv1d {
91 fn forward(&self, input: &Variable) -> Variable {
92 let input_shape = input.shape();
95 let batch_size = input_shape[0];
96 let _in_channels = input_shape[1];
97 let in_length = input_shape[2];
98
99 let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
101
102 let input_data = input.data();
105 let weight_data = self.weight.data();
106 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
107
108 for b in 0..batch_size {
109 for oc in 0..self.out_channels {
110 for ol in 0..out_length {
111 let mut sum = 0.0f32;
112 let in_start = ol * self.stride;
113
114 for ic in 0..self.in_channels {
115 for k in 0..self.kernel_size {
116 let in_idx = in_start + k;
117 if in_idx < self.padding || in_idx >= in_length + self.padding {
118 continue;
119 }
120 let actual_idx = in_idx - self.padding;
121
122 let input_idx =
123 b * self.in_channels * in_length + ic * in_length + actual_idx;
124 let weight_idx = oc * self.in_channels * self.kernel_size
125 + ic * self.kernel_size
126 + k;
127
128 sum +=
129 input_data.to_vec()[input_idx] * weight_data.to_vec()[weight_idx];
130 }
131 }
132
133 if let Some(ref bias) = self.bias {
135 sum += bias.data().to_vec()[oc];
136 }
137
138 let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
139 output_data[output_idx] = sum;
140 }
141 }
142 }
143
144 let output_tensor =
145 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length]).unwrap();
146
147 Variable::new(output_tensor, input.requires_grad())
148 }
149
150 fn parameters(&self) -> Vec<Parameter> {
151 let mut params = vec![self.weight.clone()];
152 if let Some(ref bias) = self.bias {
153 params.push(bias.clone());
154 }
155 params
156 }
157
158 fn named_parameters(&self) -> HashMap<String, Parameter> {
159 let mut params = HashMap::new();
160 params.insert("weight".to_string(), self.weight.clone());
161 if let Some(ref bias) = self.bias {
162 params.insert("bias".to_string(), bias.clone());
163 }
164 params
165 }
166
167 fn name(&self) -> &'static str {
168 "Conv1d"
169 }
170}
171
172pub struct Conv2d {
184 pub weight: Parameter,
186 pub bias: Option<Parameter>,
188 in_channels: usize,
190 out_channels: usize,
192 kernel_size: (usize, usize),
194 stride: (usize, usize),
196 padding: (usize, usize),
198}
199
200impl Conv2d {
201 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
203 Self::with_options(
204 in_channels,
205 out_channels,
206 (kernel_size, kernel_size),
207 (1, 1),
208 (0, 0),
209 true,
210 )
211 }
212
213 pub fn with_options(
215 in_channels: usize,
216 out_channels: usize,
217 kernel_size: (usize, usize),
218 stride: (usize, usize),
219 padding: (usize, usize),
220 bias: bool,
221 ) -> Self {
222 let (kh, kw) = kernel_size;
223 let fan_in = in_channels * kh * kw;
224
225 let weight_data = kaiming_uniform(out_channels, fan_in);
227 let weight_reshaped = weight_data
228 .reshape(&[
229 out_channels as isize,
230 in_channels as isize,
231 kh as isize,
232 kw as isize,
233 ])
234 .unwrap();
235 let weight = Parameter::named("weight", weight_reshaped, true);
236
237 let bias_param = if bias {
238 Some(Parameter::named("bias", zeros(&[out_channels]), true))
239 } else {
240 None
241 };
242
243 Self {
244 weight,
245 bias: bias_param,
246 in_channels,
247 out_channels,
248 kernel_size,
249 stride,
250 padding,
251 }
252 }
253}
254
255impl Module for Conv2d {
256 fn forward(&self, input: &Variable) -> Variable {
257 let input_shape = input.shape();
258 let batch_size = input_shape[0];
259 let in_height = input_shape[2];
260 let in_width = input_shape[3];
261
262 let (kh, kw) = self.kernel_size;
263 let (sh, sw) = self.stride;
264 let (ph, pw) = self.padding;
265
266 let out_height = (in_height + 2 * ph - kh) / sh + 1;
267 let out_width = (in_width + 2 * pw - kw) / sw + 1;
268
269 let input_data = input.data();
270 let weight_data = self.weight.data();
271 let input_vec = input_data.to_vec();
272 let weight_vec = weight_data.to_vec();
273
274 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_height * out_width];
275
276 for b in 0..batch_size {
277 for oc in 0..self.out_channels {
278 for oh in 0..out_height {
279 for ow in 0..out_width {
280 let mut sum = 0.0f32;
281
282 for ic in 0..self.in_channels {
283 for ki in 0..kh {
284 for kj in 0..kw {
285 let ih = oh * sh + ki;
286 let iw = ow * sw + kj;
287
288 if ih < ph
290 || ih >= in_height + ph
291 || iw < pw
292 || iw >= in_width + pw
293 {
294 continue;
295 }
296
297 let actual_ih = ih - ph;
298 let actual_iw = iw - pw;
299
300 let input_idx = b * self.in_channels * in_height * in_width
301 + ic * in_height * in_width
302 + actual_ih * in_width
303 + actual_iw;
304
305 let weight_idx = oc * self.in_channels * kh * kw
306 + ic * kh * kw
307 + ki * kw
308 + kj;
309
310 sum += input_vec[input_idx] * weight_vec[weight_idx];
311 }
312 }
313 }
314
315 if let Some(ref bias) = self.bias {
317 sum += bias.data().to_vec()[oc];
318 }
319
320 let output_idx = b * self.out_channels * out_height * out_width
321 + oc * out_height * out_width
322 + oh * out_width
323 + ow;
324 output_data[output_idx] = sum;
325 }
326 }
327 }
328 }
329
330 let output_tensor = Tensor::from_vec(
331 output_data,
332 &[batch_size, self.out_channels, out_height, out_width],
333 )
334 .unwrap();
335
336 Variable::new(output_tensor, input.requires_grad())
337 }
338
339 fn parameters(&self) -> Vec<Parameter> {
340 let mut params = vec![self.weight.clone()];
341 if let Some(ref bias) = self.bias {
342 params.push(bias.clone());
343 }
344 params
345 }
346
347 fn named_parameters(&self) -> HashMap<String, Parameter> {
348 let mut params = HashMap::new();
349 params.insert("weight".to_string(), self.weight.clone());
350 if let Some(ref bias) = self.bias {
351 params.insert("bias".to_string(), bias.clone());
352 }
353 params
354 }
355
356 fn name(&self) -> &'static str {
357 "Conv2d"
358 }
359}
360
361#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_conv1d_creation() {
371 let conv = Conv1d::new(3, 16, 3);
372 assert_eq!(conv.in_channels, 3);
373 assert_eq!(conv.out_channels, 16);
374 assert_eq!(conv.kernel_size, 3);
375 }
376
377 #[test]
378 fn test_conv1d_forward() {
379 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
380 let input = Variable::new(
381 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).unwrap(),
382 false,
383 );
384 let output = conv.forward(&input);
385 assert_eq!(output.shape(), vec![1, 1, 5]);
386 }
387
388 #[test]
389 fn test_conv2d_creation() {
390 let conv = Conv2d::new(3, 64, 3);
391 assert_eq!(conv.in_channels, 3);
392 assert_eq!(conv.out_channels, 64);
393 assert_eq!(conv.kernel_size, (3, 3));
394 }
395
396 #[test]
397 fn test_conv2d_forward() {
398 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
399 let input = Variable::new(
400 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).unwrap(),
401 false,
402 );
403 let output = conv.forward(&input);
404 assert_eq!(output.shape(), vec![1, 1, 5, 5]);
405 }
406
407 #[test]
408 fn test_conv2d_parameters() {
409 let conv = Conv2d::new(3, 64, 3);
410 let params = conv.parameters();
411 assert_eq!(params.len(), 2); }
413}