1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_autograd::functions::{
21 Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
22};
23use axonml_autograd::grad_fn::GradFn;
24use axonml_autograd::no_grad::is_grad_enabled;
25use axonml_tensor::Tensor;
26use rayon::prelude::*;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32pub struct Conv1d {
44 pub weight: Parameter,
46 pub bias: Option<Parameter>,
48 in_channels: usize,
50 out_channels: usize,
52 kernel_size: usize,
54 stride: usize,
56 padding: usize,
58}
59
60impl Conv1d {
61 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
63 Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
64 }
65
66 pub fn with_options(
68 in_channels: usize,
69 out_channels: usize,
70 kernel_size: usize,
71 stride: usize,
72 padding: usize,
73 bias: bool,
74 ) -> Self {
75 let fan_in = in_channels * kernel_size;
77 let weight_data = kaiming_uniform(out_channels, fan_in);
78 let weight_reshaped = weight_data
79 .reshape(&[
80 out_channels as isize,
81 in_channels as isize,
82 kernel_size as isize,
83 ])
84 .unwrap();
85 let weight = Parameter::named("weight", weight_reshaped, true);
86
87 let bias_param = if bias {
88 Some(Parameter::named("bias", zeros(&[out_channels]), true))
89 } else {
90 None
91 };
92
93 Self {
94 weight,
95 bias: bias_param,
96 in_channels,
97 out_channels,
98 kernel_size,
99 stride,
100 padding,
101 }
102 }
103}
104
105impl Module for Conv1d {
106 fn forward(&self, input: &Variable) -> Variable {
107 let input_shape = input.shape();
108 let batch_size = input_shape[0];
109 let in_length = input_shape[2];
110
111 let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113 let input_data = input.data();
114 let weight_data = self.weight.data();
115
116 #[cfg(feature = "cuda")]
119 if input_data.device().is_gpu() {
120 let input_dev = input_data.device();
122 if !weight_data.device().is_gpu() {
123 self.weight.to_device(input_dev);
124 if let Some(ref b) = self.bias {
125 b.to_device(input_dev);
126 }
127 }
128 let weight_data = self.weight.data();
129
130 let input_4d = input_data
132 .reshape(&[
133 batch_size as isize,
134 self.in_channels as isize,
135 in_length as isize,
136 1,
137 ])
138 .unwrap();
139
140 let weight_4d = weight_data
142 .reshape(&[
143 self.out_channels as isize,
144 self.in_channels as isize,
145 self.kernel_size as isize,
146 1,
147 ])
148 .unwrap();
149
150 let bias_tensor = self.bias.as_ref().map(|b| b.data());
151 let gpu_output = input_4d.conv2d_cuda(
152 &weight_4d,
153 bias_tensor.as_ref(),
154 (self.stride, 1),
155 (self.padding, 0),
156 );
157
158 if let Some(output_4d) = gpu_output {
159 let output_tensor = output_4d
161 .reshape(&[
162 batch_size as isize,
163 self.out_channels as isize,
164 out_length as isize,
165 ])
166 .unwrap();
167
168 let requires_grad =
169 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
170 if requires_grad {
171 let weight_var = self.weight.variable();
172 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
173
174 let grad_fn = GradFn::new(Conv1dBackward::new(
175 input.grad_fn().cloned(),
176 weight_var.grad_fn().cloned(),
177 bias_grad_fn,
178 input_data,
179 weight_data,
180 input_shape,
181 self.in_channels,
182 self.out_channels,
183 self.kernel_size,
184 self.stride,
185 self.padding,
186 self.bias.is_some(),
187 ));
188 return Variable::from_operation(output_tensor, grad_fn, true);
189 } else {
190 return Variable::new(output_tensor, false);
191 }
192 }
193 }
195
196 let input_vec = input_data.to_vec();
197 let weight_vec = weight_data.to_vec();
198
199 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
200
201 for b in 0..batch_size {
202 for oc in 0..self.out_channels {
203 for ol in 0..out_length {
204 let mut sum = 0.0f32;
205 let in_start = ol * self.stride;
206
207 for ic in 0..self.in_channels {
208 for k in 0..self.kernel_size {
209 let in_idx = in_start + k;
210 if in_idx < self.padding || in_idx >= in_length + self.padding {
211 continue;
212 }
213 let actual_idx = in_idx - self.padding;
214
215 let input_idx =
216 b * self.in_channels * in_length + ic * in_length + actual_idx;
217 let weight_idx = oc * self.in_channels * self.kernel_size
218 + ic * self.kernel_size
219 + k;
220
221 sum += input_vec[input_idx] * weight_vec[weight_idx];
222 }
223 }
224
225 if let Some(ref bias) = self.bias {
226 sum += bias.data().to_vec()[oc];
227 }
228
229 let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
230 output_data[output_idx] = sum;
231 }
232 }
233 }
234
235 let output_tensor =
236 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length]).unwrap();
237
238 let requires_grad =
239 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
240
241 if requires_grad {
242 let weight_var = self.weight.variable();
243 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
244
245 let grad_fn = GradFn::new(Conv1dBackward::new(
246 input.grad_fn().cloned(),
247 weight_var.grad_fn().cloned(),
248 bias_grad_fn,
249 input_data,
250 weight_data,
251 input_shape,
252 self.in_channels,
253 self.out_channels,
254 self.kernel_size,
255 self.stride,
256 self.padding,
257 self.bias.is_some(),
258 ));
259 Variable::from_operation(output_tensor, grad_fn, true)
260 } else {
261 Variable::new(output_tensor, false)
262 }
263 }
264
265 fn parameters(&self) -> Vec<Parameter> {
266 let mut params = vec![self.weight.clone()];
267 if let Some(ref bias) = self.bias {
268 params.push(bias.clone());
269 }
270 params
271 }
272
273 fn named_parameters(&self) -> HashMap<String, Parameter> {
274 let mut params = HashMap::new();
275 params.insert("weight".to_string(), self.weight.clone());
276 if let Some(ref bias) = self.bias {
277 params.insert("bias".to_string(), bias.clone());
278 }
279 params
280 }
281
282 fn name(&self) -> &'static str {
283 "Conv1d"
284 }
285}
286
287pub struct Conv2d {
299 pub weight: Parameter,
301 pub bias: Option<Parameter>,
303 in_channels: usize,
305 out_channels: usize,
307 kernel_size: (usize, usize),
309 stride: (usize, usize),
311 padding: (usize, usize),
313 groups: usize,
315}
316
317impl Conv2d {
318 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
320 Self::with_options(
321 in_channels,
322 out_channels,
323 (kernel_size, kernel_size),
324 (1, 1),
325 (0, 0),
326 true,
327 )
328 }
329
330 pub fn with_options(
332 in_channels: usize,
333 out_channels: usize,
334 kernel_size: (usize, usize),
335 stride: (usize, usize),
336 padding: (usize, usize),
337 bias: bool,
338 ) -> Self {
339 Self::with_groups(
340 in_channels,
341 out_channels,
342 kernel_size,
343 stride,
344 padding,
345 bias,
346 1,
347 )
348 }
349
350 pub fn with_groups(
355 in_channels: usize,
356 out_channels: usize,
357 kernel_size: (usize, usize),
358 stride: (usize, usize),
359 padding: (usize, usize),
360 bias: bool,
361 groups: usize,
362 ) -> Self {
363 assert!(
364 in_channels % groups == 0,
365 "in_channels must be divisible by groups"
366 );
367 assert!(
368 out_channels % groups == 0,
369 "out_channels must be divisible by groups"
370 );
371
372 let (kh, kw) = kernel_size;
373 let in_channels_per_group = in_channels / groups;
374 let fan_in = in_channels_per_group * kh * kw;
375
376 let weight_data = kaiming_uniform(out_channels, fan_in);
377 let weight_reshaped = weight_data
378 .reshape(&[
379 out_channels as isize,
380 in_channels_per_group as isize,
381 kh as isize,
382 kw as isize,
383 ])
384 .unwrap();
385 let weight = Parameter::named("weight", weight_reshaped, true);
386
387 let bias_param = if bias {
388 Some(Parameter::named("bias", zeros(&[out_channels]), true))
389 } else {
390 None
391 };
392
393 Self {
394 weight,
395 bias: bias_param,
396 in_channels,
397 out_channels,
398 kernel_size,
399 stride,
400 padding,
401 groups,
402 }
403 }
404
405 pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
407 Self::with_groups(
408 channels,
409 channels,
410 (kernel_size, kernel_size),
411 (1, 1),
412 (kernel_size / 2, kernel_size / 2),
413 true,
414 channels,
415 )
416 }
417}
418
419fn im2col(
428 input: &[f32],
429 channels: usize,
430 height: usize,
431 width: usize,
432 kernel_h: usize,
433 kernel_w: usize,
434 pad_h: usize,
435 pad_w: usize,
436 stride_h: usize,
437 stride_w: usize,
438 out_h: usize,
439 out_w: usize,
440) -> Vec<f32> {
441 let col_h = channels * kernel_h * kernel_w;
442 let col_w = out_h * out_w;
443 let mut col = vec![0.0f32; col_h * col_w];
444 let hw = height * width;
445 let kk = kernel_h * kernel_w;
446 let h_signed = height as isize;
447 let w_signed = width as isize;
448 let pad_h_s = pad_h as isize;
449 let pad_w_s = pad_w as isize;
450
451 for col_row in 0..col_h {
455 let c = col_row / kk;
456 let k_idx = col_row % kk;
457 let kh_off = k_idx / kernel_w;
458 let kw_off = k_idx % kernel_w;
459 let input_c = c * hw;
460 let col_base = col_row * col_w;
461
462 for oh in 0..out_h {
463 let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
464 if h_in < 0 || h_in >= h_signed {
465 continue;
466 }
467 let input_row = input_c + h_in as usize * width;
468 let col_row_base = col_base + oh * out_w;
469
470 for ow in 0..out_w {
471 let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
472 if w_in >= 0 && w_in < w_signed {
473 unsafe {
474 *col.get_unchecked_mut(col_row_base + ow) =
475 *input.get_unchecked(input_row + w_in as usize);
476 }
477 }
478 }
479 }
480 }
481
482 col
483}
484
485fn conv2d_im2col(
487 input: &[f32],
488 weight: &[f32],
489 bias: Option<&[f32]>,
490 batch_size: usize,
491 in_channels: usize,
492 in_height: usize,
493 in_width: usize,
494 out_channels: usize,
495 kh: usize,
496 kw: usize,
497 sh: usize,
498 sw: usize,
499 ph: usize,
500 pw: usize,
501 groups: usize,
502) -> Vec<f32> {
503 let out_h = (in_height + 2 * ph - kh) / sh + 1;
504 let out_w = (in_width + 2 * pw - kw) / sw + 1;
505 let in_channels_per_group = in_channels / groups;
506 let out_channels_per_group = out_channels / groups;
507 let col_h = in_channels_per_group * kh * kw;
508 let col_w = out_h * out_w;
509 let spatial = out_h * out_w;
510 let in_spatial = in_height * in_width;
511
512 let out_per_batch = out_channels * spatial;
514 let per_batch: Vec<Vec<f32>> = (0..batch_size)
515 .into_par_iter()
516 .map(|b| {
517 let mut batch_out = vec![0.0f32; out_per_batch];
518
519 for g in 0..groups {
520 let ic_start = g * in_channels_per_group;
521 let oc_start = g * out_channels_per_group;
522
523 let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
525 let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
526
527 let col = im2col(
529 input_slice,
530 in_channels_per_group,
531 in_height,
532 in_width,
533 kh,
534 kw,
535 ph,
536 pw,
537 sh,
538 sw,
539 out_h,
540 out_w,
541 );
542
543 let w_offset = oc_start * in_channels_per_group * kh * kw;
545 let w_size = out_channels_per_group * col_h;
546 let weight_slice = &weight[w_offset..w_offset + w_size];
547
548 let w_tensor =
550 Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
551 .unwrap();
552 let col_tensor = Tensor::from_vec(col, &[col_h, col_w]).unwrap();
553 let result = w_tensor.matmul(&col_tensor).unwrap();
554 let result_vec = result.to_vec();
555
556 let out_offset = oc_start * spatial;
558 for oc_local in 0..out_channels_per_group {
559 let oc = oc_start + oc_local;
560 let bias_val = bias.map_or(0.0, |bv| bv[oc]);
561 let src_start = oc_local * col_w;
562 let dst_start = out_offset + oc_local * spatial;
563 if bias_val == 0.0 {
564 batch_out[dst_start..dst_start + spatial]
565 .copy_from_slice(&result_vec[src_start..src_start + spatial]);
566 } else {
567 for i in 0..spatial {
568 batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
569 }
570 }
571 }
572 }
573
574 batch_out
575 })
576 .collect();
577
578 let mut output = Vec::with_capacity(batch_size * out_per_batch);
580 for batch_out in per_batch {
581 output.extend_from_slice(&batch_out);
582 }
583 output
584}
585
586impl Module for Conv2d {
587 fn forward(&self, input: &Variable) -> Variable {
588 let input_shape = input.shape();
589 let batch_size = input_shape[0];
590 let in_height = input_shape[2];
591 let in_width = input_shape[3];
592
593 let (kh, kw) = self.kernel_size;
594 let (sh, sw) = self.stride;
595 let (ph, pw) = self.padding;
596
597 let out_height = (in_height + 2 * ph - kh) / sh + 1;
598 let out_width = (in_width + 2 * pw - kw) / sw + 1;
599
600 let input_data = input.data();
601 let weight_data = self.weight.data();
602
603 #[cfg(feature = "cuda")]
606 if input_data.device().is_gpu() {
607 let input_dev = input_data.device();
609 if !weight_data.device().is_gpu() {
610 self.weight.to_device(input_dev);
611 if let Some(ref b) = self.bias {
612 b.to_device(input_dev);
613 }
614 }
615 let weight_data = self.weight.data();
616
617 #[cfg(feature = "cudnn")]
619 let cudnn_output = {
620 let bias_tensor = self.bias.as_ref().map(|b| b.data());
621 input_data.conv2d_cudnn(
622 &weight_data,
623 bias_tensor.as_ref(),
624 self.stride,
625 self.padding,
626 self.groups,
627 )
628 };
629 #[cfg(not(feature = "cudnn"))]
630 let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
631
632 let gpu_output = if cudnn_output.is_some() {
633 cudnn_output
634 } else if self.groups == 1 {
635 let bias_tensor = self.bias.as_ref().map(|b| b.data());
637 input_data.conv2d_cuda(
638 &weight_data,
639 bias_tensor.as_ref(),
640 self.stride,
641 self.padding,
642 )
643 } else {
644 input_data.conv2d_grouped_cuda(
646 &weight_data,
647 self.bias.as_ref().map(|b| b.data()).as_ref(),
648 self.stride,
649 self.padding,
650 self.groups,
651 )
652 };
653
654 if let Some(output_tensor) = gpu_output {
655 let requires_grad =
656 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
657 if requires_grad {
658 let weight_var = self.weight.variable();
659 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
660 if self.groups == 1 {
661 let grad_fn = GradFn::new(Conv2dBackward::new(
662 input.grad_fn().cloned(),
663 weight_var.grad_fn().cloned(),
664 bias_grad_fn,
665 input_data,
666 weight_data,
667 input_shape,
668 self.in_channels,
669 self.out_channels,
670 self.kernel_size,
671 self.stride,
672 self.padding,
673 self.bias.is_some(),
674 ));
675 return Variable::from_operation(output_tensor, grad_fn, true);
676 } else {
677 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
678 input.grad_fn().cloned(),
679 weight_var.grad_fn().cloned(),
680 bias_grad_fn,
681 input_data,
682 weight_data,
683 input_shape,
684 self.in_channels,
685 self.out_channels,
686 self.kernel_size,
687 self.stride,
688 self.padding,
689 self.groups,
690 self.bias.is_some(),
691 ));
692 return Variable::from_operation(output_tensor, grad_fn, true);
693 }
694 } else {
695 return Variable::new(output_tensor, false);
696 }
697 }
698 }
700
701 let input_vec = input_data.to_vec();
702 let weight_vec = weight_data.to_vec();
703
704 let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
706 let output_data = if self.groups == 1 && conv_flops >= 500_000 {
707 let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
708 let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
709 &input_vec,
710 &weight_vec,
711 bias_vec.as_deref(),
712 batch_size,
713 self.in_channels,
714 in_height,
715 in_width,
716 self.out_channels,
717 kh,
718 kw,
719 sh,
720 sw,
721 ph,
722 pw,
723 );
724
725 if let Some(result) = gpu_result {
726 result
727 } else {
728 conv2d_im2col(
729 &input_vec,
730 &weight_vec,
731 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
732 batch_size,
733 self.in_channels,
734 in_height,
735 in_width,
736 self.out_channels,
737 kh,
738 kw,
739 sh,
740 sw,
741 ph,
742 pw,
743 self.groups,
744 )
745 }
746 } else {
747 conv2d_im2col(
748 &input_vec,
749 &weight_vec,
750 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
751 batch_size,
752 self.in_channels,
753 in_height,
754 in_width,
755 self.out_channels,
756 kh,
757 kw,
758 sh,
759 sw,
760 ph,
761 pw,
762 self.groups,
763 )
764 };
765
766 let output_tensor = Tensor::from_vec(
767 output_data,
768 &[batch_size, self.out_channels, out_height, out_width],
769 )
770 .unwrap();
771
772 let requires_grad =
773 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
774
775 if requires_grad && self.groups == 1 {
776 let weight_var = self.weight.variable();
778 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
779
780 let grad_fn = GradFn::new(Conv2dBackward::new(
781 input.grad_fn().cloned(),
782 weight_var.grad_fn().cloned(),
783 bias_grad_fn,
784 input_data,
785 weight_data,
786 input_shape,
787 self.in_channels,
788 self.out_channels,
789 self.kernel_size,
790 self.stride,
791 self.padding,
792 self.bias.is_some(),
793 ));
794 Variable::from_operation(output_tensor, grad_fn, true)
795 } else if requires_grad {
796 let weight_var = self.weight.variable();
798 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
799
800 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
801 input.grad_fn().cloned(),
802 weight_var.grad_fn().cloned(),
803 bias_grad_fn,
804 input_data,
805 weight_data,
806 input_shape,
807 self.in_channels,
808 self.out_channels,
809 self.kernel_size,
810 self.stride,
811 self.padding,
812 self.groups,
813 self.bias.is_some(),
814 ));
815 Variable::from_operation(output_tensor, grad_fn, true)
816 } else {
817 Variable::new(output_tensor, false)
818 }
819 }
820
821 fn parameters(&self) -> Vec<Parameter> {
822 let mut params = vec![self.weight.clone()];
823 if let Some(ref bias) = self.bias {
824 params.push(bias.clone());
825 }
826 params
827 }
828
829 fn named_parameters(&self) -> HashMap<String, Parameter> {
830 let mut params = HashMap::new();
831 params.insert("weight".to_string(), self.weight.clone());
832 if let Some(ref bias) = self.bias {
833 params.insert("bias".to_string(), bias.clone());
834 }
835 params
836 }
837
838 fn name(&self) -> &'static str {
839 "Conv2d"
840 }
841}
842
843pub struct ConvTranspose2d {
855 pub weight: Parameter,
857 pub bias: Option<Parameter>,
859 in_channels: usize,
860 out_channels: usize,
861 kernel_size: (usize, usize),
862 stride: (usize, usize),
863 padding: (usize, usize),
864 output_padding: (usize, usize),
865}
866
867impl ConvTranspose2d {
868 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
870 Self::with_options(
871 in_channels,
872 out_channels,
873 (kernel_size, kernel_size),
874 (1, 1),
875 (0, 0),
876 (0, 0),
877 true,
878 )
879 }
880
881 pub fn with_options(
883 in_channels: usize,
884 out_channels: usize,
885 kernel_size: (usize, usize),
886 stride: (usize, usize),
887 padding: (usize, usize),
888 output_padding: (usize, usize),
889 bias: bool,
890 ) -> Self {
891 let (kh, kw) = kernel_size;
892 let fan_in = in_channels * kh * kw;
893
894 let weight_data = kaiming_uniform(out_channels, fan_in);
895 let weight_reshaped = weight_data
896 .reshape(&[
897 in_channels as isize,
898 out_channels as isize,
899 kh as isize,
900 kw as isize,
901 ])
902 .unwrap();
903 let weight = Parameter::named("weight", weight_reshaped, true);
904
905 let bias_param = if bias {
906 Some(Parameter::named("bias", zeros(&[out_channels]), true))
907 } else {
908 None
909 };
910
911 Self {
912 weight,
913 bias: bias_param,
914 in_channels,
915 out_channels,
916 kernel_size,
917 stride,
918 padding,
919 output_padding,
920 }
921 }
922}
923
924impl Module for ConvTranspose2d {
925 fn forward(&self, input: &Variable) -> Variable {
926 let input_shape = input.shape();
927 let batch_size = input_shape[0];
928 let in_h = input_shape[2];
929 let in_w = input_shape[3];
930
931 let (kh, kw) = self.kernel_size;
932 let (sh, sw) = self.stride;
933 let (ph, pw) = self.padding;
934 let (oph, opw) = self.output_padding;
935
936 let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
937 let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
938
939 let input_data = input.data();
940 let weight_data = self.weight.data();
941 let input_vec = input_data.to_vec();
942 let weight_vec = weight_data.to_vec();
943
944 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
945
946 for b in 0..batch_size {
948 for ic in 0..self.in_channels {
949 for ih in 0..in_h {
950 for iw in 0..in_w {
951 let in_idx =
952 b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
953 let in_val = input_vec[in_idx];
954
955 for oc in 0..self.out_channels {
956 for ki in 0..kh {
957 for kj in 0..kw {
958 let oh_signed = (ih * sh + ki) as isize - ph as isize;
959 let ow_signed = (iw * sw + kj) as isize - pw as isize;
960
961 if oh_signed >= 0
962 && (oh_signed as usize) < out_h
963 && ow_signed >= 0
964 && (ow_signed as usize) < out_w
965 {
966 let oh = oh_signed as usize;
967 let ow = ow_signed as usize;
968 let out_idx = b * self.out_channels * out_h * out_w
969 + oc * out_h * out_w
970 + oh * out_w
971 + ow;
972 let w_idx = ic * self.out_channels * kh * kw
974 + oc * kh * kw
975 + ki * kw
976 + kj;
977 output_data[out_idx] += in_val * weight_vec[w_idx];
978 }
979 }
980 }
981 }
982 }
983 }
984 }
985 }
986
987 if let Some(ref bias) = self.bias {
989 let bias_vec = bias.data().to_vec();
990 for b in 0..batch_size {
991 for oc in 0..self.out_channels {
992 for oh in 0..out_h {
993 for ow in 0..out_w {
994 let out_idx = b * self.out_channels * out_h * out_w
995 + oc * out_h * out_w
996 + oh * out_w
997 + ow;
998 output_data[out_idx] += bias_vec[oc];
999 }
1000 }
1001 }
1002 }
1003 }
1004
1005 let output_tensor =
1006 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w]).unwrap();
1007
1008 let requires_grad =
1009 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1010
1011 if requires_grad {
1012 let weight_var = self.weight.variable();
1013 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1014
1015 let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1016 input.grad_fn().cloned(),
1017 weight_var.grad_fn().cloned(),
1018 bias_grad_fn,
1019 input_data,
1020 weight_data,
1021 input_shape,
1022 self.in_channels,
1023 self.out_channels,
1024 self.kernel_size,
1025 self.stride,
1026 self.padding,
1027 self.output_padding,
1028 self.bias.is_some(),
1029 ));
1030 Variable::from_operation(output_tensor, grad_fn, true)
1031 } else {
1032 Variable::new(output_tensor, false)
1033 }
1034 }
1035
1036 fn parameters(&self) -> Vec<Parameter> {
1037 let mut params = vec![self.weight.clone()];
1038 if let Some(ref bias) = self.bias {
1039 params.push(bias.clone());
1040 }
1041 params
1042 }
1043
1044 fn named_parameters(&self) -> HashMap<String, Parameter> {
1045 let mut params = HashMap::new();
1046 params.insert("weight".to_string(), self.weight.clone());
1047 if let Some(ref bias) = self.bias {
1048 params.insert("bias".to_string(), bias.clone());
1049 }
1050 params
1051 }
1052
1053 fn name(&self) -> &'static str {
1054 "ConvTranspose2d"
1055 }
1056}
1057
1058#[cfg(test)]
1063mod tests {
1064 use super::*;
1065
1066 #[test]
1067 fn test_conv1d_creation() {
1068 let conv = Conv1d::new(3, 16, 3);
1069 assert_eq!(conv.in_channels, 3);
1070 assert_eq!(conv.out_channels, 16);
1071 assert_eq!(conv.kernel_size, 3);
1072 }
1073
1074 #[test]
1075 fn test_conv1d_forward() {
1076 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1077 let input = Variable::new(
1078 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).unwrap(),
1079 false,
1080 );
1081 let output = conv.forward(&input);
1082 assert_eq!(output.shape(), vec![1, 1, 5]);
1083 }
1084
1085 #[test]
1086 fn test_conv1d_backward() {
1087 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1088 let input = Variable::new(
1089 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).unwrap(),
1090 true,
1091 );
1092 let output = conv.forward(&input);
1093 let loss = output.sum();
1094 loss.backward();
1095
1096 assert!(
1098 input.grad().is_some(),
1099 "Conv1d: input gradient should flow through backward pass"
1100 );
1101 let grad = input.grad().unwrap();
1102 assert_eq!(grad.shape(), &[1, 1, 5]);
1103 }
1104
1105 #[test]
1106 fn test_conv2d_creation() {
1107 let conv = Conv2d::new(3, 64, 3);
1108 assert_eq!(conv.in_channels, 3);
1109 assert_eq!(conv.out_channels, 64);
1110 assert_eq!(conv.kernel_size, (3, 3));
1111 }
1112
1113 #[test]
1114 fn test_conv2d_forward() {
1115 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1116 let input = Variable::new(
1117 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).unwrap(),
1118 false,
1119 );
1120 let output = conv.forward(&input);
1121 assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1122 }
1123
1124 #[test]
1125 fn test_conv2d_backward() {
1126 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1127 let input = Variable::new(
1128 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).unwrap(),
1129 true,
1130 );
1131 let output = conv.forward(&input);
1132 let loss = output.sum();
1133 loss.backward();
1134
1135 assert!(
1136 input.grad().is_some(),
1137 "Conv2d: input gradient should flow through backward pass"
1138 );
1139 let grad = input.grad().unwrap();
1140 assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1141
1142 let w_grad = conv.weight.grad();
1144 assert!(
1145 w_grad.is_some(),
1146 "Conv2d: weight gradient should be computed"
1147 );
1148 }
1149
1150 #[test]
1151 fn test_conv2d_parameters() {
1152 let conv = Conv2d::new(3, 64, 3);
1153 let params = conv.parameters();
1154 assert_eq!(params.len(), 2); }
1156
1157 #[test]
1158 fn test_conv2d_grouped() {
1159 let conv = Conv2d::depthwise(4, 3);
1161 assert_eq!(conv.groups, 4);
1162 assert_eq!(conv.in_channels, 4);
1163 assert_eq!(conv.out_channels, 4);
1164
1165 let input = Variable::new(
1166 Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).unwrap(),
1167 false,
1168 );
1169 let output = conv.forward(&input);
1170 assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1171 }
1172
1173 #[test]
1174 fn test_conv_transpose2d_forward() {
1175 let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1176 let input = Variable::new(
1177 Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).unwrap(),
1178 false,
1179 );
1180 let output = conv_t.forward(&input);
1181 assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1183 }
1184
1185 #[test]
1186 fn test_conv_transpose2d_backward() {
1187 let conv_t = ConvTranspose2d::new(1, 1, 3);
1188 let input = Variable::new(Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).unwrap(), true);
1189 let output = conv_t.forward(&input);
1190 let loss = output.sum();
1191 loss.backward();
1192
1193 assert!(
1194 input.grad().is_some(),
1195 "ConvTranspose2d: input gradient should flow through backward"
1196 );
1197 }
1198}