1use super::{conv, ctc, linear, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
4use crate::{Backend, ElementConversion, TensorMetadata};
5use burn_std::Shape;
6use core::num::NonZeroUsize;
7
8#[derive(new)]
10pub struct Conv2dBackward<B: Backend> {
11 pub x_grad: FloatTensor<B>,
13
14 pub weights_grad: FloatTensor<B>,
16
17 pub bias_grad: Option<FloatTensor<B>>,
19}
20
21#[derive(new)]
23pub struct DeformConv2dBackward<B: Backend> {
24 pub x_grad: FloatTensor<B>,
26
27 pub offset_grad: FloatTensor<B>,
29
30 pub weight_grad: FloatTensor<B>,
32
33 pub mask_grad: Option<FloatTensor<B>>,
35
36 pub bias_grad: Option<FloatTensor<B>>,
38}
39
40#[derive(new)]
42pub struct Conv3dBackward<B: Backend> {
43 pub x_grad: FloatTensor<B>,
45
46 pub weights_grad: FloatTensor<B>,
48
49 pub bias_grad: Option<FloatTensor<B>>,
51}
52
53#[derive(new)]
55pub struct MaxPool1dBackward<B: Backend> {
56 pub x_grad: FloatTensor<B>,
58}
59
60#[derive(new)]
62pub struct MaxPool1dWithIndices<B: Backend> {
63 pub output: FloatTensor<B>,
65
66 pub indices: IntTensor<B>,
68}
69
70#[derive(new)]
72pub struct MaxPool2dBackward<B: Backend> {
73 pub x_grad: FloatTensor<B>,
75}
76
77#[derive(new)]
79pub struct MaxPool2dWithIndices<B: Backend> {
80 pub output: FloatTensor<B>,
82
83 pub indices: IntTensor<B>,
85}
86
87pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
90 NonZeroUsize::new(value).expect(msg);
91 value
92}
93
94#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96pub struct ConvOptions<const N: usize> {
97 pub stride: [usize; N],
99
100 pub padding: [usize; N],
102
103 pub dilation: [usize; N],
105
106 pub groups: usize,
108}
109
110impl<const N: usize> ConvOptions<N> {
111 pub fn new(
113 stride: [usize; N],
114 padding: [usize; N],
115 dilation: [usize; N],
116 groups: usize,
117 ) -> Self {
118 Self {
119 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
120 padding,
121 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
122 groups: check_nonzero(groups, "groups must be non-zero"),
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
136pub struct PaddedConvOptions<const N: usize> {
137 pub options: ConvOptions<N>,
139 pub padding_end: Option<[usize; N]>,
143}
144
145impl<const N: usize> PaddedConvOptions<N> {
146 pub fn asymmetric(
151 stride: [usize; N],
152 padding_start: [usize; N],
153 padding_end: [usize; N],
154 dilation: [usize; N],
155 groups: usize,
156 ) -> Self {
157 let options = ConvOptions::new(stride, padding_start, dilation, groups);
158 if padding_start == padding_end {
159 Self {
160 options,
161 padding_end: None,
162 }
163 } else {
164 Self {
165 options,
166 padding_end: Some(padding_end),
167 }
168 }
169 }
170
171 pub fn is_asymmetric(&self) -> bool {
173 self.padding_end.is_some()
174 }
175}
176
177impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
178 fn from(options: ConvOptions<N>) -> Self {
179 Self {
180 options,
181 padding_end: None,
182 }
183 }
184}
185
186#[derive(Debug, Clone, Hash, PartialEq, Eq)]
188pub struct DeformConvOptions<const N: usize> {
189 pub stride: [usize; N],
191
192 pub padding: [usize; N],
194
195 pub dilation: [usize; N],
197
198 pub weight_groups: usize,
200
201 pub offset_groups: usize,
203}
204
205impl<const N: usize> DeformConvOptions<N> {
206 pub fn new(
208 stride: [usize; N],
209 padding: [usize; N],
210 dilation: [usize; N],
211 weight_groups: usize,
212 offset_groups: usize,
213 ) -> Self {
214 Self {
215 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
216 padding,
217 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
218 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
219 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
220 }
221 }
222}
223
224#[derive(Debug, Clone, Hash, PartialEq, Eq)]
226pub struct ConvTransposeOptions<const N: usize> {
227 pub stride: [usize; N],
229
230 pub padding: [usize; N],
232
233 pub padding_out: [usize; N],
235
236 pub dilation: [usize; N],
238
239 pub groups: usize,
241}
242
243impl<const N: usize> ConvTransposeOptions<N> {
244 pub fn new(
246 stride: [usize; N],
247 padding: [usize; N],
248 padding_out: [usize; N],
249 dilation: [usize; N],
250 groups: usize,
251 ) -> Self {
252 Self {
253 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
254 padding,
255 padding_out,
256 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
257 groups: check_nonzero(groups, "groups must be non-zero"),
258 }
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct UnfoldOptions {
265 pub stride: [usize; 2],
268
269 pub padding: [usize; 2],
271
272 pub dilation: [usize; 2],
274}
275
276impl UnfoldOptions {
277 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
279 Self {
280 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
281 padding,
282 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
283 }
284 }
285}
286
287#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
289pub enum InterpolateMode {
290 Nearest,
293
294 Bilinear,
297
298 Bicubic,
301
302 Lanczos3,
305}
306
307#[derive(Debug, Clone)]
309pub struct InterpolateOptions {
310 pub mode: InterpolateMode,
312 pub align_corners: bool,
315}
316
317impl InterpolateOptions {
318 pub fn new(mode: InterpolateMode) -> Self {
321 Self {
322 mode,
323 align_corners: true,
324 }
325 }
326
327 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
329 self.align_corners = align_corners;
330 self
331 }
332}
333
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
338pub enum GridSamplePaddingMode {
339 #[default]
341 Zeros,
342 Border,
344 Reflection,
346}
347
348#[derive(Debug, Clone)]
350pub struct GridSampleOptions {
351 pub mode: InterpolateMode,
353 pub padding_mode: GridSamplePaddingMode,
355 pub align_corners: bool,
359}
360
361impl Default for GridSampleOptions {
362 fn default() -> Self {
363 Self {
364 mode: InterpolateMode::Bilinear,
365 padding_mode: GridSamplePaddingMode::Zeros,
366 align_corners: false,
367 }
368 }
369}
370
371impl From<InterpolateMode> for GridSampleOptions {
372 fn from(value: InterpolateMode) -> Self {
373 GridSampleOptions::new(value)
374 }
375}
376
377impl GridSampleOptions {
378 pub fn new(mode: InterpolateMode) -> Self {
382 Self {
383 mode,
384 ..Default::default()
385 }
386 }
387
388 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
390 self.padding_mode = padding_mode;
391 self
392 }
393
394 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
396 self.align_corners = align_corners;
397 self
398 }
399}
400
401#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
412pub enum PadMode {
413 Constant(f32),
419
420 Reflect,
428
429 Edge,
435}
436
437impl Default for PadMode {
438 fn default() -> Self {
439 PadMode::Constant(0.0)
440 }
441}
442
443impl<E: ElementConversion> From<E> for PadMode {
444 fn from(value: E) -> Self {
445 PadMode::Constant(value.elem())
446 }
447}
448
449#[derive(new)]
451pub struct InterpolateBackward<B: Backend> {
452 pub x_grad: FloatTensor<B>,
454}
455
456#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]
458pub struct AttentionModuleOptions {
459 pub scale: Option<f64>,
461
462 pub softcap: Option<f64>,
465
466 pub is_causal: bool,
471}
472
473pub trait ModuleOps<B: Backend> {
475 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
486 let [batch_size, seq_length] = indices.shape().dims();
487 let [_, d_model] = weights.shape().dims();
488
489 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
490 let output = B::float_select(weights, 0, indices);
491
492 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
493 }
494
495 fn embedding_backward(
507 weights: FloatTensor<B>,
508 output_grad: FloatTensor<B>,
509 indices: IntTensor<B>,
510 ) -> FloatTensor<B> {
511 let [batch_size, seq_length] = indices.shape().dims();
512 let [n_embeddings, d_model] = weights.shape().dims();
513 let device = B::float_device(&weights);
514 let dtype = output_grad.dtype();
515
516 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
517 let output_grad =
518 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
519 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
520
521 B::float_select_add(grad, 0, indices, output_grad)
522 }
523
524 fn linear(
532 x: FloatTensor<B>,
533 weight: FloatTensor<B>,
534 bias: Option<FloatTensor<B>>,
535 ) -> FloatTensor<B> {
536 linear::linear::<B>(x, weight, bias)
537 }
538 fn linear_x_backward(weight: FloatTensor<B>, output_grad: FloatTensor<B>) -> FloatTensor<B> {
540 linear::linear_x_backward::<B>(weight, output_grad)
541 }
542 fn linear_weight_backward(x: FloatTensor<B>, output_grad: FloatTensor<B>) -> FloatTensor<B> {
544 linear::linear_weight_backward::<B>(x, output_grad)
545 }
546 fn linear_bias_backward(output_grad: FloatTensor<B>) -> FloatTensor<B> {
548 linear::linear_bias_backward::<B>(output_grad)
549 }
550
551 fn conv1d(
559 x: FloatTensor<B>,
560 weight: FloatTensor<B>,
561 bias: Option<FloatTensor<B>>,
562 options: ConvOptions<1>,
563 ) -> FloatTensor<B> {
564 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
565 }
566 fn conv1d_x_backward(
568 x: FloatTensor<B>,
569 weight: FloatTensor<B>,
570 output_grad: FloatTensor<B>,
571 options: ConvOptions<1>,
572 ) -> FloatTensor<B> {
573 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
574 }
575 fn conv1d_weight_backward(
577 x: FloatTensor<B>,
578 weight: FloatTensor<B>,
579 output_grad: FloatTensor<B>,
580 options: ConvOptions<1>,
581 ) -> FloatTensor<B> {
582 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
583 }
584 fn conv1d_bias_backward(
586 x: FloatTensor<B>,
587 bias: FloatTensor<B>,
588 output_grad: FloatTensor<B>,
589 ) -> FloatTensor<B> {
590 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
591 }
592 fn conv2d(
600 x: FloatTensor<B>,
601 weight: FloatTensor<B>,
602 bias: Option<FloatTensor<B>>,
603 options: ConvOptions<2>,
604 ) -> FloatTensor<B>;
605 fn conv2d_x_backward(
607 x: FloatTensor<B>,
608 weight: FloatTensor<B>,
609 output_grad: FloatTensor<B>,
610 options: ConvOptions<2>,
611 ) -> FloatTensor<B> {
612 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
613 }
614 fn conv2d_weight_backward(
616 x: FloatTensor<B>,
617 weight: FloatTensor<B>,
618 output_grad: FloatTensor<B>,
619 options: ConvOptions<2>,
620 ) -> FloatTensor<B> {
621 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
622 }
623 fn conv2d_bias_backward(
625 x: FloatTensor<B>,
626 bias: FloatTensor<B>,
627 output_grad: FloatTensor<B>,
628 ) -> FloatTensor<B> {
629 conv::conv2d_bias_backward::<B>(x, bias, output_grad)
630 }
631
632 fn deform_conv2d(
640 x: FloatTensor<B>,
641 offset: FloatTensor<B>,
642 weight: FloatTensor<B>,
643 mask: Option<FloatTensor<B>>,
644 bias: Option<FloatTensor<B>>,
645 options: DeformConvOptions<2>,
646 ) -> FloatTensor<B>;
647 fn deform_conv2d_backward(
649 x: FloatTensor<B>,
650 offset: FloatTensor<B>,
651 weight: FloatTensor<B>,
652 mask: Option<FloatTensor<B>>,
653 bias: Option<FloatTensor<B>>,
654 output_grad: FloatTensor<B>,
655 options: DeformConvOptions<2>,
656 ) -> DeformConv2dBackward<B>;
657
658 fn conv3d(
666 x: FloatTensor<B>,
667 weight: FloatTensor<B>,
668 bias: Option<FloatTensor<B>>,
669 options: ConvOptions<3>,
670 ) -> FloatTensor<B>;
671 fn conv3d_x_backward(
673 x: FloatTensor<B>,
674 weight: FloatTensor<B>,
675 output_grad: FloatTensor<B>,
676 options: ConvOptions<3>,
677 ) -> FloatTensor<B> {
678 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
679 }
680 fn conv3d_weight_backward(
682 x: FloatTensor<B>,
683 weight: FloatTensor<B>,
684 output_grad: FloatTensor<B>,
685 options: ConvOptions<3>,
686 ) -> FloatTensor<B> {
687 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
688 }
689 fn conv3d_bias_backward(
691 x: FloatTensor<B>,
692 bias: FloatTensor<B>,
693 output_grad: FloatTensor<B>,
694 ) -> FloatTensor<B> {
695 conv::conv3d_bias_backward::<B>(x, bias, output_grad)
696 }
697 fn conv_transpose1d(
705 x: FloatTensor<B>,
706 weight: FloatTensor<B>,
707 bias: Option<FloatTensor<B>>,
708 options: ConvTransposeOptions<1>,
709 ) -> FloatTensor<B> {
710 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
711 }
712 fn conv_transpose1d_x_backward(
714 weight: FloatTensor<B>,
715 output_grad: FloatTensor<B>,
716 options: ConvTransposeOptions<1>,
717 ) -> FloatTensor<B> {
718 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
719 }
720 fn conv_transpose1d_weight_backward(
722 x: FloatTensor<B>,
723 weight: FloatTensor<B>,
724 output_grad: FloatTensor<B>,
725 options: ConvTransposeOptions<1>,
726 ) -> FloatTensor<B> {
727 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
728 }
729 fn conv_transpose1d_bias_backward(
731 x: FloatTensor<B>,
732 bias: FloatTensor<B>,
733 output_grad: FloatTensor<B>,
734 ) -> FloatTensor<B> {
735 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
736 }
737
738 fn conv_transpose2d(
746 x: FloatTensor<B>,
747 weight: FloatTensor<B>,
748 bias: Option<FloatTensor<B>>,
749 options: ConvTransposeOptions<2>,
750 ) -> FloatTensor<B>;
751 fn conv_transpose2d_x_backward(
753 weight: FloatTensor<B>,
754 output_grad: FloatTensor<B>,
755 options: ConvTransposeOptions<2>,
756 ) -> FloatTensor<B> {
757 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
758 }
759 fn conv_transpose2d_weight_backward(
761 x: FloatTensor<B>,
762 weight: FloatTensor<B>,
763 output_grad: FloatTensor<B>,
764 options: ConvTransposeOptions<2>,
765 ) -> FloatTensor<B> {
766 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
767 }
768 fn conv_transpose2d_bias_backward(
770 x: FloatTensor<B>,
771 bias: FloatTensor<B>,
772 output_grad: FloatTensor<B>,
773 ) -> FloatTensor<B> {
774 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
775 }
776
777 fn conv_transpose3d(
785 x: FloatTensor<B>,
786 weight: FloatTensor<B>,
787 bias: Option<FloatTensor<B>>,
788 options: ConvTransposeOptions<3>,
789 ) -> FloatTensor<B>;
790 fn conv_transpose3d_x_backward(
792 weight: FloatTensor<B>,
793 output_grad: FloatTensor<B>,
794 options: ConvTransposeOptions<3>,
795 ) -> FloatTensor<B> {
796 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
797 }
798 fn conv_transpose3d_weight_backward(
800 x: FloatTensor<B>,
801 weight: FloatTensor<B>,
802 output_grad: FloatTensor<B>,
803 options: ConvTransposeOptions<3>,
804 ) -> FloatTensor<B> {
805 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
806 }
807 fn conv_transpose3d_bias_backward(
809 x: FloatTensor<B>,
810 bias: FloatTensor<B>,
811 output_grad: FloatTensor<B>,
812 ) -> FloatTensor<B> {
813 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
814 }
815
816 fn unfold4d(
823 x: FloatTensor<B>,
824 kernel_size: [usize; 2],
825 options: UnfoldOptions,
826 ) -> FloatTensor<B> {
827 if options.padding == [0, 0] && options.dilation == [1, 1] {
828 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
829 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
830
831 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
834 let shape = blocks.shape();
835
836 B::float_reshape(
839 blocks,
840 [
841 shape[0],
842 shape[1] * shape[2] * shape[3],
843 shape[4] * shape[5],
844 ]
845 .into(),
846 )
847 } else {
848 unfold4d_using_conv2d::<B>(x, kernel_size, options)
849 }
850 }
851
852 fn avg_pool1d(
858 x: FloatTensor<B>,
859 kernel_size: usize,
860 stride: usize,
861 padding: usize,
862 count_include_pad: bool,
863 ceil_mode: bool,
864 ) -> FloatTensor<B> {
865 pool::avg_pool1d_from_2d::<B>(
866 x,
867 kernel_size,
868 stride,
869 padding,
870 count_include_pad,
871 ceil_mode,
872 )
873 }
874 fn avg_pool1d_backward(
876 x: FloatTensor<B>,
877 grad: FloatTensor<B>,
878 kernel_size: usize,
879 stride: usize,
880 padding: usize,
881 count_include_pad: bool,
882 ceil_mode: bool,
883 ) -> FloatTensor<B> {
884 pool::avg_pool1d_backward_from_2d::<B>(
885 x,
886 grad,
887 kernel_size,
888 stride,
889 padding,
890 count_include_pad,
891 ceil_mode,
892 )
893 }
894 fn avg_pool2d(
900 x: FloatTensor<B>,
901 kernel_size: [usize; 2],
902 stride: [usize; 2],
903 padding: [usize; 2],
904 count_include_pad: bool,
905 ceil_mode: bool,
906 ) -> FloatTensor<B>;
907 fn avg_pool2d_backward(
909 x: FloatTensor<B>,
910 grad: FloatTensor<B>,
911 kernel_size: [usize; 2],
912 stride: [usize; 2],
913 padding: [usize; 2],
914 count_include_pad: bool,
915 ceil_mode: bool,
916 ) -> FloatTensor<B>;
917 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
923 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
925 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
931 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
932 }
933 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
935 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
936 }
937 fn max_pool1d(
943 x: FloatTensor<B>,
944 kernel_size: usize,
945 stride: usize,
946 padding: usize,
947 dilation: usize,
948 ceil_mode: bool,
949 ) -> FloatTensor<B> {
950 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
951 }
952
953 fn max_pool1d_with_indices(
959 x: FloatTensor<B>,
960 kernel_size: usize,
961 stride: usize,
962 padding: usize,
963 dilation: usize,
964 ceil_mode: bool,
965 ) -> MaxPool1dWithIndices<B> {
966 pool::max_pool1d_with_indices_from_2d::<B>(
967 x,
968 kernel_size,
969 stride,
970 padding,
971 dilation,
972 ceil_mode,
973 )
974 }
975 #[allow(clippy::too_many_arguments)]
977 fn max_pool1d_with_indices_backward(
978 x: FloatTensor<B>,
979 kernel_size: usize,
980 stride: usize,
981 padding: usize,
982 dilation: usize,
983 ceil_mode: bool,
984 output_grad: FloatTensor<B>,
985 indices: IntTensor<B>,
986 ) -> MaxPool1dBackward<B> {
987 pool::max_pool1d_with_indices_backward_from_2d::<B>(
988 x,
989 kernel_size,
990 stride,
991 padding,
992 dilation,
993 ceil_mode,
994 output_grad,
995 indices,
996 )
997 }
998
999 fn max_pool2d(
1005 x: FloatTensor<B>,
1006 kernel_size: [usize; 2],
1007 stride: [usize; 2],
1008 padding: [usize; 2],
1009 dilation: [usize; 2],
1010 ceil_mode: bool,
1011 ) -> FloatTensor<B>;
1012
1013 fn max_pool2d_with_indices(
1019 x: FloatTensor<B>,
1020 kernel_size: [usize; 2],
1021 stride: [usize; 2],
1022 padding: [usize; 2],
1023 dilation: [usize; 2],
1024 ceil_mode: bool,
1025 ) -> MaxPool2dWithIndices<B>;
1026 #[allow(clippy::too_many_arguments)]
1028 fn max_pool2d_with_indices_backward(
1029 x: FloatTensor<B>,
1030 kernel_size: [usize; 2],
1031 stride: [usize; 2],
1032 padding: [usize; 2],
1033 dilation: [usize; 2],
1034 ceil_mode: bool,
1035 output_grad: FloatTensor<B>,
1036 indices: IntTensor<B>,
1037 ) -> MaxPool2dBackward<B>;
1038
1039 fn interpolate(
1045 x: FloatTensor<B>,
1046 output_size: [usize; 2],
1047 options: InterpolateOptions,
1048 ) -> FloatTensor<B>;
1049
1050 fn interpolate_backward(
1052 x: FloatTensor<B>,
1053 grad: FloatTensor<B>,
1054 output_size: [usize; 2],
1055 options: InterpolateOptions,
1056 ) -> FloatTensor<B>;
1057
1058 fn attention(
1080 query: FloatTensor<B>,
1081 key: FloatTensor<B>,
1082 value: FloatTensor<B>,
1083 mask: Option<BoolTensor<B>>,
1084 attn_bias: Option<FloatTensor<B>>,
1085 options: AttentionModuleOptions,
1086 ) -> FloatTensor<B>;
1087
1088 fn layer_norm(
1104 tensor: FloatTensor<B>,
1105 gamma: FloatTensor<B>,
1106 beta: Option<FloatTensor<B>>,
1107 epsilon: f64,
1108 ) -> FloatTensor<B> {
1109 let shape = tensor.shape();
1110 let rank = shape.num_dims();
1111 let last_dim = rank - 1;
1112 let d_model = shape[last_dim];
1113
1114 let mean = B::float_mean_dim(tensor.clone(), last_dim);
1115 let centered = B::float_sub(tensor, mean);
1116 let var = B::float_mean_dim(B::float_mul(centered.clone(), centered.clone()), last_dim);
1117 let denom = B::float_sqrt(B::float_add_scalar(var, epsilon.into()));
1118 let normalized = B::float_div(centered, denom);
1119
1120 let broadcast_dims: alloc::vec::Vec<usize> = (0..rank)
1121 .map(|i| if i == last_dim { d_model } else { 1 })
1122 .collect();
1123 let gamma_b = B::float_reshape(gamma, Shape::from(broadcast_dims.clone()));
1124 let scaled = B::float_mul(normalized, gamma_b);
1125
1126 match beta {
1127 Some(beta) => {
1128 let beta_b = B::float_reshape(beta, Shape::from(broadcast_dims));
1129 B::float_add(scaled, beta_b)
1130 }
1131 None => scaled,
1132 }
1133 }
1134
1135 fn ctc_loss(
1152 log_probs: FloatTensor<B>,
1153 targets: IntTensor<B>,
1154 input_lengths: IntTensor<B>,
1155 target_lengths: IntTensor<B>,
1156 blank: usize,
1157 ) -> FloatTensor<B> {
1158 ctc::ctc_loss_default::<B>(log_probs, targets, input_lengths, target_lengths, blank)
1159 }
1160
1161 fn has_ctc_loss_backward() -> bool {
1173 false
1174 }
1175
1176 fn ctc_loss_backward(
1196 _log_probs: FloatTensor<B>,
1197 _targets: IntTensor<B>,
1198 _input_lengths: IntTensor<B>,
1199 _target_lengths: IntTensor<B>,
1200 _grad_loss: FloatTensor<B>,
1201 _blank: usize,
1202 ) -> FloatTensor<B> {
1203 unreachable!(
1204 "ctc_loss_backward called on a backend whose has_ctc_loss_backward() returns false"
1205 )
1206 }
1207
1208 fn rfft(
1220 signal: FloatTensor<B>,
1221 dim: usize,
1222 n: Option<usize>,
1223 ) -> (FloatTensor<B>, FloatTensor<B>);
1224
1225 fn irfft(
1233 spectrum_re: FloatTensor<B>,
1234 spectrum_im: FloatTensor<B>,
1235 dim: usize,
1236 n: Option<usize>,
1237 ) -> FloatTensor<B>;
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243
1244 #[test]
1245 #[should_panic = "stride must be non-zero"]
1246 fn conv_options_stride_zero() {
1247 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1248 }
1249
1250 #[test]
1251 #[should_panic = "dilation must be non-zero"]
1252 fn conv_options_dilation_zero() {
1253 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1254 }
1255
1256 #[test]
1257 #[should_panic = "groups must be non-zero"]
1258 fn conv_options_groups_zero() {
1259 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1260 }
1261
1262 #[test]
1263 #[should_panic = "stride must be non-zero"]
1264 fn conv_transpose_options_stride_zero() {
1265 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1266 }
1267
1268 #[test]
1269 #[should_panic = "dilation must be non-zero"]
1270 fn conv_transpose_options_dilation_zero() {
1271 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1272 }
1273
1274 #[test]
1275 #[should_panic = "groups must be non-zero"]
1276 fn conv_transpose_options_groups_zero() {
1277 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1278 }
1279
1280 #[test]
1281 #[should_panic = "stride must be non-zero"]
1282 fn deform_conv_options_stride_zero() {
1283 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1284 }
1285
1286 #[test]
1287 #[should_panic = "dilation must be non-zero"]
1288 fn deform_conv_options_dilation_zero() {
1289 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1290 }
1291
1292 #[test]
1293 #[should_panic = "weight groups must be non-zero"]
1294 fn deform_conv_options_weights_groups_zero() {
1295 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1296 }
1297
1298 #[test]
1299 #[should_panic = "offset groups must be non-zero"]
1300 fn deform_conv_options_offset_groups_zero() {
1301 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1302 }
1303
1304 #[test]
1305 #[should_panic = "stride must be non-zero"]
1306 fn unfold_options_stride_zero() {
1307 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1308 }
1309
1310 #[test]
1311 #[should_panic = "dilation must be non-zero"]
1312 fn unfold_options_dilation_zero() {
1313 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1314 }
1315}