1#![allow(clippy::single_range_in_vec_init)]
2use super::{ConvOptions, ConvTransposeOptions};
3use crate::{Backend, TensorMetadata, tensor::FloatTensor};
4use burn_std::{MetadataError, Shape, Slice};
5
6use alloc::{vec, vec::Vec};
7#[cfg(not(feature = "std"))]
8#[allow(unused_imports)]
9use num_traits::Float as _;
10
11pub fn calculate_pool_output_shape<const N: usize>(
13 in_shape: &Shape,
14 kernel_size: &[usize; N],
15 stride: &[usize; N],
16 padding: &[usize; N],
17 dilation: &[usize; N],
18 ceil_mode: bool,
19) -> Result<Shape, MetadataError> {
20 if in_shape.rank() != N + 2 {
21 return Err(MetadataError::RankMismatch {
22 left: in_shape.rank(),
23 right: N + 2,
24 });
25 }
26
27 let mut out_shape = in_shape.clone();
28 for (i, size_i) in out_shape[2..].iter_mut().enumerate() {
30 *size_i = calculate_pool_output_size(
31 kernel_size[i],
32 stride[i],
33 padding[i],
34 dilation[i],
35 *size_i,
36 ceil_mode,
37 );
38 }
39
40 Ok(out_shape)
41}
42
43pub fn calculate_conv_output_shape<const N: usize>(
45 in_shape: &Shape,
46 weight_shape: &Shape,
47 stride: &[usize; N],
48 padding: &[usize; N],
49 dilation: &[usize; N],
50) -> Result<Shape, MetadataError> {
51 if weight_shape.rank() != N + 2 {
52 return Err(MetadataError::RankMismatch {
53 left: weight_shape.rank(),
54 right: N + 2,
55 });
56 }
57
58 if in_shape.rank() != N + 2 {
59 return Err(MetadataError::RankMismatch {
60 left: in_shape.rank(),
61 right: N + 2,
62 });
63 }
64
65 let kernel_size = &weight_shape[2..];
66
67 let mut out_shape = in_shape.clone();
68 for (i, size_i) in out_shape[2..].iter_mut().enumerate() {
70 *size_i =
71 calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_i);
72 }
73 out_shape[1] = weight_shape[0];
75
76 Ok(out_shape)
77}
78
79pub fn calculate_conv_transpose_output_shape<const N: usize>(
81 in_shape: &Shape,
82 weight_shape: &Shape,
83 stride: &[usize; N],
84 padding: &[usize; N],
85 padding_out: &[usize; N],
86 dilation: &[usize; N],
87 groups: usize,
88) -> Result<Shape, MetadataError> {
89 if weight_shape.rank() != N + 2 {
90 return Err(MetadataError::RankMismatch {
91 left: weight_shape.rank(),
92 right: N + 2,
93 });
94 }
95
96 if in_shape.rank() != N + 2 {
97 return Err(MetadataError::RankMismatch {
98 left: in_shape.rank(),
99 right: N + 2,
100 });
101 }
102
103 let kernel_size = &weight_shape[2..];
104
105 let mut out_shape = in_shape.clone();
106 for (i, size_i) in out_shape[2..].iter_mut().enumerate() {
108 *size_i = calculate_conv_transpose_output_size(
109 kernel_size[i],
110 stride[i],
111 padding[i],
112 padding_out[i],
113 dilation[i],
114 *size_i,
115 );
116 }
117 out_shape[1] = weight_shape[1] * groups;
119
120 Ok(out_shape)
121}
122
123pub fn calculate_conv_padding(
125 kernel_size: usize,
126 stride: usize,
127 size_in: usize,
128 size_out: usize,
129) -> usize {
130 let kernel_size = kernel_size as f32;
131 let stride = stride as f32;
132 let size_in = size_in as f32;
133 let size_out = size_out as f32;
134
135 let padding = stride * (size_out - 1.) - size_in + kernel_size;
136 let padding = (padding / 2.).ceil();
137
138 padding as usize
139}
140
141pub fn calculate_conv_output_size(
143 kernel_size: usize,
144 stride: usize,
145 padding: usize,
146 dilation: usize,
147 size_in: usize,
148) -> usize {
149 (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
150}
151
152pub fn calculate_conv_output_sizes(
154 kernel_size: &[usize],
155 stride: &[usize],
156 padding: &[usize],
157 dilation: &[usize],
158 size_in: &[usize],
159) -> Vec<usize> {
160 size_in
161 .iter()
162 .enumerate()
163 .map(|(i, size_in)| {
164 calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_in)
165 })
166 .collect()
167}
168
169pub fn calculate_pool_output_size(
181 kernel_size: usize,
182 stride: usize,
183 padding: usize,
184 dilation: usize,
185 size_in: usize,
186 ceil_mode: bool,
187) -> usize {
188 let numerator = size_in + 2 * padding - dilation * (kernel_size - 1) - 1;
189 if ceil_mode {
190 numerator.div_ceil(stride) + 1
192 } else {
193 numerator / stride + 1
195 }
196}
197
198pub fn calculate_conv_transpose_output_size(
200 kernel_size: usize,
201 stride: usize,
202 padding: usize,
203 padding_out: usize,
204 dilation: usize,
205 size_in: usize,
206) -> usize {
207 (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding
208}
209
210fn calculate_conv_transpose_input_size(
213 kernel_size: usize,
214 stride: usize,
215 padding: usize,
216 padding_out: usize,
217 dilation: usize,
218 size_out: usize,
219) -> usize {
220 (size_out + 2 * padding - dilation * (kernel_size - 1) - padding_out - 1) / stride + 1
223}
224
225fn calculate_conv_transpose_input_sizes<const D: usize>(
227 kernel_size: [usize; D],
228 stride: [usize; D],
229 padding: [usize; D],
230 padding_out: [usize; D],
231 dilation: [usize; D],
232 size_out: [usize; D],
233) -> [usize; D] {
234 let mut res = [0; D];
235 for i in 0..D {
236 res[i] = calculate_conv_transpose_input_size(
237 kernel_size[i],
238 stride[i],
239 padding[i],
240 padding_out[i],
241 dilation[i],
242 size_out[i],
243 );
244 }
245 res
246}
247
248pub(crate) fn conv1d_x_backward<B: Backend>(
250 x: FloatTensor<B>,
251 weight: FloatTensor<B>,
252 output_grad: FloatTensor<B>,
253 options: ConvOptions<1>,
254) -> FloatTensor<B> {
255 let weight_shape = weight.shape();
256
257 let [_batch_size, _, length_in] = x.shape().dims();
258 let [_batch_size, _channels_out, length_out] = output_grad.shape().dims();
259 let [_, _, kernel_size] = weight_shape.dims();
260
261 let padding_out = calculate_padding_out(
262 kernel_size,
263 options.stride[0],
264 options.padding[0],
265 options.dilation[0],
266 length_in,
267 length_out,
268 );
269
270 B::conv_transpose1d(
271 output_grad,
272 weight,
273 None,
274 ConvTransposeOptions::new(
275 options.stride,
276 options.padding,
277 [padding_out],
278 options.dilation,
279 options.groups,
280 ),
281 )
282}
283
284pub(crate) fn conv1d_weight_backward<B: Backend>(
286 x: FloatTensor<B>,
287 weight: FloatTensor<B>,
288 output_grad: FloatTensor<B>,
289 options: ConvOptions<1>,
290) -> FloatTensor<B> {
291 let weight_dtype = weight.dtype();
292 let weight_shape = weight.shape();
293 let weight_device = B::float_device(&weight);
294
295 match options.groups == 1 {
296 true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
297 false => conv1d_weight_grad_groups::<B>(
298 x,
299 B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
300 output_grad,
301 options,
302 ),
303 }
304}
305
306pub(crate) fn conv1d_bias_backward<B: Backend>(
308 x: FloatTensor<B>,
309 bias: FloatTensor<B>,
310 output_grad: FloatTensor<B>,
311) -> FloatTensor<B> {
312 let [batch_size, _, _length_in] = x.shape().dims();
313 let [_batch_size, channels_out, length_out] = output_grad.shape().dims();
314
315 let grad = B::float_swap_dims(output_grad, 0, 1);
316 let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
317 let grad = B::float_sum_dim(grad, 1);
318
319 B::float_reshape(grad, bias.shape())
320}
321
322pub(crate) fn conv2d_x_backward<B: Backend>(
324 x: FloatTensor<B>,
325 weight: FloatTensor<B>,
326 output_grad: FloatTensor<B>,
327 options: ConvOptions<2>,
328) -> FloatTensor<B> {
329 let weight_shape = weight.shape();
330
331 let [_batch_size, _channels_in, height_in, width_in] = x.shape().dims();
332 let [_, _, height_out, width_out] = output_grad.shape().dims();
333 let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims();
334
335 let padding_1_out = calculate_padding_out(
336 kernel_size_1,
337 options.stride[0],
338 options.padding[0],
339 options.dilation[0],
340 height_in,
341 height_out,
342 );
343 let padding_2_out = calculate_padding_out(
344 kernel_size_2,
345 options.stride[1],
346 options.padding[1],
347 options.dilation[1],
348 width_in,
349 width_out,
350 );
351
352 B::conv_transpose2d(
353 output_grad,
354 weight,
355 None,
356 ConvTransposeOptions::new(
357 options.stride,
358 options.padding,
359 [padding_1_out, padding_2_out],
360 options.dilation,
361 options.groups,
362 ),
363 )
364}
365
366pub(crate) fn conv2d_weight_backward<B: Backend>(
368 x: FloatTensor<B>,
369 weight: FloatTensor<B>,
370 output_grad: FloatTensor<B>,
371 options: ConvOptions<2>,
372) -> FloatTensor<B> {
373 let weight_dtype = weight.dtype();
374 let weight_shape = weight.shape();
375 let weight_device = B::float_device(&weight);
376
377 match options.groups == 1 {
378 true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
379 false => conv2d_weight_grad_groups::<B>(
380 x,
381 B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
382 output_grad,
383 options,
384 ),
385 }
386}
387
388pub(crate) fn conv2d_bias_backward<B: Backend>(
390 x: FloatTensor<B>,
391 bias: FloatTensor<B>,
392 output_grad: FloatTensor<B>,
393) -> FloatTensor<B> {
394 let [batch_size, _, _, _] = x.shape().dims();
395 let [_, channels_out, height_out, width_out] = output_grad.shape().dims();
396
397 let grad = B::float_swap_dims(output_grad, 0, 1);
398 let grad = B::float_reshape(
399 grad,
400 Shape::new([channels_out, batch_size * height_out * width_out]),
401 );
402 let grad = B::float_sum_dim(grad, 1);
403
404 B::float_reshape(grad, bias.shape())
405}
406
407pub(crate) fn conv3d_x_backward<B: Backend>(
409 x: FloatTensor<B>,
410 weight: FloatTensor<B>,
411 output_grad: FloatTensor<B>,
412 options: ConvOptions<3>,
413) -> FloatTensor<B> {
414 let weight_shape = weight.shape();
415
416 let [_batch_size, _channels_in, depth_in, height_in, width_in] = x.shape().dims();
417 let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims();
418 let [
419 _channels_out,
420 _,
421 kernel_size_1,
422 kernel_size_2,
423 kernel_size_3,
424 ] = weight_shape.dims();
425
426 let padding_1_out = calculate_padding_out(
427 kernel_size_1,
428 options.stride[0],
429 options.padding[0],
430 options.dilation[0],
431 depth_in,
432 depth_out,
433 );
434 let padding_2_out = calculate_padding_out(
435 kernel_size_2,
436 options.stride[1],
437 options.padding[1],
438 options.dilation[1],
439 height_in,
440 height_out,
441 );
442 let padding_3_out = calculate_padding_out(
443 kernel_size_3,
444 options.stride[2],
445 options.padding[2],
446 options.dilation[2],
447 width_in,
448 width_out,
449 );
450
451 B::conv_transpose3d(
452 output_grad,
453 weight,
454 None,
455 ConvTransposeOptions::new(
456 options.stride,
457 options.padding,
458 [padding_1_out, padding_2_out, padding_3_out],
459 options.dilation,
460 options.groups,
461 ),
462 )
463}
464
465pub(crate) fn conv3d_weight_backward<B: Backend>(
467 x: FloatTensor<B>,
468 weight: FloatTensor<B>,
469 output_grad: FloatTensor<B>,
470 options: ConvOptions<3>,
471) -> FloatTensor<B> {
472 let weight_dtype = weight.dtype();
473 let weight_shape = weight.shape();
474 let weight_device = B::float_device(&weight);
475
476 match options.groups == 1 {
477 true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
478 false => conv3d_weight_grad_groups::<B>(
479 x,
480 B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
481 output_grad,
482 options,
483 ),
484 }
485}
486
487pub(crate) fn conv3d_bias_backward<B: Backend>(
489 x: FloatTensor<B>,
490 bias: FloatTensor<B>,
491 output_grad: FloatTensor<B>,
492) -> FloatTensor<B> {
493 let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = x.shape().dims();
494 let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims();
495
496 let grad = B::float_swap_dims(output_grad, 0, 1);
497 let grad = B::float_reshape(
498 grad,
499 Shape::new([
500 channels_out,
501 batch_size * depth_out * height_out * width_out,
502 ]),
503 );
504 let grad = B::float_sum_dim(grad, 1);
505
506 B::float_reshape(grad, bias.shape())
507}
508
509pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
511 weight: FloatTensor<B>,
512 output_grad: FloatTensor<B>,
513 options: ConvTransposeOptions<1>,
514) -> FloatTensor<B> {
515 let [batch_size, _c_out, out_length] = output_grad.shape().dims();
516 let [c_in, _c_out_groups, kernel_size] = weight.shape().dims();
517
518 let grad = B::conv1d(
519 output_grad,
520 weight,
521 None,
522 ConvOptions::new(
523 options.stride,
524 options.padding,
525 options.dilation,
526 options.groups,
527 ),
528 );
529
530 if options.padding_out[0] == 0 {
531 return grad;
532 }
533
534 let exp_length = calculate_conv_transpose_input_size(
535 kernel_size,
536 options.stride[0],
537 options.padding[0],
538 options.padding_out[0],
539 options.dilation[0],
540 out_length,
541 );
542
543 B::float_slice(
544 grad,
545 &[
546 Slice::from(0..batch_size),
547 Slice::from(0..c_in),
548 Slice::from(0..exp_length),
549 ],
550 )
551}
552
553pub(crate) fn conv_transpose1d_weight_backward<B: Backend>(
555 x: FloatTensor<B>,
556 weight: FloatTensor<B>,
557 output_grad: FloatTensor<B>,
558 options: ConvTransposeOptions<1>,
559) -> FloatTensor<B> {
560 let weight_dtype = weight.dtype();
561 let weight_shape = weight.shape();
562 let weight_device = B::float_device(&weight);
563
564 match options.groups == 1 {
565 true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
566 false => conv_transpose1d_weight_grad_groups::<B>(
567 x,
568 B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
569 output_grad,
570 options,
571 ),
572 }
573}
574
575pub(crate) fn conv_transpose1d_bias_backward<B: Backend>(
577 x: FloatTensor<B>,
578 bias: FloatTensor<B>,
579 output_grad: FloatTensor<B>,
580) -> FloatTensor<B> {
581 let [batch_size, _channels_in, _] = x.shape().dims();
582 let [_, channels_out, length_out] = output_grad.shape().dims();
583
584 let grad = B::float_swap_dims(output_grad, 0, 1);
585 let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
586 let grad = B::float_sum_dim(grad, 1);
587
588 B::float_reshape(grad, bias.shape())
589}
590
591pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
593 weight: FloatTensor<B>,
594 output_grad: FloatTensor<B>,
595 options: ConvTransposeOptions<2>,
596) -> FloatTensor<B> {
597 let [batch_size, _c_out, out_h, out_w] = output_grad.shape().dims();
598 let [c_in, _c_out_groups, k_h, k_w] = weight.shape().dims();
599
600 let grad = B::conv2d(
601 output_grad,
602 weight,
603 None,
604 ConvOptions::new(
605 options.stride,
606 options.padding,
607 options.dilation,
608 options.groups,
609 ),
610 );
611
612 if options.padding_out[0] == 0 && options.padding_out[1] == 0 {
613 return grad;
614 }
615
616 let [exp_h, exp_w] = calculate_conv_transpose_input_sizes(
617 [k_h, k_w],
618 options.stride,
619 options.padding,
620 options.padding_out,
621 options.dilation,
622 [out_h, out_w],
623 );
624
625 B::float_slice(
626 grad,
627 &[
628 Slice::from(0..batch_size),
629 Slice::from(0..c_in),
630 Slice::from(0..exp_h),
631 Slice::from(0..exp_w),
632 ],
633 )
634}
635
636pub(crate) fn conv_transpose2d_weight_backward<B: Backend>(
638 x: FloatTensor<B>,
639 weight: FloatTensor<B>,
640 output_grad: FloatTensor<B>,
641 options: ConvTransposeOptions<2>,
642) -> FloatTensor<B> {
643 let weight_dtype = weight.dtype();
644 let weight_shape = weight.shape();
645 let weight_device = B::float_device(&weight);
646
647 match options.groups == 1 {
648 true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
649 false => conv_transpose2d_weight_grad_groups::<B>(
650 x,
651 B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
652 output_grad,
653 options,
654 ),
655 }
656}
657
658pub(crate) fn conv_transpose2d_bias_backward<B: Backend>(
660 x: FloatTensor<B>,
661 bias: FloatTensor<B>,
662 output_grad: FloatTensor<B>,
663) -> FloatTensor<B> {
664 let [batch_size, _channels_in, _, _] = x.shape().dims();
665 let [_, channels_out, height_out, width_out] = output_grad.shape().dims();
666
667 let grad = B::float_swap_dims(output_grad, 0, 1);
668 let grad = B::float_reshape(
669 grad,
670 Shape::new([channels_out, batch_size * height_out * width_out]),
671 );
672 let grad = B::float_sum_dim(grad, 1);
673
674 B::float_reshape(grad, bias.shape())
675}
676
677pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
679 weight: FloatTensor<B>,
680 output_grad: FloatTensor<B>,
681 options: ConvTransposeOptions<3>,
682) -> FloatTensor<B> {
683 let [batch_size, _c_out, out_d, out_h, out_w] = output_grad.shape().dims();
684 let [c_in, _c_out_groups, k_d, k_h, k_w] = weight.shape().dims();
685
686 let grad = B::conv3d(
687 output_grad,
688 weight,
689 None,
690 ConvOptions::new(
691 options.stride,
692 options.padding,
693 options.dilation,
694 options.groups,
695 ),
696 );
697
698 if options.padding_out[0] == 0 && options.padding_out[1] == 0 && options.padding_out[2] == 0 {
699 return grad;
700 }
701
702 let [exp_d, exp_h, exp_w] = calculate_conv_transpose_input_sizes(
703 [k_d, k_h, k_w],
704 options.stride,
705 options.padding,
706 options.padding_out,
707 options.dilation,
708 [out_d, out_h, out_w],
709 );
710
711 B::float_slice(
712 grad,
713 &[
714 Slice::from(0..batch_size),
715 Slice::from(0..c_in),
716 Slice::from(0..exp_d),
717 Slice::from(0..exp_h),
718 Slice::from(0..exp_w),
719 ],
720 )
721}
722
723pub(crate) fn conv_transpose3d_weight_backward<B: Backend>(
725 x: FloatTensor<B>,
726 weight: FloatTensor<B>,
727 output_grad: FloatTensor<B>,
728 options: ConvTransposeOptions<3>,
729) -> FloatTensor<B> {
730 let weight_dtype = weight.dtype();
731 let weight_shape = weight.shape();
732 let weight_device = B::float_device(&weight);
733
734 match options.groups == 1 {
735 true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
736 false => conv_transpose3d_weight_grad_groups::<B>(
737 x,
738 B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
739 output_grad,
740 options,
741 ),
742 }
743}
744
745pub(crate) fn conv_transpose3d_bias_backward<B: Backend>(
747 x: FloatTensor<B>,
748 bias: FloatTensor<B>,
749 output_grad: FloatTensor<B>,
750) -> FloatTensor<B> {
751 let [batch_size, _channels_in, _, _, _] = x.shape().dims();
752 let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims();
753
754 let grad = B::float_swap_dims(output_grad, 0, 1);
755 let grad = B::float_reshape(
756 grad,
757 Shape::new([
758 channels_out,
759 batch_size * depth_out * height_out * width_out,
760 ]),
761 );
762 let grad = B::float_sum_dim(grad, 1);
763
764 B::float_reshape(grad, bias.shape())
765}
766
767pub(crate) fn conv1d_from_conv2d<B: Backend>(
769 x: FloatTensor<B>,
770 weight: FloatTensor<B>,
771 bias: Option<FloatTensor<B>>,
772 options: ConvOptions<1>,
773) -> FloatTensor<B> {
774 let [channels_out, _channels_in, kernel_size] = weight.shape().dims();
775 let [batch_size, channels_in, length_in] = x.shape().dims();
776
777 let weight = B::float_reshape(
778 weight,
779 Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]),
780 );
781 let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
782
783 let tensor = B::conv2d(
784 x,
785 weight,
786 bias,
787 ConvOptions::new(
788 [options.stride[0], 1],
789 [options.padding[0], 0],
790 [options.dilation[0], 1],
791 options.groups,
792 ),
793 );
794 let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();
795 B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
796}
797
798pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
800 x: FloatTensor<B>,
801 weight: FloatTensor<B>,
802 bias: Option<FloatTensor<B>>,
803 options: ConvTransposeOptions<1>,
804) -> FloatTensor<B> {
805 let [channels_in, channels_out, kernel_size] = weight.shape().dims();
806 let [batch_size, _channels_in, length_in] = x.shape().dims();
807
808 let weight = B::float_reshape(
809 weight,
810 Shape::new([channels_in, channels_out, kernel_size, 1]),
811 );
812 let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
813
814 let tensor = B::conv_transpose2d(
815 x,
816 weight,
817 bias,
818 ConvTransposeOptions::new(
819 [options.stride[0], 1],
820 [options.padding[0], 0],
821 [options.padding_out[0], 0],
822 [options.dilation[0], 1],
823 options.groups,
824 ),
825 );
826 let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();
827 B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
828}
829
830fn conv1d_weight_grad_no_groups<B: Backend>(
831 x: FloatTensor<B>,
832 output_grad: FloatTensor<B>,
833 weight_shape: Shape,
834 options: ConvOptions<1>,
835) -> FloatTensor<B> {
836 let x_swapped = B::float_swap_dims(x, 0, 1);
837 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
838 let weight_grad_swapped = B::conv1d(
839 x_swapped,
840 output_grad_swapped,
841 None,
842 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
843 );
844 let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
845
846 if weight_grad.shape() != weight_shape {
847 let slices = vec![
848 Slice::from(0..weight_shape[0]),
849 Slice::from(0..weight_shape[1]),
850 Slice::from(0..weight_shape[2]),
851 ];
852 weight_grad = B::float_slice(weight_grad, &slices);
853 }
854 weight_grad
855}
856
857fn conv2d_weight_grad_no_groups<B: Backend>(
858 x: FloatTensor<B>,
859 output_grad: FloatTensor<B>,
860 weight_shape: Shape,
861 options: ConvOptions<2>,
862) -> FloatTensor<B> {
863 let x_swapped = B::float_swap_dims(x, 0, 1);
864 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
865 let weight_grad_swapped = B::conv2d(
866 x_swapped,
867 output_grad_swapped,
868 None,
869 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
870 );
871 let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
872
873 if weight_grad.shape() != weight_shape {
874 let slices = vec![
875 Slice::from(0..weight_shape[0]),
876 Slice::from(0..weight_shape[1]),
877 Slice::from(0..weight_shape[2]),
878 Slice::from(0..weight_shape[3]),
879 ];
880 weight_grad = B::float_slice(weight_grad, &slices);
881 }
882 weight_grad
883}
884
885fn conv3d_weight_grad_no_groups<B: Backend>(
886 x: FloatTensor<B>,
887 output_grad: FloatTensor<B>,
888 weight_shape: Shape,
889 options: ConvOptions<3>,
890) -> FloatTensor<B> {
891 let x_swapped = B::float_swap_dims(x, 0, 1);
892 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
893 let weight_grad_swapped = B::conv3d(
894 x_swapped,
895 output_grad_swapped,
896 None,
897 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
898 );
899 let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
900
901 if weight_grad.shape() != weight_shape {
902 let slices = vec![
903 Slice::from(0..weight_shape[0]),
904 Slice::from(0..weight_shape[1]),
905 Slice::from(0..weight_shape[2]),
906 Slice::from(0..weight_shape[3]),
907 Slice::from(0..weight_shape[4]),
908 ];
909 weight_grad = B::float_slice(weight_grad, &slices);
910 }
911 weight_grad
912}
913
914fn conv1d_weight_grad_groups<B: Backend>(
915 x: FloatTensor<B>,
916 mut weight_grad: FloatTensor<B>,
917 output_grad: FloatTensor<B>,
918 options: ConvOptions<1>,
919) -> FloatTensor<B> {
920 let [channels_out, increment_ci, kernel_size] = weight_grad.shape().dims();
921 let increment_co = channels_out / options.groups;
922
923 let x_swapped = B::float_swap_dims(x, 0, 1);
924 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
925
926 for g in 0..options.groups {
927 let start_idx_ci = g * increment_ci;
928 let end_idx_ci = (g + 1) * increment_ci;
929 let start_idx_co = g * increment_co;
930 let end_idx_co = (g + 1) * increment_co;
931
932 let x_slice = vec![Slice::new(
933 start_idx_ci as isize,
934 Some(end_idx_ci as isize),
935 1,
936 )];
937 let x = B::float_slice(x_swapped.clone(), &x_slice);
938 let grad_slice = vec![Slice::new(
939 start_idx_co as isize,
940 Some(end_idx_co as isize),
941 1,
942 )];
943 let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
944 let mut weight_grad_tmp = B::conv1d(
945 x,
946 grad,
947 None,
948 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
949 );
950 weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
951 weight_grad = B::float_slice_assign(
952 weight_grad,
953 &[
954 Slice::from(start_idx_co..end_idx_co),
955 Slice::from(0..increment_ci),
956 Slice::from(0..kernel_size),
957 ],
958 weight_grad_tmp,
959 );
960 }
961
962 weight_grad
963}
964
965fn conv2d_weight_grad_groups<B: Backend>(
966 x: FloatTensor<B>,
967 mut weight_grad: FloatTensor<B>,
968 output_grad: FloatTensor<B>,
969 options: ConvOptions<2>,
970) -> FloatTensor<B> {
971 let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = weight_grad.shape().dims();
972 let increment_co = channels_out / options.groups;
973
974 let x_swapped = B::float_swap_dims(x, 0, 1);
975 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
976
977 for g in 0..options.groups {
978 let start_idx_ci = g * increment_ci;
979 let end_idx_ci = (g + 1) * increment_ci;
980 let start_idx_co = g * increment_co;
981 let end_idx_co = (g + 1) * increment_co;
982
983 let x_slice = vec![Slice::new(
984 start_idx_ci as isize,
985 Some(end_idx_ci as isize),
986 1,
987 )];
988 let x = B::float_slice(x_swapped.clone(), &x_slice);
989 let grad_slice = vec![Slice::new(
990 start_idx_co as isize,
991 Some(end_idx_co as isize),
992 1,
993 )];
994 let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
995 let mut weight_grad_tmp = B::conv2d(
996 x,
997 grad,
998 None,
999 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1000 );
1001 weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1002 let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims();
1003
1004 if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {
1005 let slices = vec![
1006 Slice::from(0..increment_co),
1007 Slice::from(0..increment_ci),
1008 Slice::from(0..kernel_size_1),
1009 Slice::from(0..kernel_size_2),
1010 ];
1011 weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1012 }
1013
1014 weight_grad = B::float_slice_assign(
1015 weight_grad,
1016 &[
1017 Slice::from(start_idx_co..end_idx_co),
1018 Slice::from(0..increment_ci),
1019 Slice::from(0..kernel_size_1),
1020 Slice::from(0..kernel_size_2),
1021 ],
1022 weight_grad_tmp,
1023 );
1024 }
1025
1026 weight_grad
1027}
1028
1029fn conv3d_weight_grad_groups<B: Backend>(
1030 x: FloatTensor<B>,
1031 mut weight_grad: FloatTensor<B>,
1032 output_grad: FloatTensor<B>,
1033 options: ConvOptions<3>,
1034) -> FloatTensor<B> {
1035 let [
1036 channels_out,
1037 increment_ci,
1038 kernel_size_1,
1039 kernel_size_2,
1040 kernel_size_3,
1041 ] = weight_grad.shape().dims();
1042 let increment_co = channels_out / options.groups;
1043
1044 let x_swapped = B::float_swap_dims(x, 0, 1);
1045 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1046
1047 for g in 0..options.groups {
1048 let start_idx_ci = g * increment_ci;
1049 let end_idx_ci = (g + 1) * increment_ci;
1050 let start_idx_co = g * increment_co;
1051 let end_idx_co = (g + 1) * increment_co;
1052
1053 let x_slice = vec![Slice::new(
1054 start_idx_ci as isize,
1055 Some(end_idx_ci as isize),
1056 1,
1057 )];
1058 let x = B::float_slice(x_swapped.clone(), &x_slice);
1059 let grad_slice = vec![Slice::new(
1060 start_idx_co as isize,
1061 Some(end_idx_co as isize),
1062 1,
1063 )];
1064 let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1065 let mut weight_grad_tmp = B::conv3d(
1066 x,
1067 grad,
1068 None,
1069 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1070 );
1071 weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1072 let [
1073 _,
1074 _,
1075 kernel_size_1_tmp,
1076 kernel_size_2_tmp,
1077 kernel_size_3_tmp,
1078 ] = weight_grad_tmp.shape().dims();
1079
1080 if kernel_size_1_tmp != kernel_size_1
1081 || kernel_size_2_tmp != kernel_size_2
1082 || kernel_size_3_tmp != kernel_size_3
1083 {
1084 let slices = vec![
1085 Slice::from(0..increment_co),
1086 Slice::from(0..increment_ci),
1087 Slice::from(0..kernel_size_1),
1088 Slice::from(0..kernel_size_2),
1089 Slice::from(0..kernel_size_3),
1090 ];
1091 weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1092 }
1093
1094 weight_grad = B::float_slice_assign(
1095 weight_grad,
1096 &[
1097 Slice::from(start_idx_co..end_idx_co),
1098 Slice::from(0..increment_ci),
1099 Slice::from(0..kernel_size_1),
1100 Slice::from(0..kernel_size_2),
1101 Slice::from(0..kernel_size_3),
1102 ],
1103 weight_grad_tmp,
1104 );
1105 }
1106
1107 weight_grad
1108}
1109
1110fn conv_transpose1d_weight_grad_no_groups<B: Backend>(
1111 x: FloatTensor<B>,
1112 output_grad: FloatTensor<B>,
1113 weight_shape: Shape,
1114 options: ConvTransposeOptions<1>,
1115) -> FloatTensor<B> {
1116 let x_swapped = B::float_swap_dims(x, 0, 1);
1117 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1118 let weight_grad_swapped = B::conv1d(
1119 output_grad_swapped,
1120 x_swapped,
1121 None,
1122 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1123 );
1124 let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
1125
1126 let grad_shape = weight_grad.shape();
1127 if grad_shape != weight_shape {
1128 let slices = vec![
1129 Slice::from(0..weight_shape[0]),
1130 Slice::from(0..weight_shape[1]),
1131 Slice::from(0..weight_shape[2]),
1132 ];
1133 weight_grad = B::float_slice(weight_grad, &slices);
1134 }
1135 weight_grad
1136}
1137
1138fn conv_transpose2d_weight_grad_no_groups<B: Backend>(
1139 x: FloatTensor<B>,
1140 output_grad: FloatTensor<B>,
1141 weight_shape: Shape,
1142 options: ConvTransposeOptions<2>,
1143) -> FloatTensor<B> {
1144 let x_swapped = B::float_swap_dims(x, 0, 1);
1145 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1146 let weight_grad_swapped = B::conv2d(
1147 output_grad_swapped,
1148 x_swapped,
1149 None,
1150 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1151 );
1152 let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
1153
1154 let grad_shape = weight_grad.shape();
1155 if grad_shape != weight_shape {
1156 let slices = vec![
1157 Slice::from(0..weight_shape[0]),
1158 Slice::from(0..weight_shape[1]),
1159 Slice::from(0..weight_shape[2]),
1160 Slice::from(0..weight_shape[3]),
1161 ];
1162 weight_grad = B::float_slice(weight_grad, &slices);
1163 }
1164 weight_grad
1165}
1166
1167fn conv_transpose3d_weight_grad_no_groups<B: Backend>(
1168 x: FloatTensor<B>,
1169 output_grad: FloatTensor<B>,
1170 weight_shape: Shape,
1171 options: ConvTransposeOptions<3>,
1172) -> FloatTensor<B> {
1173 let x_swapped = B::float_swap_dims(x, 0, 1);
1174 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1175 let weight_grad_swapped = B::conv3d(
1176 output_grad_swapped,
1177 x_swapped,
1178 None,
1179 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1180 );
1181 let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
1182
1183 let grad_shape = weight_grad.shape();
1184 if grad_shape != weight_shape {
1185 let slices = vec![
1186 Slice::from(0..weight_shape[0]),
1187 Slice::from(0..weight_shape[1]),
1188 Slice::from(0..weight_shape[2]),
1189 Slice::from(0..weight_shape[3]),
1190 Slice::from(0..weight_shape[4]),
1191 ];
1192 weight_grad = B::float_slice(weight_grad, &slices);
1193 }
1194 weight_grad
1195}
1196
1197fn conv_transpose1d_weight_grad_groups<B: Backend>(
1198 x: FloatTensor<B>,
1199 mut weight_grad: FloatTensor<B>,
1200 output_grad: FloatTensor<B>,
1201 options: ConvTransposeOptions<1>,
1202) -> FloatTensor<B> {
1203 let [channels_in, increment_co, kernel_size] = weight_grad.shape().dims();
1204 let increment_ci = channels_in / options.groups;
1205
1206 let x_swapped = B::float_swap_dims(x, 0, 1);
1207 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1208
1209 for g in 0..options.groups {
1210 let start_idx_ci = g * increment_ci;
1211 let end_idx_ci = (g + 1) * increment_ci;
1212 let start_idx_co = g * increment_co;
1213 let end_idx_co = (g + 1) * increment_co;
1214
1215 let x_slice = vec![Slice::new(
1216 start_idx_ci as isize,
1217 Some(end_idx_ci as isize),
1218 1,
1219 )];
1220 let x = B::float_slice(x_swapped.clone(), &x_slice);
1221 let grad_slice = vec![Slice::new(
1222 start_idx_co as isize,
1223 Some(end_idx_co as isize),
1224 1,
1225 )];
1226 let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1227 let mut weight_grad_tmp = B::conv1d(
1228 grad,
1229 x,
1230 None,
1231 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1232 );
1233 weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1234 let [_, _, kernel_size_tmp] = weight_grad_tmp.shape().dims();
1235
1236 if kernel_size_tmp != kernel_size {
1237 let slices = vec![
1238 Slice::from(0..increment_ci),
1239 Slice::from(0..increment_co),
1240 Slice::from(0..kernel_size),
1241 ];
1242 weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1243 }
1244
1245 weight_grad = B::float_slice_assign(
1246 weight_grad,
1247 &[
1248 Slice::from(start_idx_ci..end_idx_ci),
1249 Slice::from(0..increment_co),
1250 Slice::from(0..kernel_size),
1251 ],
1252 weight_grad_tmp,
1253 );
1254 }
1255
1256 weight_grad
1257}
1258
1259fn conv_transpose2d_weight_grad_groups<B: Backend>(
1260 x: FloatTensor<B>,
1261 mut weight_grad: FloatTensor<B>,
1262 output_grad: FloatTensor<B>,
1263 options: ConvTransposeOptions<2>,
1264) -> FloatTensor<B> {
1265 let [channels_in, increment_co, kernel_size_1, kernel_size_2] = weight_grad.shape().dims();
1266 let increment_ci = channels_in / options.groups;
1267
1268 let x_swapped = B::float_swap_dims(x, 0, 1);
1269 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1270
1271 for g in 0..options.groups {
1272 let start_idx_ci = g * increment_ci;
1273 let end_idx_ci = (g + 1) * increment_ci;
1274 let start_idx_co = g * increment_co;
1275 let end_idx_co = (g + 1) * increment_co;
1276
1277 let x_slice = vec![Slice::new(
1278 start_idx_ci as isize,
1279 Some(end_idx_ci as isize),
1280 1,
1281 )];
1282 let x = B::float_slice(x_swapped.clone(), &x_slice);
1283 let grad_slice = vec![Slice::new(
1284 start_idx_co as isize,
1285 Some(end_idx_co as isize),
1286 1,
1287 )];
1288 let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1289 let mut weight_grad_tmp = B::conv2d(
1290 grad,
1291 x,
1292 None,
1293 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1294 );
1295 weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1296 let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims();
1297
1298 if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {
1299 let slices = vec![
1300 Slice::from(0..increment_ci),
1301 Slice::from(0..increment_co),
1302 Slice::from(0..kernel_size_1),
1303 Slice::from(0..kernel_size_2),
1304 ];
1305 weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1306 }
1307
1308 weight_grad = B::float_slice_assign(
1309 weight_grad,
1310 &[
1311 Slice::from(start_idx_ci..end_idx_ci),
1312 Slice::from(0..increment_co),
1313 Slice::from(0..kernel_size_1),
1314 Slice::from(0..kernel_size_2),
1315 ],
1316 weight_grad_tmp,
1317 );
1318 }
1319
1320 weight_grad
1321}
1322
1323fn conv_transpose3d_weight_grad_groups<B: Backend>(
1324 x: FloatTensor<B>,
1325 mut weight_grad: FloatTensor<B>,
1326 output_grad: FloatTensor<B>,
1327 options: ConvTransposeOptions<3>,
1328) -> FloatTensor<B> {
1329 let [
1330 channels_in,
1331 increment_co,
1332 kernel_size_1,
1333 kernel_size_2,
1334 kernel_size_3,
1335 ] = weight_grad.shape().dims();
1336 let increment_ci = channels_in / options.groups;
1337
1338 let x_swapped = B::float_swap_dims(x, 0, 1);
1339 let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1340
1341 for g in 0..options.groups {
1342 let start_idx_ci = g * increment_ci;
1343 let end_idx_ci = (g + 1) * increment_ci;
1344 let start_idx_co = g * increment_co;
1345 let end_idx_co = (g + 1) * increment_co;
1346
1347 let x_slice = vec![Slice::new(
1348 start_idx_ci as isize,
1349 Some(end_idx_ci as isize),
1350 1,
1351 )];
1352 let x = B::float_slice(x_swapped.clone(), &x_slice);
1353 let grad_slice = vec![Slice::new(
1354 start_idx_co as isize,
1355 Some(end_idx_co as isize),
1356 1,
1357 )];
1358 let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1359 let mut weight_grad_tmp = B::conv3d(
1360 grad,
1361 x,
1362 None,
1363 ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1364 );
1365 weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1366 let [
1367 _,
1368 _,
1369 kernel_size_1_tmp,
1370 kernel_size_2_tmp,
1371 kernel_size_3_tmp,
1372 ] = weight_grad_tmp.shape().dims();
1373
1374 if kernel_size_1_tmp != kernel_size_1
1375 || kernel_size_2_tmp != kernel_size_2
1376 || kernel_size_3_tmp != kernel_size_3
1377 {
1378 let slices = vec![
1379 Slice::from(0..increment_ci),
1380 Slice::from(0..increment_co),
1381 Slice::from(0..kernel_size_1),
1382 Slice::from(0..kernel_size_2),
1383 Slice::from(0..kernel_size_3),
1384 ];
1385 weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1386 }
1387 weight_grad = B::float_slice_assign(
1388 weight_grad,
1389 &[
1390 Slice::from(start_idx_ci..end_idx_ci),
1391 Slice::from(0..increment_co),
1392 Slice::from(0..kernel_size_1),
1393 Slice::from(0..kernel_size_2),
1394 Slice::from(0..kernel_size_3),
1395 ],
1396 weight_grad_tmp,
1397 );
1398 }
1399
1400 weight_grad
1401}
1402
1403pub fn calculate_padding_out(
1408 kernel_size: usize,
1409 stride: usize,
1410 padding: usize,
1411 dilation: usize,
1412 size_in: usize,
1413 size_out: usize,
1414) -> usize {
1415 if stride <= 1 {
1416 return 0;
1417 }
1418
1419 let dilated_kernel = dilation * (kernel_size - 1) + 1;
1426 let base = (size_out as i64 - 1) * stride as i64 + dilated_kernel as i64 - 2 * padding as i64;
1427 i64::max(0, size_in as i64 - base) as usize
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432 use super::*;
1433
1434 #[test]
1435 fn test_calculate_output_size_1() {
1436 let kernel_size = 3;
1437 let stride = 1;
1438 let padding = 1;
1439 let size_in = 3;
1440 let dilation = 1;
1441
1442 let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1443
1444 assert_eq!(size_out, 3);
1445 }
1446
1447 #[test]
1448 fn test_calculate_output_size_2() {
1449 let kernel_size = 5;
1450 let stride = 2;
1451 let padding = 3;
1452 let size_in = 27;
1453 let dilation = 1;
1454
1455 let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1456
1457 assert_eq!(size_out, 15);
1458 }
1459
1460 #[test]
1461 fn test_calculate_output_size_3() {
1462 let kernel_size = 5;
1463 let stride = 2;
1464 let padding = 3;
1465 let size_in = 27;
1466 let dilation = 2;
1467
1468 let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1469
1470 assert_eq!(size_out, 13);
1471 }
1472
1473 #[test]
1474 fn test_calculate_same_padding_1() {
1475 let kernel_size = 3;
1476 let stride = 1;
1477 let size_in = 3;
1478 let dilation = 1;
1479
1480 let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
1481 let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1482
1483 assert_eq!(size_in, size_out, "Expected size");
1484 }
1485
1486 #[test]
1487 fn test_calculate_same_padding_2() {
1488 let kernel_size = 3;
1489 let stride = 2;
1490 let size_in = 7;
1491 let dilation = 1;
1492
1493 let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
1494 let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1495
1496 assert_eq!(size_in, size_out, "Expected size");
1497 }
1498
1499 #[test]
1500 fn test_calculate_output_padding_1() {
1501 let kernel_size = 3;
1502 let stride = 2;
1503 let size_in = 7;
1504 let size_out = 10;
1505 let dilation = 1;
1506
1507 let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out);
1508 let size_out_expected =
1509 calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1510
1511 assert_eq!(size_out, size_out_expected, "Expected size");
1512 }
1513
1514 #[test]
1515 fn test_expect_conv2d_output_shape() {
1516 let stride = [2, 1];
1521 let padding = [3, 1];
1522 let dilation = [2, 1];
1523 let shape = calculate_conv_output_shape(
1524 &Shape::new([12, 3, 27, 3]),
1525 &Shape::new([8, 3, 5, 3]),
1526 &stride,
1527 &padding,
1528 &dilation,
1529 )
1530 .unwrap();
1531 assert_eq!(shape, Shape::new([12, 8, 13, 3]))
1532 }
1533}