1use core::num::NonZeroUsize;
2
3use super::{conv, pool, unfold::unfold4d_using_conv2d};
4use crate::{
5 backend::Backend,
6 ops::{FloatTensor, IntTensor},
7 Shape, TensorMetadata,
8};
9
10#[derive(new)]
12pub struct Conv2dBackward<B: Backend> {
13 pub x_grad: FloatTensor<B>,
15
16 pub weights_grad: FloatTensor<B>,
18
19 pub bias_grad: Option<FloatTensor<B>>,
21}
22
23#[derive(new)]
25pub struct DeformConv2dBackward<B: Backend> {
26 pub x_grad: FloatTensor<B>,
28
29 pub offset_grad: FloatTensor<B>,
31
32 pub weight_grad: FloatTensor<B>,
34
35 pub mask_grad: Option<FloatTensor<B>>,
37
38 pub bias_grad: Option<FloatTensor<B>>,
40}
41
42#[derive(new)]
44pub struct Conv3dBackward<B: Backend> {
45 pub x_grad: FloatTensor<B>,
47
48 pub weights_grad: FloatTensor<B>,
50
51 pub bias_grad: Option<FloatTensor<B>>,
53}
54
55#[derive(new)]
57pub struct MaxPool1dBackward<B: Backend> {
58 pub x_grad: FloatTensor<B>,
60}
61
62#[derive(new)]
64pub struct MaxPool1dWithIndices<B: Backend> {
65 pub output: FloatTensor<B>,
67
68 pub indices: IntTensor<B>,
70}
71
72#[derive(new)]
74pub struct MaxPool2dBackward<B: Backend> {
75 pub x_grad: FloatTensor<B>,
77}
78
79#[derive(new)]
81pub struct MaxPool2dWithIndices<B: Backend> {
82 pub output: FloatTensor<B>,
84
85 pub indices: IntTensor<B>,
87}
88
89pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
92 NonZeroUsize::new(value).expect(msg);
93 value
94}
95
96#[derive(Debug, Clone, Hash, PartialEq, Eq)]
98pub struct ConvOptions<const N: usize> {
99 pub stride: [usize; N],
101
102 pub padding: [usize; N],
104
105 pub dilation: [usize; N],
107
108 pub groups: usize,
110}
111
112impl<const N: usize> ConvOptions<N> {
113 pub fn new(
115 stride: [usize; N],
116 padding: [usize; N],
117 dilation: [usize; N],
118 groups: usize,
119 ) -> Self {
120 Self {
121 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
122 padding,
123 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
124 groups: check_nonzero(groups, "groups must be non-zero"),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Hash, PartialEq, Eq)]
131pub struct DeformConvOptions<const N: usize> {
132 pub stride: [usize; N],
134
135 pub padding: [usize; N],
137
138 pub dilation: [usize; N],
140
141 pub weight_groups: usize,
143
144 pub offset_groups: usize,
146}
147
148impl<const N: usize> DeformConvOptions<N> {
149 pub fn new(
151 stride: [usize; N],
152 padding: [usize; N],
153 dilation: [usize; N],
154 weight_groups: usize,
155 offset_groups: usize,
156 ) -> Self {
157 Self {
158 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
159 padding,
160 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
161 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
162 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
163 }
164 }
165}
166
167#[derive(Debug, Clone, Hash, PartialEq, Eq)]
169pub struct ConvTransposeOptions<const N: usize> {
170 pub stride: [usize; N],
172
173 pub padding: [usize; N],
175
176 pub padding_out: [usize; N],
178
179 pub dilation: [usize; N],
181
182 pub groups: usize,
184}
185
186impl<const N: usize> ConvTransposeOptions<N> {
187 pub fn new(
189 stride: [usize; N],
190 padding: [usize; N],
191 padding_out: [usize; N],
192 dilation: [usize; N],
193 groups: usize,
194 ) -> Self {
195 Self {
196 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
197 padding,
198 padding_out,
199 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
200 groups: check_nonzero(groups, "groups must be non-zero"),
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct UnfoldOptions {
208 pub stride: [usize; 2],
211
212 pub padding: [usize; 2],
214
215 pub dilation: [usize; 2],
217}
218
219impl UnfoldOptions {
220 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
222 Self {
223 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
224 padding,
225 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
226 }
227 }
228}
229
230#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
232pub enum InterpolateMode {
233 Nearest,
236
237 Bilinear,
240
241 Bicubic,
244}
245
246#[derive(new, Debug, Clone)]
248pub struct InterpolateOptions {
249 pub mode: InterpolateMode,
251}
252
253#[derive(new)]
255pub struct InterpolateBackward<B: Backend> {
256 pub x_grad: FloatTensor<B>,
258}
259
260pub trait ModuleOps<B: Backend> {
262 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
273 let [batch_size, seq_length] = indices.shape().dims();
274 let [_, d_model] = weights.shape().dims();
275
276 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
277 let output = B::float_select(weights, 0, indices);
278
279 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
280 }
281
282 fn embedding_backward(
294 weights: FloatTensor<B>,
295 output_grad: FloatTensor<B>,
296 indices: IntTensor<B>,
297 ) -> FloatTensor<B> {
298 let [batch_size, seq_length] = indices.shape().dims();
299 let [n_embeddings, d_model] = weights.shape().dims();
300 let device = B::float_device(&weights);
301
302 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
303 let output_grad =
304 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
305 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device);
306
307 B::float_select_assign(grad, 0, indices, output_grad)
308 }
309 fn conv1d(
317 x: FloatTensor<B>,
318 weight: FloatTensor<B>,
319 bias: Option<FloatTensor<B>>,
320 options: ConvOptions<1>,
321 ) -> FloatTensor<B> {
322 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
323 }
324 fn conv1d_x_backward(
326 x: FloatTensor<B>,
327 weight: FloatTensor<B>,
328 output_grad: FloatTensor<B>,
329 options: ConvOptions<1>,
330 ) -> FloatTensor<B> {
331 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
332 }
333 fn conv1d_weight_backward(
335 x: FloatTensor<B>,
336 weight: FloatTensor<B>,
337 output_grad: FloatTensor<B>,
338 options: ConvOptions<1>,
339 ) -> FloatTensor<B> {
340 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
341 }
342 fn conv1d_bias_backward(
344 x: FloatTensor<B>,
345 bias: FloatTensor<B>,
346 output_grad: FloatTensor<B>,
347 ) -> FloatTensor<B> {
348 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
349 }
350 fn conv2d(
358 x: FloatTensor<B>,
359 weight: FloatTensor<B>,
360 bias: Option<FloatTensor<B>>,
361 options: ConvOptions<2>,
362 ) -> FloatTensor<B>;
363 fn conv2d_x_backward(
365 x: FloatTensor<B>,
366 weight: FloatTensor<B>,
367 output_grad: FloatTensor<B>,
368 options: ConvOptions<2>,
369 ) -> FloatTensor<B> {
370 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
371 }
372 fn conv2d_weight_backward(
374 x: FloatTensor<B>,
375 weight: FloatTensor<B>,
376 output_grad: FloatTensor<B>,
377 options: ConvOptions<2>,
378 ) -> FloatTensor<B> {
379 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
380 }
381 fn conv2d_bias_backward(
383 x: FloatTensor<B>,
384 weight: FloatTensor<B>,
385 bias: FloatTensor<B>,
386 output_grad: FloatTensor<B>,
387 ) -> FloatTensor<B> {
388 conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
389 }
390
391 fn deform_conv2d(
399 x: FloatTensor<B>,
400 offset: FloatTensor<B>,
401 weight: FloatTensor<B>,
402 mask: Option<FloatTensor<B>>,
403 bias: Option<FloatTensor<B>>,
404 options: DeformConvOptions<2>,
405 ) -> FloatTensor<B>;
406 fn deform_conv2d_backward(
408 x: FloatTensor<B>,
409 offset: FloatTensor<B>,
410 weight: FloatTensor<B>,
411 mask: Option<FloatTensor<B>>,
412 bias: Option<FloatTensor<B>>,
413 output_grad: FloatTensor<B>,
414 options: DeformConvOptions<2>,
415 ) -> DeformConv2dBackward<B>;
416
417 fn conv3d(
425 x: FloatTensor<B>,
426 weight: FloatTensor<B>,
427 bias: Option<FloatTensor<B>>,
428 options: ConvOptions<3>,
429 ) -> FloatTensor<B>;
430 fn conv3d_x_backward(
432 x: FloatTensor<B>,
433 weight: FloatTensor<B>,
434 output_grad: FloatTensor<B>,
435 options: ConvOptions<3>,
436 ) -> FloatTensor<B> {
437 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
438 }
439 fn conv3d_weight_backward(
441 x: FloatTensor<B>,
442 weight: FloatTensor<B>,
443 output_grad: FloatTensor<B>,
444 options: ConvOptions<3>,
445 ) -> FloatTensor<B> {
446 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
447 }
448 fn conv3d_bias_backward(
450 x: FloatTensor<B>,
451 weight: FloatTensor<B>,
452 bias: FloatTensor<B>,
453 output_grad: FloatTensor<B>,
454 ) -> FloatTensor<B> {
455 conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
456 }
457 fn conv_transpose1d(
465 x: FloatTensor<B>,
466 weight: FloatTensor<B>,
467 bias: Option<FloatTensor<B>>,
468 options: ConvTransposeOptions<1>,
469 ) -> FloatTensor<B> {
470 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
471 }
472 fn conv_transpose1d_x_backward(
474 weight: FloatTensor<B>,
475 output_grad: FloatTensor<B>,
476 options: ConvTransposeOptions<1>,
477 ) -> FloatTensor<B> {
478 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
479 }
480 fn conv_transpose1d_weight_backward(
482 x: FloatTensor<B>,
483 weight: FloatTensor<B>,
484 output_grad: FloatTensor<B>,
485 options: ConvTransposeOptions<1>,
486 ) -> FloatTensor<B> {
487 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
488 }
489 fn conv_transpose1d_bias_backward(
491 x: FloatTensor<B>,
492 bias: FloatTensor<B>,
493 output_grad: FloatTensor<B>,
494 ) -> FloatTensor<B> {
495 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
496 }
497
498 fn conv_transpose2d(
506 x: FloatTensor<B>,
507 weight: FloatTensor<B>,
508 bias: Option<FloatTensor<B>>,
509 options: ConvTransposeOptions<2>,
510 ) -> FloatTensor<B>;
511 fn conv_transpose2d_x_backward(
513 weight: FloatTensor<B>,
514 output_grad: FloatTensor<B>,
515 options: ConvTransposeOptions<2>,
516 ) -> FloatTensor<B> {
517 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
518 }
519 fn conv_transpose2d_weight_backward(
521 x: FloatTensor<B>,
522 weight: FloatTensor<B>,
523 output_grad: FloatTensor<B>,
524 options: ConvTransposeOptions<2>,
525 ) -> FloatTensor<B> {
526 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
527 }
528 fn conv_transpose2d_bias_backward(
530 x: FloatTensor<B>,
531 bias: FloatTensor<B>,
532 output_grad: FloatTensor<B>,
533 ) -> FloatTensor<B> {
534 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
535 }
536
537 fn conv_transpose3d(
545 x: FloatTensor<B>,
546 weight: FloatTensor<B>,
547 bias: Option<FloatTensor<B>>,
548 options: ConvTransposeOptions<3>,
549 ) -> FloatTensor<B>;
550 fn conv_transpose3d_x_backward(
552 weight: FloatTensor<B>,
553 output_grad: FloatTensor<B>,
554 options: ConvTransposeOptions<3>,
555 ) -> FloatTensor<B> {
556 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
557 }
558 fn conv_transpose3d_weight_backward(
560 x: FloatTensor<B>,
561 weight: FloatTensor<B>,
562 output_grad: FloatTensor<B>,
563 options: ConvTransposeOptions<3>,
564 ) -> FloatTensor<B> {
565 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
566 }
567 fn conv_transpose3d_bias_backward(
569 x: FloatTensor<B>,
570 bias: FloatTensor<B>,
571 output_grad: FloatTensor<B>,
572 ) -> FloatTensor<B> {
573 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
574 }
575
576 fn unfold4d(
583 x: FloatTensor<B>,
584 kernel_size: [usize; 2],
585 options: UnfoldOptions,
586 ) -> FloatTensor<B> {
587 unfold4d_using_conv2d::<B>(x, kernel_size, options)
588 }
589
590 fn avg_pool1d(
596 x: FloatTensor<B>,
597 kernel_size: usize,
598 stride: usize,
599 padding: usize,
600 count_include_pad: bool,
601 ) -> FloatTensor<B> {
602 pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad)
603 }
604 fn avg_pool1d_backward(
606 x: FloatTensor<B>,
607 grad: FloatTensor<B>,
608 kernel_size: usize,
609 stride: usize,
610 padding: usize,
611 count_include_pad: bool,
612 ) -> FloatTensor<B> {
613 pool::avg_pool1d_backward_from_2d::<B>(
614 x,
615 grad,
616 kernel_size,
617 stride,
618 padding,
619 count_include_pad,
620 )
621 }
622 fn avg_pool2d(
628 x: FloatTensor<B>,
629 kernel_size: [usize; 2],
630 stride: [usize; 2],
631 padding: [usize; 2],
632 count_include_pad: bool,
633 ) -> FloatTensor<B>;
634 fn avg_pool2d_backward(
636 x: FloatTensor<B>,
637 grad: FloatTensor<B>,
638 kernel_size: [usize; 2],
639 stride: [usize; 2],
640 padding: [usize; 2],
641 count_include_pad: bool,
642 ) -> FloatTensor<B>;
643 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
649 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
651 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
657 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
658 }
659 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
661 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
662 }
663 fn max_pool1d(
669 x: FloatTensor<B>,
670 kernel_size: usize,
671 stride: usize,
672 padding: usize,
673 dilation: usize,
674 ) -> FloatTensor<B> {
675 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
676 }
677
678 fn max_pool1d_with_indices(
684 x: FloatTensor<B>,
685 kernel_size: usize,
686 stride: usize,
687 padding: usize,
688 dilation: usize,
689 ) -> MaxPool1dWithIndices<B> {
690 pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
691 }
692 fn max_pool1d_with_indices_backward(
694 x: FloatTensor<B>,
695 kernel_size: usize,
696 stride: usize,
697 padding: usize,
698 dilation: usize,
699 output_grad: FloatTensor<B>,
700 indices: IntTensor<B>,
701 ) -> MaxPool1dBackward<B> {
702 pool::max_pool1d_with_indices_backward_from_2d::<B>(
703 x,
704 kernel_size,
705 stride,
706 padding,
707 dilation,
708 output_grad,
709 indices,
710 )
711 }
712
713 fn max_pool2d(
719 x: FloatTensor<B>,
720 kernel_size: [usize; 2],
721 stride: [usize; 2],
722 padding: [usize; 2],
723 dilation: [usize; 2],
724 ) -> FloatTensor<B>;
725
726 fn max_pool2d_with_indices(
732 x: FloatTensor<B>,
733 kernel_size: [usize; 2],
734 stride: [usize; 2],
735 padding: [usize; 2],
736 dilation: [usize; 2],
737 ) -> MaxPool2dWithIndices<B>;
738 fn max_pool2d_with_indices_backward(
740 x: FloatTensor<B>,
741 kernel_size: [usize; 2],
742 stride: [usize; 2],
743 padding: [usize; 2],
744 dilation: [usize; 2],
745 output_grad: FloatTensor<B>,
746 indices: IntTensor<B>,
747 ) -> MaxPool2dBackward<B>;
748
749 fn interpolate(
755 x: FloatTensor<B>,
756 output_size: [usize; 2],
757 options: InterpolateOptions,
758 ) -> FloatTensor<B>;
759
760 fn interpolate_backward(
762 x: FloatTensor<B>,
763 grad: FloatTensor<B>,
764 output_size: [usize; 2],
765 options: InterpolateOptions,
766 ) -> FloatTensor<B>;
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[test]
774 #[should_panic = "stride must be non-zero"]
775 fn conv_options_stride_zero() {
776 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
777 }
778
779 #[test]
780 #[should_panic = "dilation must be non-zero"]
781 fn conv_options_dilation_zero() {
782 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
783 }
784
785 #[test]
786 #[should_panic = "groups must be non-zero"]
787 fn conv_options_groups_zero() {
788 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
789 }
790
791 #[test]
792 #[should_panic = "stride must be non-zero"]
793 fn conv_transpose_options_stride_zero() {
794 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
795 }
796
797 #[test]
798 #[should_panic = "dilation must be non-zero"]
799 fn conv_transpose_options_dilation_zero() {
800 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
801 }
802
803 #[test]
804 #[should_panic = "groups must be non-zero"]
805 fn conv_transpose_options_groups_zero() {
806 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
807 }
808
809 #[test]
810 #[should_panic = "stride must be non-zero"]
811 fn deform_conv_options_stride_zero() {
812 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
813 }
814
815 #[test]
816 #[should_panic = "dilation must be non-zero"]
817 fn deform_conv_options_dilation_zero() {
818 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
819 }
820
821 #[test]
822 #[should_panic = "weight groups must be non-zero"]
823 fn deform_conv_options_weights_groups_zero() {
824 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
825 }
826
827 #[test]
828 #[should_panic = "offset groups must be non-zero"]
829 fn deform_conv_options_offset_groups_zero() {
830 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
831 }
832
833 #[test]
834 #[should_panic = "stride must be non-zero"]
835 fn unfold_options_stride_zero() {
836 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
837 }
838
839 #[test]
840 #[should_panic = "dilation must be non-zero"]
841 fn unfold_options_dilation_zero() {
842 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
843 }
844}