1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_autograd::functions::{
21 Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
22};
23use axonml_autograd::grad_fn::GradFn;
24use axonml_autograd::no_grad::is_grad_enabled;
25use axonml_tensor::Tensor;
26use rayon::prelude::*;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32pub struct Conv1d {
44 pub weight: Parameter,
46 pub bias: Option<Parameter>,
48 in_channels: usize,
50 out_channels: usize,
52 kernel_size: usize,
54 stride: usize,
56 padding: usize,
58}
59
60impl Conv1d {
61 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
63 Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
64 }
65
66 pub fn with_options(
68 in_channels: usize,
69 out_channels: usize,
70 kernel_size: usize,
71 stride: usize,
72 padding: usize,
73 bias: bool,
74 ) -> Self {
75 let fan_in = in_channels * kernel_size;
77 let weight_data = kaiming_uniform(out_channels, fan_in);
78 let weight_reshaped = weight_data
79 .reshape(&[
80 out_channels as isize,
81 in_channels as isize,
82 kernel_size as isize,
83 ])
84 .unwrap();
85 let weight = Parameter::named("weight", weight_reshaped, true);
86
87 let bias_param = if bias {
88 Some(Parameter::named("bias", zeros(&[out_channels]), true))
89 } else {
90 None
91 };
92
93 Self {
94 weight,
95 bias: bias_param,
96 in_channels,
97 out_channels,
98 kernel_size,
99 stride,
100 padding,
101 }
102 }
103}
104
105impl Module for Conv1d {
106 fn forward(&self, input: &Variable) -> Variable {
107 let input_shape = input.shape();
108 let batch_size = input_shape[0];
109 let in_length = input_shape[2];
110
111 let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113 let input_data = input.data();
114 let weight_data = self.weight.data();
115
116 #[cfg(feature = "cuda")]
119 if input_data.device().is_gpu() {
120 let input_dev = input_data.device();
122 if !weight_data.device().is_gpu() {
123 self.weight.to_device(input_dev);
124 if let Some(ref b) = self.bias {
125 b.to_device(input_dev);
126 }
127 }
128 let weight_data = self.weight.data();
129
130 let input_4d = input_data
132 .reshape(&[
133 batch_size as isize,
134 self.in_channels as isize,
135 in_length as isize,
136 1,
137 ])
138 .unwrap();
139
140 let weight_4d = weight_data
142 .reshape(&[
143 self.out_channels as isize,
144 self.in_channels as isize,
145 self.kernel_size as isize,
146 1,
147 ])
148 .unwrap();
149
150 let bias_tensor = self.bias.as_ref().map(|b| b.data());
151 let gpu_output = input_4d.conv2d_cuda(
152 &weight_4d,
153 bias_tensor.as_ref(),
154 (self.stride, 1),
155 (self.padding, 0),
156 );
157
158 if let Some(output_4d) = gpu_output {
159 let output_tensor = output_4d
161 .reshape(&[
162 batch_size as isize,
163 self.out_channels as isize,
164 out_length as isize,
165 ])
166 .unwrap();
167
168 let requires_grad =
169 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
170 if requires_grad {
171 let weight_var = self.weight.variable();
172 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
173
174 let grad_fn = GradFn::new(Conv1dBackward::new(
175 input.grad_fn().cloned(),
176 weight_var.grad_fn().cloned(),
177 bias_grad_fn,
178 input_data,
179 weight_data,
180 input_shape,
181 self.in_channels,
182 self.out_channels,
183 self.kernel_size,
184 self.stride,
185 self.padding,
186 self.bias.is_some(),
187 ));
188 return Variable::from_operation(output_tensor, grad_fn, true);
189 } else {
190 return Variable::new(output_tensor, false);
191 }
192 }
193 }
195
196 let input_vec = input_data.to_vec();
197 let weight_vec = weight_data.to_vec();
198
199 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
200
201 for b in 0..batch_size {
202 for oc in 0..self.out_channels {
203 for ol in 0..out_length {
204 let mut sum = 0.0f32;
205 let in_start = ol * self.stride;
206
207 for ic in 0..self.in_channels {
208 for k in 0..self.kernel_size {
209 let in_idx = in_start + k;
210 if in_idx < self.padding || in_idx >= in_length + self.padding {
211 continue;
212 }
213 let actual_idx = in_idx - self.padding;
214
215 let input_idx =
216 b * self.in_channels * in_length + ic * in_length + actual_idx;
217 let weight_idx = oc * self.in_channels * self.kernel_size
218 + ic * self.kernel_size
219 + k;
220
221 sum += input_vec[input_idx] * weight_vec[weight_idx];
222 }
223 }
224
225 if let Some(ref bias) = self.bias {
226 sum += bias.data().to_vec()[oc];
227 }
228
229 let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
230 output_data[output_idx] = sum;
231 }
232 }
233 }
234
235 let output_tensor =
236 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length]).expect("tensor creation failed");
237
238 let requires_grad =
239 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
240
241 if requires_grad {
242 let weight_var = self.weight.variable();
243 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
244
245 let grad_fn = GradFn::new(Conv1dBackward::new(
246 input.grad_fn().cloned(),
247 weight_var.grad_fn().cloned(),
248 bias_grad_fn,
249 input_data,
250 weight_data,
251 input_shape,
252 self.in_channels,
253 self.out_channels,
254 self.kernel_size,
255 self.stride,
256 self.padding,
257 self.bias.is_some(),
258 ));
259 Variable::from_operation(output_tensor, grad_fn, true)
260 } else {
261 Variable::new(output_tensor, false)
262 }
263 }
264
265 fn parameters(&self) -> Vec<Parameter> {
266 let mut params = vec![self.weight.clone()];
267 if let Some(ref bias) = self.bias {
268 params.push(bias.clone());
269 }
270 params
271 }
272
273 fn named_parameters(&self) -> HashMap<String, Parameter> {
274 let mut params = HashMap::new();
275 params.insert("weight".to_string(), self.weight.clone());
276 if let Some(ref bias) = self.bias {
277 params.insert("bias".to_string(), bias.clone());
278 }
279 params
280 }
281
282 fn name(&self) -> &'static str {
283 "Conv1d"
284 }
285}
286
287pub struct Conv2d {
299 pub weight: Parameter,
301 pub bias: Option<Parameter>,
303 in_channels: usize,
305 out_channels: usize,
307 kernel_size: (usize, usize),
309 stride: (usize, usize),
311 padding: (usize, usize),
313 groups: usize,
315}
316
317impl Conv2d {
318 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
320 Self::with_options(
321 in_channels,
322 out_channels,
323 (kernel_size, kernel_size),
324 (1, 1),
325 (0, 0),
326 true,
327 )
328 }
329
330 pub fn with_options(
332 in_channels: usize,
333 out_channels: usize,
334 kernel_size: (usize, usize),
335 stride: (usize, usize),
336 padding: (usize, usize),
337 bias: bool,
338 ) -> Self {
339 Self::with_groups(
340 in_channels,
341 out_channels,
342 kernel_size,
343 stride,
344 padding,
345 bias,
346 1,
347 )
348 }
349
350 pub fn with_groups(
355 in_channels: usize,
356 out_channels: usize,
357 kernel_size: (usize, usize),
358 stride: (usize, usize),
359 padding: (usize, usize),
360 bias: bool,
361 groups: usize,
362 ) -> Self {
363 assert!(
364 in_channels % groups == 0,
365 "in_channels must be divisible by groups"
366 );
367 assert!(
368 out_channels % groups == 0,
369 "out_channels must be divisible by groups"
370 );
371
372 let (kh, kw) = kernel_size;
373 let in_channels_per_group = in_channels / groups;
374 let fan_in = in_channels_per_group * kh * kw;
375
376 let weight_data = kaiming_uniform(out_channels, fan_in);
377 let weight_reshaped = weight_data
378 .reshape(&[
379 out_channels as isize,
380 in_channels_per_group as isize,
381 kh as isize,
382 kw as isize,
383 ])
384 .unwrap();
385 let weight = Parameter::named("weight", weight_reshaped, true);
386
387 let bias_param = if bias {
388 Some(Parameter::named("bias", zeros(&[out_channels]), true))
389 } else {
390 None
391 };
392
393 Self {
394 weight,
395 bias: bias_param,
396 in_channels,
397 out_channels,
398 kernel_size,
399 stride,
400 padding,
401 groups,
402 }
403 }
404
405 pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
407 Self::with_groups(
408 channels,
409 channels,
410 (kernel_size, kernel_size),
411 (1, 1),
412 (kernel_size / 2, kernel_size / 2),
413 true,
414 channels,
415 )
416 }
417}
418
419fn im2col(
428 input: &[f32],
429 channels: usize,
430 height: usize,
431 width: usize,
432 kernel_h: usize,
433 kernel_w: usize,
434 pad_h: usize,
435 pad_w: usize,
436 stride_h: usize,
437 stride_w: usize,
438 out_h: usize,
439 out_w: usize,
440) -> Vec<f32> {
441 let col_h = channels * kernel_h * kernel_w;
442 let col_w = out_h * out_w;
443 let mut col = vec![0.0f32; col_h * col_w];
444 let hw = height * width;
445 let kk = kernel_h * kernel_w;
446 let h_signed = height as isize;
447 let w_signed = width as isize;
448 let pad_h_s = pad_h as isize;
449 let pad_w_s = pad_w as isize;
450
451 for col_row in 0..col_h {
455 let c = col_row / kk;
456 let k_idx = col_row % kk;
457 let kh_off = k_idx / kernel_w;
458 let kw_off = k_idx % kernel_w;
459 let input_c = c * hw;
460 let col_base = col_row * col_w;
461
462 for oh in 0..out_h {
463 let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
464 if h_in < 0 || h_in >= h_signed {
465 continue;
466 }
467 let input_row = input_c + h_in as usize * width;
468 let col_row_base = col_base + oh * out_w;
469
470 for ow in 0..out_w {
471 let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
472 if w_in >= 0 && w_in < w_signed {
473 let col_idx = col_row_base + ow;
474 let inp_idx = input_row + w_in as usize;
475 debug_assert!(col_idx < col.len(), "im2col fwd col OOB: {col_idx} >= {}", col.len());
476 debug_assert!(inp_idx < input.len(), "im2col fwd input OOB: {inp_idx} >= {}", input.len());
477 unsafe {
478 *col.get_unchecked_mut(col_idx) =
479 *input.get_unchecked(inp_idx);
480 }
481 }
482 }
483 }
484 }
485
486 col
487}
488
489fn conv2d_im2col(
491 input: &[f32],
492 weight: &[f32],
493 bias: Option<&[f32]>,
494 batch_size: usize,
495 in_channels: usize,
496 in_height: usize,
497 in_width: usize,
498 out_channels: usize,
499 kh: usize,
500 kw: usize,
501 sh: usize,
502 sw: usize,
503 ph: usize,
504 pw: usize,
505 groups: usize,
506) -> Vec<f32> {
507 let out_h = (in_height + 2 * ph - kh) / sh + 1;
508 let out_w = (in_width + 2 * pw - kw) / sw + 1;
509 let in_channels_per_group = in_channels / groups;
510 let out_channels_per_group = out_channels / groups;
511 let col_h = in_channels_per_group * kh * kw;
512 let col_w = out_h * out_w;
513 let spatial = out_h * out_w;
514 let in_spatial = in_height * in_width;
515
516 let out_per_batch = out_channels * spatial;
518 let per_batch: Vec<Vec<f32>> = (0..batch_size)
519 .into_par_iter()
520 .map(|b| {
521 let mut batch_out = vec![0.0f32; out_per_batch];
522
523 for g in 0..groups {
524 let ic_start = g * in_channels_per_group;
525 let oc_start = g * out_channels_per_group;
526
527 let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
529 let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
530
531 let col = im2col(
533 input_slice,
534 in_channels_per_group,
535 in_height,
536 in_width,
537 kh,
538 kw,
539 ph,
540 pw,
541 sh,
542 sw,
543 out_h,
544 out_w,
545 );
546
547 let w_offset = oc_start * in_channels_per_group * kh * kw;
549 let w_size = out_channels_per_group * col_h;
550 let weight_slice = &weight[w_offset..w_offset + w_size];
551
552 let w_tensor =
554 Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
555 .unwrap();
556 let col_tensor = Tensor::from_vec(col, &[col_h, col_w]).expect("tensor creation failed");
557 let result = w_tensor.matmul(&col_tensor).expect("matmul failed");
558 let result_vec = result.to_vec();
559
560 let out_offset = oc_start * spatial;
562 for oc_local in 0..out_channels_per_group {
563 let oc = oc_start + oc_local;
564 let bias_val = bias.map_or(0.0, |bv| bv[oc]);
565 let src_start = oc_local * col_w;
566 let dst_start = out_offset + oc_local * spatial;
567 if bias_val == 0.0 {
568 batch_out[dst_start..dst_start + spatial]
569 .copy_from_slice(&result_vec[src_start..src_start + spatial]);
570 } else {
571 for i in 0..spatial {
572 batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
573 }
574 }
575 }
576 }
577
578 batch_out
579 })
580 .collect();
581
582 let mut output = Vec::with_capacity(batch_size * out_per_batch);
584 for batch_out in per_batch {
585 output.extend_from_slice(&batch_out);
586 }
587 output
588}
589
590impl Module for Conv2d {
591 fn forward(&self, input: &Variable) -> Variable {
592 let input_shape = input.shape();
593 let batch_size = input_shape[0];
594 let in_height = input_shape[2];
595 let in_width = input_shape[3];
596
597 let (kh, kw) = self.kernel_size;
598 let (sh, sw) = self.stride;
599 let (ph, pw) = self.padding;
600
601 let out_height = (in_height + 2 * ph - kh) / sh + 1;
602 let out_width = (in_width + 2 * pw - kw) / sw + 1;
603
604 let input_data = input.data();
605 let weight_data = self.weight.data();
606
607 #[cfg(feature = "cuda")]
610 if input_data.device().is_gpu() {
611 let input_dev = input_data.device();
613 if !weight_data.device().is_gpu() {
614 self.weight.to_device(input_dev);
615 if let Some(ref b) = self.bias {
616 b.to_device(input_dev);
617 }
618 }
619 let weight_data = self.weight.data();
620
621 #[cfg(feature = "cudnn")]
623 let cudnn_output = {
624 let bias_tensor = self.bias.as_ref().map(|b| b.data());
625 input_data.conv2d_cudnn(
626 &weight_data,
627 bias_tensor.as_ref(),
628 self.stride,
629 self.padding,
630 self.groups,
631 )
632 };
633 #[cfg(not(feature = "cudnn"))]
634 let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
635
636 let gpu_output = if cudnn_output.is_some() {
637 cudnn_output
638 } else if self.groups == 1 {
639 let bias_tensor = self.bias.as_ref().map(|b| b.data());
641 input_data.conv2d_cuda(
642 &weight_data,
643 bias_tensor.as_ref(),
644 self.stride,
645 self.padding,
646 )
647 } else {
648 input_data.conv2d_grouped_cuda(
650 &weight_data,
651 self.bias.as_ref().map(|b| b.data()).as_ref(),
652 self.stride,
653 self.padding,
654 self.groups,
655 )
656 };
657
658 if let Some(output_tensor) = gpu_output {
659 let requires_grad =
660 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
661 if requires_grad {
662 let weight_var = self.weight.variable();
663 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
664 if self.groups == 1 {
665 let grad_fn = GradFn::new(Conv2dBackward::new(
666 input.grad_fn().cloned(),
667 weight_var.grad_fn().cloned(),
668 bias_grad_fn,
669 input_data,
670 weight_data,
671 input_shape,
672 self.in_channels,
673 self.out_channels,
674 self.kernel_size,
675 self.stride,
676 self.padding,
677 self.bias.is_some(),
678 ));
679 return Variable::from_operation(output_tensor, grad_fn, true);
680 } else {
681 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
682 input.grad_fn().cloned(),
683 weight_var.grad_fn().cloned(),
684 bias_grad_fn,
685 input_data,
686 weight_data,
687 input_shape,
688 self.in_channels,
689 self.out_channels,
690 self.kernel_size,
691 self.stride,
692 self.padding,
693 self.groups,
694 self.bias.is_some(),
695 ));
696 return Variable::from_operation(output_tensor, grad_fn, true);
697 }
698 } else {
699 return Variable::new(output_tensor, false);
700 }
701 }
702 }
704
705 let input_vec = input_data.to_vec();
706 let weight_vec = weight_data.to_vec();
707
708 let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
710 let output_data = if self.groups == 1 && conv_flops >= 500_000 {
711 let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
712 let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
713 &input_vec,
714 &weight_vec,
715 bias_vec.as_deref(),
716 batch_size,
717 self.in_channels,
718 in_height,
719 in_width,
720 self.out_channels,
721 kh,
722 kw,
723 sh,
724 sw,
725 ph,
726 pw,
727 );
728
729 if let Some(result) = gpu_result {
730 result
731 } else {
732 conv2d_im2col(
733 &input_vec,
734 &weight_vec,
735 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
736 batch_size,
737 self.in_channels,
738 in_height,
739 in_width,
740 self.out_channels,
741 kh,
742 kw,
743 sh,
744 sw,
745 ph,
746 pw,
747 self.groups,
748 )
749 }
750 } else {
751 conv2d_im2col(
752 &input_vec,
753 &weight_vec,
754 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
755 batch_size,
756 self.in_channels,
757 in_height,
758 in_width,
759 self.out_channels,
760 kh,
761 kw,
762 sh,
763 sw,
764 ph,
765 pw,
766 self.groups,
767 )
768 };
769
770 let output_tensor = Tensor::from_vec(
771 output_data,
772 &[batch_size, self.out_channels, out_height, out_width],
773 )
774 .unwrap();
775
776 let requires_grad =
777 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
778
779 if requires_grad && self.groups == 1 {
780 let weight_var = self.weight.variable();
782 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
783
784 let grad_fn = GradFn::new(Conv2dBackward::new(
785 input.grad_fn().cloned(),
786 weight_var.grad_fn().cloned(),
787 bias_grad_fn,
788 input_data,
789 weight_data,
790 input_shape,
791 self.in_channels,
792 self.out_channels,
793 self.kernel_size,
794 self.stride,
795 self.padding,
796 self.bias.is_some(),
797 ));
798 Variable::from_operation(output_tensor, grad_fn, true)
799 } else if requires_grad {
800 let weight_var = self.weight.variable();
802 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
803
804 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
805 input.grad_fn().cloned(),
806 weight_var.grad_fn().cloned(),
807 bias_grad_fn,
808 input_data,
809 weight_data,
810 input_shape,
811 self.in_channels,
812 self.out_channels,
813 self.kernel_size,
814 self.stride,
815 self.padding,
816 self.groups,
817 self.bias.is_some(),
818 ));
819 Variable::from_operation(output_tensor, grad_fn, true)
820 } else {
821 Variable::new(output_tensor, false)
822 }
823 }
824
825 fn parameters(&self) -> Vec<Parameter> {
826 let mut params = vec![self.weight.clone()];
827 if let Some(ref bias) = self.bias {
828 params.push(bias.clone());
829 }
830 params
831 }
832
833 fn named_parameters(&self) -> HashMap<String, Parameter> {
834 let mut params = HashMap::new();
835 params.insert("weight".to_string(), self.weight.clone());
836 if let Some(ref bias) = self.bias {
837 params.insert("bias".to_string(), bias.clone());
838 }
839 params
840 }
841
842 fn name(&self) -> &'static str {
843 "Conv2d"
844 }
845}
846
847pub struct ConvTranspose2d {
859 pub weight: Parameter,
861 pub bias: Option<Parameter>,
863 in_channels: usize,
864 out_channels: usize,
865 kernel_size: (usize, usize),
866 stride: (usize, usize),
867 padding: (usize, usize),
868 output_padding: (usize, usize),
869}
870
871impl ConvTranspose2d {
872 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
874 Self::with_options(
875 in_channels,
876 out_channels,
877 (kernel_size, kernel_size),
878 (1, 1),
879 (0, 0),
880 (0, 0),
881 true,
882 )
883 }
884
885 pub fn with_options(
887 in_channels: usize,
888 out_channels: usize,
889 kernel_size: (usize, usize),
890 stride: (usize, usize),
891 padding: (usize, usize),
892 output_padding: (usize, usize),
893 bias: bool,
894 ) -> Self {
895 let (kh, kw) = kernel_size;
896 let fan_in = in_channels * kh * kw;
897
898 let weight_data = kaiming_uniform(out_channels, fan_in);
899 let weight_reshaped = weight_data
900 .reshape(&[
901 in_channels as isize,
902 out_channels as isize,
903 kh as isize,
904 kw as isize,
905 ])
906 .unwrap();
907 let weight = Parameter::named("weight", weight_reshaped, true);
908
909 let bias_param = if bias {
910 Some(Parameter::named("bias", zeros(&[out_channels]), true))
911 } else {
912 None
913 };
914
915 Self {
916 weight,
917 bias: bias_param,
918 in_channels,
919 out_channels,
920 kernel_size,
921 stride,
922 padding,
923 output_padding,
924 }
925 }
926}
927
928impl Module for ConvTranspose2d {
929 fn forward(&self, input: &Variable) -> Variable {
930 let input_shape = input.shape();
931 let batch_size = input_shape[0];
932 let in_h = input_shape[2];
933 let in_w = input_shape[3];
934
935 let (kh, kw) = self.kernel_size;
936 let (sh, sw) = self.stride;
937 let (ph, pw) = self.padding;
938 let (oph, opw) = self.output_padding;
939
940 let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
941 let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
942
943 let input_data = input.data();
944 let weight_data = self.weight.data();
945 let input_vec = input_data.to_vec();
946 let weight_vec = weight_data.to_vec();
947
948 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
949
950 for b in 0..batch_size {
952 for ic in 0..self.in_channels {
953 for ih in 0..in_h {
954 for iw in 0..in_w {
955 let in_idx =
956 b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
957 let in_val = input_vec[in_idx];
958
959 for oc in 0..self.out_channels {
960 for ki in 0..kh {
961 for kj in 0..kw {
962 let oh_signed = (ih * sh + ki) as isize - ph as isize;
963 let ow_signed = (iw * sw + kj) as isize - pw as isize;
964
965 if oh_signed >= 0
966 && (oh_signed as usize) < out_h
967 && ow_signed >= 0
968 && (ow_signed as usize) < out_w
969 {
970 let oh = oh_signed as usize;
971 let ow = ow_signed as usize;
972 let out_idx = b * self.out_channels * out_h * out_w
973 + oc * out_h * out_w
974 + oh * out_w
975 + ow;
976 let w_idx = ic * self.out_channels * kh * kw
978 + oc * kh * kw
979 + ki * kw
980 + kj;
981 output_data[out_idx] += in_val * weight_vec[w_idx];
982 }
983 }
984 }
985 }
986 }
987 }
988 }
989 }
990
991 if let Some(ref bias) = self.bias {
993 let bias_vec = bias.data().to_vec();
994 for b in 0..batch_size {
995 for oc in 0..self.out_channels {
996 for oh in 0..out_h {
997 for ow in 0..out_w {
998 let out_idx = b * self.out_channels * out_h * out_w
999 + oc * out_h * out_w
1000 + oh * out_w
1001 + ow;
1002 output_data[out_idx] += bias_vec[oc];
1003 }
1004 }
1005 }
1006 }
1007 }
1008
1009 let output_tensor =
1010 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w]).expect("tensor creation failed");
1011
1012 let requires_grad =
1013 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1014
1015 if requires_grad {
1016 let weight_var = self.weight.variable();
1017 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1018
1019 let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1020 input.grad_fn().cloned(),
1021 weight_var.grad_fn().cloned(),
1022 bias_grad_fn,
1023 input_data,
1024 weight_data,
1025 input_shape,
1026 self.in_channels,
1027 self.out_channels,
1028 self.kernel_size,
1029 self.stride,
1030 self.padding,
1031 self.output_padding,
1032 self.bias.is_some(),
1033 ));
1034 Variable::from_operation(output_tensor, grad_fn, true)
1035 } else {
1036 Variable::new(output_tensor, false)
1037 }
1038 }
1039
1040 fn parameters(&self) -> Vec<Parameter> {
1041 let mut params = vec![self.weight.clone()];
1042 if let Some(ref bias) = self.bias {
1043 params.push(bias.clone());
1044 }
1045 params
1046 }
1047
1048 fn named_parameters(&self) -> HashMap<String, Parameter> {
1049 let mut params = HashMap::new();
1050 params.insert("weight".to_string(), self.weight.clone());
1051 if let Some(ref bias) = self.bias {
1052 params.insert("bias".to_string(), bias.clone());
1053 }
1054 params
1055 }
1056
1057 fn name(&self) -> &'static str {
1058 "ConvTranspose2d"
1059 }
1060}
1061
1062#[cfg(test)]
1067mod tests {
1068 use super::*;
1069
1070 #[test]
1071 fn test_conv1d_creation() {
1072 let conv = Conv1d::new(3, 16, 3);
1073 assert_eq!(conv.in_channels, 3);
1074 assert_eq!(conv.out_channels, 16);
1075 assert_eq!(conv.kernel_size, 3);
1076 }
1077
1078 #[test]
1079 fn test_conv1d_forward() {
1080 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1081 let input = Variable::new(
1082 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).expect("tensor creation failed"),
1083 false,
1084 );
1085 let output = conv.forward(&input);
1086 assert_eq!(output.shape(), vec![1, 1, 5]);
1087 }
1088
1089 #[test]
1090 fn test_conv1d_backward() {
1091 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1092 let input = Variable::new(
1093 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).expect("tensor creation failed"),
1094 true,
1095 );
1096 let output = conv.forward(&input);
1097 let loss = output.sum();
1098 loss.backward();
1099
1100 assert!(
1102 input.grad().is_some(),
1103 "Conv1d: input gradient should flow through backward pass"
1104 );
1105 let grad = input.grad().unwrap();
1106 assert_eq!(grad.shape(), &[1, 1, 5]);
1107 }
1108
1109 #[test]
1110 fn test_conv2d_creation() {
1111 let conv = Conv2d::new(3, 64, 3);
1112 assert_eq!(conv.in_channels, 3);
1113 assert_eq!(conv.out_channels, 64);
1114 assert_eq!(conv.kernel_size, (3, 3));
1115 }
1116
1117 #[test]
1118 fn test_conv2d_forward() {
1119 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1120 let input = Variable::new(
1121 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1122 false,
1123 );
1124 let output = conv.forward(&input);
1125 assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1126 }
1127
1128 #[test]
1129 fn test_conv2d_backward() {
1130 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1131 let input = Variable::new(
1132 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1133 true,
1134 );
1135 let output = conv.forward(&input);
1136 let loss = output.sum();
1137 loss.backward();
1138
1139 assert!(
1140 input.grad().is_some(),
1141 "Conv2d: input gradient should flow through backward pass"
1142 );
1143 let grad = input.grad().unwrap();
1144 assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1145
1146 let w_grad = conv.weight.grad();
1148 assert!(
1149 w_grad.is_some(),
1150 "Conv2d: weight gradient should be computed"
1151 );
1152 }
1153
1154 #[test]
1155 fn test_conv2d_parameters() {
1156 let conv = Conv2d::new(3, 64, 3);
1157 let params = conv.parameters();
1158 assert_eq!(params.len(), 2); }
1160
1161 #[test]
1162 fn test_conv2d_grouped() {
1163 let conv = Conv2d::depthwise(4, 3);
1165 assert_eq!(conv.groups, 4);
1166 assert_eq!(conv.in_channels, 4);
1167 assert_eq!(conv.out_channels, 4);
1168
1169 let input = Variable::new(
1170 Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).expect("tensor creation failed"),
1171 false,
1172 );
1173 let output = conv.forward(&input);
1174 assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1175 }
1176
1177 #[test]
1178 fn test_conv_transpose2d_forward() {
1179 let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1180 let input = Variable::new(
1181 Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).expect("tensor creation failed"),
1182 false,
1183 );
1184 let output = conv_t.forward(&input);
1185 assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1187 }
1188
1189 #[test]
1190 fn test_conv_transpose2d_backward() {
1191 let conv_t = ConvTranspose2d::new(1, 1, 3);
1192 let input = Variable::new(Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).expect("tensor creation failed"), true);
1193 let output = conv_t.forward(&input);
1194 let loss = output.sum();
1195 loss.backward();
1196
1197 assert!(
1198 input.grad().is_some(),
1199 "ConvTranspose2d: input gradient should flow through backward"
1200 );
1201 }
1202}