1use super::{conv, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
4use crate::{Backend, ElementConversion, TensorMetadata};
5use burn_std::Shape;
6use core::num::NonZeroUsize;
7
8#[derive(new)]
10pub struct Conv2dBackward<B: Backend> {
11 pub x_grad: FloatTensor<B>,
13
14 pub weights_grad: FloatTensor<B>,
16
17 pub bias_grad: Option<FloatTensor<B>>,
19}
20
21#[derive(new)]
23pub struct DeformConv2dBackward<B: Backend> {
24 pub x_grad: FloatTensor<B>,
26
27 pub offset_grad: FloatTensor<B>,
29
30 pub weight_grad: FloatTensor<B>,
32
33 pub mask_grad: Option<FloatTensor<B>>,
35
36 pub bias_grad: Option<FloatTensor<B>>,
38}
39
40#[derive(new)]
42pub struct Conv3dBackward<B: Backend> {
43 pub x_grad: FloatTensor<B>,
45
46 pub weights_grad: FloatTensor<B>,
48
49 pub bias_grad: Option<FloatTensor<B>>,
51}
52
53#[derive(new)]
55pub struct MaxPool1dBackward<B: Backend> {
56 pub x_grad: FloatTensor<B>,
58}
59
60#[derive(new)]
62pub struct MaxPool1dWithIndices<B: Backend> {
63 pub output: FloatTensor<B>,
65
66 pub indices: IntTensor<B>,
68}
69
70#[derive(new)]
72pub struct MaxPool2dBackward<B: Backend> {
73 pub x_grad: FloatTensor<B>,
75}
76
77#[derive(new)]
79pub struct MaxPool2dWithIndices<B: Backend> {
80 pub output: FloatTensor<B>,
82
83 pub indices: IntTensor<B>,
85}
86
87pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
90 NonZeroUsize::new(value).expect(msg);
91 value
92}
93
94#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96pub struct ConvOptions<const N: usize> {
97 pub stride: [usize; N],
99
100 pub padding: [usize; N],
102
103 pub dilation: [usize; N],
105
106 pub groups: usize,
108}
109
110impl<const N: usize> ConvOptions<N> {
111 pub fn new(
113 stride: [usize; N],
114 padding: [usize; N],
115 dilation: [usize; N],
116 groups: usize,
117 ) -> Self {
118 Self {
119 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
120 padding,
121 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
122 groups: check_nonzero(groups, "groups must be non-zero"),
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
136pub struct PaddedConvOptions<const N: usize> {
137 pub options: ConvOptions<N>,
139 pub padding_end: Option<[usize; N]>,
143}
144
145impl<const N: usize> PaddedConvOptions<N> {
146 pub fn asymmetric(
151 stride: [usize; N],
152 padding_start: [usize; N],
153 padding_end: [usize; N],
154 dilation: [usize; N],
155 groups: usize,
156 ) -> Self {
157 let options = ConvOptions::new(stride, padding_start, dilation, groups);
158 if padding_start == padding_end {
159 Self {
160 options,
161 padding_end: None,
162 }
163 } else {
164 Self {
165 options,
166 padding_end: Some(padding_end),
167 }
168 }
169 }
170
171 pub fn is_asymmetric(&self) -> bool {
173 self.padding_end.is_some()
174 }
175}
176
177impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
178 fn from(options: ConvOptions<N>) -> Self {
179 Self {
180 options,
181 padding_end: None,
182 }
183 }
184}
185
186#[derive(Debug, Clone, Hash, PartialEq, Eq)]
188pub struct DeformConvOptions<const N: usize> {
189 pub stride: [usize; N],
191
192 pub padding: [usize; N],
194
195 pub dilation: [usize; N],
197
198 pub weight_groups: usize,
200
201 pub offset_groups: usize,
203}
204
205impl<const N: usize> DeformConvOptions<N> {
206 pub fn new(
208 stride: [usize; N],
209 padding: [usize; N],
210 dilation: [usize; N],
211 weight_groups: usize,
212 offset_groups: usize,
213 ) -> Self {
214 Self {
215 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
216 padding,
217 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
218 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
219 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
220 }
221 }
222}
223
224#[derive(Debug, Clone, Hash, PartialEq, Eq)]
226pub struct ConvTransposeOptions<const N: usize> {
227 pub stride: [usize; N],
229
230 pub padding: [usize; N],
232
233 pub padding_out: [usize; N],
235
236 pub dilation: [usize; N],
238
239 pub groups: usize,
241}
242
243impl<const N: usize> ConvTransposeOptions<N> {
244 pub fn new(
246 stride: [usize; N],
247 padding: [usize; N],
248 padding_out: [usize; N],
249 dilation: [usize; N],
250 groups: usize,
251 ) -> Self {
252 Self {
253 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
254 padding,
255 padding_out,
256 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
257 groups: check_nonzero(groups, "groups must be non-zero"),
258 }
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct UnfoldOptions {
265 pub stride: [usize; 2],
268
269 pub padding: [usize; 2],
271
272 pub dilation: [usize; 2],
274}
275
276impl UnfoldOptions {
277 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
279 Self {
280 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
281 padding,
282 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
283 }
284 }
285}
286
287#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
289pub enum InterpolateMode {
290 Nearest,
293
294 Bilinear,
297
298 Bicubic,
301
302 Lanczos3,
305}
306
307#[derive(Debug, Clone)]
309pub struct InterpolateOptions {
310 pub mode: InterpolateMode,
312 pub align_corners: bool,
315}
316
317impl InterpolateOptions {
318 pub fn new(mode: InterpolateMode) -> Self {
321 Self {
322 mode,
323 align_corners: true,
324 }
325 }
326
327 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
329 self.align_corners = align_corners;
330 self
331 }
332}
333
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
338pub enum GridSamplePaddingMode {
339 #[default]
341 Zeros,
342 Border,
344 Reflection,
346}
347
348#[derive(Debug, Clone)]
350pub struct GridSampleOptions {
351 pub mode: InterpolateMode,
353 pub padding_mode: GridSamplePaddingMode,
355 pub align_corners: bool,
359}
360
361impl Default for GridSampleOptions {
362 fn default() -> Self {
363 Self {
364 mode: InterpolateMode::Bilinear,
365 padding_mode: GridSamplePaddingMode::Zeros,
366 align_corners: false,
367 }
368 }
369}
370
371impl From<InterpolateMode> for GridSampleOptions {
372 fn from(value: InterpolateMode) -> Self {
373 GridSampleOptions::new(value)
374 }
375}
376
377impl GridSampleOptions {
378 pub fn new(mode: InterpolateMode) -> Self {
382 Self {
383 mode,
384 ..Default::default()
385 }
386 }
387
388 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
390 self.padding_mode = padding_mode;
391 self
392 }
393
394 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
396 self.align_corners = align_corners;
397 self
398 }
399}
400
401#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
412pub enum PadMode {
413 Constant(f32),
419
420 Reflect,
428
429 Edge,
435}
436
437impl Default for PadMode {
438 fn default() -> Self {
439 PadMode::Constant(0.0)
440 }
441}
442
443impl<E: ElementConversion> From<E> for PadMode {
444 fn from(value: E) -> Self {
445 PadMode::Constant(value.elem())
446 }
447}
448
449#[derive(new)]
451pub struct InterpolateBackward<B: Backend> {
452 pub x_grad: FloatTensor<B>,
454}
455
456#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]
458pub struct AttentionModuleOptions {
459 pub scale: Option<f64>,
461
462 pub softcap: Option<f64>,
465
466 pub is_causal: bool,
471}
472
473pub trait ModuleOps<B: Backend> {
475 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
486 let [batch_size, seq_length] = indices.shape().dims();
487 let [_, d_model] = weights.shape().dims();
488
489 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
490 let output = B::float_select(weights, 0, indices);
491
492 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
493 }
494
495 fn embedding_backward(
507 weights: FloatTensor<B>,
508 output_grad: FloatTensor<B>,
509 indices: IntTensor<B>,
510 ) -> FloatTensor<B> {
511 let [batch_size, seq_length] = indices.shape().dims();
512 let [n_embeddings, d_model] = weights.shape().dims();
513 let device = B::float_device(&weights);
514 let dtype = output_grad.dtype();
515
516 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
517 let output_grad =
518 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
519 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
520
521 B::float_select_add(grad, 0, indices, output_grad)
522 }
523 fn conv1d(
531 x: FloatTensor<B>,
532 weight: FloatTensor<B>,
533 bias: Option<FloatTensor<B>>,
534 options: ConvOptions<1>,
535 ) -> FloatTensor<B> {
536 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
537 }
538 fn conv1d_x_backward(
540 x: FloatTensor<B>,
541 weight: FloatTensor<B>,
542 output_grad: FloatTensor<B>,
543 options: ConvOptions<1>,
544 ) -> FloatTensor<B> {
545 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
546 }
547 fn conv1d_weight_backward(
549 x: FloatTensor<B>,
550 weight: FloatTensor<B>,
551 output_grad: FloatTensor<B>,
552 options: ConvOptions<1>,
553 ) -> FloatTensor<B> {
554 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
555 }
556 fn conv1d_bias_backward(
558 x: FloatTensor<B>,
559 bias: FloatTensor<B>,
560 output_grad: FloatTensor<B>,
561 ) -> FloatTensor<B> {
562 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
563 }
564 fn conv2d(
572 x: FloatTensor<B>,
573 weight: FloatTensor<B>,
574 bias: Option<FloatTensor<B>>,
575 options: ConvOptions<2>,
576 ) -> FloatTensor<B>;
577 fn conv2d_x_backward(
579 x: FloatTensor<B>,
580 weight: FloatTensor<B>,
581 output_grad: FloatTensor<B>,
582 options: ConvOptions<2>,
583 ) -> FloatTensor<B> {
584 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
585 }
586 fn conv2d_weight_backward(
588 x: FloatTensor<B>,
589 weight: FloatTensor<B>,
590 output_grad: FloatTensor<B>,
591 options: ConvOptions<2>,
592 ) -> FloatTensor<B> {
593 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
594 }
595 fn conv2d_bias_backward(
597 x: FloatTensor<B>,
598 bias: FloatTensor<B>,
599 output_grad: FloatTensor<B>,
600 ) -> FloatTensor<B> {
601 conv::conv2d_bias_backward::<B>(x, bias, output_grad)
602 }
603
604 fn deform_conv2d(
612 x: FloatTensor<B>,
613 offset: FloatTensor<B>,
614 weight: FloatTensor<B>,
615 mask: Option<FloatTensor<B>>,
616 bias: Option<FloatTensor<B>>,
617 options: DeformConvOptions<2>,
618 ) -> FloatTensor<B>;
619 fn deform_conv2d_backward(
621 x: FloatTensor<B>,
622 offset: FloatTensor<B>,
623 weight: FloatTensor<B>,
624 mask: Option<FloatTensor<B>>,
625 bias: Option<FloatTensor<B>>,
626 output_grad: FloatTensor<B>,
627 options: DeformConvOptions<2>,
628 ) -> DeformConv2dBackward<B>;
629
630 fn conv3d(
638 x: FloatTensor<B>,
639 weight: FloatTensor<B>,
640 bias: Option<FloatTensor<B>>,
641 options: ConvOptions<3>,
642 ) -> FloatTensor<B>;
643 fn conv3d_x_backward(
645 x: FloatTensor<B>,
646 weight: FloatTensor<B>,
647 output_grad: FloatTensor<B>,
648 options: ConvOptions<3>,
649 ) -> FloatTensor<B> {
650 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
651 }
652 fn conv3d_weight_backward(
654 x: FloatTensor<B>,
655 weight: FloatTensor<B>,
656 output_grad: FloatTensor<B>,
657 options: ConvOptions<3>,
658 ) -> FloatTensor<B> {
659 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
660 }
661 fn conv3d_bias_backward(
663 x: FloatTensor<B>,
664 bias: FloatTensor<B>,
665 output_grad: FloatTensor<B>,
666 ) -> FloatTensor<B> {
667 conv::conv3d_bias_backward::<B>(x, bias, output_grad)
668 }
669 fn conv_transpose1d(
677 x: FloatTensor<B>,
678 weight: FloatTensor<B>,
679 bias: Option<FloatTensor<B>>,
680 options: ConvTransposeOptions<1>,
681 ) -> FloatTensor<B> {
682 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
683 }
684 fn conv_transpose1d_x_backward(
686 weight: FloatTensor<B>,
687 output_grad: FloatTensor<B>,
688 options: ConvTransposeOptions<1>,
689 ) -> FloatTensor<B> {
690 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
691 }
692 fn conv_transpose1d_weight_backward(
694 x: FloatTensor<B>,
695 weight: FloatTensor<B>,
696 output_grad: FloatTensor<B>,
697 options: ConvTransposeOptions<1>,
698 ) -> FloatTensor<B> {
699 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
700 }
701 fn conv_transpose1d_bias_backward(
703 x: FloatTensor<B>,
704 bias: FloatTensor<B>,
705 output_grad: FloatTensor<B>,
706 ) -> FloatTensor<B> {
707 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
708 }
709
710 fn conv_transpose2d(
718 x: FloatTensor<B>,
719 weight: FloatTensor<B>,
720 bias: Option<FloatTensor<B>>,
721 options: ConvTransposeOptions<2>,
722 ) -> FloatTensor<B>;
723 fn conv_transpose2d_x_backward(
725 weight: FloatTensor<B>,
726 output_grad: FloatTensor<B>,
727 options: ConvTransposeOptions<2>,
728 ) -> FloatTensor<B> {
729 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
730 }
731 fn conv_transpose2d_weight_backward(
733 x: FloatTensor<B>,
734 weight: FloatTensor<B>,
735 output_grad: FloatTensor<B>,
736 options: ConvTransposeOptions<2>,
737 ) -> FloatTensor<B> {
738 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
739 }
740 fn conv_transpose2d_bias_backward(
742 x: FloatTensor<B>,
743 bias: FloatTensor<B>,
744 output_grad: FloatTensor<B>,
745 ) -> FloatTensor<B> {
746 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
747 }
748
749 fn conv_transpose3d(
757 x: FloatTensor<B>,
758 weight: FloatTensor<B>,
759 bias: Option<FloatTensor<B>>,
760 options: ConvTransposeOptions<3>,
761 ) -> FloatTensor<B>;
762 fn conv_transpose3d_x_backward(
764 weight: FloatTensor<B>,
765 output_grad: FloatTensor<B>,
766 options: ConvTransposeOptions<3>,
767 ) -> FloatTensor<B> {
768 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
769 }
770 fn conv_transpose3d_weight_backward(
772 x: FloatTensor<B>,
773 weight: FloatTensor<B>,
774 output_grad: FloatTensor<B>,
775 options: ConvTransposeOptions<3>,
776 ) -> FloatTensor<B> {
777 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
778 }
779 fn conv_transpose3d_bias_backward(
781 x: FloatTensor<B>,
782 bias: FloatTensor<B>,
783 output_grad: FloatTensor<B>,
784 ) -> FloatTensor<B> {
785 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
786 }
787
788 fn unfold4d(
795 x: FloatTensor<B>,
796 kernel_size: [usize; 2],
797 options: UnfoldOptions,
798 ) -> FloatTensor<B> {
799 if options.padding == [0, 0] && options.dilation == [1, 1] {
800 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
801 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
802
803 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
806 let shape = blocks.shape();
807
808 B::float_reshape(
811 blocks,
812 [
813 shape[0],
814 shape[1] * shape[2] * shape[3],
815 shape[4] * shape[5],
816 ]
817 .into(),
818 )
819 } else {
820 unfold4d_using_conv2d::<B>(x, kernel_size, options)
821 }
822 }
823
824 fn avg_pool1d(
830 x: FloatTensor<B>,
831 kernel_size: usize,
832 stride: usize,
833 padding: usize,
834 count_include_pad: bool,
835 ceil_mode: bool,
836 ) -> FloatTensor<B> {
837 pool::avg_pool1d_from_2d::<B>(
838 x,
839 kernel_size,
840 stride,
841 padding,
842 count_include_pad,
843 ceil_mode,
844 )
845 }
846 fn avg_pool1d_backward(
848 x: FloatTensor<B>,
849 grad: FloatTensor<B>,
850 kernel_size: usize,
851 stride: usize,
852 padding: usize,
853 count_include_pad: bool,
854 ceil_mode: bool,
855 ) -> FloatTensor<B> {
856 pool::avg_pool1d_backward_from_2d::<B>(
857 x,
858 grad,
859 kernel_size,
860 stride,
861 padding,
862 count_include_pad,
863 ceil_mode,
864 )
865 }
866 fn avg_pool2d(
872 x: FloatTensor<B>,
873 kernel_size: [usize; 2],
874 stride: [usize; 2],
875 padding: [usize; 2],
876 count_include_pad: bool,
877 ceil_mode: bool,
878 ) -> FloatTensor<B>;
879 fn avg_pool2d_backward(
881 x: FloatTensor<B>,
882 grad: FloatTensor<B>,
883 kernel_size: [usize; 2],
884 stride: [usize; 2],
885 padding: [usize; 2],
886 count_include_pad: bool,
887 ceil_mode: bool,
888 ) -> FloatTensor<B>;
889 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
895 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
897 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
903 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
904 }
905 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
907 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
908 }
909 fn max_pool1d(
915 x: FloatTensor<B>,
916 kernel_size: usize,
917 stride: usize,
918 padding: usize,
919 dilation: usize,
920 ceil_mode: bool,
921 ) -> FloatTensor<B> {
922 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
923 }
924
925 fn max_pool1d_with_indices(
931 x: FloatTensor<B>,
932 kernel_size: usize,
933 stride: usize,
934 padding: usize,
935 dilation: usize,
936 ceil_mode: bool,
937 ) -> MaxPool1dWithIndices<B> {
938 pool::max_pool1d_with_indices_from_2d::<B>(
939 x,
940 kernel_size,
941 stride,
942 padding,
943 dilation,
944 ceil_mode,
945 )
946 }
947 #[allow(clippy::too_many_arguments)]
949 fn max_pool1d_with_indices_backward(
950 x: FloatTensor<B>,
951 kernel_size: usize,
952 stride: usize,
953 padding: usize,
954 dilation: usize,
955 ceil_mode: bool,
956 output_grad: FloatTensor<B>,
957 indices: IntTensor<B>,
958 ) -> MaxPool1dBackward<B> {
959 pool::max_pool1d_with_indices_backward_from_2d::<B>(
960 x,
961 kernel_size,
962 stride,
963 padding,
964 dilation,
965 ceil_mode,
966 output_grad,
967 indices,
968 )
969 }
970
971 fn max_pool2d(
977 x: FloatTensor<B>,
978 kernel_size: [usize; 2],
979 stride: [usize; 2],
980 padding: [usize; 2],
981 dilation: [usize; 2],
982 ceil_mode: bool,
983 ) -> FloatTensor<B>;
984
985 fn max_pool2d_with_indices(
991 x: FloatTensor<B>,
992 kernel_size: [usize; 2],
993 stride: [usize; 2],
994 padding: [usize; 2],
995 dilation: [usize; 2],
996 ceil_mode: bool,
997 ) -> MaxPool2dWithIndices<B>;
998 #[allow(clippy::too_many_arguments)]
1000 fn max_pool2d_with_indices_backward(
1001 x: FloatTensor<B>,
1002 kernel_size: [usize; 2],
1003 stride: [usize; 2],
1004 padding: [usize; 2],
1005 dilation: [usize; 2],
1006 ceil_mode: bool,
1007 output_grad: FloatTensor<B>,
1008 indices: IntTensor<B>,
1009 ) -> MaxPool2dBackward<B>;
1010
1011 fn interpolate(
1017 x: FloatTensor<B>,
1018 output_size: [usize; 2],
1019 options: InterpolateOptions,
1020 ) -> FloatTensor<B>;
1021
1022 fn interpolate_backward(
1024 x: FloatTensor<B>,
1025 grad: FloatTensor<B>,
1026 output_size: [usize; 2],
1027 options: InterpolateOptions,
1028 ) -> FloatTensor<B>;
1029
1030 fn attention(
1052 query: FloatTensor<B>,
1053 key: FloatTensor<B>,
1054 value: FloatTensor<B>,
1055 mask: Option<BoolTensor<B>>,
1056 attn_bias: Option<FloatTensor<B>>,
1057 options: AttentionModuleOptions,
1058 ) -> FloatTensor<B>;
1059
1060 fn rfft(signal: FloatTensor<B>, dim: usize) -> (FloatTensor<B>, FloatTensor<B>);
1068
1069 fn irfft(
1075 spectrum_re: FloatTensor<B>,
1076 spectrum_im: FloatTensor<B>,
1077 dim: usize,
1078 ) -> FloatTensor<B>;
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083 use super::*;
1084
1085 #[test]
1086 #[should_panic = "stride must be non-zero"]
1087 fn conv_options_stride_zero() {
1088 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1089 }
1090
1091 #[test]
1092 #[should_panic = "dilation must be non-zero"]
1093 fn conv_options_dilation_zero() {
1094 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1095 }
1096
1097 #[test]
1098 #[should_panic = "groups must be non-zero"]
1099 fn conv_options_groups_zero() {
1100 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1101 }
1102
1103 #[test]
1104 #[should_panic = "stride must be non-zero"]
1105 fn conv_transpose_options_stride_zero() {
1106 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1107 }
1108
1109 #[test]
1110 #[should_panic = "dilation must be non-zero"]
1111 fn conv_transpose_options_dilation_zero() {
1112 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1113 }
1114
1115 #[test]
1116 #[should_panic = "groups must be non-zero"]
1117 fn conv_transpose_options_groups_zero() {
1118 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1119 }
1120
1121 #[test]
1122 #[should_panic = "stride must be non-zero"]
1123 fn deform_conv_options_stride_zero() {
1124 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1125 }
1126
1127 #[test]
1128 #[should_panic = "dilation must be non-zero"]
1129 fn deform_conv_options_dilation_zero() {
1130 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1131 }
1132
1133 #[test]
1134 #[should_panic = "weight groups must be non-zero"]
1135 fn deform_conv_options_weights_groups_zero() {
1136 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1137 }
1138
1139 #[test]
1140 #[should_panic = "offset groups must be non-zero"]
1141 fn deform_conv_options_offset_groups_zero() {
1142 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1143 }
1144
1145 #[test]
1146 #[should_panic = "stride must be non-zero"]
1147 fn unfold_options_stride_zero() {
1148 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1149 }
1150
1151 #[test]
1152 #[should_panic = "dilation must be non-zero"]
1153 fn unfold_options_dilation_zero() {
1154 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1155 }
1156}