1use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28use axonml_autograd::functions::{
29 Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
30};
31use axonml_autograd::grad_fn::GradFn;
32use axonml_autograd::no_grad::is_grad_enabled;
33use axonml_tensor::Tensor;
34use rayon::prelude::*;
35
36use crate::init::{kaiming_uniform, zeros};
37use crate::module::Module;
38use crate::parameter::Parameter;
39
40pub struct Conv1d {
52 pub weight: Parameter,
54 pub bias: Option<Parameter>,
56 in_channels: usize,
58 out_channels: usize,
60 kernel_size: usize,
62 stride: usize,
64 padding: usize,
66}
67
68impl Conv1d {
69 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
71 Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
72 }
73
74 pub fn with_options(
76 in_channels: usize,
77 out_channels: usize,
78 kernel_size: usize,
79 stride: usize,
80 padding: usize,
81 bias: bool,
82 ) -> Self {
83 let fan_in = in_channels * kernel_size;
85 let weight_data = kaiming_uniform(out_channels, fan_in);
86 let weight_reshaped = weight_data
87 .reshape(&[
88 out_channels as isize,
89 in_channels as isize,
90 kernel_size as isize,
91 ])
92 .unwrap();
93 let weight = Parameter::named("weight", weight_reshaped, true);
94
95 let bias_param = if bias {
96 Some(Parameter::named("bias", zeros(&[out_channels]), true))
97 } else {
98 None
99 };
100
101 Self {
102 weight,
103 bias: bias_param,
104 in_channels,
105 out_channels,
106 kernel_size,
107 stride,
108 padding,
109 }
110 }
111}
112
113impl Module for Conv1d {
114 fn forward(&self, input: &Variable) -> Variable {
115 let input_shape = input.shape();
116 let batch_size = input_shape[0];
117 let in_length = input_shape[2];
118
119 let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
120
121 let input_data = input.data();
122 let weight_data = self.weight.data();
123
124 #[cfg(feature = "cuda")]
127 if input_data.device().is_gpu() {
128 let input_dev = input_data.device();
130 if !weight_data.device().is_gpu() {
131 self.weight.to_device(input_dev);
132 if let Some(ref b) = self.bias {
133 b.to_device(input_dev);
134 }
135 }
136 let weight_data = self.weight.data();
137
138 let input_4d = input_data
140 .reshape(&[
141 batch_size as isize,
142 self.in_channels as isize,
143 in_length as isize,
144 1,
145 ])
146 .unwrap();
147
148 let weight_4d = weight_data
150 .reshape(&[
151 self.out_channels as isize,
152 self.in_channels as isize,
153 self.kernel_size as isize,
154 1,
155 ])
156 .unwrap();
157
158 let bias_tensor = self.bias.as_ref().map(|b| b.data());
159 let gpu_output = input_4d.conv2d_cuda(
160 &weight_4d,
161 bias_tensor.as_ref(),
162 (self.stride, 1),
163 (self.padding, 0),
164 );
165
166 if let Some(output_4d) = gpu_output {
167 let output_tensor = output_4d
169 .reshape(&[
170 batch_size as isize,
171 self.out_channels as isize,
172 out_length as isize,
173 ])
174 .unwrap();
175
176 let requires_grad =
177 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
178 if requires_grad {
179 let weight_var = self.weight.variable();
180 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
181
182 let grad_fn = GradFn::new(Conv1dBackward::new(
183 input.grad_fn().cloned(),
184 weight_var.grad_fn().cloned(),
185 bias_grad_fn,
186 input_data,
187 weight_data,
188 input_shape,
189 self.in_channels,
190 self.out_channels,
191 self.kernel_size,
192 self.stride,
193 self.padding,
194 self.bias.is_some(),
195 ));
196 return Variable::from_operation(output_tensor, grad_fn, true);
197 } else {
198 return Variable::new(output_tensor, false);
199 }
200 }
201 }
203
204 let input_vec = input_data.to_vec();
205 let weight_vec = weight_data.to_vec();
206
207 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
208
209 for b in 0..batch_size {
210 for oc in 0..self.out_channels {
211 for ol in 0..out_length {
212 let mut sum = 0.0f32;
213 let in_start = ol * self.stride;
214
215 for ic in 0..self.in_channels {
216 for k in 0..self.kernel_size {
217 let in_idx = in_start + k;
218 if in_idx < self.padding || in_idx >= in_length + self.padding {
219 continue;
220 }
221 let actual_idx = in_idx - self.padding;
222
223 let input_idx =
224 b * self.in_channels * in_length + ic * in_length + actual_idx;
225 let weight_idx = oc * self.in_channels * self.kernel_size
226 + ic * self.kernel_size
227 + k;
228
229 sum += input_vec[input_idx] * weight_vec[weight_idx];
230 }
231 }
232
233 if let Some(ref bias) = self.bias {
234 sum += bias.data().to_vec()[oc];
235 }
236
237 let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
238 output_data[output_idx] = sum;
239 }
240 }
241 }
242
243 let output_tensor =
244 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length])
245 .expect("tensor creation failed");
246
247 let requires_grad =
248 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
249
250 if requires_grad {
251 let weight_var = self.weight.variable();
252 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
253
254 let grad_fn = GradFn::new(Conv1dBackward::new(
255 input.grad_fn().cloned(),
256 weight_var.grad_fn().cloned(),
257 bias_grad_fn,
258 input_data,
259 weight_data,
260 input_shape,
261 self.in_channels,
262 self.out_channels,
263 self.kernel_size,
264 self.stride,
265 self.padding,
266 self.bias.is_some(),
267 ));
268 Variable::from_operation(output_tensor, grad_fn, true)
269 } else {
270 Variable::new(output_tensor, false)
271 }
272 }
273
274 fn parameters(&self) -> Vec<Parameter> {
275 let mut params = vec![self.weight.clone()];
276 if let Some(ref bias) = self.bias {
277 params.push(bias.clone());
278 }
279 params
280 }
281
282 fn named_parameters(&self) -> HashMap<String, Parameter> {
283 let mut params = HashMap::new();
284 params.insert("weight".to_string(), self.weight.clone());
285 if let Some(ref bias) = self.bias {
286 params.insert("bias".to_string(), bias.clone());
287 }
288 params
289 }
290
291 fn name(&self) -> &'static str {
292 "Conv1d"
293 }
294}
295
296pub struct Conv2d {
308 pub weight: Parameter,
310 pub bias: Option<Parameter>,
312 in_channels: usize,
314 out_channels: usize,
316 kernel_size: (usize, usize),
318 stride: (usize, usize),
320 padding: (usize, usize),
322 groups: usize,
324}
325
326impl Conv2d {
327 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
329 Self::with_options(
330 in_channels,
331 out_channels,
332 (kernel_size, kernel_size),
333 (1, 1),
334 (0, 0),
335 true,
336 )
337 }
338
339 pub fn with_options(
341 in_channels: usize,
342 out_channels: usize,
343 kernel_size: (usize, usize),
344 stride: (usize, usize),
345 padding: (usize, usize),
346 bias: bool,
347 ) -> Self {
348 Self::with_groups(
349 in_channels,
350 out_channels,
351 kernel_size,
352 stride,
353 padding,
354 bias,
355 1,
356 )
357 }
358
359 pub fn with_groups(
364 in_channels: usize,
365 out_channels: usize,
366 kernel_size: (usize, usize),
367 stride: (usize, usize),
368 padding: (usize, usize),
369 bias: bool,
370 groups: usize,
371 ) -> Self {
372 assert!(
373 in_channels % groups == 0,
374 "in_channels must be divisible by groups"
375 );
376 assert!(
377 out_channels % groups == 0,
378 "out_channels must be divisible by groups"
379 );
380
381 let (kh, kw) = kernel_size;
382 let in_channels_per_group = in_channels / groups;
383 let fan_in = in_channels_per_group * kh * kw;
384
385 let weight_data = kaiming_uniform(out_channels, fan_in);
386 let weight_reshaped = weight_data
387 .reshape(&[
388 out_channels as isize,
389 in_channels_per_group as isize,
390 kh as isize,
391 kw as isize,
392 ])
393 .unwrap();
394 let weight = Parameter::named("weight", weight_reshaped, true);
395
396 let bias_param = if bias {
397 Some(Parameter::named("bias", zeros(&[out_channels]), true))
398 } else {
399 None
400 };
401
402 Self {
403 weight,
404 bias: bias_param,
405 in_channels,
406 out_channels,
407 kernel_size,
408 stride,
409 padding,
410 groups,
411 }
412 }
413
414 pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
416 Self::with_groups(
417 channels,
418 channels,
419 (kernel_size, kernel_size),
420 (1, 1),
421 (kernel_size / 2, kernel_size / 2),
422 true,
423 channels,
424 )
425 }
426}
427
428fn im2col(
437 input: &[f32],
438 channels: usize,
439 height: usize,
440 width: usize,
441 kernel_h: usize,
442 kernel_w: usize,
443 pad_h: usize,
444 pad_w: usize,
445 stride_h: usize,
446 stride_w: usize,
447 out_h: usize,
448 out_w: usize,
449) -> Vec<f32> {
450 let col_h = channels * kernel_h * kernel_w;
451 let col_w = out_h * out_w;
452 let mut col = vec![0.0f32; col_h * col_w];
453 let hw = height * width;
454 let kk = kernel_h * kernel_w;
455 let h_signed = height as isize;
456 let w_signed = width as isize;
457 let pad_h_s = pad_h as isize;
458 let pad_w_s = pad_w as isize;
459
460 for col_row in 0..col_h {
464 let c = col_row / kk;
465 let k_idx = col_row % kk;
466 let kh_off = k_idx / kernel_w;
467 let kw_off = k_idx % kernel_w;
468 let input_c = c * hw;
469 let col_base = col_row * col_w;
470
471 for oh in 0..out_h {
472 let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
473 if h_in < 0 || h_in >= h_signed {
474 continue;
475 }
476 let input_row = input_c + h_in as usize * width;
477 let col_row_base = col_base + oh * out_w;
478
479 for ow in 0..out_w {
480 let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
481 if w_in >= 0 && w_in < w_signed {
482 let col_idx = col_row_base + ow;
483 let inp_idx = input_row + w_in as usize;
484 debug_assert!(
485 col_idx < col.len(),
486 "im2col fwd col OOB: {col_idx} >= {}",
487 col.len()
488 );
489 debug_assert!(
490 inp_idx < input.len(),
491 "im2col fwd input OOB: {inp_idx} >= {}",
492 input.len()
493 );
494 unsafe {
495 *col.get_unchecked_mut(col_idx) = *input.get_unchecked(inp_idx);
496 }
497 }
498 }
499 }
500 }
501
502 col
503}
504
505fn conv2d_im2col(
507 input: &[f32],
508 weight: &[f32],
509 bias: Option<&[f32]>,
510 batch_size: usize,
511 in_channels: usize,
512 in_height: usize,
513 in_width: usize,
514 out_channels: usize,
515 kh: usize,
516 kw: usize,
517 sh: usize,
518 sw: usize,
519 ph: usize,
520 pw: usize,
521 groups: usize,
522) -> Vec<f32> {
523 let out_h = (in_height + 2 * ph - kh) / sh + 1;
524 let out_w = (in_width + 2 * pw - kw) / sw + 1;
525 let in_channels_per_group = in_channels / groups;
526 let out_channels_per_group = out_channels / groups;
527 let col_h = in_channels_per_group * kh * kw;
528 let col_w = out_h * out_w;
529 let spatial = out_h * out_w;
530 let in_spatial = in_height * in_width;
531
532 let out_per_batch = out_channels * spatial;
534 let per_batch: Vec<Vec<f32>> = (0..batch_size)
535 .into_par_iter()
536 .map(|b| {
537 let mut batch_out = vec![0.0f32; out_per_batch];
538
539 for g in 0..groups {
540 let ic_start = g * in_channels_per_group;
541 let oc_start = g * out_channels_per_group;
542
543 let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
545 let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
546
547 let col = im2col(
549 input_slice,
550 in_channels_per_group,
551 in_height,
552 in_width,
553 kh,
554 kw,
555 ph,
556 pw,
557 sh,
558 sw,
559 out_h,
560 out_w,
561 );
562
563 let w_offset = oc_start * in_channels_per_group * kh * kw;
565 let w_size = out_channels_per_group * col_h;
566 let weight_slice = &weight[w_offset..w_offset + w_size];
567
568 let w_tensor =
570 Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
571 .unwrap();
572 let col_tensor =
573 Tensor::from_vec(col, &[col_h, col_w]).expect("tensor creation failed");
574 let result = w_tensor.matmul(&col_tensor).expect("matmul failed");
575 let result_vec = result.to_vec();
576
577 let out_offset = oc_start * spatial;
579 for oc_local in 0..out_channels_per_group {
580 let oc = oc_start + oc_local;
581 let bias_val = bias.map_or(0.0, |bv| bv[oc]);
582 let src_start = oc_local * col_w;
583 let dst_start = out_offset + oc_local * spatial;
584 if bias_val == 0.0 {
585 batch_out[dst_start..dst_start + spatial]
586 .copy_from_slice(&result_vec[src_start..src_start + spatial]);
587 } else {
588 for i in 0..spatial {
589 batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
590 }
591 }
592 }
593 }
594
595 batch_out
596 })
597 .collect();
598
599 let mut output = Vec::with_capacity(batch_size * out_per_batch);
601 for batch_out in per_batch {
602 output.extend_from_slice(&batch_out);
603 }
604 output
605}
606
607impl Module for Conv2d {
608 fn forward(&self, input: &Variable) -> Variable {
609 let input_shape = input.shape();
610 let batch_size = input_shape[0];
611 let in_height = input_shape[2];
612 let in_width = input_shape[3];
613
614 let (kh, kw) = self.kernel_size;
615 let (sh, sw) = self.stride;
616 let (ph, pw) = self.padding;
617
618 let out_height = (in_height + 2 * ph - kh) / sh + 1;
619 let out_width = (in_width + 2 * pw - kw) / sw + 1;
620
621 let input_data = input.data();
622 let weight_data = self.weight.data();
623
624 #[cfg(feature = "cuda")]
627 if input_data.device().is_gpu() {
628 let input_dev = input_data.device();
630 if !weight_data.device().is_gpu() {
631 self.weight.to_device(input_dev);
632 if let Some(ref b) = self.bias {
633 b.to_device(input_dev);
634 }
635 }
636 let weight_data = self.weight.data();
637
638 #[cfg(feature = "cudnn")]
640 let cudnn_output = {
641 let bias_tensor = self.bias.as_ref().map(|b| b.data());
642 input_data.conv2d_cudnn(
643 &weight_data,
644 bias_tensor.as_ref(),
645 self.stride,
646 self.padding,
647 self.groups,
648 )
649 };
650 #[cfg(not(feature = "cudnn"))]
651 let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
652
653 let gpu_output = if cudnn_output.is_some() {
654 cudnn_output
655 } else if self.groups == 1 {
656 let bias_tensor = self.bias.as_ref().map(|b| b.data());
658 input_data.conv2d_cuda(
659 &weight_data,
660 bias_tensor.as_ref(),
661 self.stride,
662 self.padding,
663 )
664 } else {
665 input_data.conv2d_grouped_cuda(
667 &weight_data,
668 self.bias.as_ref().map(|b| b.data()).as_ref(),
669 self.stride,
670 self.padding,
671 self.groups,
672 )
673 };
674
675 if let Some(output_tensor) = gpu_output {
676 let requires_grad =
677 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
678 if requires_grad {
679 let weight_var = self.weight.variable();
680 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
681 if self.groups == 1 {
682 let grad_fn = GradFn::new(Conv2dBackward::new(
683 input.grad_fn().cloned(),
684 weight_var.grad_fn().cloned(),
685 bias_grad_fn,
686 input_data,
687 weight_data,
688 input_shape,
689 self.in_channels,
690 self.out_channels,
691 self.kernel_size,
692 self.stride,
693 self.padding,
694 self.bias.is_some(),
695 ));
696 return Variable::from_operation(output_tensor, grad_fn, true);
697 } else {
698 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
699 input.grad_fn().cloned(),
700 weight_var.grad_fn().cloned(),
701 bias_grad_fn,
702 input_data,
703 weight_data,
704 input_shape,
705 self.in_channels,
706 self.out_channels,
707 self.kernel_size,
708 self.stride,
709 self.padding,
710 self.groups,
711 self.bias.is_some(),
712 ));
713 return Variable::from_operation(output_tensor, grad_fn, true);
714 }
715 } else {
716 return Variable::new(output_tensor, false);
717 }
718 }
719 }
721
722 let input_vec = input_data.to_vec();
723 let weight_vec = weight_data.to_vec();
724
725 let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
727 let output_data = if self.groups == 1 && conv_flops >= 500_000 {
728 let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
729 let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
730 &input_vec,
731 &weight_vec,
732 bias_vec.as_deref(),
733 batch_size,
734 self.in_channels,
735 in_height,
736 in_width,
737 self.out_channels,
738 kh,
739 kw,
740 sh,
741 sw,
742 ph,
743 pw,
744 );
745
746 if let Some(result) = gpu_result {
747 result
748 } else {
749 conv2d_im2col(
750 &input_vec,
751 &weight_vec,
752 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
753 batch_size,
754 self.in_channels,
755 in_height,
756 in_width,
757 self.out_channels,
758 kh,
759 kw,
760 sh,
761 sw,
762 ph,
763 pw,
764 self.groups,
765 )
766 }
767 } else {
768 conv2d_im2col(
769 &input_vec,
770 &weight_vec,
771 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
772 batch_size,
773 self.in_channels,
774 in_height,
775 in_width,
776 self.out_channels,
777 kh,
778 kw,
779 sh,
780 sw,
781 ph,
782 pw,
783 self.groups,
784 )
785 };
786
787 let output_tensor = Tensor::from_vec(
788 output_data,
789 &[batch_size, self.out_channels, out_height, out_width],
790 )
791 .unwrap();
792
793 let requires_grad =
794 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
795
796 if requires_grad && self.groups == 1 {
797 let weight_var = self.weight.variable();
799 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
800
801 let grad_fn = GradFn::new(Conv2dBackward::new(
802 input.grad_fn().cloned(),
803 weight_var.grad_fn().cloned(),
804 bias_grad_fn,
805 input_data,
806 weight_data,
807 input_shape,
808 self.in_channels,
809 self.out_channels,
810 self.kernel_size,
811 self.stride,
812 self.padding,
813 self.bias.is_some(),
814 ));
815 Variable::from_operation(output_tensor, grad_fn, true)
816 } else if requires_grad {
817 let weight_var = self.weight.variable();
819 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
820
821 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
822 input.grad_fn().cloned(),
823 weight_var.grad_fn().cloned(),
824 bias_grad_fn,
825 input_data,
826 weight_data,
827 input_shape,
828 self.in_channels,
829 self.out_channels,
830 self.kernel_size,
831 self.stride,
832 self.padding,
833 self.groups,
834 self.bias.is_some(),
835 ));
836 Variable::from_operation(output_tensor, grad_fn, true)
837 } else {
838 Variable::new(output_tensor, false)
839 }
840 }
841
842 fn parameters(&self) -> Vec<Parameter> {
843 let mut params = vec![self.weight.clone()];
844 if let Some(ref bias) = self.bias {
845 params.push(bias.clone());
846 }
847 params
848 }
849
850 fn named_parameters(&self) -> HashMap<String, Parameter> {
851 let mut params = HashMap::new();
852 params.insert("weight".to_string(), self.weight.clone());
853 if let Some(ref bias) = self.bias {
854 params.insert("bias".to_string(), bias.clone());
855 }
856 params
857 }
858
859 fn name(&self) -> &'static str {
860 "Conv2d"
861 }
862}
863
864pub struct ConvTranspose2d {
876 pub weight: Parameter,
878 pub bias: Option<Parameter>,
880 in_channels: usize,
881 out_channels: usize,
882 kernel_size: (usize, usize),
883 stride: (usize, usize),
884 padding: (usize, usize),
885 output_padding: (usize, usize),
886}
887
888impl ConvTranspose2d {
889 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
891 Self::with_options(
892 in_channels,
893 out_channels,
894 (kernel_size, kernel_size),
895 (1, 1),
896 (0, 0),
897 (0, 0),
898 true,
899 )
900 }
901
902 pub fn with_options(
904 in_channels: usize,
905 out_channels: usize,
906 kernel_size: (usize, usize),
907 stride: (usize, usize),
908 padding: (usize, usize),
909 output_padding: (usize, usize),
910 bias: bool,
911 ) -> Self {
912 let (kh, kw) = kernel_size;
913 let fan_in = in_channels * kh * kw;
914
915 let weight_data = kaiming_uniform(out_channels, fan_in);
916 let weight_reshaped = weight_data
917 .reshape(&[
918 in_channels as isize,
919 out_channels as isize,
920 kh as isize,
921 kw as isize,
922 ])
923 .unwrap();
924 let weight = Parameter::named("weight", weight_reshaped, true);
925
926 let bias_param = if bias {
927 Some(Parameter::named("bias", zeros(&[out_channels]), true))
928 } else {
929 None
930 };
931
932 Self {
933 weight,
934 bias: bias_param,
935 in_channels,
936 out_channels,
937 kernel_size,
938 stride,
939 padding,
940 output_padding,
941 }
942 }
943}
944
945impl Module for ConvTranspose2d {
946 fn forward(&self, input: &Variable) -> Variable {
947 let input_shape = input.shape();
948 let batch_size = input_shape[0];
949 let in_h = input_shape[2];
950 let in_w = input_shape[3];
951
952 let (kh, kw) = self.kernel_size;
953 let (sh, sw) = self.stride;
954 let (ph, pw) = self.padding;
955 let (oph, opw) = self.output_padding;
956
957 let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
958 let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
959
960 let input_data = input.data();
961 let weight_data = self.weight.data();
962 let input_vec = input_data.to_vec();
963 let weight_vec = weight_data.to_vec();
964
965 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
966
967 for b in 0..batch_size {
969 for ic in 0..self.in_channels {
970 for ih in 0..in_h {
971 for iw in 0..in_w {
972 let in_idx =
973 b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
974 let in_val = input_vec[in_idx];
975
976 for oc in 0..self.out_channels {
977 for ki in 0..kh {
978 for kj in 0..kw {
979 let oh_signed = (ih * sh + ki) as isize - ph as isize;
980 let ow_signed = (iw * sw + kj) as isize - pw as isize;
981
982 if oh_signed >= 0
983 && (oh_signed as usize) < out_h
984 && ow_signed >= 0
985 && (ow_signed as usize) < out_w
986 {
987 let oh = oh_signed as usize;
988 let ow = ow_signed as usize;
989 let out_idx = b * self.out_channels * out_h * out_w
990 + oc * out_h * out_w
991 + oh * out_w
992 + ow;
993 let w_idx = ic * self.out_channels * kh * kw
995 + oc * kh * kw
996 + ki * kw
997 + kj;
998 output_data[out_idx] += in_val * weight_vec[w_idx];
999 }
1000 }
1001 }
1002 }
1003 }
1004 }
1005 }
1006 }
1007
1008 if let Some(ref bias) = self.bias {
1010 let bias_vec = bias.data().to_vec();
1011 for b in 0..batch_size {
1012 for oc in 0..self.out_channels {
1013 for oh in 0..out_h {
1014 for ow in 0..out_w {
1015 let out_idx = b * self.out_channels * out_h * out_w
1016 + oc * out_h * out_w
1017 + oh * out_w
1018 + ow;
1019 output_data[out_idx] += bias_vec[oc];
1020 }
1021 }
1022 }
1023 }
1024 }
1025
1026 let output_tensor =
1027 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w])
1028 .expect("tensor creation failed");
1029
1030 let requires_grad =
1031 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1032
1033 if requires_grad {
1034 let weight_var = self.weight.variable();
1035 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1036
1037 let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1038 input.grad_fn().cloned(),
1039 weight_var.grad_fn().cloned(),
1040 bias_grad_fn,
1041 input_data,
1042 weight_data,
1043 input_shape,
1044 self.in_channels,
1045 self.out_channels,
1046 self.kernel_size,
1047 self.stride,
1048 self.padding,
1049 self.output_padding,
1050 self.bias.is_some(),
1051 ));
1052 Variable::from_operation(output_tensor, grad_fn, true)
1053 } else {
1054 Variable::new(output_tensor, false)
1055 }
1056 }
1057
1058 fn parameters(&self) -> Vec<Parameter> {
1059 let mut params = vec![self.weight.clone()];
1060 if let Some(ref bias) = self.bias {
1061 params.push(bias.clone());
1062 }
1063 params
1064 }
1065
1066 fn named_parameters(&self) -> HashMap<String, Parameter> {
1067 let mut params = HashMap::new();
1068 params.insert("weight".to_string(), self.weight.clone());
1069 if let Some(ref bias) = self.bias {
1070 params.insert("bias".to_string(), bias.clone());
1071 }
1072 params
1073 }
1074
1075 fn name(&self) -> &'static str {
1076 "ConvTranspose2d"
1077 }
1078}
1079
1080#[cfg(test)]
1085mod tests {
1086 use super::*;
1087
1088 #[test]
1089 fn test_conv1d_creation() {
1090 let conv = Conv1d::new(3, 16, 3);
1091 assert_eq!(conv.in_channels, 3);
1092 assert_eq!(conv.out_channels, 16);
1093 assert_eq!(conv.kernel_size, 3);
1094 }
1095
1096 #[test]
1097 fn test_conv1d_forward() {
1098 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1099 let input = Variable::new(
1100 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1101 .expect("tensor creation failed"),
1102 false,
1103 );
1104 let output = conv.forward(&input);
1105 assert_eq!(output.shape(), vec![1, 1, 5]);
1106 }
1107
1108 #[test]
1109 fn test_conv1d_backward() {
1110 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1111 let input = Variable::new(
1112 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1113 .expect("tensor creation failed"),
1114 true,
1115 );
1116 let output = conv.forward(&input);
1117 let loss = output.sum();
1118 loss.backward();
1119
1120 assert!(
1122 input.grad().is_some(),
1123 "Conv1d: input gradient should flow through backward pass"
1124 );
1125 let grad = input.grad().unwrap();
1126 assert_eq!(grad.shape(), &[1, 1, 5]);
1127 }
1128
1129 #[test]
1130 fn test_conv2d_creation() {
1131 let conv = Conv2d::new(3, 64, 3);
1132 assert_eq!(conv.in_channels, 3);
1133 assert_eq!(conv.out_channels, 64);
1134 assert_eq!(conv.kernel_size, (3, 3));
1135 }
1136
1137 #[test]
1138 fn test_conv2d_forward() {
1139 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1140 let input = Variable::new(
1141 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1142 false,
1143 );
1144 let output = conv.forward(&input);
1145 assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1146 }
1147
1148 #[test]
1149 fn test_conv2d_backward() {
1150 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1151 let input = Variable::new(
1152 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1153 true,
1154 );
1155 let output = conv.forward(&input);
1156 let loss = output.sum();
1157 loss.backward();
1158
1159 assert!(
1160 input.grad().is_some(),
1161 "Conv2d: input gradient should flow through backward pass"
1162 );
1163 let grad = input.grad().unwrap();
1164 assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1165
1166 let w_grad = conv.weight.grad();
1168 assert!(
1169 w_grad.is_some(),
1170 "Conv2d: weight gradient should be computed"
1171 );
1172 }
1173
1174 #[test]
1175 fn test_conv2d_parameters() {
1176 let conv = Conv2d::new(3, 64, 3);
1177 let params = conv.parameters();
1178 assert_eq!(params.len(), 2); }
1180
1181 #[test]
1182 fn test_conv2d_grouped() {
1183 let conv = Conv2d::depthwise(4, 3);
1185 assert_eq!(conv.groups, 4);
1186 assert_eq!(conv.in_channels, 4);
1187 assert_eq!(conv.out_channels, 4);
1188
1189 let input = Variable::new(
1190 Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).expect("tensor creation failed"),
1191 false,
1192 );
1193 let output = conv.forward(&input);
1194 assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1195 }
1196
1197 #[test]
1198 fn test_conv_transpose2d_forward() {
1199 let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1200 let input = Variable::new(
1201 Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).expect("tensor creation failed"),
1202 false,
1203 );
1204 let output = conv_t.forward(&input);
1205 assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1207 }
1208
1209 #[test]
1210 fn test_conv_transpose2d_backward() {
1211 let conv_t = ConvTranspose2d::new(1, 1, 3);
1212 let input = Variable::new(
1213 Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).expect("tensor creation failed"),
1214 true,
1215 );
1216 let output = conv_t.forward(&input);
1217 let loss = output.sum();
1218 loss.backward();
1219
1220 assert!(
1221 input.grad().is_some(),
1222 "ConvTranspose2d: input gradient should flow through backward"
1223 );
1224 }
1225
1226 #[test]
1231 fn test_conv1d_with_padding_and_stride() {
1232 let conv = Conv1d::with_options(1, 4, 3, 2, 1, true);
1233 let input = Variable::new(Tensor::from_vec(vec![1.0; 16], &[1, 1, 16]).unwrap(), true);
1234 let output = conv.forward(&input);
1235 assert_eq!(output.shape(), vec![1, 4, 8]);
1237
1238 output.sum().backward();
1239 let grad = input.grad().expect("Conv1d should propagate gradients");
1240 assert_eq!(grad.shape(), &[1, 1, 16]);
1241 assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1242 }
1243
1244 #[test]
1245 fn test_conv1d_multi_channel() {
1246 let conv = Conv1d::new(3, 8, 5); let input = Variable::new(
1248 Tensor::from_vec(vec![0.5; 2 * 3 * 20], &[2, 3, 20]).unwrap(),
1249 false,
1250 );
1251 let output = conv.forward(&input);
1252 assert_eq!(output.shape(), vec![2, 8, 16]);
1254 }
1255
1256 #[test]
1261 fn test_conv2d_grouped_gradient_flow() {
1262 let conv = Conv2d::depthwise(4, 3);
1263 let input = Variable::new(
1264 Tensor::from_vec(vec![1.0; 4 * 8 * 8], &[1, 4, 8, 8]).unwrap(),
1265 true,
1266 );
1267 let output = conv.forward(&input);
1268 output.sum().backward();
1269
1270 let grad = input
1271 .grad()
1272 .expect("Grouped conv should propagate gradients");
1273 assert_eq!(grad.shape(), &[1, 4, 8, 8]);
1274 assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1275
1276 for p in conv.parameters() {
1278 let g = p.grad().expect("Conv params should have gradients");
1279 assert!(g.to_vec().iter().any(|v| v.abs() > 0.0));
1280 }
1281 }
1282
1283 #[test]
1284 fn test_conv2d_groups_two() {
1285 let conv = Conv2d::with_groups(4, 8, (3, 3), (1, 1), (1, 1), true, 2);
1287 let input = Variable::new(
1288 Tensor::from_vec(vec![1.0; 4 * 6 * 6], &[1, 4, 6, 6]).unwrap(),
1289 false,
1290 );
1291 let output = conv.forward(&input);
1292 assert_eq!(output.shape(), vec![1, 8, 6, 6]);
1293 }
1294
1295 #[test]
1296 fn test_conv2d_depthwise_separable_pattern() {
1297 let dw = Conv2d::depthwise(16, 3); let pw = Conv2d::with_options(16, 32, (1, 1), (1, 1), (0, 0), true); let input = Variable::new(
1302 Tensor::from_vec(vec![1.0; 16 * 8 * 8], &[1, 16, 8, 8]).unwrap(),
1303 true,
1304 );
1305 let dw_out = dw.forward(&input);
1306 assert_eq!(dw_out.shape(), vec![1, 16, 8, 8]);
1307
1308 let pw_out = pw.forward(&dw_out);
1309 assert_eq!(pw_out.shape(), vec![1, 32, 8, 8]);
1310
1311 pw_out.sum().backward();
1313 let grad = input
1314 .grad()
1315 .expect("Should propagate through depthwise separable");
1316 assert_eq!(grad.shape(), &[1, 16, 8, 8]);
1317 }
1318
1319 #[test]
1324 fn test_conv_transpose2d_upsamples() {
1325 let conv_t = ConvTranspose2d::with_options(1, 1, (4, 4), (2, 2), (1, 1), (0, 0), true);
1327 let input = Variable::new(
1328 Tensor::from_vec(vec![1.0; 4 * 4], &[1, 1, 4, 4]).unwrap(),
1329 false,
1330 );
1331 let output = conv_t.forward(&input);
1332 assert_eq!(output.shape(), vec![1, 1, 8, 8]);
1334 }
1335
1336 #[test]
1337 fn test_conv_transpose2d_gradient_correctness() {
1338 let conv_t = ConvTranspose2d::new(2, 4, 3);
1339 let input = Variable::new(
1340 Tensor::from_vec(vec![0.5; 2 * 4 * 4], &[1, 2, 4, 4]).unwrap(),
1341 true,
1342 );
1343 let output = conv_t.forward(&input);
1344 output.sum().backward();
1345
1346 let grad = input.grad().unwrap();
1347 assert_eq!(grad.shape(), &[1, 2, 4, 4]);
1348 assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1349 assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1350
1351 for p in conv_t.parameters() {
1353 assert!(p.grad().is_some(), "ConvTranspose2d params need gradients");
1354 }
1355 }
1356
1357 #[test]
1358 fn test_conv_transpose2d_multi_channel() {
1359 let conv_t = ConvTranspose2d::new(8, 16, 3);
1360 let input = Variable::new(
1361 Tensor::from_vec(vec![1.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]).unwrap(),
1362 false,
1363 );
1364 let output = conv_t.forward(&input);
1365 assert_eq!(output.shape()[0], 2); assert_eq!(output.shape()[1], 16); }
1368}