1use std::sync::Arc;
27
28use ferrotorch_core::autograd::autocast_ops::autocast_guard;
29use ferrotorch_core::autograd::no_grad::is_grad_enabled;
30use ferrotorch_core::grad_fns::shape::{squeeze, unsqueeze};
31use ferrotorch_core::ops::linalg::{mm, transpose};
32use ferrotorch_core::storage::TensorStorage;
33use ferrotorch_core::tensor::{GradFn, Tensor};
34use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
35
36use crate::init::{NonLinearity, kaiming_uniform, uniform as uniform_init};
37use crate::module::Module;
38use crate::parameter::Parameter;
39
40fn reject_non_zeros_transpose(
53 mode: crate::padding::PaddingMode,
54 class_name: &str,
55) -> FerrotorchResult<()> {
56 if mode != crate::padding::PaddingMode::Zeros {
57 return Err(FerrotorchError::InvalidArgument {
58 message: format!("Only \"zeros\" padding mode is supported for {class_name}"),
59 });
60 }
61 Ok(())
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum StringPadding {
87 Same,
90 Valid,
92}
93
94fn same_pad_lr(kernel_size: usize, dilation: usize) -> (usize, usize) {
113 let total_padding = dilation * (kernel_size - 1);
114 let left = total_padding / 2;
115 (left, total_padding - left)
116}
117
118#[allow(clippy::too_many_arguments)]
132fn im2col<T: Float>(
133 input: &[T],
134 batch: usize,
135 channels: usize,
136 height: usize,
137 width: usize,
138 kernel_h: usize,
139 kernel_w: usize,
140 stride_h: usize,
141 stride_w: usize,
142 pad_h: usize,
143 pad_w: usize,
144) -> (Vec<T>, usize, usize) {
145 im2col_dilated(
146 input, batch, channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_h,
147 pad_w, 1, 1,
148 )
149}
150
151#[allow(clippy::too_many_arguments)]
168fn im2col_dilated<T: Float>(
169 input: &[T],
170 batch: usize,
171 channels: usize,
172 height: usize,
173 width: usize,
174 kernel_h: usize,
175 kernel_w: usize,
176 stride_h: usize,
177 stride_w: usize,
178 pad_h: usize,
179 pad_w: usize,
180 dil_h: usize,
181 dil_w: usize,
182) -> (Vec<T>, usize, usize) {
183 let eff_kh = dil_h * (kernel_h - 1) + 1;
184 let eff_kw = dil_w * (kernel_w - 1) + 1;
185 let h_out = (height + 2 * pad_h - eff_kh) / stride_h + 1;
186 let w_out = (width + 2 * pad_w - eff_kw) / stride_w + 1;
187 let col_rows = channels * kernel_h * kernel_w;
188 let col_cols = h_out * w_out;
189
190 let zero = <T as num_traits::Zero>::zero();
191 let mut cols = vec![zero; batch * col_rows * col_cols];
192
193 for b in 0..batch {
194 for c in 0..channels {
195 for kh in 0..kernel_h {
196 for kw in 0..kernel_w {
197 let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
198 for oh in 0..h_out {
199 for ow in 0..w_out {
200 let ih = oh * stride_h + kh * dil_h;
202 let iw = ow * stride_w + kw * dil_w;
203 let col = oh * w_out + ow;
204
205 let val = if ih >= pad_h
208 && iw >= pad_w
209 && (ih - pad_h) < height
210 && (iw - pad_w) < width
211 {
212 let real_h = ih - pad_h;
213 let real_w = iw - pad_w;
214 input[b * channels * height * width
215 + c * height * width
216 + real_h * width
217 + real_w]
218 } else {
219 zero
220 };
221
222 cols[b * col_rows * col_cols + row * col_cols + col] = val;
223 }
224 }
225 }
226 }
227 }
228 }
229
230 (cols, col_rows, col_cols)
231}
232
233#[cfg(test)]
246#[allow(clippy::too_many_arguments)]
247fn col2im<T: Float>(
248 cols: &[T],
249 batch: usize,
250 channels: usize,
251 height: usize,
252 width: usize,
253 kernel_h: usize,
254 kernel_w: usize,
255 stride_h: usize,
256 stride_w: usize,
257 pad_h: usize,
258 pad_w: usize,
259 h_out: usize,
260 w_out: usize,
261) -> Vec<T> {
262 col2im_dilated(
263 cols, batch, channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
264 1, 1, h_out, w_out,
265 )
266}
267
268#[allow(clippy::too_many_arguments)]
278fn col2im_dilated<T: Float>(
279 cols: &[T],
280 batch: usize,
281 channels: usize,
282 height: usize,
283 width: usize,
284 kernel_h: usize,
285 kernel_w: usize,
286 stride_h: usize,
287 stride_w: usize,
288 pad_h: usize,
289 pad_w: usize,
290 dil_h: usize,
291 dil_w: usize,
292 h_out: usize,
293 w_out: usize,
294) -> Vec<T> {
295 let zero = <T as num_traits::Zero>::zero();
296 let mut output = vec![zero; batch * channels * height * width];
297
298 let col_rows = channels * kernel_h * kernel_w;
299 let col_cols = h_out * w_out;
300
301 for b in 0..batch {
302 for c in 0..channels {
303 for kh in 0..kernel_h {
304 for kw in 0..kernel_w {
305 let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
306 for oh in 0..h_out {
307 for ow in 0..w_out {
308 let ih = oh * stride_h + kh * dil_h;
309 let iw = ow * stride_w + kw * dil_w;
310 let col = oh * w_out + ow;
311
312 if ih >= pad_h
313 && iw >= pad_w
314 && (ih - pad_h) < height
315 && (iw - pad_w) < width
316 {
317 let real_h = ih - pad_h;
318 let real_w = iw - pad_w;
319 output[b * channels * height * width
320 + c * height * width
321 + real_h * width
322 + real_w] +=
323 cols[b * col_rows * col_cols + row * col_cols + col];
324 }
325 }
326 }
327 }
328 }
329 }
330 }
331
332 output
333}
334
335#[derive(Debug)]
364pub struct Conv2d<T: Float> {
365 weight: Parameter<T>,
367 bias: Option<Parameter<T>>,
369 in_channels: usize,
371 out_channels: usize,
373 kernel_size: (usize, usize),
375 stride: (usize, usize),
377 padding: (usize, usize),
379 dilation: (usize, usize),
381 groups: usize,
384 padding_mode: crate::padding::PaddingMode,
391 string_padding: Option<StringPadding>,
398 training: bool,
400}
401
402impl<T: Float> Conv2d<T> {
403 pub fn new(
412 in_channels: usize,
413 out_channels: usize,
414 kernel_size: (usize, usize),
415 stride: (usize, usize),
416 padding: (usize, usize),
417 bias: bool,
418 ) -> FerrotorchResult<Self> {
419 Self::new_full(
420 in_channels,
421 out_channels,
422 kernel_size,
423 stride,
424 padding,
425 (1, 1),
426 1,
427 bias,
428 )
429 }
430
431 #[allow(clippy::too_many_arguments)]
448 pub fn new_full(
449 in_channels: usize,
450 out_channels: usize,
451 kernel_size: (usize, usize),
452 stride: (usize, usize),
453 padding: (usize, usize),
454 dilation: (usize, usize),
455 groups: usize,
456 bias: bool,
457 ) -> FerrotorchResult<Self> {
458 if in_channels == 0 || out_channels == 0 {
459 return Err(FerrotorchError::InvalidArgument {
460 message: "in_channels and out_channels must be > 0".into(),
461 });
462 }
463 if kernel_size.0 == 0 || kernel_size.1 == 0 {
464 return Err(FerrotorchError::InvalidArgument {
465 message: "kernel_size must be > 0 in both dimensions".into(),
466 });
467 }
468 if stride.0 == 0 || stride.1 == 0 {
469 return Err(FerrotorchError::InvalidArgument {
470 message: "stride must be > 0 in both dimensions".into(),
471 });
472 }
473 if dilation.0 == 0 || dilation.1 == 0 {
474 return Err(FerrotorchError::InvalidArgument {
475 message: format!(
476 "Conv2d::new_full: dilation must be > 0 in both dimensions, got {dilation:?}"
477 ),
478 });
479 }
480 if groups == 0 {
481 return Err(FerrotorchError::InvalidArgument {
482 message: "Conv2d::new_full: groups must be > 0".into(),
483 });
484 }
485 if in_channels % groups != 0 {
486 return Err(FerrotorchError::InvalidArgument {
487 message: format!(
488 "Conv2d::new_full: groups={groups} must divide in_channels={in_channels}"
489 ),
490 });
491 }
492 if out_channels % groups != 0 {
493 return Err(FerrotorchError::InvalidArgument {
494 message: format!(
495 "Conv2d::new_full: groups={groups} must divide out_channels={out_channels}"
496 ),
497 });
498 }
499
500 let (kh, kw) = kernel_size;
501 let mut weight = Parameter::zeros(&[out_channels, in_channels / groups, kh, kw])?;
503 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
504
505 let bias_param = if bias {
506 let mut b = Parameter::zeros(&[out_channels])?;
507 let fan_in = (in_channels / groups) * kh * kw;
511 let bound = if fan_in > 0 {
512 1.0 / (fan_in as f64).sqrt()
513 } else {
514 0.0
515 };
516 uniform_init(&mut b, -bound, bound)?;
517 Some(b)
518 } else {
519 None
520 };
521
522 Ok(Self {
523 weight,
524 bias: bias_param,
525 in_channels,
526 out_channels,
527 kernel_size,
528 stride,
529 padding,
530 dilation,
531 groups,
532 padding_mode: crate::padding::PaddingMode::Zeros,
533 string_padding: None,
534 training: true,
535 })
536 }
537
538 pub fn with_padding_mode(mut self, mode: crate::padding::PaddingMode) -> Self {
546 self.padding_mode = mode;
547 self
548 }
549
550 pub fn with_string_padding(mut self, padding: StringPadding) -> FerrotorchResult<Self> {
569 if padding == StringPadding::Same && (self.stride.0 != 1 || self.stride.1 != 1) {
570 return Err(FerrotorchError::InvalidArgument {
571 message: "padding='same' is not supported for strided convolutions".into(),
572 });
573 }
574 self.string_padding = Some(padding);
575 self.padding = (0, 0);
576 Ok(self)
577 }
578
579 pub fn set_weight(&mut self, weight: Parameter<T>) -> FerrotorchResult<()> {
586 let expected = [
587 self.out_channels,
588 self.in_channels / self.groups,
589 self.kernel_size.0,
590 self.kernel_size.1,
591 ];
592 let got = weight.tensor().shape();
593 if got != expected {
594 return Err(FerrotorchError::ShapeMismatch {
595 message: format!("Conv2d::set_weight: expected shape {expected:?}, got {got:?}"),
596 });
597 }
598 self.weight = weight;
599 Ok(())
600 }
601
602 pub fn groups(&self) -> usize {
604 self.groups
605 }
606
607 pub fn dilation(&self) -> (usize, usize) {
609 self.dilation
610 }
611
612 pub fn num_parameters(&self) -> usize {
618 let w = self.out_channels
619 * (self.in_channels / self.groups)
620 * self.kernel_size.0
621 * self.kernel_size.1;
622 let b = if self.bias.is_some() {
623 self.out_channels
624 } else {
625 0
626 };
627 w + b
628 }
629
630 pub fn from_parts(
641 weight: Tensor<T>,
642 bias: Option<Tensor<T>>,
643 stride: (usize, usize),
644 padding: (usize, usize),
645 ) -> FerrotorchResult<Self> {
646 if weight.ndim() != 4 {
647 return Err(FerrotorchError::ShapeMismatch {
648 message: format!(
649 "Conv2d::from_parts: weight must be 4-D [out, in, kH, kW], got {:?}",
650 weight.shape()
651 ),
652 });
653 }
654 let out_channels = weight.shape()[0];
655 let in_channels = weight.shape()[1];
656 let kernel_size = (weight.shape()[2], weight.shape()[3]);
657 if let Some(b) = &bias {
658 if b.ndim() != 1 || b.shape()[0] != out_channels {
659 return Err(FerrotorchError::ShapeMismatch {
660 message: format!(
661 "Conv2d::from_parts: bias shape {:?} != [{}]",
662 b.shape(),
663 out_channels
664 ),
665 });
666 }
667 }
668 Ok(Self {
669 weight: Parameter::new(weight),
670 bias: bias.map(Parameter::new),
671 in_channels,
672 out_channels,
673 kernel_size,
674 stride,
675 padding,
676 dilation: (1, 1),
677 groups: 1,
678 padding_mode: crate::padding::PaddingMode::Zeros,
679 string_padding: None,
680 training: true,
681 })
682 }
683}
684
685impl<T: Float> Conv2d<T> {
686 fn recurse_clone(
693 &self,
694 padding: (usize, usize),
695 padding_mode: crate::padding::PaddingMode,
696 ) -> Conv2d<T> {
697 Conv2d {
698 weight: Parameter::new(self.weight.tensor().clone()),
699 bias: self
700 .bias
701 .as_ref()
702 .map(|b| Parameter::new(b.tensor().clone())),
703 in_channels: self.in_channels,
704 out_channels: self.out_channels,
705 kernel_size: self.kernel_size,
706 stride: self.stride,
707 padding,
708 dilation: self.dilation,
709 groups: self.groups,
710 padding_mode,
711 string_padding: None,
712 training: self.training,
713 }
714 }
715}
716
717impl<T: Float> Module<T> for Conv2d<T> {
718 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
719 let _autocast_cat = autocast_guard("conv2d");
721
722 if input.ndim() == 3 {
731 let batched = unsqueeze(input, 0)?;
732 let output = self.forward(&batched)?;
733 return squeeze(&output, 0);
734 }
735
736 if let Some(sp) = self.string_padding {
748 match sp {
749 StringPadding::Valid => {
750 return self
752 .recurse_clone((0, 0), crate::padding::PaddingMode::Zeros)
753 .forward(input);
754 }
755 StringPadding::Same => {
756 let (kh, kw) = self.kernel_size;
757 let (dh, dw) = self.dilation;
758 let (top, bottom) = same_pad_lr(kh, dh);
759 let (left, right) = same_pad_lr(kw, dw);
760 let padded = crate::padding::functional_pad_2d(
761 input,
762 left,
763 right,
764 top,
765 bottom,
766 self.padding_mode,
767 <T as num_traits::Zero>::zero(),
768 )?;
769 return self
770 .recurse_clone((0, 0), crate::padding::PaddingMode::Zeros)
771 .forward(&padded);
772 }
773 }
774 }
775
776 if self.padding_mode != crate::padding::PaddingMode::Zeros
786 && (self.padding.0 != 0 || self.padding.1 != 0)
787 {
788 let padded = crate::padding::functional_pad_2d(
789 input,
790 self.padding.1,
791 self.padding.1,
792 self.padding.0,
793 self.padding.0,
794 self.padding_mode,
795 <T as num_traits::Zero>::zero(),
796 )?;
797 return self
801 .recurse_clone((0, 0), crate::padding::PaddingMode::Zeros)
802 .forward(&padded);
803 }
804
805 if input.ndim() != 4 {
807 return Err(FerrotorchError::InvalidArgument {
808 message: format!(
809 "Conv2d expects 4-D input [B, C, H, W], got {:?}",
810 input.shape()
811 ),
812 });
813 }
814
815 let batch = input.shape()[0];
816 let c_in = input.shape()[1];
817 let h = input.shape()[2];
818 let w = input.shape()[3];
819
820 if c_in != self.in_channels {
821 return Err(FerrotorchError::ShapeMismatch {
822 message: format!(
823 "Conv2d: expected {} input channels, got {}",
824 self.in_channels, c_in
825 ),
826 });
827 }
828
829 let (kh, kw) = self.kernel_size;
830 let (sh, sw) = self.stride;
831 let (ph, pw) = self.padding;
832 let (dh, dw) = self.dilation;
833 let groups = self.groups;
834
835 let eff_kh = dh * (kh - 1) + 1;
837 let eff_kw = dw * (kw - 1) + 1;
838
839 let h_padded = h + 2 * ph;
841 let w_padded = w + 2 * pw;
842 if h_padded < eff_kh || w_padded < eff_kw {
843 return Err(FerrotorchError::InvalidArgument {
844 message: format!(
845 "Conv2d: padded input ({h_padded}, {w_padded}) is smaller than effective kernel ({eff_kh}, {eff_kw})"
846 ),
847 });
848 }
849
850 let h_out = (h_padded - eff_kh) / sh + 1;
851 let w_out = (w_padded - eff_kw) / sw + 1;
852
853 let input_device = input.device();
855
856 let is_f32 = std::mem::size_of::<T>() == 4;
872 if is_f32 && input.is_cuda() {
873 if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
874 let bias_handle = self
875 .bias
876 .as_ref()
877 .and_then(|b| b.tensor().gpu_handle().ok());
878 let weight_shape = self.weight.tensor().shape();
879 let weight_dim4: [usize; 4] = [
880 weight_shape[0],
881 weight_shape[1],
882 weight_shape[2],
883 weight_shape[3],
884 ];
885 let (out_handle, out_shape) = backend.conv2d_f32(
886 input.gpu_handle()?,
887 self.weight.tensor().gpu_handle()?,
888 bias_handle,
889 [batch, c_in, h, w],
890 weight_dim4,
891 self.stride,
892 self.padding,
893 self.dilation,
894 groups,
895 )?;
896
897 let result = Tensor::from_storage(
898 TensorStorage::gpu(out_handle),
899 out_shape.to_vec(),
900 false,
901 )?;
902
903 if is_grad_enabled()
906 && (input.requires_grad()
907 || self.weight.requires_grad()
908 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
909 {
910 let input_data = input.data_vec()?;
912 let (cols, col_rows, col_cols) =
913 im2col(&input_data, batch, c_in, h, w, kh, kw, sh, sw, ph, pw);
914 let grad_fn = Arc::new(Conv2dBackward {
915 input: input.clone(),
916 weight: self.weight.tensor().clone(),
917 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
918 in_channels: self.in_channels,
919 out_channels: self.out_channels,
920 kernel_size: self.kernel_size,
921 stride: self.stride,
922 padding: self.padding,
923 dilation: self.dilation,
924 groups: self.groups,
925 cols,
926 col_rows,
927 col_cols,
928 h_out,
929 w_out,
930 });
931 return Tensor::from_operation(
932 result.into_storage_and_shape()?.0,
933 out_shape.to_vec(),
934 grad_fn,
935 );
936 }
937
938 return Ok(result);
939 }
940 }
941
942 let input_data = input.data_vec()?;
944 let weight_data = self.weight.data_vec()?;
945
946 let zero = <T as num_traits::Zero>::zero();
947 let mut output = vec![zero; batch * self.out_channels * h_out * w_out];
948
949 let in_per_group = self.in_channels / groups;
951 let out_per_group = self.out_channels / groups;
952 let weight_per_group_numel = out_per_group * in_per_group * kh * kw;
953 let group_col_rows = in_per_group * kh * kw;
954 let col_cols = h_out * w_out;
955
956 let saved_cols_rows = self.in_channels * kh * kw;
960 let mut saved_cols: Vec<T> = if is_grad_enabled()
961 && (input.requires_grad()
962 || self.weight.requires_grad()
963 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
964 {
965 vec![zero; batch * saved_cols_rows * col_cols]
966 } else {
967 Vec::new()
968 };
969 let save_cols = !saved_cols.is_empty();
970
971 for g in 0..groups {
972 let mut group_input = vec![zero; batch * in_per_group * h * w];
974 for b in 0..batch {
975 for c in 0..in_per_group {
976 let src_c = g * in_per_group + c;
977 let src_start = b * self.in_channels * h * w + src_c * h * w;
978 let dst_start = b * in_per_group * h * w + c * h * w;
979 group_input[dst_start..dst_start + h * w]
980 .copy_from_slice(&input_data[src_start..src_start + h * w]);
981 }
982 }
983
984 let (g_cols, g_col_rows, g_col_cols) = im2col_dilated(
985 &group_input,
986 batch,
987 in_per_group,
988 h,
989 w,
990 kh,
991 kw,
992 sh,
993 sw,
994 ph,
995 pw,
996 dh,
997 dw,
998 );
999 debug_assert_eq!(g_col_rows, group_col_rows);
1000 debug_assert_eq!(g_col_cols, col_cols);
1001
1002 if save_cols {
1004 for b in 0..batch {
1005 for c in 0..in_per_group {
1006 let dst_c = g * in_per_group + c;
1007 for kk in 0..(kh * kw) {
1008 let src_row = c * kh * kw + kk;
1009 let dst_row = dst_c * kh * kw + kk;
1010 let src_off = b * group_col_rows * col_cols + src_row * col_cols;
1011 let dst_off = b * saved_cols_rows * col_cols + dst_row * col_cols;
1012 saved_cols[dst_off..dst_off + col_cols]
1013 .copy_from_slice(&g_cols[src_off..src_off + col_cols]);
1014 }
1015 }
1016 }
1017 }
1018
1019 let w_group_start = g * weight_per_group_numel;
1022 let w_group_end = w_group_start + weight_per_group_numel;
1023 let weight_group_2d = Tensor::from_storage(
1024 TensorStorage::cpu(weight_data[w_group_start..w_group_end].to_vec()),
1025 vec![out_per_group, group_col_rows],
1026 false,
1027 )?;
1028
1029 for b in 0..batch {
1030 let col_start = b * group_col_rows * col_cols;
1031 let col_end = col_start + group_col_rows * col_cols;
1032 let cols_b = Tensor::from_storage(
1033 TensorStorage::cpu(g_cols[col_start..col_end].to_vec()),
1034 vec![group_col_rows, col_cols],
1035 false,
1036 )?;
1037
1038 let out_b = mm(&weight_group_2d, &cols_b)?;
1039 let out_data = out_b.data()?;
1040 for oc in 0..out_per_group {
1042 let dst_c = g * out_per_group + oc;
1043 let dst_start = b * self.out_channels * h_out * w_out + dst_c * h_out * w_out;
1044 let src_start = oc * h_out * w_out;
1045 output[dst_start..dst_start + h_out * w_out]
1046 .copy_from_slice(&out_data[src_start..src_start + h_out * w_out]);
1047 }
1048 }
1049 }
1050
1051 if let Some(ref bias) = self.bias {
1053 let bias_data = bias.data_vec()?;
1054 for b in 0..batch {
1055 for c in 0..self.out_channels {
1056 let bval = bias_data[c];
1057 for hw in 0..(h_out * w_out) {
1058 output[b * self.out_channels * h_out * w_out + c * h_out * w_out + hw] +=
1059 bval;
1060 }
1061 }
1062 }
1063 }
1064
1065 let result = Tensor::from_storage(
1066 TensorStorage::cpu(output),
1067 vec![batch, self.out_channels, h_out, w_out],
1068 false,
1069 )?;
1070
1071 if save_cols {
1073 let grad_fn = Arc::new(Conv2dBackward {
1074 input: input.clone(),
1075 weight: self.weight.tensor().clone(),
1076 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
1077 in_channels: self.in_channels,
1078 out_channels: self.out_channels,
1079 kernel_size: self.kernel_size,
1080 stride: self.stride,
1081 padding: self.padding,
1082 dilation: self.dilation,
1083 groups: self.groups,
1084 cols: saved_cols,
1085 col_rows: saved_cols_rows,
1086 col_cols,
1087 h_out,
1088 w_out,
1089 });
1090 Tensor::from_operation(
1091 TensorStorage::cpu(result.data()?.to_vec()),
1092 result.shape().to_vec(),
1093 grad_fn,
1094 )?
1095 .to(input_device) } else {
1097 result.to(input_device)
1098 }
1099 }
1100
1101 fn parameters(&self) -> Vec<&Parameter<T>> {
1102 let mut params = vec![&self.weight];
1103 if let Some(ref b) = self.bias {
1104 params.push(b);
1105 }
1106 params
1107 }
1108
1109 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1110 let mut params = vec![&mut self.weight];
1111 if let Some(ref mut b) = self.bias {
1112 params.push(b);
1113 }
1114 params
1115 }
1116
1117 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1118 let mut params = vec![("weight".to_string(), &self.weight)];
1119 if let Some(ref b) = self.bias {
1120 params.push(("bias".to_string(), b));
1121 }
1122 params
1123 }
1124
1125 fn train(&mut self) {
1126 self.training = true;
1127 }
1128
1129 fn eval(&mut self) {
1130 self.training = false;
1131 }
1132
1133 fn is_training(&self) -> bool {
1134 self.training
1135 }
1136}
1137
1138#[derive(Debug)]
1156struct Conv2dBackward<T: Float> {
1157 input: Tensor<T>,
1158 weight: Tensor<T>,
1159 bias: Option<Tensor<T>>,
1160 in_channels: usize,
1161 out_channels: usize,
1162 kernel_size: (usize, usize),
1163 stride: (usize, usize),
1164 padding: (usize, usize),
1165 dilation: (usize, usize),
1166 groups: usize,
1167 cols: Vec<T>,
1168 col_rows: usize,
1169 col_cols: usize,
1170 h_out: usize,
1171 w_out: usize,
1172}
1173
1174impl<T: Float> GradFn<T> for Conv2dBackward<T> {
1175 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1176 let input_device = self.input.device();
1193 let weight_device = self.weight.device();
1194 let bias_device = self.bias.as_ref().map(|b| b.device());
1195 let go_data = grad_output.data_vec()?;
1196 let batch = self.input.shape()[0];
1197 let h = self.input.shape()[2];
1198 let w = self.input.shape()[3];
1199 let (kh, kw) = self.kernel_size;
1200 let (sh, sw) = self.stride;
1201 let (ph, pw) = self.padding;
1202 let (dh, dw) = self.dilation;
1203 let groups = self.groups;
1204 let in_per_group = self.in_channels / groups;
1205 let out_per_group = self.out_channels / groups;
1206 let group_col_rows = in_per_group * kh * kw;
1207 let zero = <T as num_traits::Zero>::zero();
1208
1209 let grad_weight = if self.weight.requires_grad() {
1216 let weight_numel = self.out_channels * in_per_group * kh * kw;
1217 let mut gw_accum = vec![zero; weight_numel];
1218 let weight_per_group_numel = out_per_group * group_col_rows;
1219
1220 for g in 0..groups {
1221 for b in 0..batch {
1222 let mut go_g = vec![zero; out_per_group * self.h_out * self.w_out];
1224 for oc in 0..out_per_group {
1225 let src_c = g * out_per_group + oc;
1226 let src_start = b * self.out_channels * self.h_out * self.w_out
1227 + src_c * self.h_out * self.w_out;
1228 let dst_start = oc * self.h_out * self.w_out;
1229 go_g[dst_start..dst_start + self.h_out * self.w_out].copy_from_slice(
1230 &go_data[src_start..src_start + self.h_out * self.w_out],
1231 );
1232 }
1233 let go_b_g = Tensor::from_storage(
1234 TensorStorage::cpu(go_g),
1235 vec![out_per_group, self.h_out * self.w_out],
1236 false,
1237 )?;
1238
1239 let mut cols_g = vec![zero; group_col_rows * self.col_cols];
1241 for c in 0..in_per_group {
1242 let src_c = g * in_per_group + c;
1243 for kk in 0..(kh * kw) {
1244 let src_row = src_c * kh * kw + kk;
1245 let dst_row = c * kh * kw + kk;
1246 let src_off =
1247 b * self.col_rows * self.col_cols + src_row * self.col_cols;
1248 let dst_off = dst_row * self.col_cols;
1249 cols_g[dst_off..dst_off + self.col_cols]
1250 .copy_from_slice(&self.cols[src_off..src_off + self.col_cols]);
1251 }
1252 }
1253 let cols_b_g = Tensor::from_storage(
1254 TensorStorage::cpu(cols_g),
1255 vec![group_col_rows, self.col_cols],
1256 false,
1257 )?;
1258
1259 let cols_bt = transpose(&cols_b_g)?;
1260 let gw_b = mm(&go_b_g, &cols_bt)?;
1261 let gw_data = gw_b.data()?;
1262
1263 let dst_off = g * weight_per_group_numel;
1264 for i in 0..weight_per_group_numel {
1265 gw_accum[dst_off + i] += gw_data[i];
1266 }
1267 }
1268 }
1269
1270 Some(
1271 Tensor::from_storage(
1272 TensorStorage::cpu(gw_accum),
1273 vec![self.out_channels, in_per_group, kh, kw],
1274 false,
1275 )?
1276 .to(weight_device)?,
1277 )
1278 } else {
1279 None
1280 };
1281
1282 let grad_bias = match &self.bias {
1287 Some(b) if b.requires_grad() => {
1288 let mut gb = vec![zero; self.out_channels];
1289 for batch_idx in 0..batch {
1290 for c in 0..self.out_channels {
1291 for hw in 0..(self.h_out * self.w_out) {
1292 gb[c] +=
1293 go_data[batch_idx * self.out_channels * self.h_out * self.w_out
1294 + c * self.h_out * self.w_out
1295 + hw];
1296 }
1297 }
1298 }
1299 let target_dev = bias_device.unwrap_or(input_device);
1300 Some(
1301 Tensor::from_storage(TensorStorage::cpu(gb), vec![self.out_channels], false)?
1302 .to(target_dev)?,
1303 )
1304 }
1305 _ => None,
1306 };
1307
1308 let grad_input = if self.input.requires_grad() {
1314 let weight_data = self.weight.data_vec()?;
1315 let mut grad_input_data = vec![zero; batch * self.in_channels * h * w];
1316 let weight_per_group_numel = out_per_group * group_col_rows;
1317
1318 for g in 0..groups {
1319 let w_off = g * weight_per_group_numel;
1320 let weight_g_2d = Tensor::from_storage(
1321 TensorStorage::cpu(weight_data[w_off..w_off + weight_per_group_numel].to_vec()),
1322 vec![out_per_group, group_col_rows],
1323 false,
1324 )?;
1325 let weight_g_t = transpose(&weight_g_2d)?;
1326
1327 let mut grad_cols_g = vec![zero; batch * group_col_rows * self.col_cols];
1328 for b in 0..batch {
1329 let mut go_g = vec![zero; out_per_group * self.h_out * self.w_out];
1331 for oc in 0..out_per_group {
1332 let src_c = g * out_per_group + oc;
1333 let src_start = b * self.out_channels * self.h_out * self.w_out
1334 + src_c * self.h_out * self.w_out;
1335 let dst_start = oc * self.h_out * self.w_out;
1336 go_g[dst_start..dst_start + self.h_out * self.w_out].copy_from_slice(
1337 &go_data[src_start..src_start + self.h_out * self.w_out],
1338 );
1339 }
1340 let go_b_g = Tensor::from_storage(
1341 TensorStorage::cpu(go_g),
1342 vec![out_per_group, self.h_out * self.w_out],
1343 false,
1344 )?;
1345
1346 let gc_b = mm(&weight_g_t, &go_b_g)?;
1347 let gc_data = gc_b.data()?;
1348 let gc_start = b * group_col_rows * self.col_cols;
1349 grad_cols_g[gc_start..gc_start + group_col_rows * self.col_cols]
1350 .copy_from_slice(gc_data);
1351 }
1352
1353 let gi_g = col2im_dilated(
1355 &grad_cols_g,
1356 batch,
1357 in_per_group,
1358 h,
1359 w,
1360 kh,
1361 kw,
1362 sh,
1363 sw,
1364 ph,
1365 pw,
1366 dh,
1367 dw,
1368 self.h_out,
1369 self.w_out,
1370 );
1371
1372 for b in 0..batch {
1374 for c in 0..in_per_group {
1375 let dst_c = g * in_per_group + c;
1376 let dst_start = b * self.in_channels * h * w + dst_c * h * w;
1377 let src_start = b * in_per_group * h * w + c * h * w;
1378 grad_input_data[dst_start..dst_start + h * w]
1379 .copy_from_slice(&gi_g[src_start..src_start + h * w]);
1380 }
1381 }
1382 }
1383
1384 Some(
1385 Tensor::from_storage(
1386 TensorStorage::cpu(grad_input_data),
1387 self.input.shape().to_vec(),
1388 false,
1389 )?
1390 .to(input_device)?,
1391 )
1392 } else {
1393 None
1394 };
1395
1396 let mut grads = vec![grad_input, grad_weight];
1398 if self.bias.is_some() {
1399 grads.push(grad_bias);
1400 }
1401 Ok(grads)
1402 }
1403
1404 fn inputs(&self) -> Vec<&Tensor<T>> {
1405 let mut v = vec![&self.input, &self.weight];
1406 if let Some(ref b) = self.bias {
1407 v.push(b);
1408 }
1409 v
1410 }
1411
1412 fn name(&self) -> &'static str {
1413 "Conv2dBackward"
1414 }
1415}
1416
1417#[derive(Debug)]
1434pub struct Conv1d<T: Float> {
1435 weight: Parameter<T>,
1437 bias: Option<Parameter<T>>,
1439 in_channels: usize,
1441 out_channels: usize,
1443 kernel_size: usize,
1445 stride: usize,
1447 padding: usize,
1449 dilation: usize,
1453 groups: usize,
1458 padding_mode: crate::padding::PaddingMode,
1466 string_padding: Option<StringPadding>,
1473 training: bool,
1475}
1476
1477impl<T: Float> Conv1d<T> {
1478 pub fn new(
1487 in_channels: usize,
1488 out_channels: usize,
1489 kernel_size: usize,
1490 stride: usize,
1491 padding: usize,
1492 bias: bool,
1493 ) -> FerrotorchResult<Self> {
1494 Self::new_full(
1495 in_channels,
1496 out_channels,
1497 kernel_size,
1498 stride,
1499 padding,
1500 1,
1501 1,
1502 bias,
1503 )
1504 }
1505
1506 #[allow(clippy::too_many_arguments)]
1521 pub fn new_full(
1522 in_channels: usize,
1523 out_channels: usize,
1524 kernel_size: usize,
1525 stride: usize,
1526 padding: usize,
1527 dilation: usize,
1528 groups: usize,
1529 bias: bool,
1530 ) -> FerrotorchResult<Self> {
1531 if in_channels == 0 || out_channels == 0 {
1532 return Err(FerrotorchError::InvalidArgument {
1533 message: "in_channels and out_channels must be > 0".into(),
1534 });
1535 }
1536 if kernel_size == 0 {
1537 return Err(FerrotorchError::InvalidArgument {
1538 message: "kernel_size must be > 0".into(),
1539 });
1540 }
1541 if stride == 0 {
1542 return Err(FerrotorchError::InvalidArgument {
1543 message: "stride must be > 0".into(),
1544 });
1545 }
1546 if dilation == 0 {
1547 return Err(FerrotorchError::InvalidArgument {
1548 message: format!("Conv1d::new_full: dilation must be > 0, got {dilation}"),
1549 });
1550 }
1551 if groups == 0 {
1552 return Err(FerrotorchError::InvalidArgument {
1553 message: "Conv1d::new_full: groups must be > 0".into(),
1554 });
1555 }
1556 if in_channels % groups != 0 {
1559 return Err(FerrotorchError::InvalidArgument {
1560 message: format!(
1561 "Conv1d::new_full: groups={groups} must divide in_channels={in_channels}"
1562 ),
1563 });
1564 }
1565 if out_channels % groups != 0 {
1566 return Err(FerrotorchError::InvalidArgument {
1567 message: format!(
1568 "Conv1d::new_full: groups={groups} must divide out_channels={out_channels}"
1569 ),
1570 });
1571 }
1572
1573 let mut weight = Parameter::zeros(&[out_channels, in_channels / groups, kernel_size])?;
1575 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
1576
1577 let bias_param = if bias {
1578 let mut b = Parameter::zeros(&[out_channels])?;
1579 let fan_in = (in_channels / groups) * kernel_size;
1582 let bound = if fan_in > 0 {
1583 1.0 / (fan_in as f64).sqrt()
1584 } else {
1585 0.0
1586 };
1587 uniform_init(&mut b, -bound, bound)?;
1588 Some(b)
1589 } else {
1590 None
1591 };
1592
1593 Ok(Self {
1594 weight,
1595 bias: bias_param,
1596 in_channels,
1597 out_channels,
1598 kernel_size,
1599 stride,
1600 padding,
1601 dilation,
1602 groups,
1603 padding_mode: crate::padding::PaddingMode::Zeros,
1604 string_padding: None,
1605 training: true,
1606 })
1607 }
1608
1609 pub fn groups(&self) -> usize {
1611 self.groups
1612 }
1613
1614 pub fn dilation(&self) -> usize {
1616 self.dilation
1617 }
1618
1619 pub fn with_string_padding(mut self, padding: StringPadding) -> FerrotorchResult<Self> {
1636 if padding == StringPadding::Same && self.stride != 1 {
1637 return Err(FerrotorchError::InvalidArgument {
1638 message: "padding='same' is not supported for strided convolutions".into(),
1639 });
1640 }
1641 self.string_padding = Some(padding);
1642 self.padding = 0;
1643 Ok(self)
1644 }
1645
1646 pub fn with_padding_mode(mut self, mode: crate::padding::PaddingMode) -> Self {
1656 self.padding_mode = mode;
1657 self
1658 }
1659
1660 pub fn num_parameters(&self) -> usize {
1662 let w = self.out_channels * self.in_channels * self.kernel_size;
1663 let b = if self.bias.is_some() {
1664 self.out_channels
1665 } else {
1666 0
1667 };
1668 w + b
1669 }
1670
1671 pub fn from_parts(
1679 weight: Tensor<T>,
1680 bias: Option<Tensor<T>>,
1681 stride: usize,
1682 padding: usize,
1683 ) -> FerrotorchResult<Self> {
1684 if weight.ndim() != 3 {
1685 return Err(FerrotorchError::ShapeMismatch {
1686 message: format!(
1687 "Conv1d::from_parts: weight must be 3-D [out, in, k], got {:?}",
1688 weight.shape()
1689 ),
1690 });
1691 }
1692 let out_channels = weight.shape()[0];
1693 let in_channels = weight.shape()[1];
1694 let kernel_size = weight.shape()[2];
1695 if let Some(b) = &bias {
1696 if b.ndim() != 1 || b.shape()[0] != out_channels {
1697 return Err(FerrotorchError::ShapeMismatch {
1698 message: format!(
1699 "Conv1d::from_parts: bias shape {:?} != [{}]",
1700 b.shape(),
1701 out_channels
1702 ),
1703 });
1704 }
1705 }
1706 Ok(Self {
1707 weight: Parameter::new(weight),
1708 bias: bias.map(Parameter::new),
1709 in_channels,
1710 out_channels,
1711 kernel_size,
1712 stride,
1713 padding,
1714 dilation: 1,
1715 groups: 1,
1716 padding_mode: crate::padding::PaddingMode::Zeros,
1717 string_padding: None,
1718 training: true,
1719 })
1720 }
1721
1722 fn recurse_clone(
1727 &self,
1728 padding: usize,
1729 padding_mode: crate::padding::PaddingMode,
1730 ) -> Conv1d<T> {
1731 Conv1d {
1732 weight: Parameter::new(self.weight.tensor().clone()),
1733 bias: self
1734 .bias
1735 .as_ref()
1736 .map(|b| Parameter::new(b.tensor().clone())),
1737 in_channels: self.in_channels,
1738 out_channels: self.out_channels,
1739 kernel_size: self.kernel_size,
1740 stride: self.stride,
1741 padding,
1742 dilation: self.dilation,
1743 groups: self.groups,
1744 padding_mode,
1745 string_padding: None,
1746 training: self.training,
1747 }
1748 }
1749}
1750
1751impl<T: Float> Module<T> for Conv1d<T> {
1752 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1753 let _autocast_cat = autocast_guard("conv1d");
1755
1756 if input.ndim() == 2 {
1764 let batched = unsqueeze(input, 0)?;
1765 let output = self.forward(&batched)?;
1766 return squeeze(&output, 0);
1767 }
1768
1769 if let Some(sp) = self.string_padding {
1779 match sp {
1780 StringPadding::Valid => {
1781 return self
1782 .recurse_clone(0, crate::padding::PaddingMode::Zeros)
1783 .forward(input);
1784 }
1785 StringPadding::Same => {
1786 let (left, right) = same_pad_lr(self.kernel_size, self.dilation);
1787 let padded = crate::padding::functional_pad_1d(
1788 input,
1789 left,
1790 right,
1791 self.padding_mode,
1792 <T as num_traits::Zero>::zero(),
1793 )?;
1794 return self
1795 .recurse_clone(0, crate::padding::PaddingMode::Zeros)
1796 .forward(&padded);
1797 }
1798 }
1799 }
1800
1801 if self.padding_mode != crate::padding::PaddingMode::Zeros && self.padding != 0 {
1813 let padded = crate::padding::functional_pad_1d(
1814 input,
1815 self.padding,
1816 self.padding,
1817 self.padding_mode,
1818 <T as num_traits::Zero>::zero(),
1819 )?;
1820 return self
1824 .recurse_clone(0, crate::padding::PaddingMode::Zeros)
1825 .forward(&padded);
1826 }
1827
1828 if input.ndim() != 3 {
1830 return Err(FerrotorchError::InvalidArgument {
1831 message: format!(
1832 "Conv1d expects 3-D input [B, C, L], got {:?}",
1833 input.shape()
1834 ),
1835 });
1836 }
1837
1838 let batch = input.shape()[0];
1839 let c_in = input.shape()[1];
1840 let length = input.shape()[2];
1841
1842 if c_in != self.in_channels {
1843 return Err(FerrotorchError::ShapeMismatch {
1844 message: format!(
1845 "Conv1d: expected {} input channels, got {}",
1846 self.in_channels, c_in
1847 ),
1848 });
1849 }
1850
1851 let k = self.kernel_size;
1852 let s = self.stride;
1853 let p = self.padding;
1854 let dil = self.dilation;
1855 let groups = self.groups;
1856
1857 let eff_k = dil * (k - 1) + 1;
1860 let l_padded = length + 2 * p;
1861 if l_padded < eff_k {
1862 return Err(FerrotorchError::InvalidArgument {
1863 message: format!(
1864 "Conv1d: padded input length ({l_padded}) is smaller than effective kernel ({eff_k})"
1865 ),
1866 });
1867 }
1868
1869 let l_out = (l_padded - eff_k) / s + 1;
1870
1871 let input_device = input.device();
1873
1874 let input_data = input.data_vec()?;
1882 let weight_data = self.weight.data_vec()?;
1883
1884 let zero = <T as num_traits::Zero>::zero();
1885 let mut output = vec![zero; batch * self.out_channels * l_out];
1886
1887 let in_per_group = self.in_channels / groups;
1889 let out_per_group = self.out_channels / groups;
1890 let weight_per_group_numel = out_per_group * in_per_group * k;
1891 let group_col_rows = in_per_group * k;
1892 let col_cols = l_out;
1893
1894 let saved_cols_rows = self.in_channels * k;
1899 let mut saved_cols: Vec<T> = if is_grad_enabled()
1900 && (input.requires_grad()
1901 || self.weight.requires_grad()
1902 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
1903 {
1904 vec![zero; batch * saved_cols_rows * col_cols]
1905 } else {
1906 Vec::new()
1907 };
1908 let save_cols = !saved_cols.is_empty();
1909
1910 for g in 0..groups {
1911 let mut group_input = vec![zero; batch * in_per_group * length];
1913 for b in 0..batch {
1914 for c in 0..in_per_group {
1915 let src_c = g * in_per_group + c;
1916 let src_start = b * self.in_channels * length + src_c * length;
1917 let dst_start = b * in_per_group * length + c * length;
1918 group_input[dst_start..dst_start + length]
1919 .copy_from_slice(&input_data[src_start..src_start + length]);
1920 }
1921 }
1922
1923 let (g_cols, g_col_rows, g_col_cols) = im2col_dilated(
1924 &group_input,
1925 batch,
1926 in_per_group,
1927 1,
1928 length,
1929 1,
1930 k,
1931 1,
1932 s,
1933 0,
1934 p,
1935 1,
1936 dil,
1937 );
1938 debug_assert_eq!(g_col_rows, group_col_rows);
1939 debug_assert_eq!(g_col_cols, col_cols);
1940
1941 if save_cols {
1943 for b in 0..batch {
1944 for c in 0..in_per_group {
1945 let dst_c = g * in_per_group + c;
1946 for kk in 0..k {
1947 let src_row = c * k + kk;
1948 let dst_row = dst_c * k + kk;
1949 let src_off = b * group_col_rows * col_cols + src_row * col_cols;
1950 let dst_off = b * saved_cols_rows * col_cols + dst_row * col_cols;
1951 saved_cols[dst_off..dst_off + col_cols]
1952 .copy_from_slice(&g_cols[src_off..src_off + col_cols]);
1953 }
1954 }
1955 }
1956 }
1957
1958 let w_group_start = g * weight_per_group_numel;
1961 let w_group_end = w_group_start + weight_per_group_numel;
1962 let weight_group_2d = Tensor::from_storage(
1963 TensorStorage::cpu(weight_data[w_group_start..w_group_end].to_vec()),
1964 vec![out_per_group, group_col_rows],
1965 false,
1966 )?;
1967
1968 for b in 0..batch {
1969 let col_start = b * group_col_rows * col_cols;
1970 let col_end = col_start + group_col_rows * col_cols;
1971 let cols_b = Tensor::from_storage(
1972 TensorStorage::cpu(g_cols[col_start..col_end].to_vec()),
1973 vec![group_col_rows, col_cols],
1974 false,
1975 )?;
1976
1977 let out_b = mm(&weight_group_2d, &cols_b)?;
1978 let out_data = out_b.data()?;
1979 for oc in 0..out_per_group {
1981 let dst_c = g * out_per_group + oc;
1982 let dst_start = b * self.out_channels * l_out + dst_c * l_out;
1983 let src_start = oc * l_out;
1984 output[dst_start..dst_start + l_out]
1985 .copy_from_slice(&out_data[src_start..src_start + l_out]);
1986 }
1987 }
1988 }
1989
1990 if let Some(ref bias) = self.bias {
1992 let bias_data = bias.data_vec()?;
1993 for b in 0..batch {
1994 for c in 0..self.out_channels {
1995 let bval = bias_data[c];
1996 for l in 0..l_out {
1997 output[b * self.out_channels * l_out + c * l_out + l] += bval;
1998 }
1999 }
2000 }
2001 }
2002
2003 let result = Tensor::from_storage(
2004 TensorStorage::cpu(output),
2005 vec![batch, self.out_channels, l_out],
2006 false,
2007 )?;
2008
2009 if save_cols {
2011 let grad_fn = Arc::new(Conv1dBackward {
2012 input: input.clone(),
2013 weight: self.weight.tensor().clone(),
2014 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
2015 in_channels: self.in_channels,
2016 out_channels: self.out_channels,
2017 kernel_size: self.kernel_size,
2018 stride: self.stride,
2019 padding: self.padding,
2020 dilation: self.dilation,
2021 groups: self.groups,
2022 cols: saved_cols,
2023 col_rows: saved_cols_rows,
2024 col_cols,
2025 l_out,
2026 });
2027 Tensor::from_operation(
2028 TensorStorage::cpu(result.data()?.to_vec()),
2029 result.shape().to_vec(),
2030 grad_fn,
2031 )?
2032 .to(input_device) } else {
2034 result.to(input_device)
2035 }
2036 }
2037
2038 fn parameters(&self) -> Vec<&Parameter<T>> {
2039 let mut params = vec![&self.weight];
2040 if let Some(ref b) = self.bias {
2041 params.push(b);
2042 }
2043 params
2044 }
2045
2046 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2047 let mut params = vec![&mut self.weight];
2048 if let Some(ref mut b) = self.bias {
2049 params.push(b);
2050 }
2051 params
2052 }
2053
2054 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2055 let mut params = vec![("weight".to_string(), &self.weight)];
2056 if let Some(ref b) = self.bias {
2057 params.push(("bias".to_string(), b));
2058 }
2059 params
2060 }
2061
2062 fn train(&mut self) {
2063 self.training = true;
2064 }
2065
2066 fn eval(&mut self) {
2067 self.training = false;
2068 }
2069
2070 fn is_training(&self) -> bool {
2071 self.training
2072 }
2073}
2074
2075#[derive(Debug)]
2087struct Conv1dBackward<T: Float> {
2088 input: Tensor<T>,
2089 weight: Tensor<T>,
2090 bias: Option<Tensor<T>>,
2091 in_channels: usize,
2092 out_channels: usize,
2093 kernel_size: usize,
2094 stride: usize,
2095 padding: usize,
2096 dilation: usize,
2097 groups: usize,
2098 cols: Vec<T>,
2099 col_rows: usize,
2100 col_cols: usize,
2101 l_out: usize,
2102}
2103
2104impl<T: Float> GradFn<T> for Conv1dBackward<T> {
2105 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2106 let input_device = self.input.device();
2108 let weight_device = self.weight.device();
2109 let bias_device = self.bias.as_ref().map(|b| b.device());
2110 let go_data = grad_output.data_vec()?;
2111 let batch = self.input.shape()[0];
2112 let length = self.input.shape()[2];
2113 let k = self.kernel_size;
2114 let s = self.stride;
2115 let p = self.padding;
2116 let dil = self.dilation;
2117 let groups = self.groups;
2118 let in_per_group = self.in_channels / groups;
2119 let out_per_group = self.out_channels / groups;
2120 let group_col_rows = in_per_group * k;
2121 let zero = <T as num_traits::Zero>::zero();
2122
2123 let grad_weight = if self.weight.requires_grad() {
2127 let weight_numel = self.out_channels * in_per_group * k;
2128 let mut gw_accum = vec![zero; weight_numel];
2129 let weight_per_group_numel = out_per_group * group_col_rows;
2130
2131 for g in 0..groups {
2132 for b in 0..batch {
2133 let mut go_g = vec![zero; out_per_group * self.l_out];
2135 for oc in 0..out_per_group {
2136 let src_c = g * out_per_group + oc;
2137 let src_start = b * self.out_channels * self.l_out + src_c * self.l_out;
2138 let dst_start = oc * self.l_out;
2139 go_g[dst_start..dst_start + self.l_out]
2140 .copy_from_slice(&go_data[src_start..src_start + self.l_out]);
2141 }
2142 let go_b_g = Tensor::from_storage(
2143 TensorStorage::cpu(go_g),
2144 vec![out_per_group, self.l_out],
2145 false,
2146 )?;
2147
2148 let mut cols_g = vec![zero; group_col_rows * self.col_cols];
2150 for c in 0..in_per_group {
2151 let src_c = g * in_per_group + c;
2152 for kk in 0..k {
2153 let src_row = src_c * k + kk;
2154 let dst_row = c * k + kk;
2155 let src_off =
2156 b * self.col_rows * self.col_cols + src_row * self.col_cols;
2157 let dst_off = dst_row * self.col_cols;
2158 cols_g[dst_off..dst_off + self.col_cols]
2159 .copy_from_slice(&self.cols[src_off..src_off + self.col_cols]);
2160 }
2161 }
2162 let cols_b_g = Tensor::from_storage(
2163 TensorStorage::cpu(cols_g),
2164 vec![group_col_rows, self.col_cols],
2165 false,
2166 )?;
2167
2168 let cols_bt = transpose(&cols_b_g)?;
2169 let gw_b = mm(&go_b_g, &cols_bt)?;
2170 let gw_data = gw_b.data()?;
2171
2172 let dst_off = g * weight_per_group_numel;
2173 for i in 0..weight_per_group_numel {
2174 gw_accum[dst_off + i] += gw_data[i];
2175 }
2176 }
2177 }
2178
2179 Some(
2180 Tensor::from_storage(
2181 TensorStorage::cpu(gw_accum),
2182 vec![self.out_channels, in_per_group, k],
2183 false,
2184 )?
2185 .to(weight_device)?,
2186 )
2187 } else {
2188 None
2189 };
2190
2191 let grad_bias = match &self.bias {
2195 Some(b) if b.requires_grad() => {
2196 let mut gb = vec![zero; self.out_channels];
2197 for batch_idx in 0..batch {
2198 for c in 0..self.out_channels {
2199 for l in 0..self.l_out {
2200 gb[c] += go_data
2201 [batch_idx * self.out_channels * self.l_out + c * self.l_out + l];
2202 }
2203 }
2204 }
2205 let target_dev = bias_device.unwrap_or(input_device);
2206 Some(
2207 Tensor::from_storage(TensorStorage::cpu(gb), vec![self.out_channels], false)?
2208 .to(target_dev)?,
2209 )
2210 }
2211 _ => None,
2212 };
2213
2214 let grad_input = if self.input.requires_grad() {
2219 let weight_data = self.weight.data_vec()?;
2220 let mut grad_input_data = vec![zero; batch * self.in_channels * length];
2221 let weight_per_group_numel = out_per_group * group_col_rows;
2222
2223 for g in 0..groups {
2224 let w_off = g * weight_per_group_numel;
2225 let weight_g_2d = Tensor::from_storage(
2226 TensorStorage::cpu(weight_data[w_off..w_off + weight_per_group_numel].to_vec()),
2227 vec![out_per_group, group_col_rows],
2228 false,
2229 )?;
2230 let weight_g_t = transpose(&weight_g_2d)?;
2231
2232 let mut grad_cols_g = vec![zero; batch * group_col_rows * self.col_cols];
2233 for b in 0..batch {
2234 let mut go_g = vec![zero; out_per_group * self.l_out];
2235 for oc in 0..out_per_group {
2236 let src_c = g * out_per_group + oc;
2237 let src_start = b * self.out_channels * self.l_out + src_c * self.l_out;
2238 let dst_start = oc * self.l_out;
2239 go_g[dst_start..dst_start + self.l_out]
2240 .copy_from_slice(&go_data[src_start..src_start + self.l_out]);
2241 }
2242 let go_b_g = Tensor::from_storage(
2243 TensorStorage::cpu(go_g),
2244 vec![out_per_group, self.l_out],
2245 false,
2246 )?;
2247
2248 let gc_b = mm(&weight_g_t, &go_b_g)?;
2249 let gc_data = gc_b.data()?;
2250 let gc_start = b * group_col_rows * self.col_cols;
2251 grad_cols_g[gc_start..gc_start + group_col_rows * self.col_cols]
2252 .copy_from_slice(gc_data);
2253 }
2254
2255 let gi_g = col2im_dilated(
2258 &grad_cols_g,
2259 batch,
2260 in_per_group,
2261 1,
2262 length,
2263 1,
2264 k,
2265 1,
2266 s,
2267 0,
2268 p,
2269 1,
2270 dil,
2271 1,
2272 self.l_out,
2273 );
2274
2275 for b in 0..batch {
2276 for c in 0..in_per_group {
2277 let dst_c = g * in_per_group + c;
2278 let dst_start = b * self.in_channels * length + dst_c * length;
2279 let src_start = b * in_per_group * length + c * length;
2280 grad_input_data[dst_start..dst_start + length]
2281 .copy_from_slice(&gi_g[src_start..src_start + length]);
2282 }
2283 }
2284 }
2285
2286 Some(
2287 Tensor::from_storage(
2288 TensorStorage::cpu(grad_input_data),
2289 self.input.shape().to_vec(),
2290 false,
2291 )?
2292 .to(input_device)?,
2293 )
2294 } else {
2295 None
2296 };
2297
2298 let mut grads = vec![grad_input, grad_weight];
2299 if self.bias.is_some() {
2300 grads.push(grad_bias);
2301 }
2302 Ok(grads)
2303 }
2304
2305 fn inputs(&self) -> Vec<&Tensor<T>> {
2306 let mut v = vec![&self.input, &self.weight];
2307 if let Some(ref b) = self.bias {
2308 v.push(b);
2309 }
2310 v
2311 }
2312
2313 fn name(&self) -> &'static str {
2314 "Conv1dBackward"
2315 }
2316}
2317
2318#[derive(Debug)]
2341pub struct ConvTranspose2d<T: Float> {
2342 weight: Parameter<T>,
2347 bias: Option<Parameter<T>>,
2349 in_channels: usize,
2351 out_channels: usize,
2353 kernel_size: (usize, usize),
2355 stride: (usize, usize),
2357 padding: (usize, usize),
2359 output_padding: (usize, usize),
2361 dilation: (usize, usize),
2366 groups: usize,
2374 training: bool,
2376}
2377
2378impl<T: Float> ConvTranspose2d<T> {
2379 pub fn new(
2389 in_channels: usize,
2390 out_channels: usize,
2391 kernel_size: (usize, usize),
2392 stride: (usize, usize),
2393 padding: (usize, usize),
2394 output_padding: (usize, usize),
2395 bias: bool,
2396 ) -> FerrotorchResult<Self> {
2397 Self::new_full(
2398 in_channels,
2399 out_channels,
2400 kernel_size,
2401 stride,
2402 padding,
2403 output_padding,
2404 (1, 1),
2405 1,
2406 bias,
2407 )
2408 }
2409
2410 #[allow(clippy::too_many_arguments)]
2420 pub fn new_full(
2421 in_channels: usize,
2422 out_channels: usize,
2423 kernel_size: (usize, usize),
2424 stride: (usize, usize),
2425 padding: (usize, usize),
2426 output_padding: (usize, usize),
2427 dilation: (usize, usize),
2428 groups: usize,
2429 bias: bool,
2430 ) -> FerrotorchResult<Self> {
2431 if in_channels == 0 || out_channels == 0 {
2432 return Err(FerrotorchError::InvalidArgument {
2433 message: "in_channels and out_channels must be > 0".into(),
2434 });
2435 }
2436 if kernel_size.0 == 0 || kernel_size.1 == 0 {
2437 return Err(FerrotorchError::InvalidArgument {
2438 message: "kernel_size must be > 0 in both dimensions".into(),
2439 });
2440 }
2441 if stride.0 == 0 || stride.1 == 0 {
2442 return Err(FerrotorchError::InvalidArgument {
2443 message: "stride must be > 0 in both dimensions".into(),
2444 });
2445 }
2446 if dilation.0 == 0 || dilation.1 == 0 {
2447 return Err(FerrotorchError::InvalidArgument {
2448 message: "dilation must be > 0 in both dimensions".into(),
2449 });
2450 }
2451 if groups == 0 {
2454 return Err(FerrotorchError::InvalidArgument {
2455 message: "groups must be a positive integer".into(),
2456 });
2457 }
2458 if in_channels % groups != 0 {
2459 return Err(FerrotorchError::InvalidArgument {
2460 message: format!(
2461 "in_channels ({in_channels}) must be divisible by groups ({groups})"
2462 ),
2463 });
2464 }
2465 if out_channels % groups != 0 {
2466 return Err(FerrotorchError::InvalidArgument {
2467 message: format!(
2468 "out_channels ({out_channels}) must be divisible by groups ({groups})"
2469 ),
2470 });
2471 }
2472 if output_padding.0 >= stride.0.max(dilation.0)
2475 || output_padding.1 >= stride.1.max(dilation.1)
2476 {
2477 return Err(FerrotorchError::InvalidArgument {
2478 message: "output_padding must be strictly less than max(stride, dilation)".into(),
2479 });
2480 }
2481
2482 let (kh, kw) = kernel_size;
2485 let out_per_group = out_channels / groups;
2486 let mut weight = Parameter::zeros(&[in_channels, out_per_group, kh, kw])?;
2487 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
2488
2489 let bias_param = if bias {
2490 let mut b = Parameter::zeros(&[out_channels])?;
2491 let fan_in = out_per_group * kh * kw;
2497 let bound = if fan_in > 0 {
2498 1.0 / (fan_in as f64).sqrt()
2499 } else {
2500 0.0
2501 };
2502 uniform_init(&mut b, -bound, bound)?;
2503 Some(b)
2504 } else {
2505 None
2506 };
2507
2508 Ok(Self {
2509 weight,
2510 bias: bias_param,
2511 in_channels,
2512 out_channels,
2513 kernel_size,
2514 stride,
2515 padding,
2516 output_padding,
2517 dilation,
2518 groups,
2519 training: true,
2520 })
2521 }
2522
2523 pub fn with_padding_mode(self, mode: crate::padding::PaddingMode) -> FerrotorchResult<Self> {
2534 reject_non_zeros_transpose(mode, "ConvTranspose2d")?;
2535 Ok(self)
2536 }
2537
2538 pub fn num_parameters(&self) -> usize {
2540 let w = self.in_channels * self.out_channels * self.kernel_size.0 * self.kernel_size.1;
2541 let b = if self.bias.is_some() {
2542 self.out_channels
2543 } else {
2544 0
2545 };
2546 w + b
2547 }
2548
2549 pub fn from_parts(
2555 weight: Tensor<T>,
2556 bias: Option<Tensor<T>>,
2557 stride: (usize, usize),
2558 padding: (usize, usize),
2559 output_padding: (usize, usize),
2560 ) -> FerrotorchResult<Self> {
2561 if weight.ndim() != 4 {
2562 return Err(FerrotorchError::ShapeMismatch {
2563 message: format!(
2564 "ConvTranspose2d::from_parts: weight must be 4-D [in, out, kH, kW], got {:?}",
2565 weight.shape()
2566 ),
2567 });
2568 }
2569 let in_channels = weight.shape()[0];
2570 let out_channels = weight.shape()[1];
2571 let kernel_size = (weight.shape()[2], weight.shape()[3]);
2572 if output_padding.0 >= stride.0 || output_padding.1 >= stride.1 {
2573 return Err(FerrotorchError::InvalidArgument {
2574 message: "output_padding must be strictly less than stride".into(),
2575 });
2576 }
2577 if let Some(b) = &bias {
2578 if b.ndim() != 1 || b.shape()[0] != out_channels {
2579 return Err(FerrotorchError::ShapeMismatch {
2580 message: format!(
2581 "ConvTranspose2d::from_parts: bias shape {:?} != [{}]",
2582 b.shape(),
2583 out_channels
2584 ),
2585 });
2586 }
2587 }
2588 Ok(Self {
2589 weight: Parameter::new(weight),
2590 bias: bias.map(Parameter::new),
2591 in_channels,
2592 out_channels,
2593 kernel_size,
2594 stride,
2595 padding,
2596 output_padding,
2597 dilation: (1, 1),
2604 groups: 1,
2605 training: true,
2606 })
2607 }
2608}
2609
2610fn stride_insert_zeros<T: Float>(
2615 input: &[T],
2616 batch: usize,
2617 channels: usize,
2618 h: usize,
2619 w: usize,
2620 stride_h: usize,
2621 stride_w: usize,
2622) -> (Vec<T>, usize, usize) {
2623 let h_up = (h - 1) * stride_h + 1;
2624 let w_up = (w - 1) * stride_w + 1;
2625 let zero = <T as num_traits::Zero>::zero();
2626 let mut out = vec![zero; batch * channels * h_up * w_up];
2627
2628 for b in 0..batch {
2629 for c in 0..channels {
2630 for ih in 0..h {
2631 for iw in 0..w {
2632 let oh = ih * stride_h;
2633 let ow = iw * stride_w;
2634 out[b * channels * h_up * w_up + c * h_up * w_up + oh * w_up + ow] =
2635 input[b * channels * h * w + c * h * w + ih * w + iw];
2636 }
2637 }
2638 }
2639 }
2640
2641 (out, h_up, w_up)
2642}
2643
2644fn crop_plane_2d<T: Float>(
2651 input: &[T],
2652 batch: usize,
2653 channels: usize,
2654 h: usize,
2655 w: usize,
2656 crop_h: usize,
2657 crop_w: usize,
2658) -> (Vec<T>, usize, usize) {
2659 let h_out = h - 2 * crop_h;
2660 let w_out = w - 2 * crop_w;
2661 let zero = <T as num_traits::Zero>::zero();
2662 let mut out = vec![zero; batch * channels * h_out * w_out];
2663
2664 for b in 0..batch {
2665 for c in 0..channels {
2666 for oh in 0..h_out {
2667 let src = ((b * channels + c) * h + (oh + crop_h)) * w + crop_w;
2668 let dst = ((b * channels + c) * h_out + oh) * w_out;
2669 out[dst..dst + w_out].copy_from_slice(&input[src..src + w_out]);
2670 }
2671 }
2672 }
2673
2674 (out, h_out, w_out)
2675}
2676
2677fn flip_kernel<T: Float>(kernel: &[T], c_in: usize, c_out: usize, kh: usize, kw: usize) -> Vec<T> {
2680 let zero = <T as num_traits::Zero>::zero();
2681 let mut flipped = vec![zero; c_out * c_in * kh * kw];
2682
2683 for ci in 0..c_in {
2684 for co in 0..c_out {
2685 for h in 0..kh {
2686 for w in 0..kw {
2687 let src = ci * c_out * kh * kw + co * kh * kw + h * kw + w;
2689 let dst = co * c_in * kh * kw + ci * kh * kw + (kh - 1 - h) * kw + (kw - 1 - w);
2691 flipped[dst] = kernel[src];
2692 }
2693 }
2694 }
2695 }
2696
2697 flipped
2698}
2699
2700#[allow(clippy::too_many_arguments)]
2713fn conv_transpose2d_forward_group<T: Float>(
2714 input_data: &[T],
2715 batch: usize,
2716 in_pg: usize,
2717 out_pg: usize,
2718 h: usize,
2719 w: usize,
2720 kernel_size: (usize, usize),
2721 stride: (usize, usize),
2722 padding: (usize, usize),
2723 output_padding: (usize, usize),
2724 dilation: (usize, usize),
2725 group_weight: &[T],
2726) -> FerrotorchResult<(Vec<T>, usize, usize)> {
2727 let (kh, kw) = kernel_size;
2728 let (sh, sw) = stride;
2729 let (ph, pw) = padding;
2730 let (oph, opw) = output_padding;
2731 let (dh, dw) = dilation;
2732
2733 let (upsampled, h_up_core, w_up_core) =
2735 stride_insert_zeros(input_data, batch, in_pg, h, w, sh, sw);
2736 let h_up = h_up_core + oph;
2737 let w_up = w_up_core + opw;
2738 let upsampled = if oph > 0 || opw > 0 {
2739 let zero = <T as num_traits::Zero>::zero();
2740 let mut ext = vec![zero; batch * in_pg * h_up * w_up];
2741 for b in 0..batch {
2742 for c in 0..in_pg {
2743 for ih in 0..h_up_core {
2744 let src = ((b * in_pg + c) * h_up_core + ih) * w_up_core;
2745 let dst = ((b * in_pg + c) * h_up + ih) * w_up;
2746 ext[dst..dst + w_up_core].copy_from_slice(&upsampled[src..src + w_up_core]);
2747 }
2748 }
2749 }
2750 ext
2751 } else {
2752 upsampled
2753 };
2754
2755 let flipped = flip_kernel(group_weight, in_pg, out_pg, kh, kw);
2759
2760 let eff_kh = dh * (kh - 1) + 1;
2771 let eff_kw = dw * (kw - 1) + 1;
2772 let signed_pad_h = (eff_kh - 1) as isize - ph as isize;
2773 let signed_pad_w = (eff_kw - 1) as isize - pw as isize;
2774 let crop_h = (-signed_pad_h).max(0) as usize;
2775 let crop_w = (-signed_pad_w).max(0) as usize;
2776 let (conv_input, h_in, w_in) = if crop_h > 0 || crop_w > 0 {
2777 crop_plane_2d(&upsampled, batch, in_pg, h_up, w_up, crop_h, crop_w)
2778 } else {
2779 (upsampled, h_up, w_up)
2780 };
2781 let internal_pad_h = signed_pad_h.max(0) as usize;
2782 let internal_pad_w = signed_pad_w.max(0) as usize;
2783
2784 let (cols, col_rows, col_cols) = im2col_dilated(
2785 &conv_input,
2786 batch,
2787 in_pg,
2788 h_in,
2789 w_in,
2790 kh,
2791 kw,
2792 1,
2793 1,
2794 internal_pad_h,
2795 internal_pad_w,
2796 dh,
2797 dw,
2798 );
2799
2800 let h_out = (h_in + 2 * internal_pad_h - eff_kh) + 1;
2803 let w_out = (w_in + 2 * internal_pad_w - eff_kw) + 1;
2804
2805 let flipped_2d =
2807 Tensor::from_storage(TensorStorage::cpu(flipped), vec![out_pg, col_rows], false)?;
2808
2809 let zero = <T as num_traits::Zero>::zero();
2810 let mut output = vec![zero; batch * out_pg * h_out * w_out];
2811
2812 for b in 0..batch {
2813 let col_start = b * col_rows * col_cols;
2814 let col_end = col_start + col_rows * col_cols;
2815 let cols_b = Tensor::from_storage(
2816 TensorStorage::cpu(cols[col_start..col_end].to_vec()),
2817 vec![col_rows, col_cols],
2818 false,
2819 )?;
2820
2821 let out_b = mm(&flipped_2d, &cols_b)?;
2822 let out_data = out_b.data()?;
2823
2824 let out_start = b * out_pg * h_out * w_out;
2825 let copy_len = out_pg * h_out * w_out;
2826 output[out_start..out_start + copy_len].copy_from_slice(&out_data[..copy_len]);
2827 }
2828
2829 Ok((output, h_out, w_out))
2830}
2831
2832impl<T: Float> Module<T> for ConvTranspose2d<T> {
2833 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2834 let _autocast_cat = autocast_guard("conv_transpose2d");
2836
2837 if input.ndim() == 3 {
2846 let batched = unsqueeze(input, 0)?;
2847 let output = self.forward(&batched)?;
2848 return squeeze(&output, 0);
2849 }
2850
2851 if input.ndim() != 4 {
2853 return Err(FerrotorchError::InvalidArgument {
2854 message: format!(
2855 "ConvTranspose2d expects 4-D input [B, C, H, W], got {:?}",
2856 input.shape()
2857 ),
2858 });
2859 }
2860
2861 let batch = input.shape()[0];
2862 let c_in = input.shape()[1];
2863 let h = input.shape()[2];
2864 let w = input.shape()[3];
2865
2866 if c_in != self.in_channels {
2867 return Err(FerrotorchError::ShapeMismatch {
2868 message: format!(
2869 "ConvTranspose2d: expected {} input channels, got {}",
2870 self.in_channels, c_in
2871 ),
2872 });
2873 }
2874
2875 let (kh, kw) = self.kernel_size;
2876 let groups = self.groups;
2877 let in_pg = self.in_channels / groups;
2878 let out_pg = self.out_channels / groups;
2879 let weight_pg_numel = in_pg * out_pg * kh * kw;
2880
2881 let input_device = input.device();
2883
2884 let input_data = input.data_vec()?;
2885 let weight_data = self.weight.data_vec()?;
2886
2887 let zero = <T as num_traits::Zero>::zero();
2894 let mut output: Vec<T> = Vec::new();
2895 let mut h_out = 0usize;
2896 let mut w_out = 0usize;
2897
2898 for g in 0..groups {
2899 let mut group_input = vec![zero; batch * in_pg * h * w];
2901 for b in 0..batch {
2902 for c in 0..in_pg {
2903 let src_c = g * in_pg + c;
2904 let src_start = b * self.in_channels * h * w + src_c * h * w;
2905 let dst_start = b * in_pg * h * w + c * h * w;
2906 group_input[dst_start..dst_start + h * w]
2907 .copy_from_slice(&input_data[src_start..src_start + h * w]);
2908 }
2909 }
2910
2911 let w_start = g * weight_pg_numel;
2916 let group_weight = &weight_data[w_start..w_start + weight_pg_numel];
2917
2918 let (g_out, gho, gwo) = conv_transpose2d_forward_group(
2919 &group_input,
2920 batch,
2921 in_pg,
2922 out_pg,
2923 h,
2924 w,
2925 self.kernel_size,
2926 self.stride,
2927 self.padding,
2928 self.output_padding,
2929 self.dilation,
2930 group_weight,
2931 )?;
2932 h_out = gho;
2933 w_out = gwo;
2934
2935 if output.is_empty() {
2936 output = vec![zero; batch * self.out_channels * h_out * w_out];
2937 }
2938 for b in 0..batch {
2940 for oc in 0..out_pg {
2941 let dst_c = g * out_pg + oc;
2942 let dst_start = b * self.out_channels * h_out * w_out + dst_c * h_out * w_out;
2943 let src_start = (b * out_pg + oc) * h_out * w_out;
2944 output[dst_start..dst_start + h_out * w_out]
2945 .copy_from_slice(&g_out[src_start..src_start + h_out * w_out]);
2946 }
2947 }
2948 }
2949
2950 if let Some(ref bias) = self.bias {
2952 let bias_data = bias.data_vec()?;
2953 for b in 0..batch {
2954 for c in 0..self.out_channels {
2955 let bval = bias_data[c];
2956 for hw in 0..(h_out * w_out) {
2957 output[b * self.out_channels * h_out * w_out + c * h_out * w_out + hw] +=
2958 bval;
2959 }
2960 }
2961 }
2962 }
2963
2964 let result = Tensor::from_storage(
2965 TensorStorage::cpu(output),
2966 vec![batch, self.out_channels, h_out, w_out],
2967 false,
2968 )?;
2969
2970 if is_grad_enabled()
2972 && (input.requires_grad()
2973 || self.weight.requires_grad()
2974 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
2975 {
2976 let grad_fn = Arc::new(ConvTranspose2dBackward {
2977 input: input.clone(),
2978 weight: self.weight.tensor().clone(),
2979 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
2980 in_channels: self.in_channels,
2981 out_channels: self.out_channels,
2982 kernel_size: self.kernel_size,
2983 stride: self.stride,
2984 padding: self.padding,
2985 _output_padding: self.output_padding,
2986 dilation: self.dilation,
2987 groups: self.groups,
2988 h_out,
2989 w_out,
2990 });
2991 Tensor::from_operation(
2992 TensorStorage::cpu(result.data()?.to_vec()),
2993 result.shape().to_vec(),
2994 grad_fn,
2995 )?
2996 .to(input_device) } else {
2998 result.to(input_device)
2999 }
3000 }
3001
3002 fn parameters(&self) -> Vec<&Parameter<T>> {
3003 let mut params = vec![&self.weight];
3004 if let Some(ref b) = self.bias {
3005 params.push(b);
3006 }
3007 params
3008 }
3009
3010 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
3011 let mut params = vec![&mut self.weight];
3012 if let Some(ref mut b) = self.bias {
3013 params.push(b);
3014 }
3015 params
3016 }
3017
3018 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
3019 let mut params = vec![("weight".to_string(), &self.weight)];
3020 if let Some(ref b) = self.bias {
3021 params.push(("bias".to_string(), b));
3022 }
3023 params
3024 }
3025
3026 fn train(&mut self) {
3027 self.training = true;
3028 }
3029
3030 fn eval(&mut self) {
3031 self.training = false;
3032 }
3033
3034 fn is_training(&self) -> bool {
3035 self.training
3036 }
3037}
3038
3039#[derive(Debug)]
3047struct ConvTranspose2dBackward<T: Float> {
3048 input: Tensor<T>,
3049 weight: Tensor<T>,
3050 bias: Option<Tensor<T>>,
3051 in_channels: usize,
3052 out_channels: usize,
3053 kernel_size: (usize, usize),
3054 stride: (usize, usize),
3055 padding: (usize, usize),
3056 _output_padding: (usize, usize),
3057 dilation: (usize, usize),
3058 groups: usize,
3059 h_out: usize,
3060 w_out: usize,
3061}
3062
3063impl<T: Float> GradFn<T> for ConvTranspose2dBackward<T> {
3064 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3065 let go_data = grad_output.data_vec()?;
3067 let batch = self.input.shape()[0];
3068 let h_in = self.input.shape()[2];
3069 let w_in = self.input.shape()[3];
3070 let (kh, kw) = self.kernel_size;
3071 let (sh, sw) = self.stride;
3072 let (ph, pw) = self.padding;
3073 let (dh, dw) = self.dilation;
3074 let groups = self.groups;
3075 let in_pg = self.in_channels / groups;
3076 let out_pg = self.out_channels / groups;
3077 let zero = <T as num_traits::Zero>::zero();
3078
3079 let weight_data_all = self.weight.data_vec()?;
3080 let input_data_all = self.input.data_vec()?;
3081
3082 let mut gi_all = if self.input.requires_grad() {
3088 Some(vec![zero; batch * self.in_channels * h_in * w_in])
3089 } else {
3090 None
3091 };
3092 let mut gw_all = if self.weight.requires_grad() {
3093 Some(vec![zero; self.in_channels * out_pg * kh * kw])
3094 } else {
3095 None
3096 };
3097
3098 for g in 0..groups {
3099 if let Some(gi) = gi_all.as_mut() {
3105 let col_rows = out_pg * kh * kw;
3106 let w_start = g * in_pg * out_pg * kh * kw;
3107 let weight_2d = Tensor::from_storage(
3108 TensorStorage::cpu(
3109 weight_data_all[w_start..w_start + in_pg * out_pg * kh * kw].to_vec(),
3110 ),
3111 vec![in_pg, col_rows],
3112 false,
3113 )?;
3114
3115 let mut go_g = vec![zero; batch * out_pg * self.h_out * self.w_out];
3117 for b in 0..batch {
3118 for c in 0..out_pg {
3119 let src_c = g * out_pg + c;
3120 let src = (b * self.out_channels + src_c) * self.h_out * self.w_out;
3121 let dst = (b * out_pg + c) * self.h_out * self.w_out;
3122 go_g[dst..dst + self.h_out * self.w_out]
3123 .copy_from_slice(&go_data[src..src + self.h_out * self.w_out]);
3124 }
3125 }
3126
3127 let (go_cols, _gcr, go_col_cols) = im2col_dilated(
3128 &go_g, batch, out_pg, self.h_out, self.w_out, kh, kw, sh, sw, ph, pw, dh, dw,
3129 );
3130 debug_assert_eq!(go_col_cols, h_in * w_in);
3131
3132 for b in 0..batch {
3133 let col_start = b * col_rows * go_col_cols;
3134 let col_end = col_start + col_rows * go_col_cols;
3135 let go_cols_b = Tensor::from_storage(
3136 TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
3137 vec![col_rows, go_col_cols],
3138 false,
3139 )?;
3140 let gi_b = mm(&weight_2d, &go_cols_b)?;
3141 let gi_data = gi_b.data()?;
3142 for c in 0..in_pg {
3144 let dst_c = g * in_pg + c;
3145 let dst = (b * self.in_channels + dst_c) * h_in * w_in;
3146 let src = c * h_in * w_in;
3147 gi[dst..dst + h_in * w_in]
3148 .copy_from_slice(&gi_data[src..src + h_in * w_in]);
3149 }
3150 }
3151 }
3152
3153 if let Some(gw) = gw_all.as_mut() {
3157 for ci in 0..in_pg {
3158 let in_c = g * in_pg + ci;
3159 for co in 0..out_pg {
3160 let out_c = g * out_pg + co;
3161 for tkh in 0..kh {
3162 for tkw in 0..kw {
3163 let mut acc = zero;
3164 for ih in 0..h_in {
3165 for iw in 0..w_in {
3166 let oh = ih * sh + tkh * dh;
3167 let ow = iw * sw + tkw * dw;
3168 if oh >= ph
3169 && ow >= pw
3170 && (oh - ph) < self.h_out
3171 && (ow - pw) < self.w_out
3172 {
3173 let go_index = (out_c * self.h_out + (oh - ph))
3174 * self.w_out
3175 + (ow - pw);
3176 let in_index = (in_c * h_in + ih) * w_in + iw;
3177 for b in 0..batch {
3179 let goi =
3180 b * self.out_channels * self.h_out * self.w_out
3181 + go_index;
3182 let ini =
3183 b * self.in_channels * h_in * w_in + in_index;
3184 acc += input_data_all[ini] * go_data[goi];
3185 }
3186 }
3187 }
3188 }
3189 gw[((in_c * out_pg + co) * kh + tkh) * kw + tkw] += acc;
3191 }
3192 }
3193 }
3194 }
3195 }
3196 }
3197
3198 let grad_input = match gi_all {
3199 Some(gi) => Some(Tensor::from_storage(
3200 TensorStorage::cpu(gi),
3201 self.input.shape().to_vec(),
3202 false,
3203 )?),
3204 None => None,
3205 };
3206 let grad_weight = match gw_all {
3207 Some(gw) => Some(Tensor::from_storage(
3208 TensorStorage::cpu(gw),
3209 vec![self.in_channels, out_pg, kh, kw],
3210 false,
3211 )?),
3212 None => None,
3213 };
3214
3215 let grad_bias = match &self.bias {
3217 Some(b) if b.requires_grad() => {
3218 let zero = <T as num_traits::Zero>::zero();
3219 let mut gb = vec![zero; self.out_channels];
3220 for batch_idx in 0..batch {
3221 for c in 0..self.out_channels {
3222 for hw in 0..(self.h_out * self.w_out) {
3223 gb[c] +=
3224 go_data[batch_idx * self.out_channels * self.h_out * self.w_out
3225 + c * self.h_out * self.w_out
3226 + hw];
3227 }
3228 }
3229 }
3230 Some(Tensor::from_storage(
3231 TensorStorage::cpu(gb),
3232 vec![self.out_channels],
3233 false,
3234 )?)
3235 }
3236 _ => None,
3237 };
3238
3239 let mut grads = vec![grad_input, grad_weight];
3240 if self.bias.is_some() {
3241 grads.push(grad_bias);
3242 }
3243 Ok(grads)
3244 }
3245
3246 fn inputs(&self) -> Vec<&Tensor<T>> {
3247 let mut v = vec![&self.input, &self.weight];
3248 if let Some(ref b) = self.bias {
3249 v.push(b);
3250 }
3251 v
3252 }
3253
3254 fn name(&self) -> &'static str {
3255 "ConvTranspose2dBackward"
3256 }
3257}
3258
3259#[allow(clippy::too_many_arguments)]
3274fn im2col_3d_dilated<T: Float>(
3275 input: &[T],
3276 batch: usize,
3277 channels: usize,
3278 depth: usize,
3279 height: usize,
3280 width: usize,
3281 kernel_d: usize,
3282 kernel_h: usize,
3283 kernel_w: usize,
3284 stride_d: usize,
3285 stride_h: usize,
3286 stride_w: usize,
3287 pad_d: usize,
3288 pad_h: usize,
3289 pad_w: usize,
3290 dil_d: usize,
3291 dil_h: usize,
3292 dil_w: usize,
3293) -> (Vec<T>, usize, usize) {
3294 let eff_kd = dil_d * (kernel_d - 1) + 1;
3295 let eff_kh = dil_h * (kernel_h - 1) + 1;
3296 let eff_kw = dil_w * (kernel_w - 1) + 1;
3297 let d_out = (depth + 2 * pad_d - eff_kd) / stride_d + 1;
3298 let h_out = (height + 2 * pad_h - eff_kh) / stride_h + 1;
3299 let w_out = (width + 2 * pad_w - eff_kw) / stride_w + 1;
3300 let col_rows = channels * kernel_d * kernel_h * kernel_w;
3301 let col_cols = d_out * h_out * w_out;
3302
3303 let zero = <T as num_traits::Zero>::zero();
3304 let mut cols = vec![zero; batch * col_rows * col_cols];
3305
3306 for b in 0..batch {
3307 for c in 0..channels {
3308 for kd in 0..kernel_d {
3309 for kh in 0..kernel_h {
3310 for kw in 0..kernel_w {
3311 let row = c * kernel_d * kernel_h * kernel_w
3312 + kd * kernel_h * kernel_w
3313 + kh * kernel_w
3314 + kw;
3315 for od in 0..d_out {
3316 for oh in 0..h_out {
3317 for ow in 0..w_out {
3318 let id = od * stride_d + kd * dil_d;
3319 let ih = oh * stride_h + kh * dil_h;
3320 let iw = ow * stride_w + kw * dil_w;
3321 let col = od * h_out * w_out + oh * w_out + ow;
3322
3323 let val = if id >= pad_d
3324 && ih >= pad_h
3325 && iw >= pad_w
3326 && (id - pad_d) < depth
3327 && (ih - pad_h) < height
3328 && (iw - pad_w) < width
3329 {
3330 let real_d = id - pad_d;
3331 let real_h = ih - pad_h;
3332 let real_w = iw - pad_w;
3333 input[b * channels * depth * height * width
3334 + c * depth * height * width
3335 + real_d * height * width
3336 + real_h * width
3337 + real_w]
3338 } else {
3339 zero
3340 };
3341
3342 cols[b * col_rows * col_cols + row * col_cols + col] = val;
3343 }
3344 }
3345 }
3346 }
3347 }
3348 }
3349 }
3350 }
3351
3352 (cols, col_rows, col_cols)
3353}
3354
3355#[allow(clippy::too_many_arguments)]
3362fn col2im_3d_dilated<T: Float>(
3363 cols: &[T],
3364 batch: usize,
3365 channels: usize,
3366 depth: usize,
3367 height: usize,
3368 width: usize,
3369 kernel_d: usize,
3370 kernel_h: usize,
3371 kernel_w: usize,
3372 stride_d: usize,
3373 stride_h: usize,
3374 stride_w: usize,
3375 pad_d: usize,
3376 pad_h: usize,
3377 pad_w: usize,
3378 dil_d: usize,
3379 dil_h: usize,
3380 dil_w: usize,
3381 d_out: usize,
3382 h_out: usize,
3383 w_out: usize,
3384) -> Vec<T> {
3385 let zero = <T as num_traits::Zero>::zero();
3386 let mut output = vec![zero; batch * channels * depth * height * width];
3387
3388 let col_rows = channels * kernel_d * kernel_h * kernel_w;
3389 let col_cols = d_out * h_out * w_out;
3390
3391 for b in 0..batch {
3392 for c in 0..channels {
3393 for kd in 0..kernel_d {
3394 for kh in 0..kernel_h {
3395 for kw in 0..kernel_w {
3396 let row = c * kernel_d * kernel_h * kernel_w
3397 + kd * kernel_h * kernel_w
3398 + kh * kernel_w
3399 + kw;
3400 for od in 0..d_out {
3401 for oh in 0..h_out {
3402 for ow in 0..w_out {
3403 let id = od * stride_d + kd * dil_d;
3404 let ih = oh * stride_h + kh * dil_h;
3405 let iw = ow * stride_w + kw * dil_w;
3406 let col = od * h_out * w_out + oh * w_out + ow;
3407
3408 if id >= pad_d
3409 && ih >= pad_h
3410 && iw >= pad_w
3411 && (id - pad_d) < depth
3412 && (ih - pad_h) < height
3413 && (iw - pad_w) < width
3414 {
3415 let real_d = id - pad_d;
3416 let real_h = ih - pad_h;
3417 let real_w = iw - pad_w;
3418 output[b * channels * depth * height * width
3419 + c * depth * height * width
3420 + real_d * height * width
3421 + real_h * width
3422 + real_w] +=
3423 cols[b * col_rows * col_cols + row * col_cols + col];
3424 }
3425 }
3426 }
3427 }
3428 }
3429 }
3430 }
3431 }
3432 }
3433
3434 output
3435}
3436
3437#[derive(Debug)]
3454pub struct Conv3d<T: Float> {
3455 weight: Parameter<T>,
3457 bias: Option<Parameter<T>>,
3459 in_channels: usize,
3461 out_channels: usize,
3463 kernel_size: (usize, usize, usize),
3465 stride: (usize, usize, usize),
3467 padding: (usize, usize, usize),
3469 dilation: (usize, usize, usize),
3474 groups: usize,
3479 padding_mode: crate::padding::PaddingMode,
3487 string_padding: Option<StringPadding>,
3494 training: bool,
3496}
3497
3498impl<T: Float> Conv3d<T> {
3499 pub fn new(
3508 in_channels: usize,
3509 out_channels: usize,
3510 kernel_size: (usize, usize, usize),
3511 stride: (usize, usize, usize),
3512 padding: (usize, usize, usize),
3513 bias: bool,
3514 ) -> FerrotorchResult<Self> {
3515 Self::new_full(
3516 in_channels,
3517 out_channels,
3518 kernel_size,
3519 stride,
3520 padding,
3521 (1, 1, 1),
3522 1,
3523 bias,
3524 )
3525 }
3526
3527 #[allow(clippy::too_many_arguments)]
3543 pub fn new_full(
3544 in_channels: usize,
3545 out_channels: usize,
3546 kernel_size: (usize, usize, usize),
3547 stride: (usize, usize, usize),
3548 padding: (usize, usize, usize),
3549 dilation: (usize, usize, usize),
3550 groups: usize,
3551 bias: bool,
3552 ) -> FerrotorchResult<Self> {
3553 if in_channels == 0 || out_channels == 0 {
3554 return Err(FerrotorchError::InvalidArgument {
3555 message: "in_channels and out_channels must be > 0".into(),
3556 });
3557 }
3558 if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
3559 return Err(FerrotorchError::InvalidArgument {
3560 message: "kernel_size must be > 0 in all dimensions".into(),
3561 });
3562 }
3563 if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
3564 return Err(FerrotorchError::InvalidArgument {
3565 message: "stride must be > 0 in all dimensions".into(),
3566 });
3567 }
3568 if dilation.0 == 0 || dilation.1 == 0 || dilation.2 == 0 {
3569 return Err(FerrotorchError::InvalidArgument {
3570 message: format!(
3571 "Conv3d::new_full: dilation must be > 0 in all dimensions, got {dilation:?}"
3572 ),
3573 });
3574 }
3575 if groups == 0 {
3576 return Err(FerrotorchError::InvalidArgument {
3577 message: "Conv3d::new_full: groups must be > 0".into(),
3578 });
3579 }
3580 if in_channels % groups != 0 {
3583 return Err(FerrotorchError::InvalidArgument {
3584 message: format!(
3585 "Conv3d::new_full: groups={groups} must divide in_channels={in_channels}"
3586 ),
3587 });
3588 }
3589 if out_channels % groups != 0 {
3590 return Err(FerrotorchError::InvalidArgument {
3591 message: format!(
3592 "Conv3d::new_full: groups={groups} must divide out_channels={out_channels}"
3593 ),
3594 });
3595 }
3596
3597 let (kd, kh, kw) = kernel_size;
3598 let mut weight = Parameter::zeros(&[out_channels, in_channels / groups, kd, kh, kw])?;
3600 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
3601
3602 let bias_param = if bias {
3603 let mut b = Parameter::zeros(&[out_channels])?;
3604 let fan_in = (in_channels / groups) * kd * kh * kw;
3607 let bound = if fan_in > 0 {
3608 1.0 / (fan_in as f64).sqrt()
3609 } else {
3610 0.0
3611 };
3612 uniform_init(&mut b, -bound, bound)?;
3613 Some(b)
3614 } else {
3615 None
3616 };
3617
3618 Ok(Self {
3619 weight,
3620 bias: bias_param,
3621 in_channels,
3622 out_channels,
3623 kernel_size,
3624 stride,
3625 padding,
3626 dilation,
3627 groups,
3628 padding_mode: crate::padding::PaddingMode::Zeros,
3629 string_padding: None,
3630 training: true,
3631 })
3632 }
3633
3634 pub fn groups(&self) -> usize {
3636 self.groups
3637 }
3638
3639 pub fn dilation(&self) -> (usize, usize, usize) {
3641 self.dilation
3642 }
3643
3644 pub fn with_string_padding(mut self, padding: StringPadding) -> FerrotorchResult<Self> {
3663 if padding == StringPadding::Same
3664 && (self.stride.0 != 1 || self.stride.1 != 1 || self.stride.2 != 1)
3665 {
3666 return Err(FerrotorchError::InvalidArgument {
3667 message: "padding='same' is not supported for strided convolutions".into(),
3668 });
3669 }
3670 self.string_padding = Some(padding);
3671 self.padding = (0, 0, 0);
3672 Ok(self)
3673 }
3674
3675 pub fn with_padding_mode(mut self, mode: crate::padding::PaddingMode) -> Self {
3685 self.padding_mode = mode;
3686 self
3687 }
3688
3689 pub fn num_parameters(&self) -> usize {
3691 let w = self.out_channels
3692 * self.in_channels
3693 * self.kernel_size.0
3694 * self.kernel_size.1
3695 * self.kernel_size.2;
3696 let b = if self.bias.is_some() {
3697 self.out_channels
3698 } else {
3699 0
3700 };
3701 w + b
3702 }
3703
3704 pub fn from_parts(
3712 weight: Tensor<T>,
3713 bias: Option<Tensor<T>>,
3714 stride: (usize, usize, usize),
3715 padding: (usize, usize, usize),
3716 ) -> FerrotorchResult<Self> {
3717 if weight.ndim() != 5 {
3718 return Err(FerrotorchError::ShapeMismatch {
3719 message: format!(
3720 "Conv3d::from_parts: weight must be 5-D [out, in, kD, kH, kW], got {:?}",
3721 weight.shape()
3722 ),
3723 });
3724 }
3725 let out_channels = weight.shape()[0];
3726 let in_channels = weight.shape()[1];
3727 let kernel_size = (weight.shape()[2], weight.shape()[3], weight.shape()[4]);
3728 if let Some(b) = &bias {
3729 if b.ndim() != 1 || b.shape()[0] != out_channels {
3730 return Err(FerrotorchError::ShapeMismatch {
3731 message: format!(
3732 "Conv3d::from_parts: bias shape {:?} != [{}]",
3733 b.shape(),
3734 out_channels
3735 ),
3736 });
3737 }
3738 }
3739 Ok(Self {
3740 weight: Parameter::new(weight),
3741 bias: bias.map(Parameter::new),
3742 in_channels,
3743 out_channels,
3744 kernel_size,
3745 stride,
3746 padding,
3747 dilation: (1, 1, 1),
3748 groups: 1,
3749 padding_mode: crate::padding::PaddingMode::Zeros,
3750 string_padding: None,
3751 training: true,
3752 })
3753 }
3754
3755 fn recurse_clone(
3760 &self,
3761 padding: (usize, usize, usize),
3762 padding_mode: crate::padding::PaddingMode,
3763 ) -> Conv3d<T> {
3764 Conv3d {
3765 weight: Parameter::new(self.weight.tensor().clone()),
3766 bias: self
3767 .bias
3768 .as_ref()
3769 .map(|b| Parameter::new(b.tensor().clone())),
3770 in_channels: self.in_channels,
3771 out_channels: self.out_channels,
3772 kernel_size: self.kernel_size,
3773 stride: self.stride,
3774 padding,
3775 dilation: self.dilation,
3776 groups: self.groups,
3777 padding_mode,
3778 string_padding: None,
3779 training: self.training,
3780 }
3781 }
3782}
3783
3784impl<T: Float> Module<T> for Conv3d<T> {
3785 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3786 let _autocast_cat = autocast_guard("conv3d");
3788
3789 if input.ndim() == 4 {
3797 let batched = unsqueeze(input, 0)?;
3798 let output = self.forward(&batched)?;
3799 return squeeze(&output, 0);
3800 }
3801
3802 if let Some(sp) = self.string_padding {
3814 match sp {
3815 StringPadding::Valid => {
3816 return self
3817 .recurse_clone((0, 0, 0), crate::padding::PaddingMode::Zeros)
3818 .forward(input);
3819 }
3820 StringPadding::Same => {
3821 let (kd, kh, kw) = self.kernel_size;
3822 let (dd, dh, dw) = self.dilation;
3823 let (front, back) = same_pad_lr(kd, dd);
3824 let (top, bottom) = same_pad_lr(kh, dh);
3825 let (left, right) = same_pad_lr(kw, dw);
3826 let padded = crate::padding::functional_pad_3d(
3827 input,
3828 left,
3829 right,
3830 top,
3831 bottom,
3832 front,
3833 back,
3834 self.padding_mode,
3835 <T as num_traits::Zero>::zero(),
3836 )?;
3837 return self
3838 .recurse_clone((0, 0, 0), crate::padding::PaddingMode::Zeros)
3839 .forward(&padded);
3840 }
3841 }
3842 }
3843
3844 if self.padding_mode != crate::padding::PaddingMode::Zeros
3857 && (self.padding.0 != 0 || self.padding.1 != 0 || self.padding.2 != 0)
3858 {
3859 let (pd, ph, pw) = self.padding;
3860 let padded = crate::padding::functional_pad_3d(
3861 input,
3862 pw,
3863 pw,
3864 ph,
3865 ph,
3866 pd,
3867 pd,
3868 self.padding_mode,
3869 <T as num_traits::Zero>::zero(),
3870 )?;
3871 return self
3875 .recurse_clone((0, 0, 0), crate::padding::PaddingMode::Zeros)
3876 .forward(&padded);
3877 }
3878
3879 if input.ndim() != 5 {
3881 return Err(FerrotorchError::InvalidArgument {
3882 message: format!(
3883 "Conv3d expects 5-D input [B, C, D, H, W], got {:?}",
3884 input.shape()
3885 ),
3886 });
3887 }
3888
3889 let batch = input.shape()[0];
3890 let c_in = input.shape()[1];
3891 let d = input.shape()[2];
3892 let h = input.shape()[3];
3893 let w = input.shape()[4];
3894
3895 if c_in != self.in_channels {
3896 return Err(FerrotorchError::ShapeMismatch {
3897 message: format!(
3898 "Conv3d: expected {} input channels, got {}",
3899 self.in_channels, c_in
3900 ),
3901 });
3902 }
3903
3904 let (kd, kh, kw) = self.kernel_size;
3905 let (sd, sh, sw) = self.stride;
3906 let (pd, ph, pw) = self.padding;
3907 let (dd, dh, dw) = self.dilation;
3908 let groups = self.groups;
3909
3910 let eff_kd = dd * (kd - 1) + 1;
3912 let eff_kh = dh * (kh - 1) + 1;
3913 let eff_kw = dw * (kw - 1) + 1;
3914
3915 let d_padded = d + 2 * pd;
3916 let h_padded = h + 2 * ph;
3917 let w_padded = w + 2 * pw;
3918 if d_padded < eff_kd || h_padded < eff_kh || w_padded < eff_kw {
3919 return Err(FerrotorchError::InvalidArgument {
3920 message: format!(
3921 "Conv3d: padded input ({d_padded}, {h_padded}, {w_padded}) is smaller than effective kernel ({eff_kd}, {eff_kh}, {eff_kw})"
3922 ),
3923 });
3924 }
3925
3926 let d_out = (d_padded - eff_kd) / sd + 1;
3927 let h_out = (h_padded - eff_kh) / sh + 1;
3928 let w_out = (w_padded - eff_kw) / sw + 1;
3929
3930 let input_device = input.device();
3932
3933 let input_data = input.data_vec()?;
3939 let weight_data = self.weight.data_vec()?;
3940
3941 let zero = <T as num_traits::Zero>::zero();
3942 let spatial_in = d * h * w;
3943 let spatial_out = d_out * h_out * w_out;
3944 let mut output = vec![zero; batch * self.out_channels * spatial_out];
3945
3946 let in_per_group = self.in_channels / groups;
3948 let out_per_group = self.out_channels / groups;
3949 let group_col_rows = in_per_group * kd * kh * kw;
3950 let weight_per_group_numel = out_per_group * group_col_rows;
3951 let col_cols = spatial_out;
3952
3953 let saved_cols_rows = self.in_channels * kd * kh * kw;
3957 let mut saved_cols: Vec<T> = if is_grad_enabled()
3958 && (input.requires_grad()
3959 || self.weight.requires_grad()
3960 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
3961 {
3962 vec![zero; batch * saved_cols_rows * col_cols]
3963 } else {
3964 Vec::new()
3965 };
3966 let save_cols = !saved_cols.is_empty();
3967 let kvol = kd * kh * kw;
3968
3969 for g in 0..groups {
3970 let mut group_input = vec![zero; batch * in_per_group * spatial_in];
3972 for b in 0..batch {
3973 for c in 0..in_per_group {
3974 let src_c = g * in_per_group + c;
3975 let src_start = b * self.in_channels * spatial_in + src_c * spatial_in;
3976 let dst_start = b * in_per_group * spatial_in + c * spatial_in;
3977 group_input[dst_start..dst_start + spatial_in]
3978 .copy_from_slice(&input_data[src_start..src_start + spatial_in]);
3979 }
3980 }
3981
3982 let (g_cols, g_col_rows, g_col_cols) = im2col_3d_dilated(
3983 &group_input,
3984 batch,
3985 in_per_group,
3986 d,
3987 h,
3988 w,
3989 kd,
3990 kh,
3991 kw,
3992 sd,
3993 sh,
3994 sw,
3995 pd,
3996 ph,
3997 pw,
3998 dd,
3999 dh,
4000 dw,
4001 );
4002 debug_assert_eq!(g_col_rows, group_col_rows);
4003 debug_assert_eq!(g_col_cols, col_cols);
4004
4005 if save_cols {
4007 for b in 0..batch {
4008 for c in 0..in_per_group {
4009 let dst_c = g * in_per_group + c;
4010 for kk in 0..kvol {
4011 let src_row = c * kvol + kk;
4012 let dst_row = dst_c * kvol + kk;
4013 let src_off = b * group_col_rows * col_cols + src_row * col_cols;
4014 let dst_off = b * saved_cols_rows * col_cols + dst_row * col_cols;
4015 saved_cols[dst_off..dst_off + col_cols]
4016 .copy_from_slice(&g_cols[src_off..src_off + col_cols]);
4017 }
4018 }
4019 }
4020 }
4021
4022 let w_group_start = g * weight_per_group_numel;
4025 let w_group_end = w_group_start + weight_per_group_numel;
4026 let weight_group_2d = Tensor::from_storage(
4027 TensorStorage::cpu(weight_data[w_group_start..w_group_end].to_vec()),
4028 vec![out_per_group, group_col_rows],
4029 false,
4030 )?;
4031
4032 for b in 0..batch {
4033 let col_start = b * group_col_rows * col_cols;
4034 let col_end = col_start + group_col_rows * col_cols;
4035 let cols_b = Tensor::from_storage(
4036 TensorStorage::cpu(g_cols[col_start..col_end].to_vec()),
4037 vec![group_col_rows, col_cols],
4038 false,
4039 )?;
4040
4041 let out_b = mm(&weight_group_2d, &cols_b)?;
4042 let out_data = out_b.data()?;
4043 for oc in 0..out_per_group {
4044 let dst_c = g * out_per_group + oc;
4045 let dst_start = b * self.out_channels * spatial_out + dst_c * spatial_out;
4046 let src_start = oc * spatial_out;
4047 output[dst_start..dst_start + spatial_out]
4048 .copy_from_slice(&out_data[src_start..src_start + spatial_out]);
4049 }
4050 }
4051 }
4052
4053 if let Some(ref bias) = self.bias {
4055 let bias_data = bias.data_vec()?;
4056 for b in 0..batch {
4057 for c in 0..self.out_channels {
4058 let bval = bias_data[c];
4059 for s in 0..spatial_out {
4060 output[b * self.out_channels * spatial_out + c * spatial_out + s] += bval;
4061 }
4062 }
4063 }
4064 }
4065
4066 let result = Tensor::from_storage(
4067 TensorStorage::cpu(output),
4068 vec![batch, self.out_channels, d_out, h_out, w_out],
4069 false,
4070 )?;
4071
4072 if save_cols {
4074 let grad_fn = Arc::new(Conv3dBackward {
4075 input: input.clone(),
4076 weight: self.weight.tensor().clone(),
4077 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
4078 in_channels: self.in_channels,
4079 out_channels: self.out_channels,
4080 kernel_size: self.kernel_size,
4081 stride: self.stride,
4082 padding: self.padding,
4083 dilation: self.dilation,
4084 groups: self.groups,
4085 cols: saved_cols,
4086 col_rows: saved_cols_rows,
4087 col_cols,
4088 d_out,
4089 h_out,
4090 w_out,
4091 });
4092 Tensor::from_operation(
4093 TensorStorage::cpu(result.data()?.to_vec()),
4094 result.shape().to_vec(),
4095 grad_fn,
4096 )?
4097 .to(input_device) } else {
4099 result.to(input_device)
4100 }
4101 }
4102
4103 fn parameters(&self) -> Vec<&Parameter<T>> {
4104 let mut params = vec![&self.weight];
4105 if let Some(ref b) = self.bias {
4106 params.push(b);
4107 }
4108 params
4109 }
4110
4111 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4112 let mut params = vec![&mut self.weight];
4113 if let Some(ref mut b) = self.bias {
4114 params.push(b);
4115 }
4116 params
4117 }
4118
4119 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4120 let mut params = vec![("weight".to_string(), &self.weight)];
4121 if let Some(ref b) = self.bias {
4122 params.push(("bias".to_string(), b));
4123 }
4124 params
4125 }
4126
4127 fn train(&mut self) {
4128 self.training = true;
4129 }
4130
4131 fn eval(&mut self) {
4132 self.training = false;
4133 }
4134
4135 fn is_training(&self) -> bool {
4136 self.training
4137 }
4138}
4139
4140#[derive(Debug)]
4157struct Conv3dBackward<T: Float> {
4158 input: Tensor<T>,
4159 weight: Tensor<T>,
4160 bias: Option<Tensor<T>>,
4161 in_channels: usize,
4162 out_channels: usize,
4163 kernel_size: (usize, usize, usize),
4164 stride: (usize, usize, usize),
4165 padding: (usize, usize, usize),
4166 dilation: (usize, usize, usize),
4167 groups: usize,
4168 cols: Vec<T>,
4169 col_rows: usize,
4170 col_cols: usize,
4171 d_out: usize,
4172 h_out: usize,
4173 w_out: usize,
4174}
4175
4176impl<T: Float> GradFn<T> for Conv3dBackward<T> {
4177 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
4178 let input_device = self.input.device();
4180 let weight_device = self.weight.device();
4181 let bias_device = self.bias.as_ref().map(|b| b.device());
4182 let go_data = grad_output.data_vec()?;
4183 let batch = self.input.shape()[0];
4184 let d = self.input.shape()[2];
4185 let h = self.input.shape()[3];
4186 let w = self.input.shape()[4];
4187 let (kd, kh, kw) = self.kernel_size;
4188 let (sd, sh, sw) = self.stride;
4189 let (pd, ph, pw) = self.padding;
4190 let (dd, dh, dw) = self.dilation;
4191 let groups = self.groups;
4192 let in_per_group = self.in_channels / groups;
4193 let out_per_group = self.out_channels / groups;
4194 let kvol = kd * kh * kw;
4195 let group_col_rows = in_per_group * kvol;
4196 let spatial_in = d * h * w;
4197 let spatial_out = self.d_out * self.h_out * self.w_out;
4198 let zero = <T as num_traits::Zero>::zero();
4199
4200 let grad_weight = if self.weight.requires_grad() {
4205 let weight_numel = self.out_channels * group_col_rows;
4206 let mut gw_accum = vec![zero; weight_numel];
4207 let weight_per_group_numel = out_per_group * group_col_rows;
4208
4209 for g in 0..groups {
4210 for b in 0..batch {
4211 let mut go_g = vec![zero; out_per_group * spatial_out];
4213 for oc in 0..out_per_group {
4214 let src_c = g * out_per_group + oc;
4215 let src_start = b * self.out_channels * spatial_out + src_c * spatial_out;
4216 let dst_start = oc * spatial_out;
4217 go_g[dst_start..dst_start + spatial_out]
4218 .copy_from_slice(&go_data[src_start..src_start + spatial_out]);
4219 }
4220 let go_b_g = Tensor::from_storage(
4221 TensorStorage::cpu(go_g),
4222 vec![out_per_group, spatial_out],
4223 false,
4224 )?;
4225
4226 let mut cols_g = vec![zero; group_col_rows * self.col_cols];
4228 for c in 0..in_per_group {
4229 let src_c = g * in_per_group + c;
4230 for kk in 0..kvol {
4231 let src_row = src_c * kvol + kk;
4232 let dst_row = c * kvol + kk;
4233 let src_off =
4234 b * self.col_rows * self.col_cols + src_row * self.col_cols;
4235 let dst_off = dst_row * self.col_cols;
4236 cols_g[dst_off..dst_off + self.col_cols]
4237 .copy_from_slice(&self.cols[src_off..src_off + self.col_cols]);
4238 }
4239 }
4240 let cols_b_g = Tensor::from_storage(
4241 TensorStorage::cpu(cols_g),
4242 vec![group_col_rows, self.col_cols],
4243 false,
4244 )?;
4245
4246 let cols_bt = transpose(&cols_b_g)?;
4247 let gw_b = mm(&go_b_g, &cols_bt)?;
4248 let gw_data = gw_b.data()?;
4249
4250 let dst_off = g * weight_per_group_numel;
4251 for i in 0..weight_per_group_numel {
4252 gw_accum[dst_off + i] += gw_data[i];
4253 }
4254 }
4255 }
4256
4257 Some(
4258 Tensor::from_storage(
4259 TensorStorage::cpu(gw_accum),
4260 vec![self.out_channels, in_per_group, kd, kh, kw],
4261 false,
4262 )?
4263 .to(weight_device)?,
4264 )
4265 } else {
4266 None
4267 };
4268
4269 let grad_bias = match &self.bias {
4273 Some(b) if b.requires_grad() => {
4274 let mut gb = vec![zero; self.out_channels];
4275 for batch_idx in 0..batch {
4276 for c in 0..self.out_channels {
4277 for s in 0..spatial_out {
4278 gb[c] += go_data
4279 [batch_idx * self.out_channels * spatial_out + c * spatial_out + s];
4280 }
4281 }
4282 }
4283 let target_dev = bias_device.unwrap_or(input_device);
4284 Some(
4285 Tensor::from_storage(TensorStorage::cpu(gb), vec![self.out_channels], false)?
4286 .to(target_dev)?,
4287 )
4288 }
4289 _ => None,
4290 };
4291
4292 let grad_input = if self.input.requires_grad() {
4298 let weight_data = self.weight.data_vec()?;
4299 let mut grad_input_data = vec![zero; batch * self.in_channels * spatial_in];
4300 let weight_per_group_numel = out_per_group * group_col_rows;
4301
4302 for g in 0..groups {
4303 let w_off = g * weight_per_group_numel;
4304 let weight_g_2d = Tensor::from_storage(
4305 TensorStorage::cpu(weight_data[w_off..w_off + weight_per_group_numel].to_vec()),
4306 vec![out_per_group, group_col_rows],
4307 false,
4308 )?;
4309 let weight_g_t = transpose(&weight_g_2d)?;
4310
4311 let mut grad_cols_g = vec![zero; batch * group_col_rows * self.col_cols];
4312 for b in 0..batch {
4313 let mut go_g = vec![zero; out_per_group * spatial_out];
4314 for oc in 0..out_per_group {
4315 let src_c = g * out_per_group + oc;
4316 let src_start = b * self.out_channels * spatial_out + src_c * spatial_out;
4317 let dst_start = oc * spatial_out;
4318 go_g[dst_start..dst_start + spatial_out]
4319 .copy_from_slice(&go_data[src_start..src_start + spatial_out]);
4320 }
4321 let go_b_g = Tensor::from_storage(
4322 TensorStorage::cpu(go_g),
4323 vec![out_per_group, spatial_out],
4324 false,
4325 )?;
4326
4327 let gc_b = mm(&weight_g_t, &go_b_g)?;
4328 let gc_data = gc_b.data()?;
4329 let gc_start = b * group_col_rows * self.col_cols;
4330 grad_cols_g[gc_start..gc_start + group_col_rows * self.col_cols]
4331 .copy_from_slice(gc_data);
4332 }
4333
4334 let gi_g = col2im_3d_dilated(
4337 &grad_cols_g,
4338 batch,
4339 in_per_group,
4340 d,
4341 h,
4342 w,
4343 kd,
4344 kh,
4345 kw,
4346 sd,
4347 sh,
4348 sw,
4349 pd,
4350 ph,
4351 pw,
4352 dd,
4353 dh,
4354 dw,
4355 self.d_out,
4356 self.h_out,
4357 self.w_out,
4358 );
4359
4360 for b in 0..batch {
4361 for c in 0..in_per_group {
4362 let dst_c = g * in_per_group + c;
4363 let dst_start = b * self.in_channels * spatial_in + dst_c * spatial_in;
4364 let src_start = b * in_per_group * spatial_in + c * spatial_in;
4365 grad_input_data[dst_start..dst_start + spatial_in]
4366 .copy_from_slice(&gi_g[src_start..src_start + spatial_in]);
4367 }
4368 }
4369 }
4370
4371 Some(
4372 Tensor::from_storage(
4373 TensorStorage::cpu(grad_input_data),
4374 self.input.shape().to_vec(),
4375 false,
4376 )?
4377 .to(input_device)?,
4378 )
4379 } else {
4380 None
4381 };
4382
4383 let mut grads = vec![grad_input, grad_weight];
4385 if self.bias.is_some() {
4386 grads.push(grad_bias);
4387 }
4388 Ok(grads)
4389 }
4390
4391 fn inputs(&self) -> Vec<&Tensor<T>> {
4392 let mut v = vec![&self.input, &self.weight];
4393 if let Some(ref b) = self.bias {
4394 v.push(b);
4395 }
4396 v
4397 }
4398
4399 fn name(&self) -> &'static str {
4400 "Conv3dBackward"
4401 }
4402}
4403
4404#[derive(Debug)]
4426pub struct ConvTranspose1d<T: Float> {
4427 weight: Parameter<T>,
4432 bias: Option<Parameter<T>>,
4434 in_channels: usize,
4436 out_channels: usize,
4438 kernel_size: usize,
4440 stride: usize,
4442 padding: usize,
4444 output_padding: usize,
4446 dilation: usize,
4449 groups: usize,
4454 training: bool,
4456}
4457
4458impl<T: Float> ConvTranspose1d<T> {
4459 pub fn new(
4467 in_channels: usize,
4468 out_channels: usize,
4469 kernel_size: usize,
4470 stride: usize,
4471 padding: usize,
4472 output_padding: usize,
4473 bias: bool,
4474 ) -> FerrotorchResult<Self> {
4475 Self::new_full(
4476 in_channels,
4477 out_channels,
4478 kernel_size,
4479 stride,
4480 padding,
4481 output_padding,
4482 1,
4483 1,
4484 bias,
4485 )
4486 }
4487
4488 #[allow(clippy::too_many_arguments)]
4496 pub fn new_full(
4497 in_channels: usize,
4498 out_channels: usize,
4499 kernel_size: usize,
4500 stride: usize,
4501 padding: usize,
4502 output_padding: usize,
4503 dilation: usize,
4504 groups: usize,
4505 bias: bool,
4506 ) -> FerrotorchResult<Self> {
4507 if in_channels == 0 || out_channels == 0 {
4508 return Err(FerrotorchError::InvalidArgument {
4509 message: "in_channels and out_channels must be > 0".into(),
4510 });
4511 }
4512 if kernel_size == 0 {
4513 return Err(FerrotorchError::InvalidArgument {
4514 message: "kernel_size must be > 0".into(),
4515 });
4516 }
4517 if stride == 0 {
4518 return Err(FerrotorchError::InvalidArgument {
4519 message: "stride must be > 0".into(),
4520 });
4521 }
4522 if dilation == 0 {
4523 return Err(FerrotorchError::InvalidArgument {
4524 message: "dilation must be > 0".into(),
4525 });
4526 }
4527 if groups == 0 {
4528 return Err(FerrotorchError::InvalidArgument {
4529 message: "groups must be a positive integer".into(),
4530 });
4531 }
4532 if in_channels % groups != 0 {
4533 return Err(FerrotorchError::InvalidArgument {
4534 message: format!(
4535 "in_channels ({in_channels}) must be divisible by groups ({groups})"
4536 ),
4537 });
4538 }
4539 if out_channels % groups != 0 {
4540 return Err(FerrotorchError::InvalidArgument {
4541 message: format!(
4542 "out_channels ({out_channels}) must be divisible by groups ({groups})"
4543 ),
4544 });
4545 }
4546 if output_padding >= stride.max(dilation) {
4547 return Err(FerrotorchError::InvalidArgument {
4548 message: "output_padding must be strictly less than max(stride, dilation)".into(),
4549 });
4550 }
4551
4552 let out_per_group = out_channels / groups;
4554 let mut weight = Parameter::zeros(&[in_channels, out_per_group, kernel_size])?;
4555 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
4556
4557 let bias_param = if bias {
4558 let mut b = Parameter::zeros(&[out_channels])?;
4559 let fan_in = out_per_group * kernel_size;
4563 let bound = if fan_in > 0 {
4564 1.0 / (fan_in as f64).sqrt()
4565 } else {
4566 0.0
4567 };
4568 uniform_init(&mut b, -bound, bound)?;
4569 Some(b)
4570 } else {
4571 None
4572 };
4573
4574 Ok(Self {
4575 weight,
4576 bias: bias_param,
4577 in_channels,
4578 out_channels,
4579 kernel_size,
4580 stride,
4581 padding,
4582 output_padding,
4583 dilation,
4584 groups,
4585 training: true,
4586 })
4587 }
4588
4589 pub fn with_padding_mode(self, mode: crate::padding::PaddingMode) -> FerrotorchResult<Self> {
4600 reject_non_zeros_transpose(mode, "ConvTranspose1d")?;
4601 Ok(self)
4602 }
4603
4604 pub fn num_parameters(&self) -> usize {
4606 let w = self.in_channels * self.out_channels * self.kernel_size;
4607 let b = if self.bias.is_some() {
4608 self.out_channels
4609 } else {
4610 0
4611 };
4612 w + b
4613 }
4614
4615 pub fn from_parts(
4621 weight: Tensor<T>,
4622 bias: Option<Tensor<T>>,
4623 stride: usize,
4624 padding: usize,
4625 output_padding: usize,
4626 ) -> FerrotorchResult<Self> {
4627 if weight.ndim() != 3 {
4628 return Err(FerrotorchError::ShapeMismatch {
4629 message: format!(
4630 "ConvTranspose1d::from_parts: weight must be 3-D [in, out, k], got {:?}",
4631 weight.shape()
4632 ),
4633 });
4634 }
4635 let in_channels = weight.shape()[0];
4636 let out_channels = weight.shape()[1];
4637 let kernel_size = weight.shape()[2];
4638 if output_padding >= stride {
4639 return Err(FerrotorchError::InvalidArgument {
4640 message: "output_padding must be strictly less than stride".into(),
4641 });
4642 }
4643 if let Some(b) = &bias {
4644 if b.ndim() != 1 || b.shape()[0] != out_channels {
4645 return Err(FerrotorchError::ShapeMismatch {
4646 message: format!(
4647 "ConvTranspose1d::from_parts: bias shape {:?} != [{}]",
4648 b.shape(),
4649 out_channels
4650 ),
4651 });
4652 }
4653 }
4654 Ok(Self {
4655 weight: Parameter::new(weight),
4656 bias: bias.map(Parameter::new),
4657 in_channels,
4658 out_channels,
4659 kernel_size,
4660 stride,
4661 padding,
4662 output_padding,
4663 dilation: 1,
4667 groups: 1,
4668 training: true,
4669 })
4670 }
4671}
4672
4673impl<T: Float> Module<T> for ConvTranspose1d<T> {
4674 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4675 let _autocast_cat = autocast_guard("conv_transpose1d");
4677
4678 if input.ndim() == 2 {
4684 let batched = unsqueeze(input, 0)?;
4685 let output = self.forward(&batched)?;
4686 return squeeze(&output, 0);
4687 }
4688
4689 if input.ndim() != 3 {
4691 return Err(FerrotorchError::InvalidArgument {
4692 message: format!(
4693 "ConvTranspose1d expects 3-D input [B, C, L], got {:?}",
4694 input.shape()
4695 ),
4696 });
4697 }
4698
4699 let batch = input.shape()[0];
4700 let c_in = input.shape()[1];
4701 let length = input.shape()[2];
4702
4703 if c_in != self.in_channels {
4704 return Err(FerrotorchError::ShapeMismatch {
4705 message: format!(
4706 "ConvTranspose1d: expected {} input channels, got {}",
4707 self.in_channels, c_in
4708 ),
4709 });
4710 }
4711
4712 let k = self.kernel_size;
4713 let groups = self.groups;
4714 let in_pg = self.in_channels / groups;
4715 let out_pg = self.out_channels / groups;
4716 let weight_pg_numel = in_pg * out_pg * k;
4717
4718 let input_device = input.device();
4720
4721 let input_data = input.data_vec()?;
4722 let weight_data = self.weight.data_vec()?;
4723
4724 let zero = <T as num_traits::Zero>::zero();
4730 let mut output: Vec<T> = Vec::new();
4731 let mut l_out = 0usize;
4732
4733 for g in 0..groups {
4734 let mut group_input = vec![zero; batch * in_pg * length];
4735 for b in 0..batch {
4736 for c in 0..in_pg {
4737 let src_c = g * in_pg + c;
4738 let src = (b * self.in_channels + src_c) * length;
4739 let dst = (b * in_pg + c) * length;
4740 group_input[dst..dst + length].copy_from_slice(&input_data[src..src + length]);
4741 }
4742 }
4743
4744 let w_start = g * weight_pg_numel;
4745 let group_weight = &weight_data[w_start..w_start + weight_pg_numel];
4746
4747 let (g_out, gho, glo) = conv_transpose2d_forward_group(
4748 &group_input,
4749 batch,
4750 in_pg,
4751 out_pg,
4752 1,
4753 length,
4754 (1, k),
4755 (1, self.stride),
4756 (0, self.padding),
4757 (0, self.output_padding),
4758 (1, self.dilation),
4759 group_weight,
4760 )?;
4761 debug_assert_eq!(gho, 1);
4762 l_out = glo;
4763
4764 if output.is_empty() {
4765 output = vec![zero; batch * self.out_channels * l_out];
4766 }
4767 for b in 0..batch {
4768 for oc in 0..out_pg {
4769 let dst_c = g * out_pg + oc;
4770 let dst = (b * self.out_channels + dst_c) * l_out;
4771 let src = (b * out_pg + oc) * l_out;
4772 output[dst..dst + l_out].copy_from_slice(&g_out[src..src + l_out]);
4773 }
4774 }
4775 }
4776
4777 if let Some(ref bias) = self.bias {
4779 let bias_data = bias.data_vec()?;
4780 for b in 0..batch {
4781 for c in 0..self.out_channels {
4782 let bval = bias_data[c];
4783 for l in 0..l_out {
4784 output[b * self.out_channels * l_out + c * l_out + l] += bval;
4785 }
4786 }
4787 }
4788 }
4789
4790 let result = Tensor::from_storage(
4791 TensorStorage::cpu(output),
4792 vec![batch, self.out_channels, l_out],
4793 false,
4794 )?;
4795
4796 if is_grad_enabled()
4798 && (input.requires_grad()
4799 || self.weight.requires_grad()
4800 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
4801 {
4802 let grad_fn = Arc::new(ConvTranspose1dBackward {
4803 input: input.clone(),
4804 weight: self.weight.tensor().clone(),
4805 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
4806 in_channels: self.in_channels,
4807 out_channels: self.out_channels,
4808 kernel_size: self.kernel_size,
4809 stride: self.stride,
4810 padding: self.padding,
4811 _output_padding: self.output_padding,
4812 dilation: self.dilation,
4813 groups: self.groups,
4814 l_out,
4815 });
4816 Tensor::from_operation(
4817 TensorStorage::cpu(result.data()?.to_vec()),
4818 result.shape().to_vec(),
4819 grad_fn,
4820 )?
4821 .to(input_device) } else {
4823 result.to(input_device)
4824 }
4825 }
4826
4827 fn parameters(&self) -> Vec<&Parameter<T>> {
4828 let mut params = vec![&self.weight];
4829 if let Some(ref b) = self.bias {
4830 params.push(b);
4831 }
4832 params
4833 }
4834
4835 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4836 let mut params = vec![&mut self.weight];
4837 if let Some(ref mut b) = self.bias {
4838 params.push(b);
4839 }
4840 params
4841 }
4842
4843 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4844 let mut params = vec![("weight".to_string(), &self.weight)];
4845 if let Some(ref b) = self.bias {
4846 params.push(("bias".to_string(), b));
4847 }
4848 params
4849 }
4850
4851 fn train(&mut self) {
4852 self.training = true;
4853 }
4854
4855 fn eval(&mut self) {
4856 self.training = false;
4857 }
4858
4859 fn is_training(&self) -> bool {
4860 self.training
4861 }
4862}
4863
4864#[derive(Debug)]
4872struct ConvTranspose1dBackward<T: Float> {
4873 input: Tensor<T>,
4874 weight: Tensor<T>,
4875 bias: Option<Tensor<T>>,
4876 in_channels: usize,
4877 out_channels: usize,
4878 kernel_size: usize,
4879 stride: usize,
4880 padding: usize,
4881 _output_padding: usize,
4882 dilation: usize,
4883 groups: usize,
4884 l_out: usize,
4885}
4886
4887impl<T: Float> GradFn<T> for ConvTranspose1dBackward<T> {
4888 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
4889 let go_data = grad_output.data_vec()?;
4891 let batch = self.input.shape()[0];
4892 let l_in = self.input.shape()[2];
4893 let k = self.kernel_size;
4894 let s = self.stride;
4895 let p = self.padding;
4896 let d = self.dilation;
4897 let groups = self.groups;
4898 let in_pg = self.in_channels / groups;
4899 let out_pg = self.out_channels / groups;
4900 let zero = <T as num_traits::Zero>::zero();
4901
4902 let weight_data_all = self.weight.data_vec()?;
4903 let input_data_all = self.input.data_vec()?;
4904
4905 let mut gi_all = if self.input.requires_grad() {
4907 Some(vec![zero; batch * self.in_channels * l_in])
4908 } else {
4909 None
4910 };
4911 let mut gw_all = if self.weight.requires_grad() {
4912 Some(vec![zero; self.in_channels * out_pg * k])
4913 } else {
4914 None
4915 };
4916
4917 for g in 0..groups {
4918 if let Some(gi) = gi_all.as_mut() {
4920 let col_rows = out_pg * k;
4921 let w_start = g * in_pg * out_pg * k;
4922 let weight_2d = Tensor::from_storage(
4923 TensorStorage::cpu(
4924 weight_data_all[w_start..w_start + in_pg * out_pg * k].to_vec(),
4925 ),
4926 vec![in_pg, col_rows],
4927 false,
4928 )?;
4929
4930 let mut go_g = vec![zero; batch * out_pg * self.l_out];
4931 for b in 0..batch {
4932 for c in 0..out_pg {
4933 let src_c = g * out_pg + c;
4934 let src = (b * self.out_channels + src_c) * self.l_out;
4935 let dst = (b * out_pg + c) * self.l_out;
4936 go_g[dst..dst + self.l_out]
4937 .copy_from_slice(&go_data[src..src + self.l_out]);
4938 }
4939 }
4940
4941 let (go_cols, _gcr, go_col_cols) =
4944 im2col_dilated(&go_g, batch, out_pg, 1, self.l_out, 1, k, 1, s, 0, p, 1, d);
4945 debug_assert_eq!(go_col_cols, l_in);
4946
4947 for b in 0..batch {
4948 let col_start = b * col_rows * go_col_cols;
4949 let col_end = col_start + col_rows * go_col_cols;
4950 let go_cols_b = Tensor::from_storage(
4951 TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
4952 vec![col_rows, go_col_cols],
4953 false,
4954 )?;
4955 let gi_b = mm(&weight_2d, &go_cols_b)?;
4956 let gi_data = gi_b.data()?;
4957 for c in 0..in_pg {
4958 let dst_c = g * in_pg + c;
4959 let dst = (b * self.in_channels + dst_c) * l_in;
4960 let src = c * l_in;
4961 gi[dst..dst + l_in].copy_from_slice(&gi_data[src..src + l_in]);
4962 }
4963 }
4964 }
4965
4966 if let Some(gw) = gw_all.as_mut() {
4968 for ci in 0..in_pg {
4969 let in_c = g * in_pg + ci;
4970 for co in 0..out_pg {
4971 let out_c = g * out_pg + co;
4972 for tk in 0..k {
4973 let mut acc = zero;
4974 for il in 0..l_in {
4975 let ow = il * s + tk * d;
4976 if ow >= p && (ow - p) < self.l_out {
4977 let go_index = out_c * self.l_out + (ow - p);
4978 let in_index = in_c * l_in + il;
4979 for b in 0..batch {
4980 let goi = b * self.out_channels * self.l_out + go_index;
4981 let ini = b * self.in_channels * l_in + in_index;
4982 acc += input_data_all[ini] * go_data[goi];
4983 }
4984 }
4985 }
4986 gw[(in_c * out_pg + co) * k + tk] += acc;
4988 }
4989 }
4990 }
4991 }
4992 }
4993
4994 let grad_input = match gi_all {
4995 Some(gi) => Some(Tensor::from_storage(
4996 TensorStorage::cpu(gi),
4997 self.input.shape().to_vec(),
4998 false,
4999 )?),
5000 None => None,
5001 };
5002 let grad_weight = match gw_all {
5003 Some(gw) => Some(Tensor::from_storage(
5004 TensorStorage::cpu(gw),
5005 vec![self.in_channels, out_pg, k],
5006 false,
5007 )?),
5008 None => None,
5009 };
5010
5011 let grad_bias = match &self.bias {
5013 Some(b) if b.requires_grad() => {
5014 let zero = <T as num_traits::Zero>::zero();
5015 let mut gb = vec![zero; self.out_channels];
5016 for batch_idx in 0..batch {
5017 for c in 0..self.out_channels {
5018 for l in 0..self.l_out {
5019 gb[c] += go_data
5020 [batch_idx * self.out_channels * self.l_out + c * self.l_out + l];
5021 }
5022 }
5023 }
5024 Some(Tensor::from_storage(
5025 TensorStorage::cpu(gb),
5026 vec![self.out_channels],
5027 false,
5028 )?)
5029 }
5030 _ => None,
5031 };
5032
5033 let mut grads = vec![grad_input, grad_weight];
5034 if self.bias.is_some() {
5035 grads.push(grad_bias);
5036 }
5037 Ok(grads)
5038 }
5039
5040 fn inputs(&self) -> Vec<&Tensor<T>> {
5041 let mut v = vec![&self.input, &self.weight];
5042 if let Some(ref b) = self.bias {
5043 v.push(b);
5044 }
5045 v
5046 }
5047
5048 fn name(&self) -> &'static str {
5049 "ConvTranspose1dBackward"
5050 }
5051}
5052
5053#[derive(Debug)]
5077pub struct ConvTranspose3d<T: Float> {
5078 weight: Parameter<T>,
5083 bias: Option<Parameter<T>>,
5085 in_channels: usize,
5087 out_channels: usize,
5089 kernel_size: (usize, usize, usize),
5091 stride: (usize, usize, usize),
5093 padding: (usize, usize, usize),
5095 output_padding: (usize, usize, usize),
5097 dilation: (usize, usize, usize),
5100 groups: usize,
5105 training: bool,
5107}
5108
5109impl<T: Float> ConvTranspose3d<T> {
5110 pub fn new(
5119 in_channels: usize,
5120 out_channels: usize,
5121 kernel_size: (usize, usize, usize),
5122 stride: (usize, usize, usize),
5123 padding: (usize, usize, usize),
5124 output_padding: (usize, usize, usize),
5125 bias: bool,
5126 ) -> FerrotorchResult<Self> {
5127 Self::new_full(
5128 in_channels,
5129 out_channels,
5130 kernel_size,
5131 stride,
5132 padding,
5133 output_padding,
5134 (1, 1, 1),
5135 1,
5136 bias,
5137 )
5138 }
5139
5140 #[allow(clippy::too_many_arguments)]
5148 pub fn new_full(
5149 in_channels: usize,
5150 out_channels: usize,
5151 kernel_size: (usize, usize, usize),
5152 stride: (usize, usize, usize),
5153 padding: (usize, usize, usize),
5154 output_padding: (usize, usize, usize),
5155 dilation: (usize, usize, usize),
5156 groups: usize,
5157 bias: bool,
5158 ) -> FerrotorchResult<Self> {
5159 if in_channels == 0 || out_channels == 0 {
5160 return Err(FerrotorchError::InvalidArgument {
5161 message: "in_channels and out_channels must be > 0".into(),
5162 });
5163 }
5164 if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
5165 return Err(FerrotorchError::InvalidArgument {
5166 message: "kernel_size must be > 0 in all dimensions".into(),
5167 });
5168 }
5169 if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
5170 return Err(FerrotorchError::InvalidArgument {
5171 message: "stride must be > 0 in all dimensions".into(),
5172 });
5173 }
5174 if dilation.0 == 0 || dilation.1 == 0 || dilation.2 == 0 {
5175 return Err(FerrotorchError::InvalidArgument {
5176 message: "dilation must be > 0 in all dimensions".into(),
5177 });
5178 }
5179 if groups == 0 {
5180 return Err(FerrotorchError::InvalidArgument {
5181 message: "groups must be a positive integer".into(),
5182 });
5183 }
5184 if in_channels % groups != 0 {
5185 return Err(FerrotorchError::InvalidArgument {
5186 message: format!(
5187 "in_channels ({in_channels}) must be divisible by groups ({groups})"
5188 ),
5189 });
5190 }
5191 if out_channels % groups != 0 {
5192 return Err(FerrotorchError::InvalidArgument {
5193 message: format!(
5194 "out_channels ({out_channels}) must be divisible by groups ({groups})"
5195 ),
5196 });
5197 }
5198 if output_padding.0 >= stride.0.max(dilation.0)
5199 || output_padding.1 >= stride.1.max(dilation.1)
5200 || output_padding.2 >= stride.2.max(dilation.2)
5201 {
5202 return Err(FerrotorchError::InvalidArgument {
5203 message:
5204 "output_padding must be strictly less than max(stride, dilation) in all dimensions"
5205 .into(),
5206 });
5207 }
5208
5209 let (kd, kh, kw) = kernel_size;
5211 let out_per_group = out_channels / groups;
5212 let mut weight = Parameter::zeros(&[in_channels, out_per_group, kd, kh, kw])?;
5213 kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
5214
5215 let bias_param = if bias {
5216 let mut b = Parameter::zeros(&[out_channels])?;
5217 let fan_in = out_per_group * kd * kh * kw;
5220 let bound = if fan_in > 0 {
5221 1.0 / (fan_in as f64).sqrt()
5222 } else {
5223 0.0
5224 };
5225 uniform_init(&mut b, -bound, bound)?;
5226 Some(b)
5227 } else {
5228 None
5229 };
5230
5231 Ok(Self {
5232 weight,
5233 bias: bias_param,
5234 in_channels,
5235 out_channels,
5236 kernel_size,
5237 stride,
5238 padding,
5239 output_padding,
5240 dilation,
5241 groups,
5242 training: true,
5243 })
5244 }
5245
5246 pub fn with_padding_mode(self, mode: crate::padding::PaddingMode) -> FerrotorchResult<Self> {
5257 reject_non_zeros_transpose(mode, "ConvTranspose3d")?;
5258 Ok(self)
5259 }
5260
5261 pub fn num_parameters(&self) -> usize {
5263 let w = self.in_channels
5264 * self.out_channels
5265 * self.kernel_size.0
5266 * self.kernel_size.1
5267 * self.kernel_size.2;
5268 let b = if self.bias.is_some() {
5269 self.out_channels
5270 } else {
5271 0
5272 };
5273 w + b
5274 }
5275
5276 pub fn from_parts(
5282 weight: Tensor<T>,
5283 bias: Option<Tensor<T>>,
5284 stride: (usize, usize, usize),
5285 padding: (usize, usize, usize),
5286 output_padding: (usize, usize, usize),
5287 ) -> FerrotorchResult<Self> {
5288 if weight.ndim() != 5 {
5289 return Err(FerrotorchError::ShapeMismatch {
5290 message: format!(
5291 "ConvTranspose3d::from_parts: weight must be 5-D [in, out, kD, kH, kW], got {:?}",
5292 weight.shape()
5293 ),
5294 });
5295 }
5296 let in_channels = weight.shape()[0];
5297 let out_channels = weight.shape()[1];
5298 let kernel_size = (weight.shape()[2], weight.shape()[3], weight.shape()[4]);
5299 if output_padding.0 >= stride.0
5300 || output_padding.1 >= stride.1
5301 || output_padding.2 >= stride.2
5302 {
5303 return Err(FerrotorchError::InvalidArgument {
5304 message: "output_padding must be strictly less than stride in all dimensions"
5305 .into(),
5306 });
5307 }
5308 if let Some(b) = &bias {
5309 if b.ndim() != 1 || b.shape()[0] != out_channels {
5310 return Err(FerrotorchError::ShapeMismatch {
5311 message: format!(
5312 "ConvTranspose3d::from_parts: bias shape {:?} != [{}]",
5313 b.shape(),
5314 out_channels
5315 ),
5316 });
5317 }
5318 }
5319 Ok(Self {
5320 weight: Parameter::new(weight),
5321 bias: bias.map(Parameter::new),
5322 in_channels,
5323 out_channels,
5324 kernel_size,
5325 stride,
5326 padding,
5327 output_padding,
5328 dilation: (1, 1, 1),
5331 groups: 1,
5332 training: true,
5333 })
5334 }
5335}
5336
5337#[allow(clippy::too_many_arguments)]
5345fn stride_insert_zeros_3d<T: Float>(
5346 input: &[T],
5347 batch: usize,
5348 channels: usize,
5349 d: usize,
5350 h: usize,
5351 w: usize,
5352 stride_d: usize,
5353 stride_h: usize,
5354 stride_w: usize,
5355) -> (Vec<T>, usize, usize, usize) {
5356 let d_up = (d - 1) * stride_d + 1;
5357 let h_up = (h - 1) * stride_h + 1;
5358 let w_up = (w - 1) * stride_w + 1;
5359 let zero = <T as num_traits::Zero>::zero();
5360 let mut out = vec![zero; batch * channels * d_up * h_up * w_up];
5361
5362 for b in 0..batch {
5363 for c in 0..channels {
5364 for id in 0..d {
5365 for ih in 0..h {
5366 for iw in 0..w {
5367 let od = id * stride_d;
5368 let oh = ih * stride_h;
5369 let ow = iw * stride_w;
5370 out[b * channels * d_up * h_up * w_up
5371 + c * d_up * h_up * w_up
5372 + od * h_up * w_up
5373 + oh * w_up
5374 + ow] = input
5375 [b * channels * d * h * w + c * d * h * w + id * h * w + ih * w + iw];
5376 }
5377 }
5378 }
5379 }
5380 }
5381
5382 (out, d_up, h_up, w_up)
5383}
5384
5385#[allow(clippy::too_many_arguments)]
5398fn crop_volume_3d<T: Float>(
5399 input: &[T],
5400 batch: usize,
5401 channels: usize,
5402 d: usize,
5403 h: usize,
5404 w: usize,
5405 crop_d: usize,
5406 crop_h: usize,
5407 crop_w: usize,
5408) -> (Vec<T>, usize, usize, usize) {
5409 let d_out = d - 2 * crop_d;
5410 let h_out = h - 2 * crop_h;
5411 let w_out = w - 2 * crop_w;
5412 let zero = <T as num_traits::Zero>::zero();
5413 let mut out = vec![zero; batch * channels * d_out * h_out * w_out];
5414
5415 for b in 0..batch {
5416 for c in 0..channels {
5417 for od in 0..d_out {
5418 for oh in 0..h_out {
5419 let src =
5420 (((b * channels + c) * d + (od + crop_d)) * h + (oh + crop_h)) * w + crop_w;
5421 let dst = (((b * channels + c) * d_out + od) * h_out + oh) * w_out;
5422 out[dst..dst + w_out].copy_from_slice(&input[src..src + w_out]);
5423 }
5424 }
5425 }
5426 }
5427
5428 (out, d_out, h_out, w_out)
5429}
5430
5431fn flip_kernel_3d<T: Float>(
5435 kernel: &[T],
5436 c_in: usize,
5437 c_out: usize,
5438 kd: usize,
5439 kh: usize,
5440 kw: usize,
5441) -> Vec<T> {
5442 let zero = <T as num_traits::Zero>::zero();
5443 let mut flipped = vec![zero; c_out * c_in * kd * kh * kw];
5444
5445 for ci in 0..c_in {
5446 for co in 0..c_out {
5447 for dd in 0..kd {
5448 for dh in 0..kh {
5449 for dw in 0..kw {
5450 let src = ci * c_out * kd * kh * kw
5452 + co * kd * kh * kw
5453 + dd * kh * kw
5454 + dh * kw
5455 + dw;
5456 let dst = co * c_in * kd * kh * kw
5458 + ci * kd * kh * kw
5459 + (kd - 1 - dd) * kh * kw
5460 + (kh - 1 - dh) * kw
5461 + (kw - 1 - dw);
5462 flipped[dst] = kernel[src];
5463 }
5464 }
5465 }
5466 }
5467 }
5468
5469 flipped
5470}
5471
5472#[allow(clippy::too_many_arguments)]
5482fn conv_transpose3d_forward_group<T: Float>(
5483 input_data: &[T],
5484 batch: usize,
5485 in_pg: usize,
5486 out_pg: usize,
5487 d: usize,
5488 h: usize,
5489 w: usize,
5490 kernel_size: (usize, usize, usize),
5491 stride: (usize, usize, usize),
5492 padding: (usize, usize, usize),
5493 output_padding: (usize, usize, usize),
5494 dilation: (usize, usize, usize),
5495 group_weight: &[T],
5496) -> FerrotorchResult<(Vec<T>, usize, usize, usize)> {
5497 let (kd, kh, kw) = kernel_size;
5498 let (sd, sh, sw) = stride;
5499 let (pd, ph, pw) = padding;
5500 let (opd, oph, opw) = output_padding;
5501 let (dd, dh, dw) = dilation;
5502
5503 let (upsampled, d_up_core, h_up_core, w_up_core) =
5505 stride_insert_zeros_3d(input_data, batch, in_pg, d, h, w, sd, sh, sw);
5506 let d_up = d_up_core + opd;
5507 let h_up = h_up_core + oph;
5508 let w_up = w_up_core + opw;
5509 let upsampled = if opd > 0 || oph > 0 || opw > 0 {
5510 let zero = <T as num_traits::Zero>::zero();
5511 let mut ext = vec![zero; batch * in_pg * d_up * h_up * w_up];
5512 for b in 0..batch {
5513 for c in 0..in_pg {
5514 for id in 0..d_up_core {
5515 for ih in 0..h_up_core {
5516 let src = (((b * in_pg + c) * d_up_core + id) * h_up_core + ih) * w_up_core;
5517 let dst = (((b * in_pg + c) * d_up + id) * h_up + ih) * w_up;
5518 ext[dst..dst + w_up_core].copy_from_slice(&upsampled[src..src + w_up_core]);
5519 }
5520 }
5521 }
5522 }
5523 ext
5524 } else {
5525 upsampled
5526 };
5527
5528 let flipped = flip_kernel_3d(group_weight, in_pg, out_pg, kd, kh, kw);
5530
5531 let eff_kd = dd * (kd - 1) + 1;
5545 let eff_kh = dh * (kh - 1) + 1;
5546 let eff_kw = dw * (kw - 1) + 1;
5547 let signed_pad_d = (eff_kd - 1) as isize - pd as isize;
5548 let signed_pad_h = (eff_kh - 1) as isize - ph as isize;
5549 let signed_pad_w = (eff_kw - 1) as isize - pw as isize;
5550 let crop_d = (-signed_pad_d).max(0) as usize;
5553 let crop_h = (-signed_pad_h).max(0) as usize;
5554 let crop_w = (-signed_pad_w).max(0) as usize;
5555 let (conv_input, d_in, h_in, w_in) = if crop_d > 0 || crop_h > 0 || crop_w > 0 {
5556 crop_volume_3d(
5557 &upsampled, batch, in_pg, d_up, h_up, w_up, crop_d, crop_h, crop_w,
5558 )
5559 } else {
5560 (upsampled, d_up, h_up, w_up)
5561 };
5562 let internal_pad_d = signed_pad_d.max(0) as usize;
5563 let internal_pad_h = signed_pad_h.max(0) as usize;
5564 let internal_pad_w = signed_pad_w.max(0) as usize;
5565
5566 let (cols, col_rows, col_cols) = im2col_3d_dilated(
5567 &conv_input,
5568 batch,
5569 in_pg,
5570 d_in,
5571 h_in,
5572 w_in,
5573 kd,
5574 kh,
5575 kw,
5576 1,
5577 1,
5578 1,
5579 internal_pad_d,
5580 internal_pad_h,
5581 internal_pad_w,
5582 dd,
5583 dh,
5584 dw,
5585 );
5586
5587 let d_out = (d_in + 2 * internal_pad_d - eff_kd) + 1;
5588 let h_out = (h_in + 2 * internal_pad_h - eff_kh) + 1;
5589 let w_out = (w_in + 2 * internal_pad_w - eff_kw) + 1;
5590 let spatial_out = d_out * h_out * w_out;
5591
5592 let flipped_2d =
5593 Tensor::from_storage(TensorStorage::cpu(flipped), vec![out_pg, col_rows], false)?;
5594
5595 let zero = <T as num_traits::Zero>::zero();
5596 let mut output = vec![zero; batch * out_pg * spatial_out];
5597
5598 for b in 0..batch {
5599 let col_start = b * col_rows * col_cols;
5600 let col_end = col_start + col_rows * col_cols;
5601 let cols_b = Tensor::from_storage(
5602 TensorStorage::cpu(cols[col_start..col_end].to_vec()),
5603 vec![col_rows, col_cols],
5604 false,
5605 )?;
5606 let out_b = mm(&flipped_2d, &cols_b)?;
5607 let out_data = out_b.data()?;
5608 let out_start = b * out_pg * spatial_out;
5609 let copy_len = out_pg * spatial_out;
5610 output[out_start..out_start + copy_len].copy_from_slice(&out_data[..copy_len]);
5611 }
5612
5613 Ok((output, d_out, h_out, w_out))
5614}
5615
5616impl<T: Float> Module<T> for ConvTranspose3d<T> {
5617 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
5618 let _autocast_cat = autocast_guard("conv_transpose3d");
5620
5621 if input.ndim() == 4 {
5627 let batched = unsqueeze(input, 0)?;
5628 let output = self.forward(&batched)?;
5629 return squeeze(&output, 0);
5630 }
5631
5632 if input.ndim() != 5 {
5634 return Err(FerrotorchError::InvalidArgument {
5635 message: format!(
5636 "ConvTranspose3d expects 5-D input [B, C, D, H, W], got {:?}",
5637 input.shape()
5638 ),
5639 });
5640 }
5641
5642 let batch = input.shape()[0];
5643 let c_in = input.shape()[1];
5644 let d = input.shape()[2];
5645 let h = input.shape()[3];
5646 let w = input.shape()[4];
5647
5648 if c_in != self.in_channels {
5649 return Err(FerrotorchError::ShapeMismatch {
5650 message: format!(
5651 "ConvTranspose3d: expected {} input channels, got {}",
5652 self.in_channels, c_in
5653 ),
5654 });
5655 }
5656
5657 let (kd, kh, kw) = self.kernel_size;
5658 let groups = self.groups;
5659 let in_pg = self.in_channels / groups;
5660 let out_pg = self.out_channels / groups;
5661 let weight_pg_numel = in_pg * out_pg * kd * kh * kw;
5662
5663 let input_device = input.device();
5665
5666 let input_data = input.data_vec()?;
5667 let weight_data = self.weight.data_vec()?;
5668
5669 let zero = <T as num_traits::Zero>::zero();
5671 let mut output: Vec<T> = Vec::new();
5672 let (mut d_out, mut h_out, mut w_out) = (0usize, 0usize, 0usize);
5673 let spatial_in = d * h * w;
5674
5675 for g in 0..groups {
5676 let mut group_input = vec![zero; batch * in_pg * spatial_in];
5677 for b in 0..batch {
5678 for c in 0..in_pg {
5679 let src_c = g * in_pg + c;
5680 let src = (b * self.in_channels + src_c) * spatial_in;
5681 let dst = (b * in_pg + c) * spatial_in;
5682 group_input[dst..dst + spatial_in]
5683 .copy_from_slice(&input_data[src..src + spatial_in]);
5684 }
5685 }
5686
5687 let w_start = g * weight_pg_numel;
5688 let group_weight = &weight_data[w_start..w_start + weight_pg_numel];
5689
5690 let (g_out, gdo, gho, gwo) = conv_transpose3d_forward_group(
5691 &group_input,
5692 batch,
5693 in_pg,
5694 out_pg,
5695 d,
5696 h,
5697 w,
5698 self.kernel_size,
5699 self.stride,
5700 self.padding,
5701 self.output_padding,
5702 self.dilation,
5703 group_weight,
5704 )?;
5705 d_out = gdo;
5706 h_out = gho;
5707 w_out = gwo;
5708 let spatial_out = d_out * h_out * w_out;
5709
5710 if output.is_empty() {
5711 output = vec![zero; batch * self.out_channels * spatial_out];
5712 }
5713 for b in 0..batch {
5714 for oc in 0..out_pg {
5715 let dst_c = g * out_pg + oc;
5716 let dst = (b * self.out_channels + dst_c) * spatial_out;
5717 let src = (b * out_pg + oc) * spatial_out;
5718 output[dst..dst + spatial_out].copy_from_slice(&g_out[src..src + spatial_out]);
5719 }
5720 }
5721 }
5722
5723 let spatial_out = d_out * h_out * w_out;
5724
5725 if let Some(ref bias) = self.bias {
5727 let bias_data = bias.data_vec()?;
5728 for b in 0..batch {
5729 for c in 0..self.out_channels {
5730 let bval = bias_data[c];
5731 for s in 0..spatial_out {
5732 output[b * self.out_channels * spatial_out + c * spatial_out + s] += bval;
5733 }
5734 }
5735 }
5736 }
5737
5738 let result = Tensor::from_storage(
5739 TensorStorage::cpu(output),
5740 vec![batch, self.out_channels, d_out, h_out, w_out],
5741 false,
5742 )?;
5743
5744 if is_grad_enabled()
5746 && (input.requires_grad()
5747 || self.weight.requires_grad()
5748 || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
5749 {
5750 let grad_fn = Arc::new(ConvTranspose3dBackward {
5751 input: input.clone(),
5752 weight: self.weight.tensor().clone(),
5753 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
5754 in_channels: self.in_channels,
5755 out_channels: self.out_channels,
5756 kernel_size: self.kernel_size,
5757 stride: self.stride,
5758 padding: self.padding,
5759 _output_padding: self.output_padding,
5760 dilation: self.dilation,
5761 groups: self.groups,
5762 d_out,
5763 h_out,
5764 w_out,
5765 });
5766 Tensor::from_operation(
5767 TensorStorage::cpu(result.data()?.to_vec()),
5768 result.shape().to_vec(),
5769 grad_fn,
5770 )?
5771 .to(input_device) } else {
5773 result.to(input_device)
5774 }
5775 }
5776
5777 fn parameters(&self) -> Vec<&Parameter<T>> {
5778 let mut params = vec![&self.weight];
5779 if let Some(ref b) = self.bias {
5780 params.push(b);
5781 }
5782 params
5783 }
5784
5785 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
5786 let mut params = vec![&mut self.weight];
5787 if let Some(ref mut b) = self.bias {
5788 params.push(b);
5789 }
5790 params
5791 }
5792
5793 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
5794 let mut params = vec![("weight".to_string(), &self.weight)];
5795 if let Some(ref b) = self.bias {
5796 params.push(("bias".to_string(), b));
5797 }
5798 params
5799 }
5800
5801 fn train(&mut self) {
5802 self.training = true;
5803 }
5804
5805 fn eval(&mut self) {
5806 self.training = false;
5807 }
5808
5809 fn is_training(&self) -> bool {
5810 self.training
5811 }
5812}
5813
5814#[derive(Debug)]
5822struct ConvTranspose3dBackward<T: Float> {
5823 input: Tensor<T>,
5824 weight: Tensor<T>,
5825 bias: Option<Tensor<T>>,
5826 in_channels: usize,
5827 out_channels: usize,
5828 kernel_size: (usize, usize, usize),
5829 stride: (usize, usize, usize),
5830 padding: (usize, usize, usize),
5831 _output_padding: (usize, usize, usize),
5832 dilation: (usize, usize, usize),
5833 groups: usize,
5834 d_out: usize,
5835 h_out: usize,
5836 w_out: usize,
5837}
5838
5839impl<T: Float> GradFn<T> for ConvTranspose3dBackward<T> {
5840 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
5841 let go_data = grad_output.data_vec()?;
5843 let batch = self.input.shape()[0];
5844 let d_in = self.input.shape()[2];
5845 let h_in = self.input.shape()[3];
5846 let w_in = self.input.shape()[4];
5847 let (kd, kh, kw) = self.kernel_size;
5848 let (sd, sh, sw) = self.stride;
5849 let (pd, ph, pw) = self.padding;
5850 let (dd_, dh_, dw_) = self.dilation;
5851 let groups = self.groups;
5852 let in_pg = self.in_channels / groups;
5853 let out_pg = self.out_channels / groups;
5854 let spatial_out = self.d_out * self.h_out * self.w_out;
5855 let spatial_in = d_in * h_in * w_in;
5856 let zero = <T as num_traits::Zero>::zero();
5857
5858 let weight_data_all = self.weight.data_vec()?;
5859 let input_data_all = self.input.data_vec()?;
5860
5861 let mut gi_all = if self.input.requires_grad() {
5862 Some(vec![zero; batch * self.in_channels * spatial_in])
5863 } else {
5864 None
5865 };
5866 let mut gw_all = if self.weight.requires_grad() {
5867 Some(vec![zero; self.in_channels * out_pg * kd * kh * kw])
5868 } else {
5869 None
5870 };
5871
5872 for g in 0..groups {
5873 if let Some(gi) = gi_all.as_mut() {
5875 let col_rows = out_pg * kd * kh * kw;
5876 let w_start = g * in_pg * out_pg * kd * kh * kw;
5877 let weight_2d = Tensor::from_storage(
5878 TensorStorage::cpu(
5879 weight_data_all[w_start..w_start + in_pg * out_pg * kd * kh * kw].to_vec(),
5880 ),
5881 vec![in_pg, col_rows],
5882 false,
5883 )?;
5884
5885 let mut go_g = vec![zero; batch * out_pg * spatial_out];
5886 for b in 0..batch {
5887 for c in 0..out_pg {
5888 let src_c = g * out_pg + c;
5889 let src = (b * self.out_channels + src_c) * spatial_out;
5890 let dst = (b * out_pg + c) * spatial_out;
5891 go_g[dst..dst + spatial_out]
5892 .copy_from_slice(&go_data[src..src + spatial_out]);
5893 }
5894 }
5895
5896 let (go_cols, _gcr, go_col_cols) = im2col_3d_dilated(
5897 &go_g, batch, out_pg, self.d_out, self.h_out, self.w_out, kd, kh, kw, sd, sh,
5898 sw, pd, ph, pw, dd_, dh_, dw_,
5899 );
5900 debug_assert_eq!(go_col_cols, spatial_in);
5901
5902 for b in 0..batch {
5903 let col_start = b * col_rows * go_col_cols;
5904 let col_end = col_start + col_rows * go_col_cols;
5905 let go_cols_b = Tensor::from_storage(
5906 TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
5907 vec![col_rows, go_col_cols],
5908 false,
5909 )?;
5910 let gi_b = mm(&weight_2d, &go_cols_b)?;
5911 let gi_data = gi_b.data()?;
5912 for c in 0..in_pg {
5913 let dst_c = g * in_pg + c;
5914 let dst = (b * self.in_channels + dst_c) * spatial_in;
5915 let src = c * spatial_in;
5916 gi[dst..dst + spatial_in].copy_from_slice(&gi_data[src..src + spatial_in]);
5917 }
5918 }
5919 }
5920
5921 if let Some(gw) = gw_all.as_mut() {
5923 for ci in 0..in_pg {
5924 let in_c = g * in_pg + ci;
5925 for co in 0..out_pg {
5926 let out_c = g * out_pg + co;
5927 for tkd in 0..kd {
5928 for tkh in 0..kh {
5929 for tkw in 0..kw {
5930 let mut acc = zero;
5931 for id in 0..d_in {
5932 for ih in 0..h_in {
5933 for iw in 0..w_in {
5934 let od = id * sd + tkd * dd_;
5935 let oh = ih * sh + tkh * dh_;
5936 let ow = iw * sw + tkw * dw_;
5937 if od >= pd
5938 && oh >= ph
5939 && ow >= pw
5940 && (od - pd) < self.d_out
5941 && (oh - ph) < self.h_out
5942 && (ow - pw) < self.w_out
5943 {
5944 let go_index = out_c * spatial_out
5945 + (od - pd) * self.h_out * self.w_out
5946 + (oh - ph) * self.w_out
5947 + (ow - pw);
5948 let in_index = in_c * spatial_in
5949 + id * h_in * w_in
5950 + ih * w_in
5951 + iw;
5952 for b in 0..batch {
5953 let goi =
5954 b * self.out_channels * spatial_out
5955 + go_index;
5956 let ini = b * self.in_channels * spatial_in
5957 + in_index;
5958 acc += input_data_all[ini] * go_data[goi];
5959 }
5960 }
5961 }
5962 }
5963 }
5964 gw[((((in_c * out_pg + co) * kd + tkd) * kh + tkh) * kw)
5966 + tkw] += acc;
5967 }
5968 }
5969 }
5970 }
5971 }
5972 }
5973 }
5974
5975 let grad_input = match gi_all {
5976 Some(gi) => Some(Tensor::from_storage(
5977 TensorStorage::cpu(gi),
5978 self.input.shape().to_vec(),
5979 false,
5980 )?),
5981 None => None,
5982 };
5983 let grad_weight = match gw_all {
5984 Some(gw) => Some(Tensor::from_storage(
5985 TensorStorage::cpu(gw),
5986 vec![self.in_channels, out_pg, kd, kh, kw],
5987 false,
5988 )?),
5989 None => None,
5990 };
5991
5992 let grad_bias = match &self.bias {
5994 Some(b) if b.requires_grad() => {
5995 let zero = <T as num_traits::Zero>::zero();
5996 let mut gb = vec![zero; self.out_channels];
5997 for batch_idx in 0..batch {
5998 for c in 0..self.out_channels {
5999 for s in 0..spatial_out {
6000 gb[c] += go_data
6001 [batch_idx * self.out_channels * spatial_out + c * spatial_out + s];
6002 }
6003 }
6004 }
6005 Some(Tensor::from_storage(
6006 TensorStorage::cpu(gb),
6007 vec![self.out_channels],
6008 false,
6009 )?)
6010 }
6011 _ => None,
6012 };
6013
6014 let mut grads = vec![grad_input, grad_weight];
6015 if self.bias.is_some() {
6016 grads.push(grad_bias);
6017 }
6018 Ok(grads)
6019 }
6020
6021 fn inputs(&self) -> Vec<&Tensor<T>> {
6022 let mut v = vec![&self.input, &self.weight];
6023 if let Some(ref b) = self.bias {
6024 v.push(b);
6025 }
6026 v
6027 }
6028
6029 fn name(&self) -> &'static str {
6030 "ConvTranspose3dBackward"
6031 }
6032}
6033
6034#[cfg(test)]
6039mod tests {
6040 use super::*;
6041 use crate::module::Module;
6042
6043 #[test]
6051 fn test_conv2d_bias_init_bounded_uniform() {
6052 let in_c = 16usize;
6053 let out_c = 32usize;
6054 let kh = 3usize;
6055 let kw = 3usize;
6056 let groups = 1usize;
6057 let layer =
6058 Conv2d::<f32>::new_full(in_c, out_c, (kh, kw), (1, 1), (0, 0), (1, 1), groups, true)
6059 .unwrap();
6060 let bias = layer.bias.as_ref().expect("bias requested");
6061 let bias_data = bias.tensor().data_vec().unwrap();
6062 let fan_in = (in_c / groups) * kh * kw;
6063 let bound = 1.0_f32 / (fan_in as f32).sqrt();
6064 let mut nonzero = 0usize;
6065 for &b in &bias_data {
6066 assert!(
6067 b.abs() <= bound + 1e-6,
6068 "Conv2d bias element {b} exceeds bound {bound}"
6069 );
6070 if b != 0.0 {
6071 nonzero += 1;
6072 }
6073 }
6074 assert!(
6075 nonzero > out_c / 2,
6076 "expected most Conv2d bias entries to be nonzero; \
6077 would FAIL pre-fix when bias was zeros_init"
6078 );
6079 }
6080
6081 fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
6083 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
6084 }
6085
6086 fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
6088 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
6089 }
6090
6091 fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
6093 assert_eq!(
6094 actual.len(),
6095 expected.len(),
6096 "length mismatch: {} vs {}",
6097 actual.len(),
6098 expected.len()
6099 );
6100 for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
6101 assert!(
6102 (a - e).abs() < tol,
6103 "index {i}: actual={a} expected={e} (diff {})",
6104 (a - e).abs()
6105 );
6106 }
6107 }
6108
6109 #[test]
6114 fn test_output_shape_no_padding() {
6115 let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6118 let input = t(&[0.0; 25], &[1, 1, 5, 5]);
6119 let output = conv.forward(&input).unwrap();
6120 assert_eq!(output.shape(), &[1, 1, 3, 3]);
6121 }
6122
6123 #[test]
6124 fn test_output_shape_with_padding() {
6125 let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (1, 1), true).unwrap();
6128 let input = t(&vec![0.0; 2 * 3 * 8 * 8], &[2, 3, 8, 8]);
6129 let output = conv.forward(&input).unwrap();
6130 assert_eq!(output.shape(), &[2, 16, 8, 8]);
6131 }
6132
6133 #[test]
6134 fn test_output_shape_with_stride() {
6135 let conv = Conv2d::<f32>::new(1, 4, (3, 3), (2, 2), (0, 0), false).unwrap();
6138 let input = t(&[0.0; 36], &[1, 1, 6, 6]);
6139 let output = conv.forward(&input).unwrap();
6140 assert_eq!(output.shape(), &[1, 4, 2, 2]);
6141 }
6142
6143 #[test]
6148 fn test_1x1_conv_equals_linear() {
6149 let weight_data: Vec<f32> = vec![
6158 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
6162 let input_data: Vec<f32> = vec![
6164 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
6167
6168 let weight_param = Parameter::from_slice(&weight_data, &[3, 2, 1, 1]).unwrap();
6170 let conv = Conv2d {
6171 weight: weight_param,
6172 bias: None,
6173 in_channels: 2,
6174 out_channels: 3,
6175 kernel_size: (1, 1),
6176 stride: (1, 1),
6177 padding: (0, 0),
6178 dilation: (1, 1),
6179 groups: 1,
6180 padding_mode: crate::padding::PaddingMode::Zeros,
6181 string_padding: None,
6182 training: false,
6183 };
6184
6185 let input = t(&input_data, &[1, 2, 2, 2]);
6186 let output = conv.forward(&input).unwrap();
6187 assert_eq!(output.shape(), &[1, 3, 2, 2]);
6188
6189 let out = output.data().unwrap();
6190
6191 let expected = [
6201 11.0, 14.0, 17.0, 20.0, 23.0, 30.0, 37.0, 44.0, 35.0, 46.0, 57.0, 68.0, ];
6205 assert_close(out, &expected, 1e-5);
6206 }
6207
6208 #[test]
6213 fn test_bias_addition() {
6214 let weight_data = vec![1.0f32]; let bias_data = vec![10.0f32]; let conv = Conv2d {
6219 weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
6220 bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
6221 in_channels: 1,
6222 out_channels: 1,
6223 kernel_size: (1, 1),
6224 stride: (1, 1),
6225 padding: (0, 0),
6226 dilation: (1, 1),
6227 groups: 1,
6228 padding_mode: crate::padding::PaddingMode::Zeros,
6229 string_padding: None,
6230 training: false,
6231 };
6232
6233 let input = t(&[2.0, 3.0, 4.0, 5.0], &[1, 1, 2, 2]);
6234 let output = conv.forward(&input).unwrap();
6235 assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0, 15.0], 1e-5);
6237 }
6238
6239 #[test]
6244 fn test_backward_produces_correct_shapes() {
6245 let weight_data = vec![1.0f32; 2 * 3 * 3]; let input_data = vec![1.0f32; 5 * 5]; let bias_data = vec![0.0f32; 2];
6249
6250 let weight_param = Parameter::from_slice(&weight_data, &[2, 1, 3, 3]).unwrap();
6251 let bias_param = Parameter::from_slice(&bias_data, &[2]).unwrap();
6252
6253 let conv = Conv2d {
6254 weight: weight_param,
6255 bias: Some(bias_param),
6256 in_channels: 1,
6257 out_channels: 2,
6258 kernel_size: (3, 3),
6259 stride: (1, 1),
6260 padding: (0, 0),
6261 dilation: (1, 1),
6262 groups: 1,
6263 padding_mode: crate::padding::PaddingMode::Zeros,
6264 string_padding: None,
6265 training: false,
6266 };
6267
6268 let input = leaf(&input_data, &[1, 1, 5, 5]);
6270 let output = conv.forward(&input).unwrap();
6271 assert_eq!(output.shape(), &[1, 2, 3, 3]);
6272
6273 assert!(output.grad_fn().is_some());
6275 assert_eq!(output.grad_fn().unwrap().name(), "Conv2dBackward");
6276
6277 let grad_output = t(&[1.0; 2 * 3 * 3], &[1, 2, 3, 3]);
6279 let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
6280
6281 assert!(grads[0].is_some());
6283 assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5]);
6284
6285 assert!(grads[1].is_some());
6287 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3]);
6288
6289 assert!(grads[2].is_some());
6291 assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
6292 }
6293
6294 #[test]
6299 fn test_parameter_count_with_bias() {
6300 let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), true).unwrap();
6301 assert_eq!(conv.num_parameters(), 448);
6305 assert_eq!(conv.parameters().len(), 2);
6306 }
6307
6308 #[test]
6309 fn test_parameter_count_without_bias() {
6310 let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), false).unwrap();
6311 assert_eq!(conv.num_parameters(), 432);
6312 assert_eq!(conv.parameters().len(), 1);
6313 }
6314
6315 #[test]
6320 fn test_named_parameters() {
6321 let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
6322 let named = conv.named_parameters();
6323 assert_eq!(named.len(), 2);
6324 assert_eq!(named[0].0, "weight");
6325 assert_eq!(named[1].0, "bias");
6326 }
6327
6328 #[test]
6329 fn test_train_eval() {
6330 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6331 assert!(conv.is_training());
6332 conv.eval();
6333 assert!(!conv.is_training());
6334 conv.train();
6335 assert!(conv.is_training());
6336 }
6337
6338 #[test]
6343 fn test_invalid_input_ndim() {
6344 let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6345 let input = t(&[1.0, 2.0, 3.0], &[3]);
6346 assert!(conv.forward(&input).is_err());
6347 }
6348
6349 #[test]
6350 fn test_channel_mismatch() {
6351 let conv = Conv2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6352 let input = t(&[0.0; 5 * 5], &[1, 1, 5, 5]);
6353 assert!(conv.forward(&input).is_err());
6354 }
6355
6356 #[test]
6357 fn test_zero_channels_rejected() {
6358 assert!(Conv2d::<f32>::new(0, 16, (3, 3), (1, 1), (0, 0), false).is_err());
6359 assert!(Conv2d::<f32>::new(3, 0, (3, 3), (1, 1), (0, 0), false).is_err());
6360 }
6361
6362 #[test]
6363 fn test_zero_kernel_rejected() {
6364 assert!(Conv2d::<f32>::new(1, 1, (0, 3), (1, 1), (0, 0), false).is_err());
6365 }
6366
6367 #[test]
6368 fn test_zero_stride_rejected() {
6369 assert!(Conv2d::<f32>::new(1, 1, (3, 3), (0, 1), (0, 0), false).is_err());
6370 }
6371
6372 #[test]
6377 fn test_im2col_basic() {
6378 #[rustfmt::skip]
6382 let input: Vec<f32> = vec![
6383 1.0, 2.0, 3.0,
6384 4.0, 5.0, 6.0,
6385 7.0, 8.0, 9.0,
6386 ];
6387 let (cols, rows, n_cols) = im2col(&input, 1, 1, 3, 3, 2, 2, 1, 1, 0, 0);
6388 assert_eq!(rows, 4); assert_eq!(n_cols, 4); assert_close(
6402 &cols,
6403 &[
6404 1.0, 2.0, 4.0, 5.0, 2.0, 3.0, 5.0, 6.0, 4.0, 5.0, 7.0, 8.0, 5.0, 6.0, 8.0, 9.0, ],
6409 1e-7,
6410 );
6411 }
6412
6413 #[test]
6414 fn test_col2im_roundtrip_no_overlap() {
6415 #[rustfmt::skip]
6419 let input: Vec<f32> = vec![
6420 1.0, 2.0, 3.0, 4.0,
6421 5.0, 6.0, 7.0, 8.0,
6422 9.0, 10.0, 11.0, 12.0,
6423 13.0, 14.0, 15.0, 16.0,
6424 ];
6425
6426 let (cols, _rows, _n_cols) = im2col(&input, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0);
6427 let recovered = col2im(&cols, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0, 2, 2);
6428
6429 assert_close(&recovered, &input, 1e-7);
6430 }
6431
6432 #[test]
6437 fn test_3x3_conv_forward() {
6438 #[rustfmt::skip]
6441 let input_data: Vec<f32> = vec![
6442 1.0, 2.0, 3.0,
6443 4.0, 5.0, 6.0,
6444 7.0, 8.0, 9.0,
6445 ];
6446 #[rustfmt::skip]
6447 let weight_data: Vec<f32> = vec![
6448 1.0, 0.0, -1.0,
6449 1.0, 0.0, -1.0,
6450 1.0, 0.0, -1.0,
6451 ];
6452
6453 let conv = Conv2d {
6454 weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3]).unwrap(),
6455 bias: None,
6456 in_channels: 1,
6457 out_channels: 1,
6458 kernel_size: (3, 3),
6459 stride: (1, 1),
6460 padding: (0, 0),
6461 dilation: (1, 1),
6462 groups: 1,
6463 padding_mode: crate::padding::PaddingMode::Zeros,
6464 string_padding: None,
6465 training: false,
6466 };
6467
6468 let input = t(&input_data, &[1, 1, 3, 3]);
6469 let output = conv.forward(&input).unwrap();
6470 assert_eq!(output.shape(), &[1, 1, 1, 1]);
6471
6472 assert_close(output.data().unwrap(), &[-6.0], 1e-5);
6475 }
6476
6477 #[test]
6482 fn test_padding_preserves_spatial_size() {
6483 let weight_data = vec![0.0f32; 9];
6486 let mut weight_data_center = weight_data;
6487 weight_data_center[4] = 1.0; let conv = Conv2d {
6490 weight: Parameter::from_slice(&weight_data_center, &[1, 1, 3, 3]).unwrap(),
6491 bias: None,
6492 in_channels: 1,
6493 out_channels: 1,
6494 kernel_size: (3, 3),
6495 stride: (1, 1),
6496 padding: (1, 1),
6497 dilation: (1, 1),
6498 groups: 1,
6499 padding_mode: crate::padding::PaddingMode::Zeros,
6500 string_padding: None,
6501 training: false,
6502 };
6503
6504 let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
6505 let input = t(&input_data, &[1, 1, 3, 3]);
6506 let output = conv.forward(&input).unwrap();
6507 assert_eq!(output.shape(), &[1, 1, 3, 3]);
6508
6509 assert_close(output.data().unwrap(), &input_data, 1e-5);
6511 }
6512
6513 #[test]
6522 fn test_conv1d_output_shape_no_padding() {
6523 let conv = Conv1d::<f32>::new(1, 4, 3, 1, 0, false).unwrap();
6526 let input = t(&[0.0; 10], &[1, 1, 10]);
6527 let output = conv.forward(&input).unwrap();
6528 assert_eq!(output.shape(), &[1, 4, 8]);
6529 }
6530
6531 #[test]
6532 fn test_conv1d_output_shape_with_padding() {
6533 let conv = Conv1d::<f32>::new(3, 8, 3, 1, 1, true).unwrap();
6536 let input = t(&vec![0.0; 2 * 3 * 16], &[2, 3, 16]);
6537 let output = conv.forward(&input).unwrap();
6538 assert_eq!(output.shape(), &[2, 8, 16]);
6539 }
6540
6541 #[test]
6542 fn test_conv1d_output_shape_with_stride() {
6543 let conv = Conv1d::<f32>::new(1, 2, 3, 2, 0, false).unwrap();
6546 let input = t(&[0.0; 10], &[1, 1, 10]);
6547 let output = conv.forward(&input).unwrap();
6548 assert_eq!(output.shape(), &[1, 2, 4]);
6549 }
6550
6551 #[test]
6556 fn test_conv1d_1x1_kernel_correctness() {
6557 let weight_data = vec![3.0f32, 5.0];
6566 let conv = Conv1d {
6567 weight: Parameter::from_slice(&weight_data, &[2, 1, 1]).unwrap(),
6568 bias: None,
6569 in_channels: 1,
6570 out_channels: 2,
6571 kernel_size: 1,
6572 stride: 1,
6573 padding: 0,
6574 dilation: 1,
6575 groups: 1,
6576 padding_mode: crate::padding::PaddingMode::Zeros,
6577 string_padding: None,
6578 training: false,
6579 };
6580
6581 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
6582 let output = conv.forward(&input).unwrap();
6583 assert_eq!(output.shape(), &[1, 2, 4]);
6584 assert_close(
6585 output.data().unwrap(),
6586 &[3.0, 6.0, 9.0, 12.0, 5.0, 10.0, 15.0, 20.0],
6587 1e-5,
6588 );
6589 }
6590
6591 #[test]
6596 fn test_conv1d_3_kernel_forward() {
6597 let conv = Conv1d {
6602 weight: Parameter::from_slice(&[1.0f32, 0.0, -1.0], &[1, 1, 3]).unwrap(),
6603 bias: None,
6604 in_channels: 1,
6605 out_channels: 1,
6606 kernel_size: 3,
6607 stride: 1,
6608 padding: 0,
6609 dilation: 1,
6610 groups: 1,
6611 padding_mode: crate::padding::PaddingMode::Zeros,
6612 string_padding: None,
6613 training: false,
6614 };
6615
6616 let input = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
6617 let output = conv.forward(&input).unwrap();
6618 assert_eq!(output.shape(), &[1, 1, 3]);
6619 assert_close(output.data().unwrap(), &[-2.0, -2.0, -2.0], 1e-5);
6620 }
6621
6622 #[test]
6627 fn test_conv1d_bias() {
6628 let conv = Conv1d {
6629 weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
6630 bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
6631 in_channels: 1,
6632 out_channels: 1,
6633 kernel_size: 1,
6634 stride: 1,
6635 padding: 0,
6636 dilation: 1,
6637 groups: 1,
6638 padding_mode: crate::padding::PaddingMode::Zeros,
6639 string_padding: None,
6640 training: false,
6641 };
6642
6643 let input = t(&[2.0, 3.0, 4.0], &[1, 1, 3]);
6644 let output = conv.forward(&input).unwrap();
6645 assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0], 1e-5);
6646 }
6647
6648 #[test]
6653 fn test_conv1d_invalid_ndim() {
6654 let conv = Conv1d::<f32>::new(1, 1, 3, 1, 0, false).unwrap();
6655 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6656 assert!(conv.forward(&input).is_err());
6657 }
6658
6659 #[test]
6660 fn test_conv1d_channel_mismatch() {
6661 let conv = Conv1d::<f32>::new(3, 1, 3, 1, 0, false).unwrap();
6662 let input = t(&[0.0; 10], &[1, 1, 10]);
6663 assert!(conv.forward(&input).is_err());
6664 }
6665
6666 #[test]
6667 fn test_conv1d_zero_channels_rejected() {
6668 assert!(Conv1d::<f32>::new(0, 4, 3, 1, 0, false).is_err());
6669 assert!(Conv1d::<f32>::new(1, 0, 3, 1, 0, false).is_err());
6670 }
6671
6672 #[test]
6673 fn test_conv1d_zero_kernel_rejected() {
6674 assert!(Conv1d::<f32>::new(1, 1, 0, 1, 0, false).is_err());
6675 }
6676
6677 #[test]
6678 fn test_conv1d_zero_stride_rejected() {
6679 assert!(Conv1d::<f32>::new(1, 1, 3, 0, 0, false).is_err());
6680 }
6681
6682 #[test]
6683 fn test_conv1d_parameter_count() {
6684 let conv = Conv1d::<f32>::new(3, 8, 5, 1, 0, true).unwrap();
6685 assert_eq!(conv.num_parameters(), 128);
6687 assert_eq!(conv.parameters().len(), 2);
6688 }
6689
6690 #[test]
6699 fn test_conv_transpose2d_output_shape_basic() {
6700 let conv =
6703 ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
6704 let input = t(&[0.0; 9], &[1, 1, 3, 3]);
6705 let output = conv.forward(&input).unwrap();
6706 assert_eq!(output.shape(), &[1, 1, 5, 5]);
6707 }
6708
6709 #[test]
6710 fn test_conv_transpose2d_output_shape_stride2() {
6711 let conv =
6714 ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (0, 0), false).unwrap();
6715 let input = t(&[0.0; 4], &[1, 1, 2, 2]);
6716 let output = conv.forward(&input).unwrap();
6717 assert_eq!(output.shape(), &[1, 1, 5, 5]);
6718 }
6719
6720 #[test]
6721 fn test_conv_transpose2d_output_shape_with_padding() {
6722 let conv =
6725 ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (0, 0), false).unwrap();
6726 let input = t(&[0.0; 9], &[1, 1, 3, 3]);
6727 let output = conv.forward(&input).unwrap();
6728 assert_eq!(output.shape(), &[1, 1, 5, 5]);
6729 }
6730
6731 #[test]
6732 fn test_conv_transpose2d_output_shape_with_output_padding() {
6733 let conv =
6736 ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false).unwrap();
6737 let input = t(&[0.0; 9], &[1, 1, 3, 3]);
6738 let output = conv.forward(&input).unwrap();
6739 assert_eq!(output.shape(), &[1, 1, 6, 6]);
6740 }
6741
6742 #[test]
6747 fn test_conv_transpose2d_stride2_upsamples() {
6748 let conv =
6752 ConvTranspose2d::<f32>::new(1, 1, (2, 2), (2, 2), (0, 0), (0, 0), false).unwrap();
6753 let input = t(&[0.0; 4 * 4], &[1, 1, 4, 4]);
6754 let output = conv.forward(&input).unwrap();
6755 assert_eq!(output.shape(), &[1, 1, 8, 8]);
6756 }
6757
6758 #[test]
6759 fn test_conv_transpose2d_stride2_upsamples_multichannel() {
6760 let conv =
6762 ConvTranspose2d::<f32>::new(8, 16, (2, 2), (2, 2), (0, 0), (0, 0), true).unwrap();
6763 let input = t(&vec![0.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]);
6764 let output = conv.forward(&input).unwrap();
6765 assert_eq!(output.shape(), &[2, 16, 8, 8]);
6766 }
6767
6768 #[test]
6773 fn test_conv_transpose2d_1x1_kernel() {
6774 let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose2d {
6782 weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1]).unwrap(),
6783 bias: None,
6784 in_channels: 1,
6785 out_channels: 2,
6786 kernel_size: (1, 1),
6787 stride: (1, 1),
6788 padding: (0, 0),
6789 output_padding: (0, 0),
6790 dilation: (1, 1),
6791 groups: 1,
6792 training: false,
6793 };
6794
6795 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6796 let output = conv.forward(&input).unwrap();
6797 assert_eq!(output.shape(), &[1, 2, 2, 2]);
6798
6799 assert_close(
6802 output.data().unwrap(),
6803 &[3.0, 6.0, 9.0, 12.0, 7.0, 14.0, 21.0, 28.0],
6804 1e-5,
6805 );
6806 }
6807
6808 #[test]
6813 fn test_conv_transpose2d_stride2_correctness() {
6814 let weight_data = vec![1.0f32; 4]; let conv = ConvTranspose2d {
6859 weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2]).unwrap(),
6860 bias: None,
6861 in_channels: 1,
6862 out_channels: 1,
6863 kernel_size: (2, 2),
6864 stride: (2, 2),
6865 padding: (0, 0),
6866 output_padding: (0, 0),
6867 dilation: (1, 1),
6868 groups: 1,
6869 training: false,
6870 };
6871
6872 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6873 let output = conv.forward(&input).unwrap();
6874 assert_eq!(output.shape(), &[1, 1, 4, 4]);
6875
6876 #[rustfmt::skip]
6877 let expected = [
6878 1.0, 1.0, 2.0, 2.0,
6879 1.0, 1.0, 2.0, 2.0,
6880 3.0, 3.0, 4.0, 4.0,
6881 3.0, 3.0, 4.0, 4.0,
6882 ];
6883 assert_close(output.data().unwrap(), &expected, 1e-5);
6884 }
6885
6886 #[test]
6891 fn test_conv_transpose2d_bias() {
6892 let weight_data = vec![1.0f32]; let bias_data = vec![5.0f32];
6894 let conv = ConvTranspose2d {
6895 weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
6896 bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
6897 in_channels: 1,
6898 out_channels: 1,
6899 kernel_size: (1, 1),
6900 stride: (1, 1),
6901 padding: (0, 0),
6902 output_padding: (0, 0),
6903 dilation: (1, 1),
6904 groups: 1,
6905 training: false,
6906 };
6907
6908 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6909 let output = conv.forward(&input).unwrap();
6910 assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0, 9.0], 1e-5);
6911 }
6912
6913 #[test]
6918 fn test_conv_transpose2d_invalid_ndim() {
6919 let conv =
6920 ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
6921 let input = t(&[1.0, 2.0, 3.0], &[1, 3]);
6925 assert!(conv.forward(&input).is_err());
6926 }
6927
6928 #[test]
6929 fn test_conv_transpose2d_channel_mismatch() {
6930 let conv =
6931 ConvTranspose2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
6932 let input = t(&[0.0; 5 * 5], &[1, 1, 5, 5]);
6933 assert!(conv.forward(&input).is_err());
6934 }
6935
6936 #[test]
6937 fn test_conv_transpose2d_zero_channels_rejected() {
6938 assert!(ConvTranspose2d::<f32>::new(0, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
6939 assert!(ConvTranspose2d::<f32>::new(1, 0, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
6940 }
6941
6942 #[test]
6943 fn test_conv_transpose2d_output_padding_too_large() {
6944 assert!(ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (2, 2), false).is_err());
6946 }
6947
6948 #[test]
6949 fn test_conv_transpose2d_parameter_count() {
6950 let conv =
6951 ConvTranspose2d::<f32>::new(8, 16, (3, 3), (2, 2), (1, 1), (0, 0), true).unwrap();
6952 assert_eq!(conv.num_parameters(), 1168);
6954 assert_eq!(conv.parameters().len(), 2);
6955 }
6956
6957 #[test]
6966 fn test_conv3d_output_shape_no_padding() {
6967 let conv = Conv3d::<f32>::new(1, 4, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
6970 let input = t(&vec![0.0; 5 * 5 * 5], &[1, 1, 5, 5, 5]);
6971 let output = conv.forward(&input).unwrap();
6972 assert_eq!(output.shape(), &[1, 4, 3, 3, 3]);
6973 }
6974
6975 #[test]
6976 fn test_conv3d_output_shape_with_padding() {
6977 let conv = Conv3d::<f32>::new(3, 16, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
6980 let input = t(&vec![0.0; 2 * 3 * 8 * 8 * 8], &[2, 3, 8, 8, 8]);
6981 let output = conv.forward(&input).unwrap();
6982 assert_eq!(output.shape(), &[2, 16, 8, 8, 8]);
6983 }
6984
6985 #[test]
6986 fn test_conv3d_output_shape_with_stride() {
6987 let conv = Conv3d::<f32>::new(1, 4, (3, 3, 3), (2, 2, 2), (0, 0, 0), false).unwrap();
6990 let input = t(&vec![0.0; 6 * 6 * 6], &[1, 1, 6, 6, 6]);
6991 let output = conv.forward(&input).unwrap();
6992 assert_eq!(output.shape(), &[1, 4, 2, 2, 2]);
6993 }
6994
6995 #[test]
7000 fn test_conv3d_1x1x1_kernel_correctness() {
7001 let weight_data = vec![3.0f32, 5.0];
7007 let conv = Conv3d {
7008 weight: Parameter::from_slice(&weight_data, &[2, 1, 1, 1, 1]).unwrap(),
7009 bias: None,
7010 in_channels: 1,
7011 out_channels: 2,
7012 kernel_size: (1, 1, 1),
7013 stride: (1, 1, 1),
7014 padding: (0, 0, 0),
7015 dilation: (1, 1, 1),
7016 groups: 1,
7017 padding_mode: crate::padding::PaddingMode::Zeros,
7018 string_padding: None,
7019 training: false,
7020 };
7021
7022 let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
7023 let output = conv.forward(&input).unwrap();
7024 assert_eq!(output.shape(), &[1, 2, 2, 1, 1]);
7025 assert_close(output.data().unwrap(), &[3.0, 6.0, 5.0, 10.0], 1e-5);
7026 }
7027
7028 #[test]
7033 fn test_conv3d_3x3x3_kernel_forward() {
7034 let input_data = vec![1.0f32; 27];
7037 let weight_data = vec![1.0f32; 27];
7038 let conv = Conv3d {
7039 weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3, 3]).unwrap(),
7040 bias: None,
7041 in_channels: 1,
7042 out_channels: 1,
7043 kernel_size: (3, 3, 3),
7044 stride: (1, 1, 1),
7045 padding: (0, 0, 0),
7046 dilation: (1, 1, 1),
7047 groups: 1,
7048 padding_mode: crate::padding::PaddingMode::Zeros,
7049 string_padding: None,
7050 training: false,
7051 };
7052
7053 let input = t(&input_data, &[1, 1, 3, 3, 3]);
7054 let output = conv.forward(&input).unwrap();
7055 assert_eq!(output.shape(), &[1, 1, 1, 1, 1]);
7056 assert_close(output.data().unwrap(), &[27.0], 1e-5);
7057 }
7058
7059 #[test]
7064 fn test_conv3d_bias() {
7065 let conv = Conv3d {
7066 weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1, 1, 1]).unwrap(),
7067 bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
7068 in_channels: 1,
7069 out_channels: 1,
7070 kernel_size: (1, 1, 1),
7071 stride: (1, 1, 1),
7072 padding: (0, 0, 0),
7073 dilation: (1, 1, 1),
7074 groups: 1,
7075 padding_mode: crate::padding::PaddingMode::Zeros,
7076 string_padding: None,
7077 training: false,
7078 };
7079
7080 let input = t(&[2.0, 3.0], &[1, 1, 2, 1, 1]);
7081 let output = conv.forward(&input).unwrap();
7082 assert_close(output.data().unwrap(), &[12.0, 13.0], 1e-5);
7083 }
7084
7085 #[test]
7090 fn test_conv3d_backward_produces_correct_shapes() {
7091 let weight_data = vec![1.0f32; 2 * 3 * 3 * 3]; let input_data = vec![1.0f32; 5 * 5 * 5]; let bias_data = vec![0.0f32; 2];
7094
7095 let conv = Conv3d {
7096 weight: Parameter::from_slice(&weight_data, &[2, 1, 3, 3, 3]).unwrap(),
7097 bias: Some(Parameter::from_slice(&bias_data, &[2]).unwrap()),
7098 in_channels: 1,
7099 out_channels: 2,
7100 kernel_size: (3, 3, 3),
7101 stride: (1, 1, 1),
7102 padding: (0, 0, 0),
7103 dilation: (1, 1, 1),
7104 groups: 1,
7105 padding_mode: crate::padding::PaddingMode::Zeros,
7106 string_padding: None,
7107 training: false,
7108 };
7109
7110 let input = leaf(&input_data, &[1, 1, 5, 5, 5]);
7111 let output = conv.forward(&input).unwrap();
7112 assert_eq!(output.shape(), &[1, 2, 3, 3, 3]);
7113 assert!(output.grad_fn().is_some());
7114 assert_eq!(output.grad_fn().unwrap().name(), "Conv3dBackward");
7115
7116 let grad_output = t(&vec![1.0; 2 * 3 * 3 * 3], &[1, 2, 3, 3, 3]);
7117 let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
7118
7119 assert!(grads[0].is_some());
7120 assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5, 5]);
7121 assert!(grads[1].is_some());
7122 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3, 3]);
7123 assert!(grads[2].is_some());
7124 assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
7125 }
7126
7127 #[test]
7132 fn test_conv3d_invalid_ndim() {
7133 let conv = Conv3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
7134 let input = t(&[0.0; 25], &[1, 1, 5, 5]);
7135 assert!(conv.forward(&input).is_err());
7136 }
7137
7138 #[test]
7139 fn test_conv3d_channel_mismatch() {
7140 let conv = Conv3d::<f32>::new(3, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
7141 let input = t(&vec![0.0; 5 * 5 * 5], &[1, 1, 5, 5, 5]);
7142 assert!(conv.forward(&input).is_err());
7143 }
7144
7145 #[test]
7146 fn test_conv3d_zero_channels_rejected() {
7147 assert!(Conv3d::<f32>::new(0, 16, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
7148 assert!(Conv3d::<f32>::new(3, 0, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
7149 }
7150
7151 #[test]
7152 fn test_conv3d_zero_kernel_rejected() {
7153 assert!(Conv3d::<f32>::new(1, 1, (0, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
7154 }
7155
7156 #[test]
7157 fn test_conv3d_zero_stride_rejected() {
7158 assert!(Conv3d::<f32>::new(1, 1, (3, 3, 3), (0, 1, 1), (0, 0, 0), false).is_err());
7159 }
7160
7161 #[test]
7162 fn test_conv3d_parameter_count() {
7163 let conv = Conv3d::<f32>::new(3, 8, (3, 3, 3), (1, 1, 1), (0, 0, 0), true).unwrap();
7164 assert_eq!(conv.num_parameters(), 656);
7166 assert_eq!(conv.parameters().len(), 2);
7167 }
7168
7169 #[test]
7170 fn test_conv3d_named_parameters() {
7171 let conv = Conv3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), true).unwrap();
7172 let named = conv.named_parameters();
7173 assert_eq!(named.len(), 2);
7174 assert_eq!(named[0].0, "weight");
7175 assert_eq!(named[1].0, "bias");
7176 }
7177
7178 #[test]
7187 fn test_conv_transpose1d_output_shape_basic() {
7188 let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 1, 0, 0, false).unwrap();
7191 let input = t(&[0.0; 5], &[1, 1, 5]);
7192 let output = conv.forward(&input).unwrap();
7193 assert_eq!(output.shape(), &[1, 1, 7]);
7194 }
7195
7196 #[test]
7197 fn test_conv_transpose1d_output_shape_stride2() {
7198 let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 0, 0, false).unwrap();
7201 let input = t(&[0.0; 3], &[1, 1, 3]);
7202 let output = conv.forward(&input).unwrap();
7203 assert_eq!(output.shape(), &[1, 1, 7]);
7204 }
7205
7206 #[test]
7207 fn test_conv_transpose1d_output_shape_with_padding() {
7208 let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 1, 0, false).unwrap();
7211 let input = t(&[0.0; 5], &[1, 1, 5]);
7212 let output = conv.forward(&input).unwrap();
7213 assert_eq!(output.shape(), &[1, 1, 9]);
7214 }
7215
7216 #[test]
7217 fn test_conv_transpose1d_output_shape_with_output_padding() {
7218 let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 1, 1, false).unwrap();
7221 let input = t(&[0.0; 5], &[1, 1, 5]);
7222 let output = conv.forward(&input).unwrap();
7223 assert_eq!(output.shape(), &[1, 1, 10]);
7224 }
7225
7226 #[test]
7231 fn test_conv_transpose1d_1x1_kernel() {
7232 let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose1d {
7237 weight: Parameter::from_slice(&weight_data, &[1, 2, 1]).unwrap(),
7238 bias: None,
7239 in_channels: 1,
7240 out_channels: 2,
7241 kernel_size: 1,
7242 stride: 1,
7243 padding: 0,
7244 output_padding: 0,
7245 dilation: 1,
7246 groups: 1,
7247 training: false,
7248 };
7249
7250 let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
7251 let output = conv.forward(&input).unwrap();
7252 assert_eq!(output.shape(), &[1, 2, 3]);
7253
7254 assert_close(
7257 output.data().unwrap(),
7258 &[3.0, 6.0, 9.0, 7.0, 14.0, 21.0],
7259 1e-5,
7260 );
7261 }
7262
7263 #[test]
7268 fn test_conv_transpose1d_stride2_correctness() {
7269 let weight_data = vec![1.0f32; 2]; let conv = ConvTranspose1d {
7285 weight: Parameter::from_slice(&weight_data, &[1, 1, 2]).unwrap(),
7286 bias: None,
7287 in_channels: 1,
7288 out_channels: 1,
7289 kernel_size: 2,
7290 stride: 2,
7291 padding: 0,
7292 output_padding: 0,
7293 dilation: 1,
7294 groups: 1,
7295 training: false,
7296 };
7297
7298 let input = t(&[1.0, 2.0], &[1, 1, 2]);
7299 let output = conv.forward(&input).unwrap();
7300 assert_eq!(output.shape(), &[1, 1, 4]);
7301 assert_close(output.data().unwrap(), &[1.0, 1.0, 2.0, 2.0], 1e-5);
7302 }
7303
7304 #[test]
7309 fn test_conv_transpose1d_bias() {
7310 let conv = ConvTranspose1d {
7311 weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
7312 bias: Some(Parameter::from_slice(&[5.0f32], &[1]).unwrap()),
7313 in_channels: 1,
7314 out_channels: 1,
7315 kernel_size: 1,
7316 stride: 1,
7317 padding: 0,
7318 output_padding: 0,
7319 dilation: 1,
7320 groups: 1,
7321 training: false,
7322 };
7323
7324 let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
7325 let output = conv.forward(&input).unwrap();
7326 assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0], 1e-5);
7327 }
7328
7329 #[test]
7334 fn test_conv_transpose1d_backward_produces_gradients() {
7335 let weight_data = vec![1.0f32; 3]; let bias_data = vec![0.0f32; 1];
7337
7338 let conv = ConvTranspose1d {
7339 weight: Parameter::from_slice(&weight_data, &[1, 1, 3]).unwrap(),
7340 bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
7341 in_channels: 1,
7342 out_channels: 1,
7343 kernel_size: 3,
7344 stride: 1,
7345 padding: 0,
7346 output_padding: 0,
7347 dilation: 1,
7348 groups: 1,
7349 training: false,
7350 };
7351
7352 let input = leaf(&[1.0f32, 2.0, 3.0], &[1, 1, 3]);
7353 let output = conv.forward(&input).unwrap();
7354 assert_eq!(output.shape(), &[1, 1, 5]);
7356 assert!(output.grad_fn().is_some());
7357 assert_eq!(output.grad_fn().unwrap().name(), "ConvTranspose1dBackward");
7358
7359 let grad_output = t(&[1.0; 5], &[1, 1, 5]);
7360 let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
7361
7362 assert!(grads[0].is_some());
7364 assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 3]);
7365 assert!(grads[1].is_some());
7367 assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 3]);
7368 assert!(grads[2].is_some());
7370 assert_eq!(grads[2].as_ref().unwrap().shape(), &[1]);
7371 }
7372
7373 #[test]
7378 fn test_conv_transpose1d_invalid_ndim() {
7379 let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 1, 0, 0, false).unwrap();
7380 let input = t(&[0.0; 4], &[1, 1, 2, 2]);
7381 assert!(conv.forward(&input).is_err());
7382 }
7383
7384 #[test]
7385 fn test_conv_transpose1d_channel_mismatch() {
7386 let conv = ConvTranspose1d::<f32>::new(3, 1, 3, 1, 0, 0, false).unwrap();
7387 let input = t(&[0.0; 10], &[1, 1, 10]);
7388 assert!(conv.forward(&input).is_err());
7389 }
7390
7391 #[test]
7392 fn test_conv_transpose1d_zero_channels_rejected() {
7393 assert!(ConvTranspose1d::<f32>::new(0, 1, 3, 1, 0, 0, false).is_err());
7394 assert!(ConvTranspose1d::<f32>::new(1, 0, 3, 1, 0, 0, false).is_err());
7395 }
7396
7397 #[test]
7398 fn test_conv_transpose1d_output_padding_too_large() {
7399 assert!(ConvTranspose1d::<f32>::new(1, 1, 3, 2, 0, 2, false).is_err());
7400 }
7401
7402 #[test]
7403 fn test_conv_transpose1d_parameter_count() {
7404 let conv = ConvTranspose1d::<f32>::new(8, 16, 5, 2, 1, 0, true).unwrap();
7405 assert_eq!(conv.num_parameters(), 656);
7407 assert_eq!(conv.parameters().len(), 2);
7408 }
7409
7410 #[test]
7419 fn test_conv_transpose3d_output_shape_basic() {
7420 let conv =
7423 ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7424 .unwrap();
7425 let input = t(&[0.0; 27], &[1, 1, 3, 3, 3]);
7426 let output = conv.forward(&input).unwrap();
7427 assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
7428 }
7429
7430 #[test]
7431 fn test_conv_transpose3d_output_shape_stride2() {
7432 let conv =
7435 ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (0, 0, 0), (0, 0, 0), false)
7436 .unwrap();
7437 let input = t(&[0.0; 8], &[1, 1, 2, 2, 2]);
7438 let output = conv.forward(&input).unwrap();
7439 assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
7440 }
7441
7442 #[test]
7443 fn test_conv_transpose3d_output_shape_with_padding() {
7444 let conv =
7447 ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0), false)
7448 .unwrap();
7449 let input = t(&[0.0; 27], &[1, 1, 3, 3, 3]);
7450 let output = conv.forward(&input).unwrap();
7451 assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
7452 }
7453
7454 #[test]
7455 fn test_conv_transpose3d_output_shape_with_output_padding() {
7456 let conv =
7459 ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), false)
7460 .unwrap();
7461 let input = t(&[0.0; 27], &[1, 1, 3, 3, 3]);
7462 let output = conv.forward(&input).unwrap();
7463 assert_eq!(output.shape(), &[1, 1, 6, 6, 6]);
7464 }
7465
7466 #[test]
7471 fn test_conv_transpose3d_stride2_upsamples() {
7472 let conv =
7475 ConvTranspose3d::<f32>::new(1, 1, (2, 2, 2), (2, 2, 2), (0, 0, 0), (0, 0, 0), false)
7476 .unwrap();
7477 let input = t(&vec![0.0; 4 * 4 * 4], &[1, 1, 4, 4, 4]);
7478 let output = conv.forward(&input).unwrap();
7479 assert_eq!(output.shape(), &[1, 1, 8, 8, 8]);
7480 }
7481
7482 #[test]
7487 fn test_conv_transpose3d_1x1x1_kernel() {
7488 let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose3d {
7491 weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1, 1]).unwrap(),
7492 bias: None,
7493 in_channels: 1,
7494 out_channels: 2,
7495 kernel_size: (1, 1, 1),
7496 stride: (1, 1, 1),
7497 padding: (0, 0, 0),
7498 output_padding: (0, 0, 0),
7499 dilation: (1, 1, 1),
7500 groups: 1,
7501 training: false,
7502 };
7503
7504 let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
7505 let output = conv.forward(&input).unwrap();
7506 assert_eq!(output.shape(), &[1, 2, 2, 1, 1]);
7507 assert_close(output.data().unwrap(), &[3.0, 6.0, 7.0, 14.0], 1e-5);
7508 }
7509
7510 #[test]
7515 fn test_conv_transpose3d_bias() {
7516 let conv = ConvTranspose3d {
7517 weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1, 1, 1]).unwrap(),
7518 bias: Some(Parameter::from_slice(&[5.0f32], &[1]).unwrap()),
7519 in_channels: 1,
7520 out_channels: 1,
7521 kernel_size: (1, 1, 1),
7522 stride: (1, 1, 1),
7523 padding: (0, 0, 0),
7524 output_padding: (0, 0, 0),
7525 dilation: (1, 1, 1),
7526 groups: 1,
7527 training: false,
7528 };
7529
7530 let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
7531 let output = conv.forward(&input).unwrap();
7532 assert_close(output.data().unwrap(), &[6.0, 7.0], 1e-5);
7533 }
7534
7535 #[test]
7540 fn test_conv_transpose3d_backward_produces_gradients() {
7541 let weight_data = vec![1.0f32; 2 * 2 * 2]; let bias_data = vec![0.0f32; 1];
7543
7544 let conv = ConvTranspose3d {
7545 weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2, 2]).unwrap(),
7546 bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
7547 in_channels: 1,
7548 out_channels: 1,
7549 kernel_size: (2, 2, 2),
7550 stride: (1, 1, 1),
7551 padding: (0, 0, 0),
7552 output_padding: (0, 0, 0),
7553 dilation: (1, 1, 1),
7554 groups: 1,
7555 training: false,
7556 };
7557
7558 let input = leaf(&[1.0f32; 8], &[1, 1, 2, 2, 2]);
7560 let output = conv.forward(&input).unwrap();
7561 assert_eq!(output.shape(), &[1, 1, 3, 3, 3]);
7562 assert!(output.grad_fn().is_some());
7563 assert_eq!(output.grad_fn().unwrap().name(), "ConvTranspose3dBackward");
7564
7565 let grad_output = t(&[1.0; 27], &[1, 1, 3, 3, 3]);
7566 let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
7567
7568 assert!(grads[0].is_some());
7569 assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
7570 assert!(grads[1].is_some());
7571 assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
7572 assert!(grads[2].is_some());
7573 assert_eq!(grads[2].as_ref().unwrap().shape(), &[1]);
7574 }
7575
7576 #[test]
7581 fn test_conv_transpose3d_invalid_ndim() {
7582 let conv =
7583 ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7584 .unwrap();
7585 let input = t(&[0.0; 25], &[1, 5, 5]);
7588 assert!(conv.forward(&input).is_err());
7589 }
7590
7591 #[test]
7592 fn test_conv_transpose3d_channel_mismatch() {
7593 let conv =
7594 ConvTranspose3d::<f32>::new(3, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7595 .unwrap();
7596 let input = t(&vec![0.0; 5 * 5 * 5], &[1, 1, 5, 5, 5]);
7597 assert!(conv.forward(&input).is_err());
7598 }
7599
7600 #[test]
7601 fn test_conv_transpose3d_zero_channels_rejected() {
7602 assert!(
7603 ConvTranspose3d::<f32>::new(0, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7604 .is_err()
7605 );
7606 assert!(
7607 ConvTranspose3d::<f32>::new(1, 0, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7608 .is_err()
7609 );
7610 }
7611
7612 #[test]
7613 fn test_conv_transpose3d_output_padding_too_large() {
7614 assert!(
7615 ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (0, 0, 0), (2, 2, 2), false)
7616 .is_err()
7617 );
7618 }
7619
7620 #[test]
7621 fn test_conv_transpose3d_parameter_count() {
7622 let conv =
7623 ConvTranspose3d::<f32>::new(8, 16, (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0), true)
7624 .unwrap();
7625 assert_eq!(conv.num_parameters(), 3472);
7627 assert_eq!(conv.parameters().len(), 2);
7628 }
7629
7630 #[allow(clippy::too_many_arguments)]
7648 fn ct1d_full_fixed(
7649 in_c: usize,
7650 out_c: usize,
7651 k: usize,
7652 stride: usize,
7653 padding: usize,
7654 output_padding: usize,
7655 dilation: usize,
7656 groups: usize,
7657 weight: &[f32],
7658 bias: Option<&[f32]>,
7659 ) -> ConvTranspose1d<f32> {
7660 let mut ct = ConvTranspose1d::<f32>::new_full(
7661 in_c,
7662 out_c,
7663 k,
7664 stride,
7665 padding,
7666 output_padding,
7667 dilation,
7668 groups,
7669 bias.is_some(),
7670 )
7671 .unwrap();
7672 ct.weight = Parameter::from_slice(weight, &[in_c, out_c / groups, k]).unwrap();
7673 if let Some(bvals) = bias {
7674 ct.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
7675 }
7676 ct
7677 }
7678
7679 #[test]
7682 fn test_conv_transpose1d_groups2_matches_torch() {
7683 let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.1).collect(); let bias = [0.5f32, -0.5, 0.25, -0.25];
7685 let ct = ct1d_full_fixed(4, 4, 2, 1, 0, 0, 1, 2, &weight, Some(&bias));
7686 let x = leaf(&(1..=20).map(|i| i as f32).collect::<Vec<_>>(), &[1, 4, 5]);
7687 let y = ct.forward(&x).unwrap();
7688 assert_eq!(y.shape(), &[1, 4, 6]);
7689 assert_close(
7690 y.data().unwrap(),
7691 &[
7692 3.6, 8.0, 9.4, 10.8, 12.2, 7.5, 4.0, 10.2, 12.4, 14.6, 16.8, 9.5, 30.95, 66.55,
7693 71.15, 75.75, 80.35, 43.25, 35.85, 77.25, 82.65, 88.05, 93.45, 49.75,
7694 ],
7695 1e-3,
7696 );
7697 let grads = ct
7698 .forward(&x)
7699 .unwrap()
7700 .grad_fn()
7701 .unwrap()
7702 .backward(&t(&[1.0f32; 24], &[1, 4, 6]))
7703 .unwrap();
7704 assert_close(
7705 grads[0].as_ref().unwrap().data().unwrap(),
7706 &[
7707 1.0, 1.0, 1.0, 1.0, 1.0, 2.6, 2.6, 2.6, 2.6, 2.6, 4.2, 4.2, 4.2, 4.2, 4.2, 5.8,
7708 5.8, 5.8, 5.8, 5.8,
7709 ],
7710 1e-4,
7711 );
7712 assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 2, 2]);
7713 assert_close(
7714 grads[1].as_ref().unwrap().data().unwrap(),
7715 &[
7716 15.0, 15.0, 15.0, 15.0, 40.0, 40.0, 40.0, 40.0, 65.0, 65.0, 65.0, 65.0, 90.0, 90.0,
7717 90.0, 90.0,
7718 ],
7719 1e-4,
7720 );
7721 assert_close(
7722 grads[2].as_ref().unwrap().data().unwrap(),
7723 &[6.0, 6.0, 6.0, 6.0],
7724 1e-4,
7725 );
7726 }
7727
7728 #[test]
7730 fn test_conv_transpose1d_groups3_depthwise_matches_torch() {
7731 let weight: Vec<f32> = (1..=6).map(|i| i as f32 * 0.5).collect(); let ct = ct1d_full_fixed(3, 3, 2, 1, 0, 0, 1, 3, &weight, None);
7733 let x = leaf(&(1..=15).map(|i| i as f32).collect::<Vec<_>>(), &[1, 3, 5]);
7734 let y = ct.forward(&x).unwrap();
7735 assert_eq!(y.shape(), &[1, 3, 6]);
7736 assert_close(
7737 y.data().unwrap(),
7738 &[
7739 0.5, 2.0, 3.5, 5.0, 6.5, 5.0, 9.0, 22.5, 26.0, 29.5, 33.0, 20.0, 27.5, 63.0, 68.5,
7740 74.0, 79.5, 45.0,
7741 ],
7742 1e-3,
7743 );
7744 let grads = ct
7745 .forward(&x)
7746 .unwrap()
7747 .grad_fn()
7748 .unwrap()
7749 .backward(&t(&[1.0f32; 18], &[1, 3, 6]))
7750 .unwrap();
7751 assert_close(
7752 grads[0].as_ref().unwrap().data().unwrap(),
7753 &[
7754 1.5, 1.5, 1.5, 1.5, 1.5, 3.5, 3.5, 3.5, 3.5, 3.5, 5.5, 5.5, 5.5, 5.5, 5.5,
7755 ],
7756 1e-4,
7757 );
7758 assert_eq!(grads[1].as_ref().unwrap().shape(), &[3, 1, 2]);
7759 assert_close(
7760 grads[1].as_ref().unwrap().data().unwrap(),
7761 &[15.0, 15.0, 40.0, 40.0, 65.0, 65.0],
7762 1e-4,
7763 );
7764 }
7765
7766 #[test]
7768 fn test_conv_transpose1d_dilation2_matches_torch() {
7769 let weight: Vec<f32> = (1..=12).map(|i| i as f32 * 0.1).collect(); let bias = [1.0f32, -1.0];
7771 let ct = ct1d_full_fixed(2, 2, 3, 1, 0, 0, 2, 1, &weight, Some(&bias));
7772 let x = leaf(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), &[1, 2, 4]);
7773 let y = ct.forward(&x).unwrap();
7774 assert_eq!(y.shape(), &[1, 2, 8]);
7775 assert_close(
7776 y.data().unwrap(),
7777 &[
7778 4.6, 5.4, 10.4, 12.2, 12.0, 14.2, 8.2, 9.4, 4.4, 5.8, 13.2, 16.2, 14.8, 18.2, 9.2,
7779 11.0,
7780 ],
7781 1e-3,
7782 );
7783 let grads = ct
7784 .forward(&x)
7785 .unwrap()
7786 .grad_fn()
7787 .unwrap()
7788 .backward(&t(&[1.0f32; 16], &[1, 2, 8]))
7789 .unwrap();
7790 assert_close(
7791 grads[0].as_ref().unwrap().data().unwrap(),
7792 &[2.1, 2.1, 2.1, 2.1, 5.7, 5.7, 5.7, 5.7],
7793 1e-4,
7794 );
7795 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 2, 3]);
7796 assert_close(
7797 grads[1].as_ref().unwrap().data().unwrap(),
7798 &[
7799 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 26.0, 26.0, 26.0, 26.0, 26.0, 26.0,
7800 ],
7801 1e-4,
7802 );
7803 assert_close(
7804 grads[2].as_ref().unwrap().data().unwrap(),
7805 &[8.0, 8.0],
7806 1e-4,
7807 );
7808 }
7809
7810 #[test]
7813 fn test_conv_transpose1d_combo_matches_torch() {
7814 let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); let bias = [0.5f32, -0.5];
7816 let ct = ct1d_full_fixed(4, 2, 2, 2, 1, 1, 2, 2, &weight, Some(&bias));
7817 let x = leaf(&(1..=12).map(|i| i as f32).collect::<Vec<_>>(), &[1, 4, 3]);
7818 let y = ct.forward(&x).unwrap();
7819 assert_eq!(y.shape(), &[1, 2, 6]);
7820 assert_close(
7821 y.data().unwrap(),
7822 &[
7823 0.5, 4.0, 0.5, 5.0, 0.5, 3.5, -0.5, 23.4, -0.5, 26.0, -0.5, 14.5,
7824 ],
7825 1e-3,
7826 );
7827 let grads = ct
7828 .forward(&x)
7829 .unwrap()
7830 .grad_fn()
7831 .unwrap()
7832 .backward(&t(&[1.0f32; 12], &[1, 2, 6]))
7833 .unwrap();
7834 assert_close(
7835 grads[0].as_ref().unwrap().data().unwrap(),
7836 &[0.2, 0.3, 0.3, 0.4, 0.7, 0.7, 0.6, 1.1, 1.1, 0.8, 1.5, 1.5],
7837 1e-4,
7838 );
7839 assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 1, 2]);
7840 assert_close(
7841 grads[1].as_ref().unwrap().data().unwrap(),
7842 &[5.0, 6.0, 11.0, 15.0, 17.0, 24.0, 23.0, 33.0],
7843 1e-4,
7844 );
7845 assert_close(
7846 grads[2].as_ref().unwrap().data().unwrap(),
7847 &[6.0, 6.0],
7848 1e-4,
7849 );
7850 }
7851
7852 #[test]
7855 fn test_conv_transpose1d_unbatched_matches_torch() {
7856 let weight: Vec<f32> = (1..=12).map(|i| i as f32 * 0.1).collect(); let bias = [0.5f32, -0.5, 0.25];
7858 let ct = ct1d_full_fixed(2, 3, 2, 1, 0, 0, 1, 1, &weight, Some(&bias));
7859 let x = leaf(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), &[2, 3]); let y = ct.forward(&x).unwrap();
7861 assert_eq!(
7862 y.shape(),
7863 &[3, 4],
7864 "unbatched output must be rank 2 (C_out, L_out)"
7865 );
7866 assert_close(
7867 y.data().unwrap(),
7868 &[
7869 3.4, 7.6, 9.4, 5.9, 3.4, 9.0, 11.6, 6.7, 5.15, 12.15, 15.55, 9.25,
7870 ],
7871 1e-3,
7872 );
7873 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
7877 ferrotorch_core::backward(&sum).unwrap();
7878 let gx = x.grad().unwrap().expect("input grad must be populated");
7879 assert_eq!(gx.shape(), &[2, 3], "grad must match unbatched input shape");
7880 assert_close(gx.data().unwrap(), &[2.1, 2.1, 2.1, 5.7, 5.7, 5.7], 1e-4);
7881 }
7882
7883 #[test]
7886 fn test_conv_transpose1d_groups_must_divide_channels() {
7887 assert!(ConvTranspose1d::<f32>::new_full(3, 4, 2, 1, 0, 0, 1, 2, true).is_err());
7888 assert!(ConvTranspose1d::<f32>::new_full(4, 5, 2, 1, 0, 0, 1, 2, true).is_err());
7889 }
7890
7891 #[allow(clippy::too_many_arguments)]
7896 fn ct2d_full_fixed(
7897 in_c: usize,
7898 out_c: usize,
7899 k: (usize, usize),
7900 stride: (usize, usize),
7901 padding: (usize, usize),
7902 output_padding: (usize, usize),
7903 dilation: (usize, usize),
7904 groups: usize,
7905 weight: &[f32],
7906 bias: Option<&[f32]>,
7907 ) -> ConvTranspose2d<f32> {
7908 let mut ct = ConvTranspose2d::<f32>::new_full(
7909 in_c,
7910 out_c,
7911 k,
7912 stride,
7913 padding,
7914 output_padding,
7915 dilation,
7916 groups,
7917 bias.is_some(),
7918 )
7919 .unwrap();
7920 ct.weight = Parameter::from_slice(weight, &[in_c, out_c / groups, k.0, k.1]).unwrap();
7921 if let Some(bvals) = bias {
7922 ct.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
7923 }
7924 ct
7925 }
7926
7927 #[test]
7929 fn test_conv_transpose2d_groups2_matches_torch() {
7930 let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.1).collect(); let bias = [0.5f32, -0.5];
7932 let ct = ct2d_full_fixed(
7933 4,
7934 2,
7935 (2, 2),
7936 (1, 1),
7937 (0, 0),
7938 (0, 0),
7939 (1, 1),
7940 2,
7941 &weight,
7942 Some(&bias),
7943 );
7944 let x = leaf(
7945 &(1..=16).map(|i| i as f32).collect::<Vec<_>>(),
7946 &[1, 4, 2, 2],
7947 );
7948 let y = ct.forward(&x).unwrap();
7949 assert_eq!(y.shape(), &[1, 2, 3, 3]);
7950 assert_close(
7951 y.data().unwrap(),
7952 &[
7953 3.1, 6.9, 4.5, 8.1, 18.9, 11.7, 6.3, 14.1, 8.5, 24.5, 53.9, 29.1, 58.3, 126.7,
7954 68.3, 34.1, 73.9, 39.5,
7955 ],
7956 1e-3,
7957 );
7958 let grads = ct
7959 .forward(&x)
7960 .unwrap()
7961 .grad_fn()
7962 .unwrap()
7963 .backward(&t(&[1.0f32; 18], &[1, 2, 3, 3]))
7964 .unwrap();
7965 assert_close(
7966 grads[0].as_ref().unwrap().data().unwrap(),
7967 &[
7968 1.0, 1.0, 1.0, 1.0, 2.6, 2.6, 2.6, 2.6, 4.2, 4.2, 4.2, 4.2, 5.8, 5.8, 5.8, 5.8,
7969 ],
7970 1e-4,
7971 );
7972 assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 1, 2, 2]);
7973 assert_close(
7974 grads[1].as_ref().unwrap().data().unwrap(),
7975 &[
7976 10.0, 10.0, 10.0, 10.0, 26.0, 26.0, 26.0, 26.0, 42.0, 42.0, 42.0, 42.0, 58.0, 58.0,
7977 58.0, 58.0,
7978 ],
7979 1e-4,
7980 );
7981 assert_close(
7982 grads[2].as_ref().unwrap().data().unwrap(),
7983 &[9.0, 9.0],
7984 1e-4,
7985 );
7986 }
7987
7988 #[test]
7990 fn test_conv_transpose2d_dilation2_matches_torch() {
7991 let weight: Vec<f32> = (1..=4).map(|i| i as f32 * 0.1).collect(); let ct = ct2d_full_fixed(
7993 1,
7994 1,
7995 (2, 2),
7996 (1, 1),
7997 (0, 0),
7998 (0, 0),
7999 (2, 2),
8000 1,
8001 &weight,
8002 None,
8003 );
8004 let x = leaf(
8005 &(1..=9).map(|i| i as f32).collect::<Vec<_>>(),
8006 &[1, 1, 3, 3],
8007 );
8008 let y = ct.forward(&x).unwrap();
8009 assert_eq!(y.shape(), &[1, 1, 5, 5]);
8010 assert_close(
8011 y.data().unwrap(),
8012 &[
8013 0.1, 0.2, 0.5, 0.4, 0.6, 0.4, 0.5, 1.4, 1.0, 1.2, 1.0, 1.4, 3.6, 2.4, 3.0, 1.2,
8014 1.5, 3.4, 2.0, 2.4, 2.1, 2.4, 5.5, 3.2, 3.6,
8015 ],
8016 1e-3,
8017 );
8018 let grads = ct
8019 .forward(&x)
8020 .unwrap()
8021 .grad_fn()
8022 .unwrap()
8023 .backward(&t(&[1.0f32; 25], &[1, 1, 5, 5]))
8024 .unwrap();
8025 assert_close(
8026 grads[0].as_ref().unwrap().data().unwrap(),
8027 &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
8028 1e-4,
8029 );
8030 assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2]);
8031 assert_close(
8032 grads[1].as_ref().unwrap().data().unwrap(),
8033 &[45.0, 45.0, 45.0, 45.0],
8034 1e-4,
8035 );
8036 }
8037
8038 #[test]
8041 fn test_conv_transpose2d_combo_matches_torch() {
8042 let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); let bias = [0.25f32, -0.25];
8044 let ct = ct2d_full_fixed(
8045 2,
8046 2,
8047 (2, 2),
8048 (2, 2),
8049 (1, 1),
8050 (1, 1),
8051 (2, 2),
8052 2,
8053 &weight,
8054 Some(&bias),
8055 );
8056 let x = leaf(
8057 &(1..=8).map(|i| i as f32).collect::<Vec<_>>(),
8058 &[1, 2, 2, 2],
8059 );
8060 let y = ct.forward(&x).unwrap();
8061 assert_eq!(y.shape(), &[1, 2, 4, 4]);
8062 assert_close(
8063 y.data().unwrap(),
8064 &[
8065 0.25, 0.25, 0.25, 0.25, 0.25, 2.25, 0.25, 1.85, 0.25, 0.25, 0.25, 0.25, 0.25, 2.65,
8066 0.25, 1.85, -0.25, -0.25, -0.25, -0.25, -0.25, 16.15, -0.25, 9.35, -0.25, -0.25,
8067 -0.25, -0.25, -0.25, 10.95, -0.25, 6.15,
8068 ],
8069 1e-3,
8070 );
8071 let grads = ct
8072 .forward(&x)
8073 .unwrap()
8074 .grad_fn()
8075 .unwrap()
8076 .backward(&t(&[1.0f32; 32], &[1, 2, 4, 4]))
8077 .unwrap();
8078 assert_close(
8079 grads[0].as_ref().unwrap().data().unwrap(),
8080 &[0.4, 0.7, 0.6, 1.0, 0.8, 1.5, 1.4, 2.6],
8081 1e-4,
8082 );
8083 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 2, 2]);
8084 assert_close(
8085 grads[1].as_ref().unwrap().data().unwrap(),
8086 &[4.0, 7.0, 6.0, 10.0, 8.0, 15.0, 14.0, 26.0],
8087 1e-4,
8088 );
8089 assert_close(
8090 grads[2].as_ref().unwrap().data().unwrap(),
8091 &[16.0, 16.0],
8092 1e-4,
8093 );
8094 }
8095
8096 #[test]
8098 fn test_conv_transpose2d_unbatched_matches_torch() {
8099 let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); let bias = [0.5f32];
8101 let ct = ct2d_full_fixed(
8102 2,
8103 1,
8104 (2, 2),
8105 (1, 1),
8106 (0, 0),
8107 (0, 0),
8108 (1, 1),
8109 1,
8110 &weight,
8111 Some(&bias),
8112 );
8113 let x = leaf(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), &[2, 2, 2]); let y = ct.forward(&x).unwrap();
8115 assert_eq!(y.shape(), &[1, 3, 3], "unbatched output must be rank 3");
8116 assert_close(
8117 y.data().unwrap(),
8118 &[3.1, 6.9, 4.5, 8.1, 18.9, 11.7, 6.3, 14.1, 8.5],
8119 1e-3,
8120 );
8121 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8122 ferrotorch_core::backward(&sum).unwrap();
8123 let gx = x.grad().unwrap().expect("input grad must be populated");
8124 assert_eq!(
8125 gx.shape(),
8126 &[2, 2, 2],
8127 "grad must match unbatched input shape"
8128 );
8129 assert_close(
8130 gx.data().unwrap(),
8131 &[1.0, 1.0, 1.0, 1.0, 2.6, 2.6, 2.6, 2.6],
8132 1e-4,
8133 );
8134 }
8135
8136 #[allow(clippy::too_many_arguments)]
8141 fn ct3d_full_fixed(
8142 in_c: usize,
8143 out_c: usize,
8144 k: (usize, usize, usize),
8145 stride: (usize, usize, usize),
8146 padding: (usize, usize, usize),
8147 output_padding: (usize, usize, usize),
8148 dilation: (usize, usize, usize),
8149 groups: usize,
8150 weight: &[f32],
8151 bias: Option<&[f32]>,
8152 ) -> ConvTranspose3d<f32> {
8153 let mut ct = ConvTranspose3d::<f32>::new_full(
8154 in_c,
8155 out_c,
8156 k,
8157 stride,
8158 padding,
8159 output_padding,
8160 dilation,
8161 groups,
8162 bias.is_some(),
8163 )
8164 .unwrap();
8165 ct.weight = Parameter::from_slice(weight, &[in_c, out_c / groups, k.0, k.1, k.2]).unwrap();
8166 if let Some(bvals) = bias {
8167 ct.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
8168 }
8169 ct
8170 }
8171
8172 #[test]
8174 fn test_conv_transpose3d_groups2_matches_torch() {
8175 let weight: Vec<f32> = (1..=2).map(|i| i as f32 * 0.5).collect(); let bias = [0.5f32, -0.5];
8177 let ct = ct3d_full_fixed(
8178 2,
8179 2,
8180 (1, 1, 1),
8181 (1, 1, 1),
8182 (0, 0, 0),
8183 (0, 0, 0),
8184 (1, 1, 1),
8185 2,
8186 &weight,
8187 Some(&bias),
8188 );
8189 let x = leaf(
8190 &(1..=16).map(|i| i as f32).collect::<Vec<_>>(),
8191 &[1, 2, 2, 2, 2],
8192 );
8193 let y = ct.forward(&x).unwrap();
8194 assert_eq!(y.shape(), &[1, 2, 2, 2, 2]);
8195 assert_close(
8196 y.data().unwrap(),
8197 &[
8198 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5,
8199 15.5,
8200 ],
8201 1e-3,
8202 );
8203 let grads = ct
8204 .forward(&x)
8205 .unwrap()
8206 .grad_fn()
8207 .unwrap()
8208 .backward(&t(&[1.0f32; 16], &[1, 2, 2, 2, 2]))
8209 .unwrap();
8210 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 1, 1, 1]);
8211 assert_close(
8212 grads[1].as_ref().unwrap().data().unwrap(),
8213 &[36.0, 100.0],
8214 1e-4,
8215 );
8216 assert_close(
8217 grads[2].as_ref().unwrap().data().unwrap(),
8218 &[8.0, 8.0],
8219 1e-4,
8220 );
8221 }
8222
8223 #[test]
8225 fn test_conv_transpose3d_dilation2_matches_torch() {
8226 let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); let ct = ct3d_full_fixed(
8228 1,
8229 1,
8230 (2, 2, 2),
8231 (1, 1, 1),
8232 (0, 0, 0),
8233 (0, 0, 0),
8234 (2, 2, 2),
8235 1,
8236 &weight,
8237 None,
8238 );
8239 let x = leaf(
8240 &(1..=8).map(|i| i as f32).collect::<Vec<_>>(),
8241 &[1, 1, 2, 2, 2],
8242 );
8243 let y = ct.forward(&x).unwrap();
8244 assert_eq!(y.shape(), &[1, 1, 4, 4, 4]);
8245 let yd = y.data().unwrap();
8247 assert_close(&yd[0..8], &[0.1, 0.2, 0.2, 0.4, 0.3, 0.4, 0.6, 0.8], 1e-3);
8248 assert_close(&yd[56..64], &[3.5, 4.2, 4.0, 4.8, 4.9, 5.6, 5.6, 6.4], 1e-3);
8249 let grads = ct
8250 .forward(&x)
8251 .unwrap()
8252 .grad_fn()
8253 .unwrap()
8254 .backward(&t(&[1.0f32; 64], &[1, 1, 4, 4, 4]))
8255 .unwrap();
8256 assert_close(
8257 grads[0].as_ref().unwrap().data().unwrap(),
8258 &[3.6, 3.6, 3.6, 3.6, 3.6, 3.6, 3.6, 3.6],
8259 1e-4,
8260 );
8261 assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
8262 assert_close(
8263 grads[1].as_ref().unwrap().data().unwrap(),
8264 &[36.0, 36.0, 36.0, 36.0, 36.0, 36.0, 36.0, 36.0],
8265 1e-4,
8266 );
8267 }
8268
8269 #[test]
8272 fn test_conv_transpose3d_combo_matches_torch() {
8273 let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.05).collect(); let bias = [0.1f32, -0.1];
8275 let ct = ct3d_full_fixed(
8276 2,
8277 2,
8278 (2, 2, 2),
8279 (2, 2, 2),
8280 (0, 0, 0),
8281 (1, 1, 1),
8282 (1, 1, 1),
8283 2,
8284 &weight,
8285 Some(&bias),
8286 );
8287 let x = leaf(
8288 &(1..=2).map(|i| i as f32).collect::<Vec<_>>(),
8289 &[1, 2, 1, 1, 1],
8290 );
8291 let y = ct.forward(&x).unwrap();
8292 assert_eq!(y.shape(), &[1, 2, 3, 3, 3]);
8293 let yd = y.data().unwrap();
8294 assert_close(
8296 &yd[0..9],
8297 &[0.15, 0.2, 0.1, 0.25, 0.3, 0.1, 0.1, 0.1, 0.1],
8298 1e-3,
8299 );
8300 assert_close(
8301 &yd[27..36],
8302 &[0.8, 0.9, -0.1, 1.0, 1.1, -0.1, -0.1, -0.1, -0.1],
8303 1e-3,
8304 );
8305 let grads = ct
8306 .forward(&x)
8307 .unwrap()
8308 .grad_fn()
8309 .unwrap()
8310 .backward(&t(&[1.0f32; 54], &[1, 2, 3, 3, 3]))
8311 .unwrap();
8312 assert_close(
8313 grads[0].as_ref().unwrap().data().unwrap(),
8314 &[1.8, 5.0],
8315 1e-4,
8316 );
8317 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 2, 2, 2]);
8318 assert_close(
8319 grads[1].as_ref().unwrap().data().unwrap(),
8320 &[
8321 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
8322 ],
8323 1e-4,
8324 );
8325 assert_close(
8326 grads[2].as_ref().unwrap().data().unwrap(),
8327 &[27.0, 27.0],
8328 1e-4,
8329 );
8330 }
8331
8332 #[test]
8346 fn test_conv_transpose3d_dilated_output_padding_negative_internal_pad_matches_torch() {
8347 let weight: Vec<f32> = (1..=4).map(|i| i as f32 * 0.1).collect(); let bias = [0.5f32];
8349 let ct = ct3d_full_fixed(
8350 1,
8351 1,
8352 (2, 2, 1),
8353 (2, 2, 2),
8354 (1, 1, 1),
8355 (1, 1, 1),
8356 (2, 3, 2),
8357 1,
8358 &weight,
8359 Some(&bias),
8360 );
8361 let x = leaf(
8362 &(1..=8).map(|i| i as f32).collect::<Vec<_>>(),
8363 &[1, 1, 2, 2, 2],
8364 );
8365 let y = ct.forward(&x).unwrap();
8366 assert_eq!(y.shape(), &[1, 1, 4, 5, 2]);
8367 let yd = y.data().unwrap();
8371 #[rustfmt::skip]
8372 let oracle = [
8373 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 2.5, 0.5, 2.5,
8374 0.5, 0.5, 0.5, 3.7, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
8375 0.5, 2.9, 0.5, 2.9, 0.5, 0.5, 0.5, 3.7,
8376 ];
8377 assert_close(yd, &oracle, 1e-4);
8378 let grads = ct
8380 .forward(&x)
8381 .unwrap()
8382 .grad_fn()
8383 .unwrap()
8384 .backward(&t(&[1.0f32; 40], &[1, 1, 4, 5, 2]))
8385 .unwrap();
8386 assert_close(
8387 grads[0].as_ref().unwrap().data().unwrap(),
8388 &[0.0, 0.4, 0.0, 0.7, 0.0, 0.6, 0.0, 1.0],
8389 1e-4,
8390 );
8391 assert_close(
8392 grads[1].as_ref().unwrap().data().unwrap(),
8393 &[8.0, 14.0, 12.0, 20.0],
8394 1e-4,
8395 );
8396 assert_close(grads[2].as_ref().unwrap().data().unwrap(), &[40.0], 1e-4);
8397 }
8398
8399 #[test]
8401 fn test_conv_transpose3d_unbatched_matches_torch() {
8402 let weight: Vec<f32> = (1..=2).map(|i| i as f32 * 0.5).collect(); let bias = [1.0f32];
8404 let ct = ct3d_full_fixed(
8405 2,
8406 1,
8407 (1, 1, 1),
8408 (1, 1, 1),
8409 (0, 0, 0),
8410 (0, 0, 0),
8411 (1, 1, 1),
8412 1,
8413 &weight,
8414 Some(&bias),
8415 );
8416 let x = leaf(
8417 &(1..=16).map(|i| i as f32).collect::<Vec<_>>(),
8418 &[2, 2, 2, 2],
8419 ); let y = ct.forward(&x).unwrap();
8421 assert_eq!(y.shape(), &[1, 2, 2, 2], "unbatched output must be rank 4");
8422 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8425 ferrotorch_core::backward(&sum).unwrap();
8426 let gx = x.grad().unwrap().expect("input grad must be populated");
8427 assert_eq!(
8428 gx.shape(),
8429 &[2, 2, 2, 2],
8430 "grad must match unbatched input shape"
8431 );
8432 assert_close(
8434 gx.data().unwrap(),
8435 &[
8436 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
8437 ],
8438 1e-4,
8439 );
8440 }
8441
8442 #[test]
8444 fn test_conv_transpose3d_groups_must_divide_channels() {
8445 assert!(
8446 ConvTranspose3d::<f32>::new_full(
8447 3,
8448 4,
8449 (1, 1, 1),
8450 (1, 1, 1),
8451 (0, 0, 0),
8452 (0, 0, 0),
8453 (1, 1, 1),
8454 2,
8455 true
8456 )
8457 .is_err()
8458 );
8459 assert!(
8460 ConvTranspose2d::<f32>::new_full(4, 5, (1, 1), (1, 1), (0, 0), (0, 0), (1, 1), 2, true)
8461 .is_err()
8462 );
8463 }
8464
8465 fn conv1d_fixed(
8481 weight: &[f32],
8482 wshape: &[usize],
8483 bias: &[f32],
8484 kernel: usize,
8485 padding: usize,
8486 mode: crate::padding::PaddingMode,
8487 ) -> Conv1d<f32> {
8488 let w = Parameter::from_slice(weight, wshape).unwrap();
8489 let b = Parameter::from_slice(bias, &[wshape[0]]).unwrap();
8490 Conv1d {
8491 weight: w,
8492 bias: Some(b),
8493 in_channels: wshape[1],
8494 out_channels: wshape[0],
8495 kernel_size: kernel,
8496 stride: 1,
8497 padding,
8498 dilation: 1,
8499 groups: 1,
8500 padding_mode: mode,
8501 string_padding: None,
8502 training: false,
8503 }
8504 }
8505
8506 #[test]
8510 fn test_conv1d_reflect_forward_matches_torch() {
8511 let conv = conv1d_fixed(
8512 &[1.0, 2.0, 3.0],
8513 &[1, 1, 3],
8514 &[0.5],
8515 3,
8516 1,
8517 crate::padding::PaddingMode::Reflect,
8518 );
8519 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8520 let y = conv.forward(&x).unwrap();
8521 assert_eq!(y.shape(), &[1, 1, 5]);
8522 assert_close(y.data().unwrap(), &[10.5, 14.5, 20.5, 26.5, 26.5], 1e-4);
8523 }
8524
8525 #[test]
8528 fn test_conv1d_replicate_forward_matches_torch() {
8529 let conv = conv1d_fixed(
8530 &[1.0, 2.0, 3.0],
8531 &[1, 1, 3],
8532 &[0.5],
8533 3,
8534 1,
8535 crate::padding::PaddingMode::Replicate,
8536 );
8537 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8538 let y = conv.forward(&x).unwrap();
8539 assert_close(y.data().unwrap(), &[9.5, 14.5, 20.5, 26.5, 29.5], 1e-4);
8540 }
8541
8542 #[test]
8545 fn test_conv1d_circular_forward_matches_torch() {
8546 let conv = conv1d_fixed(
8547 &[1.0, 2.0, 3.0],
8548 &[1, 1, 3],
8549 &[0.5],
8550 3,
8551 1,
8552 crate::padding::PaddingMode::Circular,
8553 );
8554 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8555 let y = conv.forward(&x).unwrap();
8556 assert_close(y.data().unwrap(), &[13.5, 14.5, 20.5, 26.5, 17.5], 1e-4);
8557 }
8558
8559 #[test]
8564 fn test_conv1d_reflect_backward_input_grad_matches_torch() {
8565 let conv = conv1d_fixed(
8566 &[1.0, 2.0, 3.0],
8567 &[1, 1, 3],
8568 &[0.5],
8569 3,
8570 1,
8571 crate::padding::PaddingMode::Reflect,
8572 );
8573 let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8574 let y = conv.forward(&x).unwrap();
8575 assert!(
8577 y.grad_fn().is_some(),
8578 "Conv1d reflect output lost its grad_fn — pre-pad severed autograd (#1550 class)"
8579 );
8580 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8582 ferrotorch_core::backward(&sum).unwrap();
8583 let xg = x
8584 .grad()
8585 .unwrap()
8586 .expect("input grad must be populated — pre-pad must be autograd-aware");
8587 assert_close(xg.data().unwrap(), &[3.0, 7.0, 6.0, 9.0, 5.0], 1e-4);
8588 }
8589
8590 #[test]
8592 fn test_conv1d_circular_backward_input_grad_matches_torch() {
8593 let conv = conv1d_fixed(
8594 &[1.0, 2.0, 3.0],
8595 &[1, 1, 3],
8596 &[0.5],
8597 3,
8598 1,
8599 crate::padding::PaddingMode::Circular,
8600 );
8601 let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8602 let y = conv.forward(&x).unwrap();
8603 assert!(y.grad_fn().is_some());
8604 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8605 ferrotorch_core::backward(&sum).unwrap();
8606 let xg = x.grad().unwrap().expect("input grad must be populated");
8607 assert_close(xg.data().unwrap(), &[6.0, 6.0, 6.0, 6.0, 6.0], 1e-4);
8608 }
8609
8610 fn conv1d_full_fixed(
8619 in_c: usize,
8620 out_c: usize,
8621 k: usize,
8622 dilation: usize,
8623 groups: usize,
8624 weight: &[f32],
8625 bias: Option<&[f32]>,
8626 ) -> Conv1d<f32> {
8627 let mut conv =
8628 Conv1d::<f32>::new_full(in_c, out_c, k, 1, 0, dilation, groups, bias.is_some())
8629 .unwrap();
8630 conv.weight = Parameter::from_slice(weight, &[out_c, in_c / groups, k]).unwrap();
8633 if let Some(bvals) = bias {
8634 conv.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
8635 }
8636 conv
8637 }
8638
8639 #[test]
8643 fn test_conv1d_groups2_forward_and_backward_matches_torch() {
8644 let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.1).collect();
8646 let bias = [0.5f32, -0.5, 0.25, -0.25];
8647 let conv = conv1d_full_fixed(4, 4, 2, 1, 2, &weight, Some(&bias));
8648
8649 let x_data: Vec<f32> = (1..=20).map(|i| i as f32).collect();
8651 let x = leaf(&x_data, &[1, 4, 5]);
8652 let y = conv.forward(&x).unwrap();
8653 assert_eq!(y.shape(), &[1, 4, 4]);
8654 assert_close(
8656 y.data().unwrap(),
8657 &[
8658 5.6, 6.6, 7.6, 8.6, 11.0, 13.6, 16.2, 18.8, 60.15, 64.35, 68.55, 72.75, 82.05,
8659 87.85, 93.65, 99.45,
8660 ],
8661 1e-3,
8662 );
8663
8664 let grad_output = t(&[1.0f32; 16], &[1, 4, 4]);
8666 let grads = conv
8667 .forward(&x)
8668 .unwrap()
8669 .grad_fn()
8670 .unwrap()
8671 .backward(&grad_output)
8672 .unwrap();
8673 assert_close(
8675 grads[0].as_ref().unwrap().data().unwrap(),
8676 &[
8677 0.6, 1.4, 1.4, 1.4, 0.8, 1.0, 2.2, 2.2, 2.2, 1.2, 2.2, 4.6, 4.6, 4.6, 2.4, 2.6,
8678 5.4, 5.4, 5.4, 2.8,
8679 ],
8680 1e-4,
8681 );
8682 assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 2, 2]);
8684 assert_close(
8685 grads[1].as_ref().unwrap().data().unwrap(),
8686 &[
8687 10.0, 14.0, 30.0, 34.0, 10.0, 14.0, 30.0, 34.0, 50.0, 54.0, 70.0, 74.0, 50.0, 54.0,
8688 70.0, 74.0,
8689 ],
8690 1e-4,
8691 );
8692 assert_close(
8694 grads[2].as_ref().unwrap().data().unwrap(),
8695 &[4.0, 4.0, 4.0, 4.0],
8696 1e-4,
8697 );
8698 }
8699
8700 #[test]
8703 fn test_conv1d_groups3_depthwise_forward_and_backward_matches_torch() {
8704 let weight: Vec<f32> = (1..=6).map(|i| i as f32 * 0.5).collect();
8706 let conv = conv1d_full_fixed(3, 3, 2, 1, 3, &weight, None);
8707
8708 let x_data: Vec<f32> = (1..=18).map(|i| i as f32).collect();
8710 let x = leaf(&x_data, &[1, 3, 6]);
8711 let y = conv.forward(&x).unwrap();
8712 assert_eq!(y.shape(), &[1, 3, 5]);
8713 assert_close(
8715 y.data().unwrap(),
8716 &[
8717 2.5, 4.0, 5.5, 7.0, 8.5, 26.5, 30.0, 33.5, 37.0, 40.5, 74.5, 80.0, 85.5, 91.0, 96.5,
8718 ],
8719 1e-3,
8720 );
8721
8722 let grad_output = t(&[1.0f32; 15], &[1, 3, 5]);
8723 let grads = conv
8724 .forward(&x)
8725 .unwrap()
8726 .grad_fn()
8727 .unwrap()
8728 .backward(&grad_output)
8729 .unwrap();
8730 assert_close(
8732 grads[0].as_ref().unwrap().data().unwrap(),
8733 &[
8734 0.5, 1.5, 1.5, 1.5, 1.5, 1.0, 1.5, 3.5, 3.5, 3.5, 3.5, 2.0, 2.5, 5.5, 5.5, 5.5,
8735 5.5, 3.0,
8736 ],
8737 1e-4,
8738 );
8739 assert_eq!(grads[1].as_ref().unwrap().shape(), &[3, 1, 2]);
8741 assert_close(
8742 grads[1].as_ref().unwrap().data().unwrap(),
8743 &[15.0, 20.0, 45.0, 50.0, 75.0, 80.0],
8744 1e-4,
8745 );
8746 }
8747
8748 #[test]
8752 fn test_conv1d_dilation2_forward_and_backward_matches_torch() {
8753 let weight: Vec<f32> = (1..=12).map(|i| i as f32 * 0.1).collect();
8755 let bias = [1.0f32, -1.0];
8756 let conv = conv1d_full_fixed(2, 2, 3, 2, 1, &weight, Some(&bias));
8757
8758 let x_data: Vec<f32> = (1..=14).map(|i| i as f32).collect();
8760 let x = leaf(&x_data, &[1, 2, 7]);
8761 let y = conv.forward(&x).unwrap();
8762 assert_eq!(y.shape(), &[1, 2, 3]);
8763 assert_close(
8765 y.data().unwrap(),
8766 &[18.6, 20.7, 22.8, 40.0, 45.7, 51.4],
8767 1e-3,
8768 );
8769
8770 let grad_output = t(&[1.0f32; 6], &[1, 2, 3]);
8771 let grads = conv
8772 .forward(&x)
8773 .unwrap()
8774 .grad_fn()
8775 .unwrap()
8776 .backward(&grad_output)
8777 .unwrap();
8778 assert_close(
8780 grads[0].as_ref().unwrap().data().unwrap(),
8781 &[
8782 0.8, 0.8, 1.8, 1.0, 2.2, 1.2, 1.2, 1.4, 1.4, 3.0, 1.6, 3.4, 1.8, 1.8,
8783 ],
8784 1e-4,
8785 );
8786 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 2, 3]);
8788 assert_close(
8789 grads[1].as_ref().unwrap().data().unwrap(),
8790 &[
8791 6.0, 12.0, 18.0, 27.0, 33.0, 39.0, 6.0, 12.0, 18.0, 27.0, 33.0, 39.0,
8792 ],
8793 1e-4,
8794 );
8795 assert_close(
8797 grads[2].as_ref().unwrap().data().unwrap(),
8798 &[3.0, 3.0],
8799 1e-4,
8800 );
8801 }
8802
8803 #[test]
8806 fn test_conv1d_groups_must_divide_channels() {
8807 assert!(Conv1d::<f32>::new_full(3, 4, 2, 1, 0, 1, 2, true).is_err());
8809 assert!(Conv1d::<f32>::new_full(4, 5, 2, 1, 0, 1, 2, true).is_err());
8811 assert!(Conv1d::<f32>::new_full(4, 4, 2, 1, 0, 1, 0, true).is_err());
8813 assert!(Conv1d::<f32>::new_full(4, 4, 2, 1, 0, 0, 2, true).is_err());
8815 assert!(Conv1d::<f32>::new_full(4, 4, 2, 1, 0, 1, 2, true).is_ok());
8817 }
8818
8819 fn conv3d_fixed(
8821 weight: &[f32],
8822 wshape: &[usize],
8823 bias: &[f32],
8824 kernel: (usize, usize, usize),
8825 padding: (usize, usize, usize),
8826 mode: crate::padding::PaddingMode,
8827 ) -> Conv3d<f32> {
8828 let w = Parameter::from_slice(weight, wshape).unwrap();
8829 let b = Parameter::from_slice(bias, &[wshape[0]]).unwrap();
8830 Conv3d {
8831 weight: w,
8832 bias: Some(b),
8833 in_channels: wshape[1],
8834 out_channels: wshape[0],
8835 kernel_size: kernel,
8836 stride: (1, 1, 1),
8837 padding,
8838 dilation: (1, 1, 1),
8839 groups: 1,
8840 padding_mode: mode,
8841 string_padding: None,
8842 training: false,
8843 }
8844 }
8845
8846 #[test]
8850 fn test_conv3d_replicate_forward_matches_torch() {
8851 let w: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8852 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8853 let conv = conv3d_fixed(
8854 &w,
8855 &[1, 1, 2, 2, 2],
8856 &[0.0],
8857 (2, 2, 2),
8858 (1, 1, 1),
8859 crate::padding::PaddingMode::Replicate,
8860 );
8861 let x = t(&x_data, &[1, 1, 2, 2, 2]);
8862 let y = conv.forward(&x).unwrap();
8863 assert_eq!(y.shape(), &[1, 1, 3, 3, 3]);
8864 let expected = [
8865 36.0, 56.0, 72.0, 80.0, 100.0, 116.0, 108.0, 128.0, 144.0, 140.0, 160.0, 176.0, 184.0,
8866 204.0, 220.0, 212.0, 232.0, 248.0, 180.0, 200.0, 216.0, 224.0, 244.0, 260.0, 252.0,
8867 272.0, 288.0,
8868 ];
8869 assert_close(y.data().unwrap(), &expected, 1e-3);
8870 }
8871
8872 #[test]
8874 fn test_conv3d_reflect_forward_matches_torch() {
8875 let w: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8876 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8877 let conv = conv3d_fixed(
8878 &w,
8879 &[1, 1, 2, 2, 2],
8880 &[0.0],
8881 (2, 2, 2),
8882 (1, 1, 1),
8883 crate::padding::PaddingMode::Reflect,
8884 );
8885 let x = t(&x_data, &[1, 1, 2, 2, 2]);
8886 let y = conv.forward(&x).unwrap();
8887 let expected = [
8888 120.0, 124.0, 120.0, 136.0, 140.0, 136.0, 120.0, 124.0, 120.0, 184.0, 188.0, 184.0,
8889 200.0, 204.0, 200.0, 184.0, 188.0, 184.0, 120.0, 124.0, 120.0, 136.0, 140.0, 136.0,
8890 120.0, 124.0, 120.0,
8891 ];
8892 assert_close(y.data().unwrap(), &expected, 1e-3);
8893 }
8894
8895 #[test]
8899 fn test_conv3d_circular_forward_matches_torch() {
8900 let mut w = vec![0.0f32; 8];
8901 w[0] = 1.0;
8902 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8903 let conv = conv3d_fixed(
8904 &w,
8905 &[1, 1, 2, 2, 2],
8906 &[0.0],
8907 (2, 2, 2),
8908 (1, 1, 1),
8909 crate::padding::PaddingMode::Circular,
8910 );
8911 let x = t(&x_data, &[1, 1, 2, 2, 2]);
8912 let y = conv.forward(&x).unwrap();
8913 let expected = [
8914 8.0, 7.0, 8.0, 6.0, 5.0, 6.0, 8.0, 7.0, 8.0, 4.0, 3.0, 4.0, 2.0, 1.0, 2.0, 4.0, 3.0,
8915 4.0, 8.0, 7.0, 8.0, 6.0, 5.0, 6.0, 8.0, 7.0, 8.0,
8916 ];
8917 assert_close(y.data().unwrap(), &expected, 1e-3);
8918 }
8919
8920 #[test]
8924 fn test_conv3d_replicate_backward_input_grad_matches_torch() {
8925 let w: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8926 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8927 let conv = conv3d_fixed(
8928 &w,
8929 &[1, 1, 2, 2, 2],
8930 &[0.0],
8931 (2, 2, 2),
8932 (1, 1, 1),
8933 crate::padding::PaddingMode::Replicate,
8934 );
8935 let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
8936 let y = conv.forward(&x).unwrap();
8937 assert!(
8938 y.grad_fn().is_some(),
8939 "Conv3d replicate output lost its grad_fn — pre-pad severed autograd (#1550 class)"
8940 );
8941 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8942 ferrotorch_core::backward(&sum).unwrap();
8943 let xg = x.grad().unwrap().expect("input grad must be populated");
8944 assert_close(
8945 xg.data().unwrap(),
8946 &[90.0, 99.0, 108.0, 117.0, 126.0, 135.0, 144.0, 153.0],
8947 1e-3,
8948 );
8949 }
8950
8951 #[test]
8960 fn test_conv3d_groups2_dilation2_forward_and_backward_matches_torch() {
8961 let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.01).collect();
8963 let bias = [0.1f32, -0.1];
8964 let mut conv =
8965 Conv3d::<f32>::new_full(2, 2, (2, 2, 2), (1, 1, 1), (0, 0, 0), (2, 2, 2), 2, true)
8966 .unwrap();
8967 conv.weight = Parameter::from_slice(&weight, &[2, 1, 2, 2, 2]).unwrap();
8968 conv.bias = Some(Parameter::from_slice(&bias, &[2]).unwrap());
8969
8970 let x_data: Vec<f32> = (1..=128).map(|i| i as f32).collect();
8972 let x = leaf(&x_data, &[1, 2, 4, 4, 4]);
8973 let y = conv.forward(&x).unwrap();
8974 assert_eq!(y.shape(), &[1, 2, 2, 2, 2]);
8975 assert_close(
8977 y.data().unwrap(),
8978 &[
8979 10.94, 11.3, 12.38, 12.74, 16.7, 17.06, 18.14, 18.5, 88.82, 89.82, 92.82, 93.82,
8980 104.82, 105.82, 108.82, 109.82,
8981 ],
8982 1e-3,
8983 );
8984
8985 let grad_output = t(&[1.0f32; 16], &[1, 2, 2, 2, 2]);
8986 let grads = conv
8987 .forward(&x)
8988 .unwrap()
8989 .grad_fn()
8990 .unwrap()
8991 .backward(&grad_output)
8992 .unwrap();
8993 #[rustfmt::skip]
8995 let d_gx: [f32; 128] = [
8996 0.01, 0.01, 0.02, 0.02, 0.01, 0.01, 0.02, 0.02, 0.03, 0.03, 0.04, 0.04, 0.03, 0.03,
8997 0.04, 0.04, 0.01, 0.01, 0.02, 0.02, 0.01, 0.01, 0.02, 0.02, 0.03, 0.03, 0.04, 0.04,
8998 0.03, 0.03, 0.04, 0.04, 0.05, 0.05, 0.06, 0.06, 0.05, 0.05, 0.06, 0.06, 0.07, 0.07,
8999 0.08, 0.08, 0.07, 0.07, 0.08, 0.08, 0.05, 0.05, 0.06, 0.06, 0.05, 0.05, 0.06, 0.06,
9000 0.07, 0.07, 0.08, 0.08, 0.07, 0.07, 0.08, 0.08, 0.09, 0.09, 0.1, 0.1, 0.09, 0.09, 0.1,
9001 0.1, 0.11, 0.11, 0.12, 0.12, 0.11, 0.11, 0.12, 0.12, 0.09, 0.09, 0.1, 0.1, 0.09, 0.09,
9002 0.1, 0.1, 0.11, 0.11, 0.12, 0.12, 0.11, 0.11, 0.12, 0.12, 0.13, 0.13, 0.14, 0.14, 0.13,
9003 0.13, 0.14, 0.14, 0.15, 0.15, 0.16, 0.16, 0.15, 0.15, 0.16, 0.16, 0.13, 0.13, 0.14,
9004 0.14, 0.13, 0.13, 0.14, 0.14, 0.15, 0.15, 0.16, 0.16, 0.15, 0.15, 0.16, 0.16,
9005 ];
9006 assert_close(grads[0].as_ref().unwrap().data().unwrap(), &d_gx, 1e-4);
9007 assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 2, 2, 2]);
9009 assert_close(
9010 grads[1].as_ref().unwrap().data().unwrap(),
9011 &[
9012 92.0, 108.0, 156.0, 172.0, 348.0, 364.0, 412.0, 428.0, 604.0, 620.0, 668.0, 684.0,
9013 860.0, 876.0, 924.0, 940.0,
9014 ],
9015 1e-3,
9016 );
9017 assert_close(
9019 grads[2].as_ref().unwrap().data().unwrap(),
9020 &[8.0, 8.0],
9021 1e-4,
9022 );
9023 }
9024
9025 #[test]
9029 fn test_conv3d_groups2_forward_and_backward_matches_torch() {
9030 let weight = [1.0f32, 2.0, 3.0, 4.0];
9032 let mut conv =
9033 Conv3d::<f32>::new_full(2, 4, (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1), 2, false)
9034 .unwrap();
9035 conv.weight = Parameter::from_slice(&weight, &[4, 1, 1, 1, 1]).unwrap();
9036
9037 let x_data: Vec<f32> = (1..=16).map(|i| i as f32).collect();
9039 let x = leaf(&x_data, &[1, 2, 2, 2, 2]);
9040 let y = conv.forward(&x).unwrap();
9041 assert_eq!(y.shape(), &[1, 4, 2, 2, 2]);
9042 assert_close(
9044 y.data().unwrap(),
9045 &[
9046 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0,
9047 27.0, 30.0, 33.0, 36.0, 39.0, 42.0, 45.0, 48.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0,
9048 60.0, 64.0,
9049 ],
9050 1e-3,
9051 );
9052
9053 let grad_output = t(&[1.0f32; 32], &[1, 4, 2, 2, 2]);
9054 let grads = conv
9055 .forward(&x)
9056 .unwrap()
9057 .grad_fn()
9058 .unwrap()
9059 .backward(&grad_output)
9060 .unwrap();
9061 assert_close(
9063 grads[0].as_ref().unwrap().data().unwrap(),
9064 &[
9065 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0,
9066 ],
9067 1e-4,
9068 );
9069 assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 1, 1, 1, 1]);
9071 assert_close(
9072 grads[1].as_ref().unwrap().data().unwrap(),
9073 &[36.0, 36.0, 100.0, 100.0],
9074 1e-4,
9075 );
9076 }
9077
9078 #[test]
9081 fn test_conv3d_groups_must_divide_channels() {
9082 assert!(
9084 Conv3d::<f32>::new_full(3, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (1, 1, 1), 2, true)
9085 .is_err()
9086 );
9087 assert!(
9089 Conv3d::<f32>::new_full(4, 5, (2, 2, 2), (1, 1, 1), (0, 0, 0), (1, 1, 1), 2, true)
9090 .is_err()
9091 );
9092 assert!(
9094 Conv3d::<f32>::new_full(4, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (1, 1, 1), 0, true)
9095 .is_err()
9096 );
9097 assert!(
9099 Conv3d::<f32>::new_full(4, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (0, 1, 1), 2, true)
9100 .is_err()
9101 );
9102 assert!(
9104 Conv3d::<f32>::new_full(4, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (2, 2, 2), 2, true)
9105 .is_ok()
9106 );
9107 }
9108
9109 #[test]
9112 fn test_conv1d_reflect_zero_padding_is_noop() {
9113 let conv = conv1d_fixed(
9114 &[1.0, 2.0, 3.0],
9115 &[1, 1, 3],
9116 &[0.0],
9117 3,
9118 0,
9119 crate::padding::PaddingMode::Reflect,
9120 );
9121 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9122 let y = conv.forward(&x).unwrap();
9123 assert_eq!(y.shape(), &[1, 1, 3]);
9125 assert_close(y.data().unwrap(), &[14.0, 20.0, 26.0], 1e-4);
9126 }
9127
9128 #[test]
9131 fn test_conv_transpose1d_reflect_padding_mode_rejected() {
9132 let conv = ConvTranspose1d::<f32>::new(2, 2, 3, 1, 0, 0, false).unwrap();
9133 let err = conv
9134 .with_padding_mode(crate::padding::PaddingMode::Reflect)
9135 .unwrap_err();
9136 let msg = format!("{err}");
9139 assert!(
9140 msg.contains("Only \"zeros\" padding mode is supported for ConvTranspose1d"),
9141 "got: {msg}"
9142 );
9143 }
9144
9145 #[test]
9146 fn test_conv_transpose2d_replicate_padding_mode_rejected() {
9147 let conv =
9148 ConvTranspose2d::<f32>::new(2, 2, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
9149 let err = conv
9150 .with_padding_mode(crate::padding::PaddingMode::Replicate)
9151 .unwrap_err();
9152 let msg = format!("{err}");
9153 assert!(
9154 msg.contains("Only \"zeros\" padding mode is supported for ConvTranspose2d"),
9155 "got: {msg}"
9156 );
9157 }
9158
9159 #[test]
9160 fn test_conv_transpose3d_circular_padding_mode_rejected() {
9161 let conv =
9162 ConvTranspose3d::<f32>::new(2, 2, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
9163 .unwrap();
9164 let err = conv
9165 .with_padding_mode(crate::padding::PaddingMode::Circular)
9166 .unwrap_err();
9167 let msg = format!("{err}");
9168 assert!(
9169 msg.contains("Only \"zeros\" padding mode is supported for ConvTranspose3d"),
9170 "got: {msg}"
9171 );
9172 }
9173
9174 #[test]
9176 fn test_conv_transpose2d_zeros_padding_mode_accepted() {
9177 let conv =
9178 ConvTranspose2d::<f32>::new(2, 2, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
9179 assert!(
9180 conv.with_padding_mode(crate::padding::PaddingMode::Zeros)
9181 .is_ok()
9182 );
9183 }
9184
9185 fn conv1d_with_weight(weight: &[f32], wshape: &[usize], bias: f32) -> Conv1d<f32> {
9196 let mut c = Conv1d::<f32>::new(wshape[1], wshape[0], wshape[2], 1, 0, true).unwrap();
9197 c.weight = Parameter::from_slice(weight, wshape).unwrap();
9200 c.bias = Some(Parameter::from_slice(&[bias], &[wshape[0]]).unwrap());
9201 c
9202 }
9203
9204 #[test]
9208 fn test_conv1d_same_odd_kernel_matches_torch() {
9209 let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5)
9210 .with_string_padding(StringPadding::Same)
9211 .unwrap();
9212 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9213 let y = conv.forward(&x).unwrap();
9214 assert_eq!(y.shape(), &[1, 1, 5]);
9215 assert_close(y.data().unwrap(), &[8.5, 14.5, 20.5, 26.5, 14.5], 1e-4);
9216 }
9217
9218 #[test]
9225 fn test_conv1d_same_even_kernel_asymmetric_matches_torch() {
9226 let conv = conv1d_with_weight(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4], 0.0)
9227 .with_string_padding(StringPadding::Same)
9228 .unwrap();
9229 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 1, 6]);
9230 let y = conv.forward(&x).unwrap();
9231 assert_eq!(y.shape(), &[1, 1, 6]);
9232 assert_close(
9233 y.data().unwrap(),
9234 &[20.0, 30.0, 40.0, 50.0, 32.0, 17.0],
9235 1e-4,
9236 );
9237 }
9238
9239 #[test]
9244 fn test_conv1d_same_odd_kernel_backward_matches_torch() {
9245 let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5)
9246 .with_string_padding(StringPadding::Same)
9247 .unwrap();
9248 let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9249 let y = conv.forward(&x).unwrap();
9250 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9252 ferrotorch_core::backward(&sum).unwrap();
9253 let gx = x.grad().unwrap().expect("input grad must be populated");
9254 assert_eq!(gx.shape(), &[1, 1, 5]);
9255 assert_close(gx.data().unwrap(), &[3.0, 6.0, 6.0, 6.0, 5.0], 1e-4);
9256 }
9257
9258 #[test]
9261 fn test_conv1d_same_even_kernel_backward_matches_torch() {
9262 let conv = conv1d_with_weight(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4], 0.0)
9263 .with_string_padding(StringPadding::Same)
9264 .unwrap();
9265 let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 1, 6]);
9266 let y = conv.forward(&x).unwrap();
9267 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9269 ferrotorch_core::backward(&sum).unwrap();
9270 let gx = x.grad().unwrap().expect("input grad must be populated");
9271 assert_eq!(gx.shape(), &[1, 1, 6]);
9272 assert_close(gx.data().unwrap(), &[3.0, 6.0, 10.0, 10.0, 10.0, 9.0], 1e-4);
9273 }
9274
9275 #[test]
9278 fn test_conv1d_valid_matches_torch() {
9279 let conv = conv1d_with_weight(&[1.0, 1.0, 1.0], &[1, 1, 3], 0.0)
9280 .with_string_padding(StringPadding::Valid)
9281 .unwrap();
9282 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9283 let y = conv.forward(&x).unwrap();
9284 assert_eq!(y.shape(), &[1, 1, 3]);
9285 assert_close(y.data().unwrap(), &[6.0, 9.0, 12.0], 1e-4);
9286 }
9287
9288 #[test]
9292 fn test_conv2d_same_odd_kernel_matches_torch() {
9293 let weight = Parameter::from_slice(
9294 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9295 &[1, 1, 3, 3],
9296 )
9297 .unwrap();
9298 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
9299 conv.set_weight(weight).unwrap();
9300 conv.bias = Some(Parameter::from_slice(&[0.5], &[1]).unwrap());
9301 let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9302 let x = t(
9303 &(1..=16).map(|v| v as f32).collect::<Vec<_>>(),
9304 &[1, 1, 4, 4],
9305 );
9306 let y = conv.forward(&x).unwrap();
9307 assert_eq!(y.shape(), &[1, 1, 4, 4]);
9308 let expected = [
9309 111.5, 178.5, 217.5, 145.5, 231.5, 348.5, 393.5, 252.5, 363.5, 528.5, 573.5, 360.5,
9310 197.5, 274.5, 295.5, 175.5,
9311 ];
9312 assert_close(y.data().unwrap(), &expected, 1e-3);
9313 }
9314
9315 #[test]
9319 fn test_conv2d_same_even_kernel_asymmetric_matches_torch() {
9320 let weight = Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
9321 let mut conv = Conv2d::<f32>::new(1, 1, (2, 2), (1, 1), (0, 0), false).unwrap();
9322 conv.set_weight(weight).unwrap();
9323 let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9324 let x = t(
9325 &(1..=9).map(|v| v as f32).collect::<Vec<_>>(),
9326 &[1, 1, 3, 3],
9327 );
9328 let y = conv.forward(&x).unwrap();
9329 assert_eq!(y.shape(), &[1, 1, 3, 3]);
9330 let expected = [37.0, 47.0, 21.0, 67.0, 77.0, 33.0, 23.0, 26.0, 9.0];
9331 assert_close(y.data().unwrap(), &expected, 1e-3);
9332 }
9333
9334 #[test]
9337 fn test_conv2d_same_backward_matches_torch() {
9338 let weight = Parameter::from_slice(
9339 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9340 &[1, 1, 3, 3],
9341 )
9342 .unwrap();
9343 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
9344 conv.set_weight(weight).unwrap();
9345 conv.bias = Some(Parameter::from_slice(&[0.5], &[1]).unwrap());
9346 let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9347 let x = leaf(
9348 &(1..=16).map(|v| v as f32).collect::<Vec<_>>(),
9349 &[1, 1, 4, 4],
9350 );
9351 let y = conv.forward(&x).unwrap();
9352 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9354 ferrotorch_core::backward(&sum).unwrap();
9355 let gx = x.grad().unwrap().expect("input grad must be populated");
9356 assert_eq!(gx.shape(), &[1, 1, 4, 4]);
9357 let expected = [
9358 12.0, 21.0, 21.0, 16.0, 27.0, 45.0, 45.0, 33.0, 27.0, 45.0, 45.0, 33.0, 24.0, 39.0,
9359 39.0, 28.0,
9360 ];
9361 assert_close(gx.data().unwrap(), &expected, 1e-3);
9362 }
9363
9364 #[test]
9367 fn test_conv2d_valid_matches_torch() {
9368 let weight = Parameter::from_slice(&[1.0; 9], &[1, 1, 3, 3]).unwrap();
9369 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9370 conv.set_weight(weight).unwrap();
9371 let conv = conv.with_string_padding(StringPadding::Valid).unwrap();
9372 let x = t(
9373 &(1..=25).map(|v| v as f32).collect::<Vec<_>>(),
9374 &[1, 1, 5, 5],
9375 );
9376 let y = conv.forward(&x).unwrap();
9377 assert_eq!(y.shape(), &[1, 1, 3, 3]);
9378 let expected = [63.0, 72.0, 81.0, 108.0, 117.0, 126.0, 153.0, 162.0, 171.0];
9379 assert_close(y.data().unwrap(), &expected, 1e-3);
9380 }
9381
9382 #[test]
9387 fn test_conv3d_same_even_kernel_asymmetric_matches_torch() {
9388 let weight =
9389 Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9390 .unwrap();
9391 let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9392 conv.weight = weight;
9394 let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9395 let x = t(
9396 &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9397 &[1, 1, 3, 3, 3],
9398 );
9399 let y = conv.forward(&x).unwrap();
9400 assert_eq!(y.shape(), &[1, 1, 3, 3, 3]);
9401 let expected = [
9402 356.0, 392.0, 186.0, 464.0, 500.0, 234.0, 205.0, 219.0, 99.0, 680.0, 716.0, 330.0,
9403 788.0, 824.0, 378.0, 331.0, 345.0, 153.0, 217.0, 227.0, 93.0, 247.0, 257.0, 105.0,
9404 77.0, 80.0, 27.0,
9405 ];
9406 assert_close(y.data().unwrap(), &expected, 1e-3);
9407 }
9408
9409 #[test]
9411 fn test_conv3d_same_backward_matches_torch() {
9412 let weight =
9413 Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9414 .unwrap();
9415 let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9416 conv.weight = weight;
9418 let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9419 let x = leaf(
9420 &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9421 &[1, 1, 3, 3, 3],
9422 );
9423 let y = conv.forward(&x).unwrap();
9424 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9426 ferrotorch_core::backward(&sum).unwrap();
9427 let gx = x.grad().unwrap().expect("input grad must be populated");
9428 assert_eq!(gx.shape(), &[1, 1, 3, 3, 3]);
9429 let expected = [
9430 1.0, 3.0, 3.0, 4.0, 10.0, 10.0, 4.0, 10.0, 10.0, 6.0, 14.0, 14.0, 16.0, 36.0, 36.0,
9431 16.0, 36.0, 36.0, 6.0, 14.0, 14.0, 16.0, 36.0, 36.0, 16.0, 36.0, 36.0,
9432 ];
9433 assert_close(gx.data().unwrap(), &expected, 1e-3);
9434 }
9435
9436 #[test]
9440 fn test_conv_same_stride_gt1_rejected() {
9441 let c1 = Conv1d::<f32>::new(1, 1, 3, 2, 0, false)
9443 .unwrap()
9444 .with_string_padding(StringPadding::Same);
9445 let e1 = c1.unwrap_err();
9446 assert!(
9447 format!("{e1}").contains("padding='same' is not supported for strided convolutions"),
9448 "conv1d: {e1}"
9449 );
9450 let c2 = Conv2d::<f32>::new(1, 1, (3, 3), (1, 2), (0, 0), false)
9452 .unwrap()
9453 .with_string_padding(StringPadding::Same);
9454 assert!(
9455 format!("{}", c2.unwrap_err())
9456 .contains("padding='same' is not supported for strided convolutions")
9457 );
9458 let c3 = Conv3d::<f32>::new(1, 1, (2, 2, 2), (2, 1, 1), (0, 0, 0), false)
9460 .unwrap()
9461 .with_string_padding(StringPadding::Same);
9462 assert!(
9463 format!("{}", c3.unwrap_err())
9464 .contains("padding='same' is not supported for strided convolutions")
9465 );
9466 assert!(
9468 Conv2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), false)
9469 .unwrap()
9470 .with_string_padding(StringPadding::Valid)
9471 .is_ok()
9472 );
9473 }
9474
9475 #[test]
9486 fn test_conv1d_unbatched_forward_matches_torch() {
9487 let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5);
9488 let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 5]); let y = conv.forward(&x).unwrap();
9490 assert_eq!(y.ndim(), 2, "unbatched output must be rank 2");
9491 assert_eq!(y.shape(), &[1, 3]);
9492 assert_close(y.data().unwrap(), &[14.5, 20.5, 26.5], 1e-4);
9493 }
9494
9495 #[test]
9498 fn test_conv1d_unbatched_backward_matches_torch() {
9499 let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5);
9500 let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 5]);
9501 let y = conv.forward(&x).unwrap();
9502 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9504 ferrotorch_core::backward(&sum).unwrap();
9505 let gx = x.grad().unwrap().expect("input grad must be populated");
9506 assert_eq!(gx.shape(), &[1, 5], "grad must match unbatched input shape");
9507 assert_close(gx.data().unwrap(), &[1.0, 3.0, 6.0, 5.0, 3.0], 1e-4);
9508 }
9509
9510 #[test]
9514 fn test_conv2d_unbatched_forward_matches_torch() {
9515 let weight = Parameter::from_slice(
9516 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9517 &[1, 1, 3, 3],
9518 )
9519 .unwrap();
9520 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9521 conv.set_weight(weight).unwrap();
9522 let x = t(&(1..=25).map(|v| v as f32).collect::<Vec<_>>(), &[1, 5, 5]); let y = conv.forward(&x).unwrap();
9524 assert_eq!(y.ndim(), 3, "unbatched output must be rank 3");
9525 assert_eq!(y.shape(), &[1, 3, 3]);
9526 let expected = [
9527 411.0, 456.0, 501.0, 636.0, 681.0, 726.0, 861.0, 906.0, 951.0,
9528 ];
9529 assert_close(y.data().unwrap(), &expected, 1e-3);
9530 }
9531
9532 #[test]
9535 fn test_conv2d_unbatched_backward_matches_torch() {
9536 let weight = Parameter::from_slice(
9537 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9538 &[1, 1, 3, 3],
9539 )
9540 .unwrap();
9541 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9542 conv.set_weight(weight).unwrap();
9543 let x = leaf(&(1..=25).map(|v| v as f32).collect::<Vec<_>>(), &[1, 5, 5]);
9544 let y = conv.forward(&x).unwrap();
9545 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9547 ferrotorch_core::backward(&sum).unwrap();
9548 let gx = x.grad().unwrap().expect("input grad must be populated");
9549 assert_eq!(gx.shape(), &[1, 5, 5], "grad must match unbatched input");
9550 let expected = [
9551 1.0, 3.0, 6.0, 5.0, 3.0, 5.0, 12.0, 21.0, 16.0, 9.0, 12.0, 27.0, 45.0, 33.0, 18.0,
9552 11.0, 24.0, 39.0, 28.0, 15.0, 7.0, 15.0, 24.0, 17.0, 9.0,
9553 ];
9554 assert_close(gx.data().unwrap(), &expected, 1e-3);
9555 }
9556
9557 #[test]
9561 fn test_conv3d_unbatched_forward_matches_torch() {
9562 let weight =
9563 Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9564 .unwrap();
9565 let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9566 conv.weight = weight;
9568 let x = t(
9569 &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9570 &[1, 3, 3, 3],
9571 ); let y = conv.forward(&x).unwrap();
9573 assert_eq!(y.ndim(), 4, "unbatched output must be rank 4");
9574 assert_eq!(y.shape(), &[1, 2, 2, 2]);
9575 let expected = [356.0, 392.0, 464.0, 500.0, 680.0, 716.0, 788.0, 824.0];
9576 assert_close(y.data().unwrap(), &expected, 1e-3);
9577 }
9578
9579 #[test]
9582 fn test_conv3d_unbatched_backward_matches_torch() {
9583 let weight =
9584 Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9585 .unwrap();
9586 let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9587 conv.weight = weight;
9589 let x = leaf(
9590 &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9591 &[1, 3, 3, 3],
9592 );
9593 let y = conv.forward(&x).unwrap();
9594 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9596 ferrotorch_core::backward(&sum).unwrap();
9597 let gx = x.grad().unwrap().expect("input grad must be populated");
9598 assert_eq!(
9599 gx.shape(),
9600 &[1, 3, 3, 3],
9601 "grad must match unbatched input shape"
9602 );
9603 let expected = [
9604 1.0, 3.0, 2.0, 4.0, 10.0, 6.0, 3.0, 7.0, 4.0, 6.0, 14.0, 8.0, 16.0, 36.0, 20.0, 10.0,
9605 22.0, 12.0, 5.0, 11.0, 6.0, 12.0, 26.0, 14.0, 7.0, 15.0, 8.0,
9606 ];
9607 assert_close(gx.data().unwrap(), &expected, 1e-3);
9608 }
9609
9610 #[test]
9613 fn test_conv2d_unbatched_same_composes() {
9614 let weight = Parameter::from_slice(
9615 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9616 &[1, 1, 3, 3],
9617 )
9618 .unwrap();
9619 let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9620 conv.set_weight(weight).unwrap();
9621 let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9622 let x = t(&(1..=16).map(|v| v as f32).collect::<Vec<_>>(), &[1, 4, 4]); let y = conv.forward(&x).unwrap();
9624 assert_eq!(y.ndim(), 3);
9625 assert_eq!(y.shape(), &[1, 4, 4], "same padding preserves spatial dims");
9626 }
9627}