1use super::{conv, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
4use crate::{Backend, ElementConversion, TensorMetadata};
5use burn_std::Shape;
6use core::num::NonZeroUsize;
7
8#[derive(new)]
10pub struct Conv2dBackward<B: Backend> {
11 pub x_grad: FloatTensor<B>,
13
14 pub weights_grad: FloatTensor<B>,
16
17 pub bias_grad: Option<FloatTensor<B>>,
19}
20
21#[derive(new)]
23pub struct DeformConv2dBackward<B: Backend> {
24 pub x_grad: FloatTensor<B>,
26
27 pub offset_grad: FloatTensor<B>,
29
30 pub weight_grad: FloatTensor<B>,
32
33 pub mask_grad: Option<FloatTensor<B>>,
35
36 pub bias_grad: Option<FloatTensor<B>>,
38}
39
40#[derive(new)]
42pub struct Conv3dBackward<B: Backend> {
43 pub x_grad: FloatTensor<B>,
45
46 pub weights_grad: FloatTensor<B>,
48
49 pub bias_grad: Option<FloatTensor<B>>,
51}
52
53#[derive(new)]
55pub struct MaxPool1dBackward<B: Backend> {
56 pub x_grad: FloatTensor<B>,
58}
59
60#[derive(new)]
62pub struct MaxPool1dWithIndices<B: Backend> {
63 pub output: FloatTensor<B>,
65
66 pub indices: IntTensor<B>,
68}
69
70#[derive(new)]
72pub struct MaxPool2dBackward<B: Backend> {
73 pub x_grad: FloatTensor<B>,
75}
76
77#[derive(new)]
79pub struct MaxPool2dWithIndices<B: Backend> {
80 pub output: FloatTensor<B>,
82
83 pub indices: IntTensor<B>,
85}
86
87pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
90 NonZeroUsize::new(value).expect(msg);
91 value
92}
93
94#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96pub struct ConvOptions<const N: usize> {
97 pub stride: [usize; N],
99
100 pub padding: [usize; N],
102
103 pub dilation: [usize; N],
105
106 pub groups: usize,
108}
109
110impl<const N: usize> ConvOptions<N> {
111 pub fn new(
113 stride: [usize; N],
114 padding: [usize; N],
115 dilation: [usize; N],
116 groups: usize,
117 ) -> Self {
118 Self {
119 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
120 padding,
121 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
122 groups: check_nonzero(groups, "groups must be non-zero"),
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
136pub struct PaddedConvOptions<const N: usize> {
137 pub options: ConvOptions<N>,
139 pub padding_end: Option<[usize; N]>,
143}
144
145impl<const N: usize> PaddedConvOptions<N> {
146 pub fn asymmetric(
151 stride: [usize; N],
152 padding_start: [usize; N],
153 padding_end: [usize; N],
154 dilation: [usize; N],
155 groups: usize,
156 ) -> Self {
157 let options = ConvOptions::new(stride, padding_start, dilation, groups);
158 if padding_start == padding_end {
159 Self {
160 options,
161 padding_end: None,
162 }
163 } else {
164 Self {
165 options,
166 padding_end: Some(padding_end),
167 }
168 }
169 }
170
171 pub fn is_asymmetric(&self) -> bool {
173 self.padding_end.is_some()
174 }
175}
176
177impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
178 fn from(options: ConvOptions<N>) -> Self {
179 Self {
180 options,
181 padding_end: None,
182 }
183 }
184}
185
186#[derive(Debug, Clone, Hash, PartialEq, Eq)]
188pub struct DeformConvOptions<const N: usize> {
189 pub stride: [usize; N],
191
192 pub padding: [usize; N],
194
195 pub dilation: [usize; N],
197
198 pub weight_groups: usize,
200
201 pub offset_groups: usize,
203}
204
205impl<const N: usize> DeformConvOptions<N> {
206 pub fn new(
208 stride: [usize; N],
209 padding: [usize; N],
210 dilation: [usize; N],
211 weight_groups: usize,
212 offset_groups: usize,
213 ) -> Self {
214 Self {
215 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
216 padding,
217 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
218 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
219 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
220 }
221 }
222}
223
224#[derive(Debug, Clone, Hash, PartialEq, Eq)]
226pub struct ConvTransposeOptions<const N: usize> {
227 pub stride: [usize; N],
229
230 pub padding: [usize; N],
232
233 pub padding_out: [usize; N],
235
236 pub dilation: [usize; N],
238
239 pub groups: usize,
241}
242
243impl<const N: usize> ConvTransposeOptions<N> {
244 pub fn new(
246 stride: [usize; N],
247 padding: [usize; N],
248 padding_out: [usize; N],
249 dilation: [usize; N],
250 groups: usize,
251 ) -> Self {
252 Self {
253 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
254 padding,
255 padding_out,
256 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
257 groups: check_nonzero(groups, "groups must be non-zero"),
258 }
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct UnfoldOptions {
265 pub stride: [usize; 2],
268
269 pub padding: [usize; 2],
271
272 pub dilation: [usize; 2],
274}
275
276impl UnfoldOptions {
277 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
279 Self {
280 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
281 padding,
282 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
283 }
284 }
285}
286
287#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
289pub enum InterpolateMode {
290 Nearest,
293
294 Bilinear,
297
298 Bicubic,
301}
302
303#[derive(Debug, Clone)]
305pub struct InterpolateOptions {
306 pub mode: InterpolateMode,
308 pub align_corners: bool,
311}
312
313impl InterpolateOptions {
314 pub fn new(mode: InterpolateMode) -> Self {
317 Self {
318 mode,
319 align_corners: true,
320 }
321 }
322
323 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
325 self.align_corners = align_corners;
326 self
327 }
328}
329
330#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
334pub enum GridSamplePaddingMode {
335 #[default]
337 Zeros,
338 Border,
340 Reflection,
342}
343
344#[derive(Debug, Clone)]
346pub struct GridSampleOptions {
347 pub mode: InterpolateMode,
349 pub padding_mode: GridSamplePaddingMode,
351 pub align_corners: bool,
355}
356
357impl Default for GridSampleOptions {
358 fn default() -> Self {
359 Self {
360 mode: InterpolateMode::Bilinear,
361 padding_mode: GridSamplePaddingMode::Zeros,
362 align_corners: false,
363 }
364 }
365}
366
367impl From<InterpolateMode> for GridSampleOptions {
368 fn from(value: InterpolateMode) -> Self {
369 GridSampleOptions::new(value)
370 }
371}
372
373impl GridSampleOptions {
374 pub fn new(mode: InterpolateMode) -> Self {
378 Self {
379 mode,
380 ..Default::default()
381 }
382 }
383
384 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
386 self.padding_mode = padding_mode;
387 self
388 }
389
390 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
392 self.align_corners = align_corners;
393 self
394 }
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
408pub enum PadMode {
409 Constant(f32),
415
416 Reflect,
424
425 Edge,
431}
432
433impl Default for PadMode {
434 fn default() -> Self {
435 PadMode::Constant(0.0)
436 }
437}
438
439impl<E: ElementConversion> From<E> for PadMode {
440 fn from(value: E) -> Self {
441 PadMode::Constant(value.elem())
442 }
443}
444
445#[derive(new)]
447pub struct InterpolateBackward<B: Backend> {
448 pub x_grad: FloatTensor<B>,
450}
451
452#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]
454pub struct AttentionModuleOptions {
455 pub scale: Option<f64>,
457
458 pub softcap: Option<f64>,
461
462 pub is_causal: bool,
467}
468
469pub trait ModuleOps<B: Backend> {
471 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
482 let [batch_size, seq_length] = indices.shape().dims();
483 let [_, d_model] = weights.shape().dims();
484
485 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
486 let output = B::float_select(weights, 0, indices);
487
488 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
489 }
490
491 fn embedding_backward(
503 weights: FloatTensor<B>,
504 output_grad: FloatTensor<B>,
505 indices: IntTensor<B>,
506 ) -> FloatTensor<B> {
507 let [batch_size, seq_length] = indices.shape().dims();
508 let [n_embeddings, d_model] = weights.shape().dims();
509 let device = B::float_device(&weights);
510 let dtype = output_grad.dtype();
511
512 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
513 let output_grad =
514 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
515 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
516
517 B::float_select_add(grad, 0, indices, output_grad)
518 }
519 fn conv1d(
527 x: FloatTensor<B>,
528 weight: FloatTensor<B>,
529 bias: Option<FloatTensor<B>>,
530 options: ConvOptions<1>,
531 ) -> FloatTensor<B> {
532 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
533 }
534 fn conv1d_x_backward(
536 x: FloatTensor<B>,
537 weight: FloatTensor<B>,
538 output_grad: FloatTensor<B>,
539 options: ConvOptions<1>,
540 ) -> FloatTensor<B> {
541 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
542 }
543 fn conv1d_weight_backward(
545 x: FloatTensor<B>,
546 weight: FloatTensor<B>,
547 output_grad: FloatTensor<B>,
548 options: ConvOptions<1>,
549 ) -> FloatTensor<B> {
550 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
551 }
552 fn conv1d_bias_backward(
554 x: FloatTensor<B>,
555 bias: FloatTensor<B>,
556 output_grad: FloatTensor<B>,
557 ) -> FloatTensor<B> {
558 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
559 }
560 fn conv2d(
568 x: FloatTensor<B>,
569 weight: FloatTensor<B>,
570 bias: Option<FloatTensor<B>>,
571 options: ConvOptions<2>,
572 ) -> FloatTensor<B>;
573 fn conv2d_x_backward(
575 x: FloatTensor<B>,
576 weight: FloatTensor<B>,
577 output_grad: FloatTensor<B>,
578 options: ConvOptions<2>,
579 ) -> FloatTensor<B> {
580 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
581 }
582 fn conv2d_weight_backward(
584 x: FloatTensor<B>,
585 weight: FloatTensor<B>,
586 output_grad: FloatTensor<B>,
587 options: ConvOptions<2>,
588 ) -> FloatTensor<B> {
589 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
590 }
591 fn conv2d_bias_backward(
593 x: FloatTensor<B>,
594 bias: FloatTensor<B>,
595 output_grad: FloatTensor<B>,
596 ) -> FloatTensor<B> {
597 conv::conv2d_bias_backward::<B>(x, bias, output_grad)
598 }
599
600 fn deform_conv2d(
608 x: FloatTensor<B>,
609 offset: FloatTensor<B>,
610 weight: FloatTensor<B>,
611 mask: Option<FloatTensor<B>>,
612 bias: Option<FloatTensor<B>>,
613 options: DeformConvOptions<2>,
614 ) -> FloatTensor<B>;
615 fn deform_conv2d_backward(
617 x: FloatTensor<B>,
618 offset: FloatTensor<B>,
619 weight: FloatTensor<B>,
620 mask: Option<FloatTensor<B>>,
621 bias: Option<FloatTensor<B>>,
622 output_grad: FloatTensor<B>,
623 options: DeformConvOptions<2>,
624 ) -> DeformConv2dBackward<B>;
625
626 fn conv3d(
634 x: FloatTensor<B>,
635 weight: FloatTensor<B>,
636 bias: Option<FloatTensor<B>>,
637 options: ConvOptions<3>,
638 ) -> FloatTensor<B>;
639 fn conv3d_x_backward(
641 x: FloatTensor<B>,
642 weight: FloatTensor<B>,
643 output_grad: FloatTensor<B>,
644 options: ConvOptions<3>,
645 ) -> FloatTensor<B> {
646 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
647 }
648 fn conv3d_weight_backward(
650 x: FloatTensor<B>,
651 weight: FloatTensor<B>,
652 output_grad: FloatTensor<B>,
653 options: ConvOptions<3>,
654 ) -> FloatTensor<B> {
655 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
656 }
657 fn conv3d_bias_backward(
659 x: FloatTensor<B>,
660 bias: FloatTensor<B>,
661 output_grad: FloatTensor<B>,
662 ) -> FloatTensor<B> {
663 conv::conv3d_bias_backward::<B>(x, bias, output_grad)
664 }
665 fn conv_transpose1d(
673 x: FloatTensor<B>,
674 weight: FloatTensor<B>,
675 bias: Option<FloatTensor<B>>,
676 options: ConvTransposeOptions<1>,
677 ) -> FloatTensor<B> {
678 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
679 }
680 fn conv_transpose1d_x_backward(
682 weight: FloatTensor<B>,
683 output_grad: FloatTensor<B>,
684 options: ConvTransposeOptions<1>,
685 ) -> FloatTensor<B> {
686 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
687 }
688 fn conv_transpose1d_weight_backward(
690 x: FloatTensor<B>,
691 weight: FloatTensor<B>,
692 output_grad: FloatTensor<B>,
693 options: ConvTransposeOptions<1>,
694 ) -> FloatTensor<B> {
695 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
696 }
697 fn conv_transpose1d_bias_backward(
699 x: FloatTensor<B>,
700 bias: FloatTensor<B>,
701 output_grad: FloatTensor<B>,
702 ) -> FloatTensor<B> {
703 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
704 }
705
706 fn conv_transpose2d(
714 x: FloatTensor<B>,
715 weight: FloatTensor<B>,
716 bias: Option<FloatTensor<B>>,
717 options: ConvTransposeOptions<2>,
718 ) -> FloatTensor<B>;
719 fn conv_transpose2d_x_backward(
721 weight: FloatTensor<B>,
722 output_grad: FloatTensor<B>,
723 options: ConvTransposeOptions<2>,
724 ) -> FloatTensor<B> {
725 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
726 }
727 fn conv_transpose2d_weight_backward(
729 x: FloatTensor<B>,
730 weight: FloatTensor<B>,
731 output_grad: FloatTensor<B>,
732 options: ConvTransposeOptions<2>,
733 ) -> FloatTensor<B> {
734 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
735 }
736 fn conv_transpose2d_bias_backward(
738 x: FloatTensor<B>,
739 bias: FloatTensor<B>,
740 output_grad: FloatTensor<B>,
741 ) -> FloatTensor<B> {
742 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
743 }
744
745 fn conv_transpose3d(
753 x: FloatTensor<B>,
754 weight: FloatTensor<B>,
755 bias: Option<FloatTensor<B>>,
756 options: ConvTransposeOptions<3>,
757 ) -> FloatTensor<B>;
758 fn conv_transpose3d_x_backward(
760 weight: FloatTensor<B>,
761 output_grad: FloatTensor<B>,
762 options: ConvTransposeOptions<3>,
763 ) -> FloatTensor<B> {
764 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
765 }
766 fn conv_transpose3d_weight_backward(
768 x: FloatTensor<B>,
769 weight: FloatTensor<B>,
770 output_grad: FloatTensor<B>,
771 options: ConvTransposeOptions<3>,
772 ) -> FloatTensor<B> {
773 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
774 }
775 fn conv_transpose3d_bias_backward(
777 x: FloatTensor<B>,
778 bias: FloatTensor<B>,
779 output_grad: FloatTensor<B>,
780 ) -> FloatTensor<B> {
781 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
782 }
783
784 fn unfold4d(
791 x: FloatTensor<B>,
792 kernel_size: [usize; 2],
793 options: UnfoldOptions,
794 ) -> FloatTensor<B> {
795 if options.padding == [0, 0] && options.dilation == [1, 1] {
796 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
797 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
798
799 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
802 let shape = blocks.shape();
803
804 B::float_reshape(
807 blocks,
808 [
809 shape[0],
810 shape[1] * shape[2] * shape[3],
811 shape[4] * shape[5],
812 ]
813 .into(),
814 )
815 } else {
816 unfold4d_using_conv2d::<B>(x, kernel_size, options)
817 }
818 }
819
820 fn avg_pool1d(
826 x: FloatTensor<B>,
827 kernel_size: usize,
828 stride: usize,
829 padding: usize,
830 count_include_pad: bool,
831 ceil_mode: bool,
832 ) -> FloatTensor<B> {
833 pool::avg_pool1d_from_2d::<B>(
834 x,
835 kernel_size,
836 stride,
837 padding,
838 count_include_pad,
839 ceil_mode,
840 )
841 }
842 fn avg_pool1d_backward(
844 x: FloatTensor<B>,
845 grad: FloatTensor<B>,
846 kernel_size: usize,
847 stride: usize,
848 padding: usize,
849 count_include_pad: bool,
850 ceil_mode: bool,
851 ) -> FloatTensor<B> {
852 pool::avg_pool1d_backward_from_2d::<B>(
853 x,
854 grad,
855 kernel_size,
856 stride,
857 padding,
858 count_include_pad,
859 ceil_mode,
860 )
861 }
862 fn avg_pool2d(
868 x: FloatTensor<B>,
869 kernel_size: [usize; 2],
870 stride: [usize; 2],
871 padding: [usize; 2],
872 count_include_pad: bool,
873 ceil_mode: bool,
874 ) -> FloatTensor<B>;
875 fn avg_pool2d_backward(
877 x: FloatTensor<B>,
878 grad: FloatTensor<B>,
879 kernel_size: [usize; 2],
880 stride: [usize; 2],
881 padding: [usize; 2],
882 count_include_pad: bool,
883 ceil_mode: bool,
884 ) -> FloatTensor<B>;
885 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
891 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
893 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
899 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
900 }
901 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
903 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
904 }
905 fn max_pool1d(
911 x: FloatTensor<B>,
912 kernel_size: usize,
913 stride: usize,
914 padding: usize,
915 dilation: usize,
916 ceil_mode: bool,
917 ) -> FloatTensor<B> {
918 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
919 }
920
921 fn max_pool1d_with_indices(
927 x: FloatTensor<B>,
928 kernel_size: usize,
929 stride: usize,
930 padding: usize,
931 dilation: usize,
932 ceil_mode: bool,
933 ) -> MaxPool1dWithIndices<B> {
934 pool::max_pool1d_with_indices_from_2d::<B>(
935 x,
936 kernel_size,
937 stride,
938 padding,
939 dilation,
940 ceil_mode,
941 )
942 }
943 #[allow(clippy::too_many_arguments)]
945 fn max_pool1d_with_indices_backward(
946 x: FloatTensor<B>,
947 kernel_size: usize,
948 stride: usize,
949 padding: usize,
950 dilation: usize,
951 ceil_mode: bool,
952 output_grad: FloatTensor<B>,
953 indices: IntTensor<B>,
954 ) -> MaxPool1dBackward<B> {
955 pool::max_pool1d_with_indices_backward_from_2d::<B>(
956 x,
957 kernel_size,
958 stride,
959 padding,
960 dilation,
961 ceil_mode,
962 output_grad,
963 indices,
964 )
965 }
966
967 fn max_pool2d(
973 x: FloatTensor<B>,
974 kernel_size: [usize; 2],
975 stride: [usize; 2],
976 padding: [usize; 2],
977 dilation: [usize; 2],
978 ceil_mode: bool,
979 ) -> FloatTensor<B>;
980
981 fn max_pool2d_with_indices(
987 x: FloatTensor<B>,
988 kernel_size: [usize; 2],
989 stride: [usize; 2],
990 padding: [usize; 2],
991 dilation: [usize; 2],
992 ceil_mode: bool,
993 ) -> MaxPool2dWithIndices<B>;
994 #[allow(clippy::too_many_arguments)]
996 fn max_pool2d_with_indices_backward(
997 x: FloatTensor<B>,
998 kernel_size: [usize; 2],
999 stride: [usize; 2],
1000 padding: [usize; 2],
1001 dilation: [usize; 2],
1002 ceil_mode: bool,
1003 output_grad: FloatTensor<B>,
1004 indices: IntTensor<B>,
1005 ) -> MaxPool2dBackward<B>;
1006
1007 fn interpolate(
1013 x: FloatTensor<B>,
1014 output_size: [usize; 2],
1015 options: InterpolateOptions,
1016 ) -> FloatTensor<B>;
1017
1018 fn interpolate_backward(
1020 x: FloatTensor<B>,
1021 grad: FloatTensor<B>,
1022 output_size: [usize; 2],
1023 options: InterpolateOptions,
1024 ) -> FloatTensor<B>;
1025
1026 fn attention(
1048 query: FloatTensor<B>,
1049 key: FloatTensor<B>,
1050 value: FloatTensor<B>,
1051 mask: Option<BoolTensor<B>>,
1052 attn_bias: Option<FloatTensor<B>>,
1053 options: AttentionModuleOptions,
1054 ) -> FloatTensor<B>;
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060
1061 #[test]
1062 #[should_panic = "stride must be non-zero"]
1063 fn conv_options_stride_zero() {
1064 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1065 }
1066
1067 #[test]
1068 #[should_panic = "dilation must be non-zero"]
1069 fn conv_options_dilation_zero() {
1070 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1071 }
1072
1073 #[test]
1074 #[should_panic = "groups must be non-zero"]
1075 fn conv_options_groups_zero() {
1076 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1077 }
1078
1079 #[test]
1080 #[should_panic = "stride must be non-zero"]
1081 fn conv_transpose_options_stride_zero() {
1082 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1083 }
1084
1085 #[test]
1086 #[should_panic = "dilation must be non-zero"]
1087 fn conv_transpose_options_dilation_zero() {
1088 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1089 }
1090
1091 #[test]
1092 #[should_panic = "groups must be non-zero"]
1093 fn conv_transpose_options_groups_zero() {
1094 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1095 }
1096
1097 #[test]
1098 #[should_panic = "stride must be non-zero"]
1099 fn deform_conv_options_stride_zero() {
1100 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1101 }
1102
1103 #[test]
1104 #[should_panic = "dilation must be non-zero"]
1105 fn deform_conv_options_dilation_zero() {
1106 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1107 }
1108
1109 #[test]
1110 #[should_panic = "weight groups must be non-zero"]
1111 fn deform_conv_options_weights_groups_zero() {
1112 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1113 }
1114
1115 #[test]
1116 #[should_panic = "offset groups must be non-zero"]
1117 fn deform_conv_options_offset_groups_zero() {
1118 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1119 }
1120
1121 #[test]
1122 #[should_panic = "stride must be non-zero"]
1123 fn unfold_options_stride_zero() {
1124 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1125 }
1126
1127 #[test]
1128 #[should_panic = "dilation must be non-zero"]
1129 fn unfold_options_dilation_zero() {
1130 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1131 }
1132}