1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_autograd::functions::{
21 Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
22};
23use axonml_autograd::grad_fn::GradFn;
24use axonml_autograd::no_grad::is_grad_enabled;
25use axonml_tensor::Tensor;
26use rayon::prelude::*;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32pub struct Conv1d {
44 pub weight: Parameter,
46 pub bias: Option<Parameter>,
48 in_channels: usize,
50 out_channels: usize,
52 kernel_size: usize,
54 stride: usize,
56 padding: usize,
58}
59
60impl Conv1d {
61 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
63 Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
64 }
65
66 pub fn with_options(
68 in_channels: usize,
69 out_channels: usize,
70 kernel_size: usize,
71 stride: usize,
72 padding: usize,
73 bias: bool,
74 ) -> Self {
75 let fan_in = in_channels * kernel_size;
77 let weight_data = kaiming_uniform(out_channels, fan_in);
78 let weight_reshaped = weight_data
79 .reshape(&[
80 out_channels as isize,
81 in_channels as isize,
82 kernel_size as isize,
83 ])
84 .unwrap();
85 let weight = Parameter::named("weight", weight_reshaped, true);
86
87 let bias_param = if bias {
88 Some(Parameter::named("bias", zeros(&[out_channels]), true))
89 } else {
90 None
91 };
92
93 Self {
94 weight,
95 bias: bias_param,
96 in_channels,
97 out_channels,
98 kernel_size,
99 stride,
100 padding,
101 }
102 }
103}
104
105impl Module for Conv1d {
106 fn forward(&self, input: &Variable) -> Variable {
107 let input_shape = input.shape();
108 let batch_size = input_shape[0];
109 let in_length = input_shape[2];
110
111 let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113 let input_data = input.data();
114 let weight_data = self.weight.data();
115
116 #[cfg(feature = "cuda")]
119 if input_data.device().is_gpu() {
120 let input_dev = input_data.device();
122 if !weight_data.device().is_gpu() {
123 self.weight.to_device(input_dev);
124 if let Some(ref b) = self.bias {
125 b.to_device(input_dev);
126 }
127 }
128 let weight_data = self.weight.data();
129
130 let input_4d = input_data
132 .reshape(&[
133 batch_size as isize,
134 self.in_channels as isize,
135 in_length as isize,
136 1,
137 ])
138 .unwrap();
139
140 let weight_4d = weight_data
142 .reshape(&[
143 self.out_channels as isize,
144 self.in_channels as isize,
145 self.kernel_size as isize,
146 1,
147 ])
148 .unwrap();
149
150 let bias_tensor = self.bias.as_ref().map(|b| b.data());
151 let gpu_output = input_4d.conv2d_cuda(
152 &weight_4d,
153 bias_tensor.as_ref(),
154 (self.stride, 1),
155 (self.padding, 0),
156 );
157
158 if let Some(output_4d) = gpu_output {
159 let output_tensor = output_4d
161 .reshape(&[
162 batch_size as isize,
163 self.out_channels as isize,
164 out_length as isize,
165 ])
166 .unwrap();
167
168 let requires_grad =
169 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
170 if requires_grad {
171 let weight_var = self.weight.variable();
172 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
173
174 let grad_fn = GradFn::new(Conv1dBackward::new(
175 input.grad_fn().cloned(),
176 weight_var.grad_fn().cloned(),
177 bias_grad_fn,
178 input_data,
179 weight_data,
180 input_shape,
181 self.in_channels,
182 self.out_channels,
183 self.kernel_size,
184 self.stride,
185 self.padding,
186 self.bias.is_some(),
187 ));
188 return Variable::from_operation(output_tensor, grad_fn, true);
189 } else {
190 return Variable::new(output_tensor, false);
191 }
192 }
193 }
195
196 let input_vec = input_data.to_vec();
197 let weight_vec = weight_data.to_vec();
198
199 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
200
201 for b in 0..batch_size {
202 for oc in 0..self.out_channels {
203 for ol in 0..out_length {
204 let mut sum = 0.0f32;
205 let in_start = ol * self.stride;
206
207 for ic in 0..self.in_channels {
208 for k in 0..self.kernel_size {
209 let in_idx = in_start + k;
210 if in_idx < self.padding || in_idx >= in_length + self.padding {
211 continue;
212 }
213 let actual_idx = in_idx - self.padding;
214
215 let input_idx =
216 b * self.in_channels * in_length + ic * in_length + actual_idx;
217 let weight_idx = oc * self.in_channels * self.kernel_size
218 + ic * self.kernel_size
219 + k;
220
221 sum += input_vec[input_idx] * weight_vec[weight_idx];
222 }
223 }
224
225 if let Some(ref bias) = self.bias {
226 sum += bias.data().to_vec()[oc];
227 }
228
229 let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
230 output_data[output_idx] = sum;
231 }
232 }
233 }
234
235 let output_tensor =
236 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length])
237 .expect("tensor creation failed");
238
239 let requires_grad =
240 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
241
242 if requires_grad {
243 let weight_var = self.weight.variable();
244 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
245
246 let grad_fn = GradFn::new(Conv1dBackward::new(
247 input.grad_fn().cloned(),
248 weight_var.grad_fn().cloned(),
249 bias_grad_fn,
250 input_data,
251 weight_data,
252 input_shape,
253 self.in_channels,
254 self.out_channels,
255 self.kernel_size,
256 self.stride,
257 self.padding,
258 self.bias.is_some(),
259 ));
260 Variable::from_operation(output_tensor, grad_fn, true)
261 } else {
262 Variable::new(output_tensor, false)
263 }
264 }
265
266 fn parameters(&self) -> Vec<Parameter> {
267 let mut params = vec![self.weight.clone()];
268 if let Some(ref bias) = self.bias {
269 params.push(bias.clone());
270 }
271 params
272 }
273
274 fn named_parameters(&self) -> HashMap<String, Parameter> {
275 let mut params = HashMap::new();
276 params.insert("weight".to_string(), self.weight.clone());
277 if let Some(ref bias) = self.bias {
278 params.insert("bias".to_string(), bias.clone());
279 }
280 params
281 }
282
283 fn name(&self) -> &'static str {
284 "Conv1d"
285 }
286}
287
288pub struct Conv2d {
300 pub weight: Parameter,
302 pub bias: Option<Parameter>,
304 in_channels: usize,
306 out_channels: usize,
308 kernel_size: (usize, usize),
310 stride: (usize, usize),
312 padding: (usize, usize),
314 groups: usize,
316}
317
318impl Conv2d {
319 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
321 Self::with_options(
322 in_channels,
323 out_channels,
324 (kernel_size, kernel_size),
325 (1, 1),
326 (0, 0),
327 true,
328 )
329 }
330
331 pub fn with_options(
333 in_channels: usize,
334 out_channels: usize,
335 kernel_size: (usize, usize),
336 stride: (usize, usize),
337 padding: (usize, usize),
338 bias: bool,
339 ) -> Self {
340 Self::with_groups(
341 in_channels,
342 out_channels,
343 kernel_size,
344 stride,
345 padding,
346 bias,
347 1,
348 )
349 }
350
351 pub fn with_groups(
356 in_channels: usize,
357 out_channels: usize,
358 kernel_size: (usize, usize),
359 stride: (usize, usize),
360 padding: (usize, usize),
361 bias: bool,
362 groups: usize,
363 ) -> Self {
364 assert!(
365 in_channels % groups == 0,
366 "in_channels must be divisible by groups"
367 );
368 assert!(
369 out_channels % groups == 0,
370 "out_channels must be divisible by groups"
371 );
372
373 let (kh, kw) = kernel_size;
374 let in_channels_per_group = in_channels / groups;
375 let fan_in = in_channels_per_group * kh * kw;
376
377 let weight_data = kaiming_uniform(out_channels, fan_in);
378 let weight_reshaped = weight_data
379 .reshape(&[
380 out_channels as isize,
381 in_channels_per_group as isize,
382 kh as isize,
383 kw as isize,
384 ])
385 .unwrap();
386 let weight = Parameter::named("weight", weight_reshaped, true);
387
388 let bias_param = if bias {
389 Some(Parameter::named("bias", zeros(&[out_channels]), true))
390 } else {
391 None
392 };
393
394 Self {
395 weight,
396 bias: bias_param,
397 in_channels,
398 out_channels,
399 kernel_size,
400 stride,
401 padding,
402 groups,
403 }
404 }
405
406 pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
408 Self::with_groups(
409 channels,
410 channels,
411 (kernel_size, kernel_size),
412 (1, 1),
413 (kernel_size / 2, kernel_size / 2),
414 true,
415 channels,
416 )
417 }
418}
419
420fn im2col(
429 input: &[f32],
430 channels: usize,
431 height: usize,
432 width: usize,
433 kernel_h: usize,
434 kernel_w: usize,
435 pad_h: usize,
436 pad_w: usize,
437 stride_h: usize,
438 stride_w: usize,
439 out_h: usize,
440 out_w: usize,
441) -> Vec<f32> {
442 let col_h = channels * kernel_h * kernel_w;
443 let col_w = out_h * out_w;
444 let mut col = vec![0.0f32; col_h * col_w];
445 let hw = height * width;
446 let kk = kernel_h * kernel_w;
447 let h_signed = height as isize;
448 let w_signed = width as isize;
449 let pad_h_s = pad_h as isize;
450 let pad_w_s = pad_w as isize;
451
452 for col_row in 0..col_h {
456 let c = col_row / kk;
457 let k_idx = col_row % kk;
458 let kh_off = k_idx / kernel_w;
459 let kw_off = k_idx % kernel_w;
460 let input_c = c * hw;
461 let col_base = col_row * col_w;
462
463 for oh in 0..out_h {
464 let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
465 if h_in < 0 || h_in >= h_signed {
466 continue;
467 }
468 let input_row = input_c + h_in as usize * width;
469 let col_row_base = col_base + oh * out_w;
470
471 for ow in 0..out_w {
472 let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
473 if w_in >= 0 && w_in < w_signed {
474 let col_idx = col_row_base + ow;
475 let inp_idx = input_row + w_in as usize;
476 debug_assert!(
477 col_idx < col.len(),
478 "im2col fwd col OOB: {col_idx} >= {}",
479 col.len()
480 );
481 debug_assert!(
482 inp_idx < input.len(),
483 "im2col fwd input OOB: {inp_idx} >= {}",
484 input.len()
485 );
486 unsafe {
487 *col.get_unchecked_mut(col_idx) = *input.get_unchecked(inp_idx);
488 }
489 }
490 }
491 }
492 }
493
494 col
495}
496
497fn conv2d_im2col(
499 input: &[f32],
500 weight: &[f32],
501 bias: Option<&[f32]>,
502 batch_size: usize,
503 in_channels: usize,
504 in_height: usize,
505 in_width: usize,
506 out_channels: usize,
507 kh: usize,
508 kw: usize,
509 sh: usize,
510 sw: usize,
511 ph: usize,
512 pw: usize,
513 groups: usize,
514) -> Vec<f32> {
515 let out_h = (in_height + 2 * ph - kh) / sh + 1;
516 let out_w = (in_width + 2 * pw - kw) / sw + 1;
517 let in_channels_per_group = in_channels / groups;
518 let out_channels_per_group = out_channels / groups;
519 let col_h = in_channels_per_group * kh * kw;
520 let col_w = out_h * out_w;
521 let spatial = out_h * out_w;
522 let in_spatial = in_height * in_width;
523
524 let out_per_batch = out_channels * spatial;
526 let per_batch: Vec<Vec<f32>> = (0..batch_size)
527 .into_par_iter()
528 .map(|b| {
529 let mut batch_out = vec![0.0f32; out_per_batch];
530
531 for g in 0..groups {
532 let ic_start = g * in_channels_per_group;
533 let oc_start = g * out_channels_per_group;
534
535 let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
537 let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
538
539 let col = im2col(
541 input_slice,
542 in_channels_per_group,
543 in_height,
544 in_width,
545 kh,
546 kw,
547 ph,
548 pw,
549 sh,
550 sw,
551 out_h,
552 out_w,
553 );
554
555 let w_offset = oc_start * in_channels_per_group * kh * kw;
557 let w_size = out_channels_per_group * col_h;
558 let weight_slice = &weight[w_offset..w_offset + w_size];
559
560 let w_tensor =
562 Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
563 .unwrap();
564 let col_tensor =
565 Tensor::from_vec(col, &[col_h, col_w]).expect("tensor creation failed");
566 let result = w_tensor.matmul(&col_tensor).expect("matmul failed");
567 let result_vec = result.to_vec();
568
569 let out_offset = oc_start * spatial;
571 for oc_local in 0..out_channels_per_group {
572 let oc = oc_start + oc_local;
573 let bias_val = bias.map_or(0.0, |bv| bv[oc]);
574 let src_start = oc_local * col_w;
575 let dst_start = out_offset + oc_local * spatial;
576 if bias_val == 0.0 {
577 batch_out[dst_start..dst_start + spatial]
578 .copy_from_slice(&result_vec[src_start..src_start + spatial]);
579 } else {
580 for i in 0..spatial {
581 batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
582 }
583 }
584 }
585 }
586
587 batch_out
588 })
589 .collect();
590
591 let mut output = Vec::with_capacity(batch_size * out_per_batch);
593 for batch_out in per_batch {
594 output.extend_from_slice(&batch_out);
595 }
596 output
597}
598
599impl Module for Conv2d {
600 fn forward(&self, input: &Variable) -> Variable {
601 let input_shape = input.shape();
602 let batch_size = input_shape[0];
603 let in_height = input_shape[2];
604 let in_width = input_shape[3];
605
606 let (kh, kw) = self.kernel_size;
607 let (sh, sw) = self.stride;
608 let (ph, pw) = self.padding;
609
610 let out_height = (in_height + 2 * ph - kh) / sh + 1;
611 let out_width = (in_width + 2 * pw - kw) / sw + 1;
612
613 let input_data = input.data();
614 let weight_data = self.weight.data();
615
616 #[cfg(feature = "cuda")]
619 if input_data.device().is_gpu() {
620 let input_dev = input_data.device();
622 if !weight_data.device().is_gpu() {
623 self.weight.to_device(input_dev);
624 if let Some(ref b) = self.bias {
625 b.to_device(input_dev);
626 }
627 }
628 let weight_data = self.weight.data();
629
630 #[cfg(feature = "cudnn")]
632 let cudnn_output = {
633 let bias_tensor = self.bias.as_ref().map(|b| b.data());
634 input_data.conv2d_cudnn(
635 &weight_data,
636 bias_tensor.as_ref(),
637 self.stride,
638 self.padding,
639 self.groups,
640 )
641 };
642 #[cfg(not(feature = "cudnn"))]
643 let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
644
645 let gpu_output = if cudnn_output.is_some() {
646 cudnn_output
647 } else if self.groups == 1 {
648 let bias_tensor = self.bias.as_ref().map(|b| b.data());
650 input_data.conv2d_cuda(
651 &weight_data,
652 bias_tensor.as_ref(),
653 self.stride,
654 self.padding,
655 )
656 } else {
657 input_data.conv2d_grouped_cuda(
659 &weight_data,
660 self.bias.as_ref().map(|b| b.data()).as_ref(),
661 self.stride,
662 self.padding,
663 self.groups,
664 )
665 };
666
667 if let Some(output_tensor) = gpu_output {
668 let requires_grad =
669 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
670 if requires_grad {
671 let weight_var = self.weight.variable();
672 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
673 if self.groups == 1 {
674 let grad_fn = GradFn::new(Conv2dBackward::new(
675 input.grad_fn().cloned(),
676 weight_var.grad_fn().cloned(),
677 bias_grad_fn,
678 input_data,
679 weight_data,
680 input_shape,
681 self.in_channels,
682 self.out_channels,
683 self.kernel_size,
684 self.stride,
685 self.padding,
686 self.bias.is_some(),
687 ));
688 return Variable::from_operation(output_tensor, grad_fn, true);
689 } else {
690 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
691 input.grad_fn().cloned(),
692 weight_var.grad_fn().cloned(),
693 bias_grad_fn,
694 input_data,
695 weight_data,
696 input_shape,
697 self.in_channels,
698 self.out_channels,
699 self.kernel_size,
700 self.stride,
701 self.padding,
702 self.groups,
703 self.bias.is_some(),
704 ));
705 return Variable::from_operation(output_tensor, grad_fn, true);
706 }
707 } else {
708 return Variable::new(output_tensor, false);
709 }
710 }
711 }
713
714 let input_vec = input_data.to_vec();
715 let weight_vec = weight_data.to_vec();
716
717 let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
719 let output_data = if self.groups == 1 && conv_flops >= 500_000 {
720 let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
721 let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
722 &input_vec,
723 &weight_vec,
724 bias_vec.as_deref(),
725 batch_size,
726 self.in_channels,
727 in_height,
728 in_width,
729 self.out_channels,
730 kh,
731 kw,
732 sh,
733 sw,
734 ph,
735 pw,
736 );
737
738 if let Some(result) = gpu_result {
739 result
740 } else {
741 conv2d_im2col(
742 &input_vec,
743 &weight_vec,
744 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
745 batch_size,
746 self.in_channels,
747 in_height,
748 in_width,
749 self.out_channels,
750 kh,
751 kw,
752 sh,
753 sw,
754 ph,
755 pw,
756 self.groups,
757 )
758 }
759 } else {
760 conv2d_im2col(
761 &input_vec,
762 &weight_vec,
763 self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
764 batch_size,
765 self.in_channels,
766 in_height,
767 in_width,
768 self.out_channels,
769 kh,
770 kw,
771 sh,
772 sw,
773 ph,
774 pw,
775 self.groups,
776 )
777 };
778
779 let output_tensor = Tensor::from_vec(
780 output_data,
781 &[batch_size, self.out_channels, out_height, out_width],
782 )
783 .unwrap();
784
785 let requires_grad =
786 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
787
788 if requires_grad && self.groups == 1 {
789 let weight_var = self.weight.variable();
791 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
792
793 let grad_fn = GradFn::new(Conv2dBackward::new(
794 input.grad_fn().cloned(),
795 weight_var.grad_fn().cloned(),
796 bias_grad_fn,
797 input_data,
798 weight_data,
799 input_shape,
800 self.in_channels,
801 self.out_channels,
802 self.kernel_size,
803 self.stride,
804 self.padding,
805 self.bias.is_some(),
806 ));
807 Variable::from_operation(output_tensor, grad_fn, true)
808 } else if requires_grad {
809 let weight_var = self.weight.variable();
811 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
812
813 let grad_fn = GradFn::new(GroupedConv2dBackward::new(
814 input.grad_fn().cloned(),
815 weight_var.grad_fn().cloned(),
816 bias_grad_fn,
817 input_data,
818 weight_data,
819 input_shape,
820 self.in_channels,
821 self.out_channels,
822 self.kernel_size,
823 self.stride,
824 self.padding,
825 self.groups,
826 self.bias.is_some(),
827 ));
828 Variable::from_operation(output_tensor, grad_fn, true)
829 } else {
830 Variable::new(output_tensor, false)
831 }
832 }
833
834 fn parameters(&self) -> Vec<Parameter> {
835 let mut params = vec![self.weight.clone()];
836 if let Some(ref bias) = self.bias {
837 params.push(bias.clone());
838 }
839 params
840 }
841
842 fn named_parameters(&self) -> HashMap<String, Parameter> {
843 let mut params = HashMap::new();
844 params.insert("weight".to_string(), self.weight.clone());
845 if let Some(ref bias) = self.bias {
846 params.insert("bias".to_string(), bias.clone());
847 }
848 params
849 }
850
851 fn name(&self) -> &'static str {
852 "Conv2d"
853 }
854}
855
856pub struct ConvTranspose2d {
868 pub weight: Parameter,
870 pub bias: Option<Parameter>,
872 in_channels: usize,
873 out_channels: usize,
874 kernel_size: (usize, usize),
875 stride: (usize, usize),
876 padding: (usize, usize),
877 output_padding: (usize, usize),
878}
879
880impl ConvTranspose2d {
881 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
883 Self::with_options(
884 in_channels,
885 out_channels,
886 (kernel_size, kernel_size),
887 (1, 1),
888 (0, 0),
889 (0, 0),
890 true,
891 )
892 }
893
894 pub fn with_options(
896 in_channels: usize,
897 out_channels: usize,
898 kernel_size: (usize, usize),
899 stride: (usize, usize),
900 padding: (usize, usize),
901 output_padding: (usize, usize),
902 bias: bool,
903 ) -> Self {
904 let (kh, kw) = kernel_size;
905 let fan_in = in_channels * kh * kw;
906
907 let weight_data = kaiming_uniform(out_channels, fan_in);
908 let weight_reshaped = weight_data
909 .reshape(&[
910 in_channels as isize,
911 out_channels as isize,
912 kh as isize,
913 kw as isize,
914 ])
915 .unwrap();
916 let weight = Parameter::named("weight", weight_reshaped, true);
917
918 let bias_param = if bias {
919 Some(Parameter::named("bias", zeros(&[out_channels]), true))
920 } else {
921 None
922 };
923
924 Self {
925 weight,
926 bias: bias_param,
927 in_channels,
928 out_channels,
929 kernel_size,
930 stride,
931 padding,
932 output_padding,
933 }
934 }
935}
936
937impl Module for ConvTranspose2d {
938 fn forward(&self, input: &Variable) -> Variable {
939 let input_shape = input.shape();
940 let batch_size = input_shape[0];
941 let in_h = input_shape[2];
942 let in_w = input_shape[3];
943
944 let (kh, kw) = self.kernel_size;
945 let (sh, sw) = self.stride;
946 let (ph, pw) = self.padding;
947 let (oph, opw) = self.output_padding;
948
949 let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
950 let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
951
952 let input_data = input.data();
953 let weight_data = self.weight.data();
954 let input_vec = input_data.to_vec();
955 let weight_vec = weight_data.to_vec();
956
957 let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
958
959 for b in 0..batch_size {
961 for ic in 0..self.in_channels {
962 for ih in 0..in_h {
963 for iw in 0..in_w {
964 let in_idx =
965 b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
966 let in_val = input_vec[in_idx];
967
968 for oc in 0..self.out_channels {
969 for ki in 0..kh {
970 for kj in 0..kw {
971 let oh_signed = (ih * sh + ki) as isize - ph as isize;
972 let ow_signed = (iw * sw + kj) as isize - pw as isize;
973
974 if oh_signed >= 0
975 && (oh_signed as usize) < out_h
976 && ow_signed >= 0
977 && (ow_signed as usize) < out_w
978 {
979 let oh = oh_signed as usize;
980 let ow = ow_signed as usize;
981 let out_idx = b * self.out_channels * out_h * out_w
982 + oc * out_h * out_w
983 + oh * out_w
984 + ow;
985 let w_idx = ic * self.out_channels * kh * kw
987 + oc * kh * kw
988 + ki * kw
989 + kj;
990 output_data[out_idx] += in_val * weight_vec[w_idx];
991 }
992 }
993 }
994 }
995 }
996 }
997 }
998 }
999
1000 if let Some(ref bias) = self.bias {
1002 let bias_vec = bias.data().to_vec();
1003 for b in 0..batch_size {
1004 for oc in 0..self.out_channels {
1005 for oh in 0..out_h {
1006 for ow in 0..out_w {
1007 let out_idx = b * self.out_channels * out_h * out_w
1008 + oc * out_h * out_w
1009 + oh * out_w
1010 + ow;
1011 output_data[out_idx] += bias_vec[oc];
1012 }
1013 }
1014 }
1015 }
1016 }
1017
1018 let output_tensor =
1019 Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w])
1020 .expect("tensor creation failed");
1021
1022 let requires_grad =
1023 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1024
1025 if requires_grad {
1026 let weight_var = self.weight.variable();
1027 let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1028
1029 let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1030 input.grad_fn().cloned(),
1031 weight_var.grad_fn().cloned(),
1032 bias_grad_fn,
1033 input_data,
1034 weight_data,
1035 input_shape,
1036 self.in_channels,
1037 self.out_channels,
1038 self.kernel_size,
1039 self.stride,
1040 self.padding,
1041 self.output_padding,
1042 self.bias.is_some(),
1043 ));
1044 Variable::from_operation(output_tensor, grad_fn, true)
1045 } else {
1046 Variable::new(output_tensor, false)
1047 }
1048 }
1049
1050 fn parameters(&self) -> Vec<Parameter> {
1051 let mut params = vec![self.weight.clone()];
1052 if let Some(ref bias) = self.bias {
1053 params.push(bias.clone());
1054 }
1055 params
1056 }
1057
1058 fn named_parameters(&self) -> HashMap<String, Parameter> {
1059 let mut params = HashMap::new();
1060 params.insert("weight".to_string(), self.weight.clone());
1061 if let Some(ref bias) = self.bias {
1062 params.insert("bias".to_string(), bias.clone());
1063 }
1064 params
1065 }
1066
1067 fn name(&self) -> &'static str {
1068 "ConvTranspose2d"
1069 }
1070}
1071
1072#[cfg(test)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_conv1d_creation() {
1082 let conv = Conv1d::new(3, 16, 3);
1083 assert_eq!(conv.in_channels, 3);
1084 assert_eq!(conv.out_channels, 16);
1085 assert_eq!(conv.kernel_size, 3);
1086 }
1087
1088 #[test]
1089 fn test_conv1d_forward() {
1090 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1091 let input = Variable::new(
1092 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1093 .expect("tensor creation failed"),
1094 false,
1095 );
1096 let output = conv.forward(&input);
1097 assert_eq!(output.shape(), vec![1, 1, 5]);
1098 }
1099
1100 #[test]
1101 fn test_conv1d_backward() {
1102 let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1103 let input = Variable::new(
1104 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1105 .expect("tensor creation failed"),
1106 true,
1107 );
1108 let output = conv.forward(&input);
1109 let loss = output.sum();
1110 loss.backward();
1111
1112 assert!(
1114 input.grad().is_some(),
1115 "Conv1d: input gradient should flow through backward pass"
1116 );
1117 let grad = input.grad().unwrap();
1118 assert_eq!(grad.shape(), &[1, 1, 5]);
1119 }
1120
1121 #[test]
1122 fn test_conv2d_creation() {
1123 let conv = Conv2d::new(3, 64, 3);
1124 assert_eq!(conv.in_channels, 3);
1125 assert_eq!(conv.out_channels, 64);
1126 assert_eq!(conv.kernel_size, (3, 3));
1127 }
1128
1129 #[test]
1130 fn test_conv2d_forward() {
1131 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1132 let input = Variable::new(
1133 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1134 false,
1135 );
1136 let output = conv.forward(&input);
1137 assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1138 }
1139
1140 #[test]
1141 fn test_conv2d_backward() {
1142 let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1143 let input = Variable::new(
1144 Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1145 true,
1146 );
1147 let output = conv.forward(&input);
1148 let loss = output.sum();
1149 loss.backward();
1150
1151 assert!(
1152 input.grad().is_some(),
1153 "Conv2d: input gradient should flow through backward pass"
1154 );
1155 let grad = input.grad().unwrap();
1156 assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1157
1158 let w_grad = conv.weight.grad();
1160 assert!(
1161 w_grad.is_some(),
1162 "Conv2d: weight gradient should be computed"
1163 );
1164 }
1165
1166 #[test]
1167 fn test_conv2d_parameters() {
1168 let conv = Conv2d::new(3, 64, 3);
1169 let params = conv.parameters();
1170 assert_eq!(params.len(), 2); }
1172
1173 #[test]
1174 fn test_conv2d_grouped() {
1175 let conv = Conv2d::depthwise(4, 3);
1177 assert_eq!(conv.groups, 4);
1178 assert_eq!(conv.in_channels, 4);
1179 assert_eq!(conv.out_channels, 4);
1180
1181 let input = Variable::new(
1182 Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).expect("tensor creation failed"),
1183 false,
1184 );
1185 let output = conv.forward(&input);
1186 assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1187 }
1188
1189 #[test]
1190 fn test_conv_transpose2d_forward() {
1191 let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1192 let input = Variable::new(
1193 Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).expect("tensor creation failed"),
1194 false,
1195 );
1196 let output = conv_t.forward(&input);
1197 assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1199 }
1200
1201 #[test]
1202 fn test_conv_transpose2d_backward() {
1203 let conv_t = ConvTranspose2d::new(1, 1, 3);
1204 let input = Variable::new(
1205 Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).expect("tensor creation failed"),
1206 true,
1207 );
1208 let output = conv_t.forward(&input);
1209 let loss = output.sum();
1210 loss.backward();
1211
1212 assert!(
1213 input.grad().is_some(),
1214 "ConvTranspose2d: input gradient should flow through backward"
1215 );
1216 }
1217
1218 #[test]
1223 fn test_conv1d_with_padding_and_stride() {
1224 let conv = Conv1d::with_options(1, 4, 3, 2, 1, true);
1225 let input = Variable::new(
1226 Tensor::from_vec(vec![1.0; 1 * 1 * 16], &[1, 1, 16]).unwrap(),
1227 true,
1228 );
1229 let output = conv.forward(&input);
1230 assert_eq!(output.shape(), vec![1, 4, 8]);
1232
1233 output.sum().backward();
1234 let grad = input.grad().expect("Conv1d should propagate gradients");
1235 assert_eq!(grad.shape(), &[1, 1, 16]);
1236 assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1237 }
1238
1239 #[test]
1240 fn test_conv1d_multi_channel() {
1241 let conv = Conv1d::new(3, 8, 5); let input = Variable::new(
1243 Tensor::from_vec(vec![0.5; 2 * 3 * 20], &[2, 3, 20]).unwrap(),
1244 false,
1245 );
1246 let output = conv.forward(&input);
1247 assert_eq!(output.shape(), vec![2, 8, 16]);
1249 }
1250
1251 #[test]
1256 fn test_conv2d_grouped_gradient_flow() {
1257 let conv = Conv2d::depthwise(4, 3);
1258 let input = Variable::new(
1259 Tensor::from_vec(vec![1.0; 1 * 4 * 8 * 8], &[1, 4, 8, 8]).unwrap(),
1260 true,
1261 );
1262 let output = conv.forward(&input);
1263 output.sum().backward();
1264
1265 let grad = input
1266 .grad()
1267 .expect("Grouped conv should propagate gradients");
1268 assert_eq!(grad.shape(), &[1, 4, 8, 8]);
1269 assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1270
1271 for p in conv.parameters() {
1273 let g = p.grad().expect("Conv params should have gradients");
1274 assert!(g.to_vec().iter().any(|v| v.abs() > 0.0));
1275 }
1276 }
1277
1278 #[test]
1279 fn test_conv2d_groups_two() {
1280 let conv = Conv2d::with_groups(4, 8, (3, 3), (1, 1), (1, 1), true, 2);
1282 let input = Variable::new(
1283 Tensor::from_vec(vec![1.0; 1 * 4 * 6 * 6], &[1, 4, 6, 6]).unwrap(),
1284 false,
1285 );
1286 let output = conv.forward(&input);
1287 assert_eq!(output.shape(), vec![1, 8, 6, 6]);
1288 }
1289
1290 #[test]
1291 fn test_conv2d_depthwise_separable_pattern() {
1292 let dw = Conv2d::depthwise(16, 3); let pw = Conv2d::with_options(16, 32, (1, 1), (1, 1), (0, 0), true); let input = Variable::new(
1297 Tensor::from_vec(vec![1.0; 1 * 16 * 8 * 8], &[1, 16, 8, 8]).unwrap(),
1298 true,
1299 );
1300 let dw_out = dw.forward(&input);
1301 assert_eq!(dw_out.shape(), vec![1, 16, 8, 8]);
1302
1303 let pw_out = pw.forward(&dw_out);
1304 assert_eq!(pw_out.shape(), vec![1, 32, 8, 8]);
1305
1306 pw_out.sum().backward();
1308 let grad = input
1309 .grad()
1310 .expect("Should propagate through depthwise separable");
1311 assert_eq!(grad.shape(), &[1, 16, 8, 8]);
1312 }
1313
1314 #[test]
1319 fn test_conv_transpose2d_upsamples() {
1320 let conv_t = ConvTranspose2d::with_options(1, 1, (4, 4), (2, 2), (1, 1), (0, 0), true);
1322 let input = Variable::new(
1323 Tensor::from_vec(vec![1.0; 1 * 1 * 4 * 4], &[1, 1, 4, 4]).unwrap(),
1324 false,
1325 );
1326 let output = conv_t.forward(&input);
1327 assert_eq!(output.shape(), vec![1, 1, 8, 8]);
1329 }
1330
1331 #[test]
1332 fn test_conv_transpose2d_gradient_correctness() {
1333 let conv_t = ConvTranspose2d::new(2, 4, 3);
1334 let input = Variable::new(
1335 Tensor::from_vec(vec![0.5; 1 * 2 * 4 * 4], &[1, 2, 4, 4]).unwrap(),
1336 true,
1337 );
1338 let output = conv_t.forward(&input);
1339 output.sum().backward();
1340
1341 let grad = input.grad().unwrap();
1342 assert_eq!(grad.shape(), &[1, 2, 4, 4]);
1343 assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1344 assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1345
1346 for p in conv_t.parameters() {
1348 assert!(p.grad().is_some(), "ConvTranspose2d params need gradients");
1349 }
1350 }
1351
1352 #[test]
1353 fn test_conv_transpose2d_multi_channel() {
1354 let conv_t = ConvTranspose2d::new(8, 16, 3);
1355 let input = Variable::new(
1356 Tensor::from_vec(vec![1.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]).unwrap(),
1357 false,
1358 );
1359 let output = conv_t.forward(&input);
1360 assert_eq!(output.shape()[0], 2); assert_eq!(output.shape()[1], 16); }
1363}