1use super::{conv, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::{
4 Shape, TensorMetadata,
5 backend::Backend,
6 ops::{FloatTensor, IntTensor},
7};
8use core::num::NonZeroUsize;
9
10#[derive(new)]
12pub struct Conv2dBackward<B: Backend> {
13 pub x_grad: FloatTensor<B>,
15
16 pub weights_grad: FloatTensor<B>,
18
19 pub bias_grad: Option<FloatTensor<B>>,
21}
22
23#[derive(new)]
25pub struct DeformConv2dBackward<B: Backend> {
26 pub x_grad: FloatTensor<B>,
28
29 pub offset_grad: FloatTensor<B>,
31
32 pub weight_grad: FloatTensor<B>,
34
35 pub mask_grad: Option<FloatTensor<B>>,
37
38 pub bias_grad: Option<FloatTensor<B>>,
40}
41
42#[derive(new)]
44pub struct Conv3dBackward<B: Backend> {
45 pub x_grad: FloatTensor<B>,
47
48 pub weights_grad: FloatTensor<B>,
50
51 pub bias_grad: Option<FloatTensor<B>>,
53}
54
55#[derive(new)]
57pub struct MaxPool1dBackward<B: Backend> {
58 pub x_grad: FloatTensor<B>,
60}
61
62#[derive(new)]
64pub struct MaxPool1dWithIndices<B: Backend> {
65 pub output: FloatTensor<B>,
67
68 pub indices: IntTensor<B>,
70}
71
72#[derive(new)]
74pub struct MaxPool2dBackward<B: Backend> {
75 pub x_grad: FloatTensor<B>,
77}
78
79#[derive(new)]
81pub struct MaxPool2dWithIndices<B: Backend> {
82 pub output: FloatTensor<B>,
84
85 pub indices: IntTensor<B>,
87}
88
89pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
92 NonZeroUsize::new(value).expect(msg);
93 value
94}
95
96#[derive(Debug, Clone, Hash, PartialEq, Eq)]
98pub struct ConvOptions<const N: usize> {
99 pub stride: [usize; N],
101
102 pub padding: [usize; N],
104
105 pub dilation: [usize; N],
107
108 pub groups: usize,
110}
111
112impl<const N: usize> ConvOptions<N> {
113 pub fn new(
115 stride: [usize; N],
116 padding: [usize; N],
117 dilation: [usize; N],
118 groups: usize,
119 ) -> Self {
120 Self {
121 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
122 padding,
123 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
124 groups: check_nonzero(groups, "groups must be non-zero"),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Hash, PartialEq, Eq)]
131pub struct DeformConvOptions<const N: usize> {
132 pub stride: [usize; N],
134
135 pub padding: [usize; N],
137
138 pub dilation: [usize; N],
140
141 pub weight_groups: usize,
143
144 pub offset_groups: usize,
146}
147
148impl<const N: usize> DeformConvOptions<N> {
149 pub fn new(
151 stride: [usize; N],
152 padding: [usize; N],
153 dilation: [usize; N],
154 weight_groups: usize,
155 offset_groups: usize,
156 ) -> Self {
157 Self {
158 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
159 padding,
160 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
161 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
162 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
163 }
164 }
165}
166
167#[derive(Debug, Clone, Hash, PartialEq, Eq)]
169pub struct ConvTransposeOptions<const N: usize> {
170 pub stride: [usize; N],
172
173 pub padding: [usize; N],
175
176 pub padding_out: [usize; N],
178
179 pub dilation: [usize; N],
181
182 pub groups: usize,
184}
185
186impl<const N: usize> ConvTransposeOptions<N> {
187 pub fn new(
189 stride: [usize; N],
190 padding: [usize; N],
191 padding_out: [usize; N],
192 dilation: [usize; N],
193 groups: usize,
194 ) -> Self {
195 Self {
196 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
197 padding,
198 padding_out,
199 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
200 groups: check_nonzero(groups, "groups must be non-zero"),
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct UnfoldOptions {
208 pub stride: [usize; 2],
211
212 pub padding: [usize; 2],
214
215 pub dilation: [usize; 2],
217}
218
219impl UnfoldOptions {
220 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
222 Self {
223 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
224 padding,
225 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
226 }
227 }
228}
229
230#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
232pub enum InterpolateMode {
233 Nearest,
236
237 Bilinear,
240
241 Bicubic,
244}
245
246#[derive(new, Debug, Clone)]
248pub struct InterpolateOptions {
249 pub mode: InterpolateMode,
251}
252
253#[derive(new)]
255pub struct InterpolateBackward<B: Backend> {
256 pub x_grad: FloatTensor<B>,
258}
259
260pub trait ModuleOps<B: Backend> {
262 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
273 let [batch_size, seq_length] = indices.shape().dims();
274 let [_, d_model] = weights.shape().dims();
275
276 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
277 let output = B::float_select(weights, 0, indices);
278
279 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
280 }
281
282 fn embedding_backward(
294 weights: FloatTensor<B>,
295 output_grad: FloatTensor<B>,
296 indices: IntTensor<B>,
297 ) -> FloatTensor<B> {
298 let [batch_size, seq_length] = indices.shape().dims();
299 let [n_embeddings, d_model] = weights.shape().dims();
300 let device = B::float_device(&weights);
301 let dtype = output_grad.dtype();
302
303 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
304 let output_grad =
305 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
306 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
307
308 B::float_select_assign(grad, 0, indices, output_grad)
309 }
310 fn conv1d(
318 x: FloatTensor<B>,
319 weight: FloatTensor<B>,
320 bias: Option<FloatTensor<B>>,
321 options: ConvOptions<1>,
322 ) -> FloatTensor<B> {
323 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
324 }
325 fn conv1d_x_backward(
327 x: FloatTensor<B>,
328 weight: FloatTensor<B>,
329 output_grad: FloatTensor<B>,
330 options: ConvOptions<1>,
331 ) -> FloatTensor<B> {
332 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
333 }
334 fn conv1d_weight_backward(
336 x: FloatTensor<B>,
337 weight: FloatTensor<B>,
338 output_grad: FloatTensor<B>,
339 options: ConvOptions<1>,
340 ) -> FloatTensor<B> {
341 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
342 }
343 fn conv1d_bias_backward(
345 x: FloatTensor<B>,
346 bias: FloatTensor<B>,
347 output_grad: FloatTensor<B>,
348 ) -> FloatTensor<B> {
349 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
350 }
351 fn conv2d(
359 x: FloatTensor<B>,
360 weight: FloatTensor<B>,
361 bias: Option<FloatTensor<B>>,
362 options: ConvOptions<2>,
363 ) -> FloatTensor<B>;
364 fn conv2d_x_backward(
366 x: FloatTensor<B>,
367 weight: FloatTensor<B>,
368 output_grad: FloatTensor<B>,
369 options: ConvOptions<2>,
370 ) -> FloatTensor<B> {
371 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
372 }
373 fn conv2d_weight_backward(
375 x: FloatTensor<B>,
376 weight: FloatTensor<B>,
377 output_grad: FloatTensor<B>,
378 options: ConvOptions<2>,
379 ) -> FloatTensor<B> {
380 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
381 }
382 fn conv2d_bias_backward(
384 x: FloatTensor<B>,
385 weight: FloatTensor<B>,
386 bias: FloatTensor<B>,
387 output_grad: FloatTensor<B>,
388 ) -> FloatTensor<B> {
389 conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
390 }
391
392 fn deform_conv2d(
400 x: FloatTensor<B>,
401 offset: FloatTensor<B>,
402 weight: FloatTensor<B>,
403 mask: Option<FloatTensor<B>>,
404 bias: Option<FloatTensor<B>>,
405 options: DeformConvOptions<2>,
406 ) -> FloatTensor<B>;
407 fn deform_conv2d_backward(
409 x: FloatTensor<B>,
410 offset: FloatTensor<B>,
411 weight: FloatTensor<B>,
412 mask: Option<FloatTensor<B>>,
413 bias: Option<FloatTensor<B>>,
414 output_grad: FloatTensor<B>,
415 options: DeformConvOptions<2>,
416 ) -> DeformConv2dBackward<B>;
417
418 fn conv3d(
426 x: FloatTensor<B>,
427 weight: FloatTensor<B>,
428 bias: Option<FloatTensor<B>>,
429 options: ConvOptions<3>,
430 ) -> FloatTensor<B>;
431 fn conv3d_x_backward(
433 x: FloatTensor<B>,
434 weight: FloatTensor<B>,
435 output_grad: FloatTensor<B>,
436 options: ConvOptions<3>,
437 ) -> FloatTensor<B> {
438 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
439 }
440 fn conv3d_weight_backward(
442 x: FloatTensor<B>,
443 weight: FloatTensor<B>,
444 output_grad: FloatTensor<B>,
445 options: ConvOptions<3>,
446 ) -> FloatTensor<B> {
447 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
448 }
449 fn conv3d_bias_backward(
451 x: FloatTensor<B>,
452 weight: FloatTensor<B>,
453 bias: FloatTensor<B>,
454 output_grad: FloatTensor<B>,
455 ) -> FloatTensor<B> {
456 conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
457 }
458 fn conv_transpose1d(
466 x: FloatTensor<B>,
467 weight: FloatTensor<B>,
468 bias: Option<FloatTensor<B>>,
469 options: ConvTransposeOptions<1>,
470 ) -> FloatTensor<B> {
471 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
472 }
473 fn conv_transpose1d_x_backward(
475 weight: FloatTensor<B>,
476 output_grad: FloatTensor<B>,
477 options: ConvTransposeOptions<1>,
478 ) -> FloatTensor<B> {
479 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
480 }
481 fn conv_transpose1d_weight_backward(
483 x: FloatTensor<B>,
484 weight: FloatTensor<B>,
485 output_grad: FloatTensor<B>,
486 options: ConvTransposeOptions<1>,
487 ) -> FloatTensor<B> {
488 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
489 }
490 fn conv_transpose1d_bias_backward(
492 x: FloatTensor<B>,
493 bias: FloatTensor<B>,
494 output_grad: FloatTensor<B>,
495 ) -> FloatTensor<B> {
496 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
497 }
498
499 fn conv_transpose2d(
507 x: FloatTensor<B>,
508 weight: FloatTensor<B>,
509 bias: Option<FloatTensor<B>>,
510 options: ConvTransposeOptions<2>,
511 ) -> FloatTensor<B>;
512 fn conv_transpose2d_x_backward(
514 weight: FloatTensor<B>,
515 output_grad: FloatTensor<B>,
516 options: ConvTransposeOptions<2>,
517 ) -> FloatTensor<B> {
518 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
519 }
520 fn conv_transpose2d_weight_backward(
522 x: FloatTensor<B>,
523 weight: FloatTensor<B>,
524 output_grad: FloatTensor<B>,
525 options: ConvTransposeOptions<2>,
526 ) -> FloatTensor<B> {
527 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
528 }
529 fn conv_transpose2d_bias_backward(
531 x: FloatTensor<B>,
532 bias: FloatTensor<B>,
533 output_grad: FloatTensor<B>,
534 ) -> FloatTensor<B> {
535 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
536 }
537
538 fn conv_transpose3d(
546 x: FloatTensor<B>,
547 weight: FloatTensor<B>,
548 bias: Option<FloatTensor<B>>,
549 options: ConvTransposeOptions<3>,
550 ) -> FloatTensor<B>;
551 fn conv_transpose3d_x_backward(
553 weight: FloatTensor<B>,
554 output_grad: FloatTensor<B>,
555 options: ConvTransposeOptions<3>,
556 ) -> FloatTensor<B> {
557 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
558 }
559 fn conv_transpose3d_weight_backward(
561 x: FloatTensor<B>,
562 weight: FloatTensor<B>,
563 output_grad: FloatTensor<B>,
564 options: ConvTransposeOptions<3>,
565 ) -> FloatTensor<B> {
566 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
567 }
568 fn conv_transpose3d_bias_backward(
570 x: FloatTensor<B>,
571 bias: FloatTensor<B>,
572 output_grad: FloatTensor<B>,
573 ) -> FloatTensor<B> {
574 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
575 }
576
577 fn unfold4d(
584 x: FloatTensor<B>,
585 kernel_size: [usize; 2],
586 options: UnfoldOptions,
587 ) -> FloatTensor<B> {
588 if options.padding == [0, 0] && options.dilation == [1, 1] {
589 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
590 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
591
592 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
595 let shape = &blocks.shape().dims;
596
597 B::float_reshape(
600 blocks,
601 [
602 shape[0],
603 shape[1] * shape[2] * shape[3],
604 shape[4] * shape[5],
605 ]
606 .into(),
607 )
608 } else {
609 unfold4d_using_conv2d::<B>(x, kernel_size, options)
610 }
611 }
612
613 fn avg_pool1d(
619 x: FloatTensor<B>,
620 kernel_size: usize,
621 stride: usize,
622 padding: usize,
623 count_include_pad: bool,
624 ) -> FloatTensor<B> {
625 pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad)
626 }
627 fn avg_pool1d_backward(
629 x: FloatTensor<B>,
630 grad: FloatTensor<B>,
631 kernel_size: usize,
632 stride: usize,
633 padding: usize,
634 count_include_pad: bool,
635 ) -> FloatTensor<B> {
636 pool::avg_pool1d_backward_from_2d::<B>(
637 x,
638 grad,
639 kernel_size,
640 stride,
641 padding,
642 count_include_pad,
643 )
644 }
645 fn avg_pool2d(
651 x: FloatTensor<B>,
652 kernel_size: [usize; 2],
653 stride: [usize; 2],
654 padding: [usize; 2],
655 count_include_pad: bool,
656 ) -> FloatTensor<B>;
657 fn avg_pool2d_backward(
659 x: FloatTensor<B>,
660 grad: FloatTensor<B>,
661 kernel_size: [usize; 2],
662 stride: [usize; 2],
663 padding: [usize; 2],
664 count_include_pad: bool,
665 ) -> FloatTensor<B>;
666 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
672 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
674 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
680 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
681 }
682 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
684 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
685 }
686 fn max_pool1d(
692 x: FloatTensor<B>,
693 kernel_size: usize,
694 stride: usize,
695 padding: usize,
696 dilation: usize,
697 ) -> FloatTensor<B> {
698 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
699 }
700
701 fn max_pool1d_with_indices(
707 x: FloatTensor<B>,
708 kernel_size: usize,
709 stride: usize,
710 padding: usize,
711 dilation: usize,
712 ) -> MaxPool1dWithIndices<B> {
713 pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
714 }
715 fn max_pool1d_with_indices_backward(
717 x: FloatTensor<B>,
718 kernel_size: usize,
719 stride: usize,
720 padding: usize,
721 dilation: usize,
722 output_grad: FloatTensor<B>,
723 indices: IntTensor<B>,
724 ) -> MaxPool1dBackward<B> {
725 pool::max_pool1d_with_indices_backward_from_2d::<B>(
726 x,
727 kernel_size,
728 stride,
729 padding,
730 dilation,
731 output_grad,
732 indices,
733 )
734 }
735
736 fn max_pool2d(
742 x: FloatTensor<B>,
743 kernel_size: [usize; 2],
744 stride: [usize; 2],
745 padding: [usize; 2],
746 dilation: [usize; 2],
747 ) -> FloatTensor<B>;
748
749 fn max_pool2d_with_indices(
755 x: FloatTensor<B>,
756 kernel_size: [usize; 2],
757 stride: [usize; 2],
758 padding: [usize; 2],
759 dilation: [usize; 2],
760 ) -> MaxPool2dWithIndices<B>;
761 fn max_pool2d_with_indices_backward(
763 x: FloatTensor<B>,
764 kernel_size: [usize; 2],
765 stride: [usize; 2],
766 padding: [usize; 2],
767 dilation: [usize; 2],
768 output_grad: FloatTensor<B>,
769 indices: IntTensor<B>,
770 ) -> MaxPool2dBackward<B>;
771
772 fn interpolate(
778 x: FloatTensor<B>,
779 output_size: [usize; 2],
780 options: InterpolateOptions,
781 ) -> FloatTensor<B>;
782
783 fn interpolate_backward(
785 x: FloatTensor<B>,
786 grad: FloatTensor<B>,
787 output_size: [usize; 2],
788 options: InterpolateOptions,
789 ) -> FloatTensor<B>;
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795
796 #[test]
797 #[should_panic = "stride must be non-zero"]
798 fn conv_options_stride_zero() {
799 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
800 }
801
802 #[test]
803 #[should_panic = "dilation must be non-zero"]
804 fn conv_options_dilation_zero() {
805 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
806 }
807
808 #[test]
809 #[should_panic = "groups must be non-zero"]
810 fn conv_options_groups_zero() {
811 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
812 }
813
814 #[test]
815 #[should_panic = "stride must be non-zero"]
816 fn conv_transpose_options_stride_zero() {
817 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
818 }
819
820 #[test]
821 #[should_panic = "dilation must be non-zero"]
822 fn conv_transpose_options_dilation_zero() {
823 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
824 }
825
826 #[test]
827 #[should_panic = "groups must be non-zero"]
828 fn conv_transpose_options_groups_zero() {
829 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
830 }
831
832 #[test]
833 #[should_panic = "stride must be non-zero"]
834 fn deform_conv_options_stride_zero() {
835 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
836 }
837
838 #[test]
839 #[should_panic = "dilation must be non-zero"]
840 fn deform_conv_options_dilation_zero() {
841 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
842 }
843
844 #[test]
845 #[should_panic = "weight groups must be non-zero"]
846 fn deform_conv_options_weights_groups_zero() {
847 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
848 }
849
850 #[test]
851 #[should_panic = "offset groups must be non-zero"]
852 fn deform_conv_options_offset_groups_zero() {
853 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
854 }
855
856 #[test]
857 #[should_panic = "stride must be non-zero"]
858 fn unfold_options_stride_zero() {
859 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
860 }
861
862 #[test]
863 #[should_panic = "dilation must be non-zero"]
864 fn unfold_options_dilation_zero() {
865 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
866 }
867}