1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, ElementConversion, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12 pub x_grad: FloatTensor<B>,
14
15 pub weights_grad: FloatTensor<B>,
17
18 pub bias_grad: Option<FloatTensor<B>>,
20}
21
22#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25 pub x_grad: FloatTensor<B>,
27
28 pub offset_grad: FloatTensor<B>,
30
31 pub weight_grad: FloatTensor<B>,
33
34 pub mask_grad: Option<FloatTensor<B>>,
36
37 pub bias_grad: Option<FloatTensor<B>>,
39}
40
41#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44 pub x_grad: FloatTensor<B>,
46
47 pub weights_grad: FloatTensor<B>,
49
50 pub bias_grad: Option<FloatTensor<B>>,
52}
53
54#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57 pub x_grad: FloatTensor<B>,
59}
60
61#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64 pub output: FloatTensor<B>,
66
67 pub indices: IntTensor<B>,
69}
70
71#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74 pub x_grad: FloatTensor<B>,
76}
77
78#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81 pub output: FloatTensor<B>,
83
84 pub indices: IntTensor<B>,
86}
87
88pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91 NonZeroUsize::new(value).expect(msg);
92 value
93}
94
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98 pub stride: [usize; N],
100
101 pub padding: [usize; N],
103
104 pub dilation: [usize; N],
106
107 pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112 pub fn new(
114 stride: [usize; N],
115 padding: [usize; N],
116 dilation: [usize; N],
117 groups: usize,
118 ) -> Self {
119 Self {
120 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121 padding,
122 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123 groups: check_nonzero(groups, "groups must be non-zero"),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Hash, PartialEq, Eq)]
130pub struct DeformConvOptions<const N: usize> {
131 pub stride: [usize; N],
133
134 pub padding: [usize; N],
136
137 pub dilation: [usize; N],
139
140 pub weight_groups: usize,
142
143 pub offset_groups: usize,
145}
146
147impl<const N: usize> DeformConvOptions<N> {
148 pub fn new(
150 stride: [usize; N],
151 padding: [usize; N],
152 dilation: [usize; N],
153 weight_groups: usize,
154 offset_groups: usize,
155 ) -> Self {
156 Self {
157 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
158 padding,
159 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
160 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
161 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
162 }
163 }
164}
165
166#[derive(Debug, Clone, Hash, PartialEq, Eq)]
168pub struct ConvTransposeOptions<const N: usize> {
169 pub stride: [usize; N],
171
172 pub padding: [usize; N],
174
175 pub padding_out: [usize; N],
177
178 pub dilation: [usize; N],
180
181 pub groups: usize,
183}
184
185impl<const N: usize> ConvTransposeOptions<N> {
186 pub fn new(
188 stride: [usize; N],
189 padding: [usize; N],
190 padding_out: [usize; N],
191 dilation: [usize; N],
192 groups: usize,
193 ) -> Self {
194 Self {
195 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
196 padding,
197 padding_out,
198 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
199 groups: check_nonzero(groups, "groups must be non-zero"),
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct UnfoldOptions {
207 pub stride: [usize; 2],
210
211 pub padding: [usize; 2],
213
214 pub dilation: [usize; 2],
216}
217
218impl UnfoldOptions {
219 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
221 Self {
222 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
223 padding,
224 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
225 }
226 }
227}
228
229#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
231pub enum InterpolateMode {
232 Nearest,
235
236 Bilinear,
239
240 Bicubic,
243}
244
245#[derive(new, Debug, Clone)]
247pub struct InterpolateOptions {
248 pub mode: InterpolateMode,
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
256pub enum GridSamplePaddingMode {
257 #[default]
259 Zeros,
260 Border,
262 Reflection,
264}
265
266#[derive(Debug, Clone)]
268pub struct GridSampleOptions {
269 pub mode: InterpolateMode,
271 pub padding_mode: GridSamplePaddingMode,
273 pub align_corners: bool,
277}
278
279impl Default for GridSampleOptions {
280 fn default() -> Self {
281 Self {
282 mode: InterpolateMode::Bilinear,
283 padding_mode: GridSamplePaddingMode::Zeros,
284 align_corners: false,
285 }
286 }
287}
288
289impl From<InterpolateMode> for GridSampleOptions {
290 fn from(value: InterpolateMode) -> Self {
291 GridSampleOptions::new(value)
292 }
293}
294
295impl GridSampleOptions {
296 pub fn new(mode: InterpolateMode) -> Self {
300 Self {
301 mode,
302 ..Default::default()
303 }
304 }
305
306 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
308 self.padding_mode = padding_mode;
309 self
310 }
311
312 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
314 self.align_corners = align_corners;
315 self
316 }
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
332pub enum PadMode {
333 Constant(f32),
339
340 Reflect,
348
349 Edge,
355}
356
357impl Default for PadMode {
358 fn default() -> Self {
359 PadMode::Constant(0.0)
360 }
361}
362
363impl<E: ElementConversion> From<E> for PadMode {
364 fn from(value: E) -> Self {
365 PadMode::Constant(value.elem())
366 }
367}
368
369#[derive(new)]
371pub struct InterpolateBackward<B: Backend> {
372 pub x_grad: FloatTensor<B>,
374}
375
376pub trait ModuleOps<B: Backend> {
378 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
389 let [batch_size, seq_length] = indices.shape().dims();
390 let [_, d_model] = weights.shape().dims();
391
392 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
393 let output = B::float_select(weights, 0, indices);
394
395 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
396 }
397
398 fn embedding_backward(
410 weights: FloatTensor<B>,
411 output_grad: FloatTensor<B>,
412 indices: IntTensor<B>,
413 ) -> FloatTensor<B> {
414 let [batch_size, seq_length] = indices.shape().dims();
415 let [n_embeddings, d_model] = weights.shape().dims();
416 let device = B::float_device(&weights);
417 let dtype = output_grad.dtype();
418
419 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
420 let output_grad =
421 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
422 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
423
424 B::float_select_add(grad, 0, indices, output_grad)
425 }
426 fn conv1d(
434 x: FloatTensor<B>,
435 weight: FloatTensor<B>,
436 bias: Option<FloatTensor<B>>,
437 options: ConvOptions<1>,
438 ) -> FloatTensor<B> {
439 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
440 }
441 fn conv1d_x_backward(
443 x: FloatTensor<B>,
444 weight: FloatTensor<B>,
445 output_grad: FloatTensor<B>,
446 options: ConvOptions<1>,
447 ) -> FloatTensor<B> {
448 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
449 }
450 fn conv1d_weight_backward(
452 x: FloatTensor<B>,
453 weight: FloatTensor<B>,
454 output_grad: FloatTensor<B>,
455 options: ConvOptions<1>,
456 ) -> FloatTensor<B> {
457 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
458 }
459 fn conv1d_bias_backward(
461 x: FloatTensor<B>,
462 bias: FloatTensor<B>,
463 output_grad: FloatTensor<B>,
464 ) -> FloatTensor<B> {
465 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
466 }
467 fn conv2d(
475 x: FloatTensor<B>,
476 weight: FloatTensor<B>,
477 bias: Option<FloatTensor<B>>,
478 options: ConvOptions<2>,
479 ) -> FloatTensor<B>;
480 fn conv2d_x_backward(
482 x: FloatTensor<B>,
483 weight: FloatTensor<B>,
484 output_grad: FloatTensor<B>,
485 options: ConvOptions<2>,
486 ) -> FloatTensor<B> {
487 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
488 }
489 fn conv2d_weight_backward(
491 x: FloatTensor<B>,
492 weight: FloatTensor<B>,
493 output_grad: FloatTensor<B>,
494 options: ConvOptions<2>,
495 ) -> FloatTensor<B> {
496 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
497 }
498 fn conv2d_bias_backward(
500 x: FloatTensor<B>,
501 weight: FloatTensor<B>,
502 bias: FloatTensor<B>,
503 output_grad: FloatTensor<B>,
504 ) -> FloatTensor<B> {
505 conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
506 }
507
508 fn deform_conv2d(
516 x: FloatTensor<B>,
517 offset: FloatTensor<B>,
518 weight: FloatTensor<B>,
519 mask: Option<FloatTensor<B>>,
520 bias: Option<FloatTensor<B>>,
521 options: DeformConvOptions<2>,
522 ) -> FloatTensor<B>;
523 fn deform_conv2d_backward(
525 x: FloatTensor<B>,
526 offset: FloatTensor<B>,
527 weight: FloatTensor<B>,
528 mask: Option<FloatTensor<B>>,
529 bias: Option<FloatTensor<B>>,
530 output_grad: FloatTensor<B>,
531 options: DeformConvOptions<2>,
532 ) -> DeformConv2dBackward<B>;
533
534 fn conv3d(
542 x: FloatTensor<B>,
543 weight: FloatTensor<B>,
544 bias: Option<FloatTensor<B>>,
545 options: ConvOptions<3>,
546 ) -> FloatTensor<B>;
547 fn conv3d_x_backward(
549 x: FloatTensor<B>,
550 weight: FloatTensor<B>,
551 output_grad: FloatTensor<B>,
552 options: ConvOptions<3>,
553 ) -> FloatTensor<B> {
554 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
555 }
556 fn conv3d_weight_backward(
558 x: FloatTensor<B>,
559 weight: FloatTensor<B>,
560 output_grad: FloatTensor<B>,
561 options: ConvOptions<3>,
562 ) -> FloatTensor<B> {
563 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
564 }
565 fn conv3d_bias_backward(
567 x: FloatTensor<B>,
568 weight: FloatTensor<B>,
569 bias: FloatTensor<B>,
570 output_grad: FloatTensor<B>,
571 ) -> FloatTensor<B> {
572 conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
573 }
574 fn conv_transpose1d(
582 x: FloatTensor<B>,
583 weight: FloatTensor<B>,
584 bias: Option<FloatTensor<B>>,
585 options: ConvTransposeOptions<1>,
586 ) -> FloatTensor<B> {
587 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
588 }
589 fn conv_transpose1d_x_backward(
591 weight: FloatTensor<B>,
592 output_grad: FloatTensor<B>,
593 options: ConvTransposeOptions<1>,
594 ) -> FloatTensor<B> {
595 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
596 }
597 fn conv_transpose1d_weight_backward(
599 x: FloatTensor<B>,
600 weight: FloatTensor<B>,
601 output_grad: FloatTensor<B>,
602 options: ConvTransposeOptions<1>,
603 ) -> FloatTensor<B> {
604 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
605 }
606 fn conv_transpose1d_bias_backward(
608 x: FloatTensor<B>,
609 bias: FloatTensor<B>,
610 output_grad: FloatTensor<B>,
611 ) -> FloatTensor<B> {
612 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
613 }
614
615 fn conv_transpose2d(
623 x: FloatTensor<B>,
624 weight: FloatTensor<B>,
625 bias: Option<FloatTensor<B>>,
626 options: ConvTransposeOptions<2>,
627 ) -> FloatTensor<B>;
628 fn conv_transpose2d_x_backward(
630 weight: FloatTensor<B>,
631 output_grad: FloatTensor<B>,
632 options: ConvTransposeOptions<2>,
633 ) -> FloatTensor<B> {
634 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
635 }
636 fn conv_transpose2d_weight_backward(
638 x: FloatTensor<B>,
639 weight: FloatTensor<B>,
640 output_grad: FloatTensor<B>,
641 options: ConvTransposeOptions<2>,
642 ) -> FloatTensor<B> {
643 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
644 }
645 fn conv_transpose2d_bias_backward(
647 x: FloatTensor<B>,
648 bias: FloatTensor<B>,
649 output_grad: FloatTensor<B>,
650 ) -> FloatTensor<B> {
651 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
652 }
653
654 fn conv_transpose3d(
662 x: FloatTensor<B>,
663 weight: FloatTensor<B>,
664 bias: Option<FloatTensor<B>>,
665 options: ConvTransposeOptions<3>,
666 ) -> FloatTensor<B>;
667 fn conv_transpose3d_x_backward(
669 weight: FloatTensor<B>,
670 output_grad: FloatTensor<B>,
671 options: ConvTransposeOptions<3>,
672 ) -> FloatTensor<B> {
673 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
674 }
675 fn conv_transpose3d_weight_backward(
677 x: FloatTensor<B>,
678 weight: FloatTensor<B>,
679 output_grad: FloatTensor<B>,
680 options: ConvTransposeOptions<3>,
681 ) -> FloatTensor<B> {
682 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
683 }
684 fn conv_transpose3d_bias_backward(
686 x: FloatTensor<B>,
687 bias: FloatTensor<B>,
688 output_grad: FloatTensor<B>,
689 ) -> FloatTensor<B> {
690 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
691 }
692
693 fn unfold4d(
700 x: FloatTensor<B>,
701 kernel_size: [usize; 2],
702 options: UnfoldOptions,
703 ) -> FloatTensor<B> {
704 if options.padding == [0, 0] && options.dilation == [1, 1] {
705 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
706 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
707
708 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
711 let shape = &blocks.shape().dims;
712
713 B::float_reshape(
716 blocks,
717 [
718 shape[0],
719 shape[1] * shape[2] * shape[3],
720 shape[4] * shape[5],
721 ]
722 .into(),
723 )
724 } else {
725 unfold4d_using_conv2d::<B>(x, kernel_size, options)
726 }
727 }
728
729 fn avg_pool1d(
735 x: FloatTensor<B>,
736 kernel_size: usize,
737 stride: usize,
738 padding: usize,
739 count_include_pad: bool,
740 ceil_mode: bool,
741 ) -> FloatTensor<B> {
742 pool::avg_pool1d_from_2d::<B>(
743 x,
744 kernel_size,
745 stride,
746 padding,
747 count_include_pad,
748 ceil_mode,
749 )
750 }
751 fn avg_pool1d_backward(
753 x: FloatTensor<B>,
754 grad: FloatTensor<B>,
755 kernel_size: usize,
756 stride: usize,
757 padding: usize,
758 count_include_pad: bool,
759 ceil_mode: bool,
760 ) -> FloatTensor<B> {
761 pool::avg_pool1d_backward_from_2d::<B>(
762 x,
763 grad,
764 kernel_size,
765 stride,
766 padding,
767 count_include_pad,
768 ceil_mode,
769 )
770 }
771 fn avg_pool2d(
777 x: FloatTensor<B>,
778 kernel_size: [usize; 2],
779 stride: [usize; 2],
780 padding: [usize; 2],
781 count_include_pad: bool,
782 ceil_mode: bool,
783 ) -> FloatTensor<B>;
784 fn avg_pool2d_backward(
786 x: FloatTensor<B>,
787 grad: FloatTensor<B>,
788 kernel_size: [usize; 2],
789 stride: [usize; 2],
790 padding: [usize; 2],
791 count_include_pad: bool,
792 ceil_mode: bool,
793 ) -> FloatTensor<B>;
794 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
800 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
802 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
808 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
809 }
810 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
812 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
813 }
814 fn max_pool1d(
820 x: FloatTensor<B>,
821 kernel_size: usize,
822 stride: usize,
823 padding: usize,
824 dilation: usize,
825 ceil_mode: bool,
826 ) -> FloatTensor<B> {
827 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
828 }
829
830 fn max_pool1d_with_indices(
836 x: FloatTensor<B>,
837 kernel_size: usize,
838 stride: usize,
839 padding: usize,
840 dilation: usize,
841 ceil_mode: bool,
842 ) -> MaxPool1dWithIndices<B> {
843 pool::max_pool1d_with_indices_from_2d::<B>(
844 x,
845 kernel_size,
846 stride,
847 padding,
848 dilation,
849 ceil_mode,
850 )
851 }
852 #[allow(clippy::too_many_arguments)]
854 fn max_pool1d_with_indices_backward(
855 x: FloatTensor<B>,
856 kernel_size: usize,
857 stride: usize,
858 padding: usize,
859 dilation: usize,
860 ceil_mode: bool,
861 output_grad: FloatTensor<B>,
862 indices: IntTensor<B>,
863 ) -> MaxPool1dBackward<B> {
864 pool::max_pool1d_with_indices_backward_from_2d::<B>(
865 x,
866 kernel_size,
867 stride,
868 padding,
869 dilation,
870 ceil_mode,
871 output_grad,
872 indices,
873 )
874 }
875
876 fn max_pool2d(
882 x: FloatTensor<B>,
883 kernel_size: [usize; 2],
884 stride: [usize; 2],
885 padding: [usize; 2],
886 dilation: [usize; 2],
887 ceil_mode: bool,
888 ) -> FloatTensor<B>;
889
890 fn max_pool2d_with_indices(
896 x: FloatTensor<B>,
897 kernel_size: [usize; 2],
898 stride: [usize; 2],
899 padding: [usize; 2],
900 dilation: [usize; 2],
901 ceil_mode: bool,
902 ) -> MaxPool2dWithIndices<B>;
903 #[allow(clippy::too_many_arguments)]
905 fn max_pool2d_with_indices_backward(
906 x: FloatTensor<B>,
907 kernel_size: [usize; 2],
908 stride: [usize; 2],
909 padding: [usize; 2],
910 dilation: [usize; 2],
911 ceil_mode: bool,
912 output_grad: FloatTensor<B>,
913 indices: IntTensor<B>,
914 ) -> MaxPool2dBackward<B>;
915
916 fn interpolate(
922 x: FloatTensor<B>,
923 output_size: [usize; 2],
924 options: InterpolateOptions,
925 ) -> FloatTensor<B>;
926
927 fn interpolate_backward(
929 x: FloatTensor<B>,
930 grad: FloatTensor<B>,
931 output_size: [usize; 2],
932 options: InterpolateOptions,
933 ) -> FloatTensor<B>;
934
935 fn attention(
953 query: FloatTensor<B>,
954 key: FloatTensor<B>,
955 value: FloatTensor<B>,
956 mask: Option<BoolTensor<B>>,
957 ) -> FloatTensor<B> {
958 attention::naive_attention::<B>(query, key, value, mask)
959 }
960}
961
962#[cfg(test)]
963mod tests {
964 use super::*;
965
966 #[test]
967 #[should_panic = "stride must be non-zero"]
968 fn conv_options_stride_zero() {
969 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
970 }
971
972 #[test]
973 #[should_panic = "dilation must be non-zero"]
974 fn conv_options_dilation_zero() {
975 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
976 }
977
978 #[test]
979 #[should_panic = "groups must be non-zero"]
980 fn conv_options_groups_zero() {
981 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
982 }
983
984 #[test]
985 #[should_panic = "stride must be non-zero"]
986 fn conv_transpose_options_stride_zero() {
987 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
988 }
989
990 #[test]
991 #[should_panic = "dilation must be non-zero"]
992 fn conv_transpose_options_dilation_zero() {
993 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
994 }
995
996 #[test]
997 #[should_panic = "groups must be non-zero"]
998 fn conv_transpose_options_groups_zero() {
999 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1000 }
1001
1002 #[test]
1003 #[should_panic = "stride must be non-zero"]
1004 fn deform_conv_options_stride_zero() {
1005 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1006 }
1007
1008 #[test]
1009 #[should_panic = "dilation must be non-zero"]
1010 fn deform_conv_options_dilation_zero() {
1011 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1012 }
1013
1014 #[test]
1015 #[should_panic = "weight groups must be non-zero"]
1016 fn deform_conv_options_weights_groups_zero() {
1017 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1018 }
1019
1020 #[test]
1021 #[should_panic = "offset groups must be non-zero"]
1022 fn deform_conv_options_offset_groups_zero() {
1023 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1024 }
1025
1026 #[test]
1027 #[should_panic = "stride must be non-zero"]
1028 fn unfold_options_stride_zero() {
1029 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1030 }
1031
1032 #[test]
1033 #[should_panic = "dilation must be non-zero"]
1034 fn unfold_options_dilation_zero() {
1035 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1036 }
1037}