1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12 pub x_grad: FloatTensor<B>,
14
15 pub weights_grad: FloatTensor<B>,
17
18 pub bias_grad: Option<FloatTensor<B>>,
20}
21
22#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25 pub x_grad: FloatTensor<B>,
27
28 pub offset_grad: FloatTensor<B>,
30
31 pub weight_grad: FloatTensor<B>,
33
34 pub mask_grad: Option<FloatTensor<B>>,
36
37 pub bias_grad: Option<FloatTensor<B>>,
39}
40
41#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44 pub x_grad: FloatTensor<B>,
46
47 pub weights_grad: FloatTensor<B>,
49
50 pub bias_grad: Option<FloatTensor<B>>,
52}
53
54#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57 pub x_grad: FloatTensor<B>,
59}
60
61#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64 pub output: FloatTensor<B>,
66
67 pub indices: IntTensor<B>,
69}
70
71#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74 pub x_grad: FloatTensor<B>,
76}
77
78#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81 pub output: FloatTensor<B>,
83
84 pub indices: IntTensor<B>,
86}
87
88pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91 NonZeroUsize::new(value).expect(msg);
92 value
93}
94
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98 pub stride: [usize; N],
100
101 pub padding: [usize; N],
103
104 pub dilation: [usize; N],
106
107 pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112 pub fn new(
114 stride: [usize; N],
115 padding: [usize; N],
116 dilation: [usize; N],
117 groups: usize,
118 ) -> Self {
119 Self {
120 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121 padding,
122 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123 groups: check_nonzero(groups, "groups must be non-zero"),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Hash, PartialEq, Eq)]
130pub struct DeformConvOptions<const N: usize> {
131 pub stride: [usize; N],
133
134 pub padding: [usize; N],
136
137 pub dilation: [usize; N],
139
140 pub weight_groups: usize,
142
143 pub offset_groups: usize,
145}
146
147impl<const N: usize> DeformConvOptions<N> {
148 pub fn new(
150 stride: [usize; N],
151 padding: [usize; N],
152 dilation: [usize; N],
153 weight_groups: usize,
154 offset_groups: usize,
155 ) -> Self {
156 Self {
157 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
158 padding,
159 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
160 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
161 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
162 }
163 }
164}
165
166#[derive(Debug, Clone, Hash, PartialEq, Eq)]
168pub struct ConvTransposeOptions<const N: usize> {
169 pub stride: [usize; N],
171
172 pub padding: [usize; N],
174
175 pub padding_out: [usize; N],
177
178 pub dilation: [usize; N],
180
181 pub groups: usize,
183}
184
185impl<const N: usize> ConvTransposeOptions<N> {
186 pub fn new(
188 stride: [usize; N],
189 padding: [usize; N],
190 padding_out: [usize; N],
191 dilation: [usize; N],
192 groups: usize,
193 ) -> Self {
194 Self {
195 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
196 padding,
197 padding_out,
198 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
199 groups: check_nonzero(groups, "groups must be non-zero"),
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct UnfoldOptions {
207 pub stride: [usize; 2],
210
211 pub padding: [usize; 2],
213
214 pub dilation: [usize; 2],
216}
217
218impl UnfoldOptions {
219 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
221 Self {
222 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
223 padding,
224 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
225 }
226 }
227}
228
229#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
231pub enum InterpolateMode {
232 Nearest,
235
236 Bilinear,
239
240 Bicubic,
243}
244
245#[derive(new, Debug, Clone)]
247pub struct InterpolateOptions {
248 pub mode: InterpolateMode,
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
256pub enum GridSamplePaddingMode {
257 #[default]
259 Zeros,
260 Border,
262 Reflection,
264}
265
266#[derive(Debug, Clone)]
268pub struct GridSampleOptions {
269 pub mode: InterpolateMode,
271 pub padding_mode: GridSamplePaddingMode,
273 pub align_corners: bool,
277}
278
279impl Default for GridSampleOptions {
280 fn default() -> Self {
281 Self {
282 mode: InterpolateMode::Bilinear,
283 padding_mode: GridSamplePaddingMode::Zeros,
284 align_corners: false,
285 }
286 }
287}
288
289impl GridSampleOptions {
290 pub fn new(mode: InterpolateMode) -> Self {
294 Self {
295 mode,
296 ..Default::default()
297 }
298 }
299
300 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
302 self.padding_mode = padding_mode;
303 self
304 }
305
306 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
308 self.align_corners = align_corners;
309 self
310 }
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
326pub enum PadMode {
327 Constant(f32),
333
334 Reflect,
342
343 Edge,
349}
350
351impl Default for PadMode {
352 fn default() -> Self {
353 PadMode::Constant(0.0)
354 }
355}
356
357#[derive(new)]
359pub struct InterpolateBackward<B: Backend> {
360 pub x_grad: FloatTensor<B>,
362}
363
364pub trait ModuleOps<B: Backend> {
366 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
377 let [batch_size, seq_length] = indices.shape().dims();
378 let [_, d_model] = weights.shape().dims();
379
380 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
381 let output = B::float_select(weights, 0, indices);
382
383 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
384 }
385
386 fn embedding_backward(
398 weights: FloatTensor<B>,
399 output_grad: FloatTensor<B>,
400 indices: IntTensor<B>,
401 ) -> FloatTensor<B> {
402 let [batch_size, seq_length] = indices.shape().dims();
403 let [n_embeddings, d_model] = weights.shape().dims();
404 let device = B::float_device(&weights);
405 let dtype = output_grad.dtype();
406
407 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
408 let output_grad =
409 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
410 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
411
412 B::float_select_add(grad, 0, indices, output_grad)
413 }
414 fn conv1d(
422 x: FloatTensor<B>,
423 weight: FloatTensor<B>,
424 bias: Option<FloatTensor<B>>,
425 options: ConvOptions<1>,
426 ) -> FloatTensor<B> {
427 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
428 }
429 fn conv1d_x_backward(
431 x: FloatTensor<B>,
432 weight: FloatTensor<B>,
433 output_grad: FloatTensor<B>,
434 options: ConvOptions<1>,
435 ) -> FloatTensor<B> {
436 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
437 }
438 fn conv1d_weight_backward(
440 x: FloatTensor<B>,
441 weight: FloatTensor<B>,
442 output_grad: FloatTensor<B>,
443 options: ConvOptions<1>,
444 ) -> FloatTensor<B> {
445 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
446 }
447 fn conv1d_bias_backward(
449 x: FloatTensor<B>,
450 bias: FloatTensor<B>,
451 output_grad: FloatTensor<B>,
452 ) -> FloatTensor<B> {
453 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
454 }
455 fn conv2d(
463 x: FloatTensor<B>,
464 weight: FloatTensor<B>,
465 bias: Option<FloatTensor<B>>,
466 options: ConvOptions<2>,
467 ) -> FloatTensor<B>;
468 fn conv2d_x_backward(
470 x: FloatTensor<B>,
471 weight: FloatTensor<B>,
472 output_grad: FloatTensor<B>,
473 options: ConvOptions<2>,
474 ) -> FloatTensor<B> {
475 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
476 }
477 fn conv2d_weight_backward(
479 x: FloatTensor<B>,
480 weight: FloatTensor<B>,
481 output_grad: FloatTensor<B>,
482 options: ConvOptions<2>,
483 ) -> FloatTensor<B> {
484 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
485 }
486 fn conv2d_bias_backward(
488 x: FloatTensor<B>,
489 weight: FloatTensor<B>,
490 bias: FloatTensor<B>,
491 output_grad: FloatTensor<B>,
492 ) -> FloatTensor<B> {
493 conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
494 }
495
496 fn deform_conv2d(
504 x: FloatTensor<B>,
505 offset: FloatTensor<B>,
506 weight: FloatTensor<B>,
507 mask: Option<FloatTensor<B>>,
508 bias: Option<FloatTensor<B>>,
509 options: DeformConvOptions<2>,
510 ) -> FloatTensor<B>;
511 fn deform_conv2d_backward(
513 x: FloatTensor<B>,
514 offset: FloatTensor<B>,
515 weight: FloatTensor<B>,
516 mask: Option<FloatTensor<B>>,
517 bias: Option<FloatTensor<B>>,
518 output_grad: FloatTensor<B>,
519 options: DeformConvOptions<2>,
520 ) -> DeformConv2dBackward<B>;
521
522 fn conv3d(
530 x: FloatTensor<B>,
531 weight: FloatTensor<B>,
532 bias: Option<FloatTensor<B>>,
533 options: ConvOptions<3>,
534 ) -> FloatTensor<B>;
535 fn conv3d_x_backward(
537 x: FloatTensor<B>,
538 weight: FloatTensor<B>,
539 output_grad: FloatTensor<B>,
540 options: ConvOptions<3>,
541 ) -> FloatTensor<B> {
542 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
543 }
544 fn conv3d_weight_backward(
546 x: FloatTensor<B>,
547 weight: FloatTensor<B>,
548 output_grad: FloatTensor<B>,
549 options: ConvOptions<3>,
550 ) -> FloatTensor<B> {
551 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
552 }
553 fn conv3d_bias_backward(
555 x: FloatTensor<B>,
556 weight: FloatTensor<B>,
557 bias: FloatTensor<B>,
558 output_grad: FloatTensor<B>,
559 ) -> FloatTensor<B> {
560 conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
561 }
562 fn conv_transpose1d(
570 x: FloatTensor<B>,
571 weight: FloatTensor<B>,
572 bias: Option<FloatTensor<B>>,
573 options: ConvTransposeOptions<1>,
574 ) -> FloatTensor<B> {
575 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
576 }
577 fn conv_transpose1d_x_backward(
579 weight: FloatTensor<B>,
580 output_grad: FloatTensor<B>,
581 options: ConvTransposeOptions<1>,
582 ) -> FloatTensor<B> {
583 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
584 }
585 fn conv_transpose1d_weight_backward(
587 x: FloatTensor<B>,
588 weight: FloatTensor<B>,
589 output_grad: FloatTensor<B>,
590 options: ConvTransposeOptions<1>,
591 ) -> FloatTensor<B> {
592 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
593 }
594 fn conv_transpose1d_bias_backward(
596 x: FloatTensor<B>,
597 bias: FloatTensor<B>,
598 output_grad: FloatTensor<B>,
599 ) -> FloatTensor<B> {
600 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
601 }
602
603 fn conv_transpose2d(
611 x: FloatTensor<B>,
612 weight: FloatTensor<B>,
613 bias: Option<FloatTensor<B>>,
614 options: ConvTransposeOptions<2>,
615 ) -> FloatTensor<B>;
616 fn conv_transpose2d_x_backward(
618 weight: FloatTensor<B>,
619 output_grad: FloatTensor<B>,
620 options: ConvTransposeOptions<2>,
621 ) -> FloatTensor<B> {
622 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
623 }
624 fn conv_transpose2d_weight_backward(
626 x: FloatTensor<B>,
627 weight: FloatTensor<B>,
628 output_grad: FloatTensor<B>,
629 options: ConvTransposeOptions<2>,
630 ) -> FloatTensor<B> {
631 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
632 }
633 fn conv_transpose2d_bias_backward(
635 x: FloatTensor<B>,
636 bias: FloatTensor<B>,
637 output_grad: FloatTensor<B>,
638 ) -> FloatTensor<B> {
639 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
640 }
641
642 fn conv_transpose3d(
650 x: FloatTensor<B>,
651 weight: FloatTensor<B>,
652 bias: Option<FloatTensor<B>>,
653 options: ConvTransposeOptions<3>,
654 ) -> FloatTensor<B>;
655 fn conv_transpose3d_x_backward(
657 weight: FloatTensor<B>,
658 output_grad: FloatTensor<B>,
659 options: ConvTransposeOptions<3>,
660 ) -> FloatTensor<B> {
661 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
662 }
663 fn conv_transpose3d_weight_backward(
665 x: FloatTensor<B>,
666 weight: FloatTensor<B>,
667 output_grad: FloatTensor<B>,
668 options: ConvTransposeOptions<3>,
669 ) -> FloatTensor<B> {
670 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
671 }
672 fn conv_transpose3d_bias_backward(
674 x: FloatTensor<B>,
675 bias: FloatTensor<B>,
676 output_grad: FloatTensor<B>,
677 ) -> FloatTensor<B> {
678 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
679 }
680
681 fn unfold4d(
688 x: FloatTensor<B>,
689 kernel_size: [usize; 2],
690 options: UnfoldOptions,
691 ) -> FloatTensor<B> {
692 if options.padding == [0, 0] && options.dilation == [1, 1] {
693 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
694 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
695
696 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
699 let shape = &blocks.shape().dims;
700
701 B::float_reshape(
704 blocks,
705 [
706 shape[0],
707 shape[1] * shape[2] * shape[3],
708 shape[4] * shape[5],
709 ]
710 .into(),
711 )
712 } else {
713 unfold4d_using_conv2d::<B>(x, kernel_size, options)
714 }
715 }
716
717 fn avg_pool1d(
723 x: FloatTensor<B>,
724 kernel_size: usize,
725 stride: usize,
726 padding: usize,
727 count_include_pad: bool,
728 ceil_mode: bool,
729 ) -> FloatTensor<B> {
730 pool::avg_pool1d_from_2d::<B>(
731 x,
732 kernel_size,
733 stride,
734 padding,
735 count_include_pad,
736 ceil_mode,
737 )
738 }
739 fn avg_pool1d_backward(
741 x: FloatTensor<B>,
742 grad: FloatTensor<B>,
743 kernel_size: usize,
744 stride: usize,
745 padding: usize,
746 count_include_pad: bool,
747 ceil_mode: bool,
748 ) -> FloatTensor<B> {
749 pool::avg_pool1d_backward_from_2d::<B>(
750 x,
751 grad,
752 kernel_size,
753 stride,
754 padding,
755 count_include_pad,
756 ceil_mode,
757 )
758 }
759 fn avg_pool2d(
765 x: FloatTensor<B>,
766 kernel_size: [usize; 2],
767 stride: [usize; 2],
768 padding: [usize; 2],
769 count_include_pad: bool,
770 ceil_mode: bool,
771 ) -> FloatTensor<B>;
772 fn avg_pool2d_backward(
774 x: FloatTensor<B>,
775 grad: FloatTensor<B>,
776 kernel_size: [usize; 2],
777 stride: [usize; 2],
778 padding: [usize; 2],
779 count_include_pad: bool,
780 ceil_mode: bool,
781 ) -> FloatTensor<B>;
782 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
788 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
790 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
796 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
797 }
798 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
800 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
801 }
802 fn max_pool1d(
808 x: FloatTensor<B>,
809 kernel_size: usize,
810 stride: usize,
811 padding: usize,
812 dilation: usize,
813 ceil_mode: bool,
814 ) -> FloatTensor<B> {
815 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
816 }
817
818 fn max_pool1d_with_indices(
824 x: FloatTensor<B>,
825 kernel_size: usize,
826 stride: usize,
827 padding: usize,
828 dilation: usize,
829 ceil_mode: bool,
830 ) -> MaxPool1dWithIndices<B> {
831 pool::max_pool1d_with_indices_from_2d::<B>(
832 x,
833 kernel_size,
834 stride,
835 padding,
836 dilation,
837 ceil_mode,
838 )
839 }
840 #[allow(clippy::too_many_arguments)]
842 fn max_pool1d_with_indices_backward(
843 x: FloatTensor<B>,
844 kernel_size: usize,
845 stride: usize,
846 padding: usize,
847 dilation: usize,
848 ceil_mode: bool,
849 output_grad: FloatTensor<B>,
850 indices: IntTensor<B>,
851 ) -> MaxPool1dBackward<B> {
852 pool::max_pool1d_with_indices_backward_from_2d::<B>(
853 x,
854 kernel_size,
855 stride,
856 padding,
857 dilation,
858 ceil_mode,
859 output_grad,
860 indices,
861 )
862 }
863
864 fn max_pool2d(
870 x: FloatTensor<B>,
871 kernel_size: [usize; 2],
872 stride: [usize; 2],
873 padding: [usize; 2],
874 dilation: [usize; 2],
875 ceil_mode: bool,
876 ) -> FloatTensor<B>;
877
878 fn max_pool2d_with_indices(
884 x: FloatTensor<B>,
885 kernel_size: [usize; 2],
886 stride: [usize; 2],
887 padding: [usize; 2],
888 dilation: [usize; 2],
889 ceil_mode: bool,
890 ) -> MaxPool2dWithIndices<B>;
891 #[allow(clippy::too_many_arguments)]
893 fn max_pool2d_with_indices_backward(
894 x: FloatTensor<B>,
895 kernel_size: [usize; 2],
896 stride: [usize; 2],
897 padding: [usize; 2],
898 dilation: [usize; 2],
899 ceil_mode: bool,
900 output_grad: FloatTensor<B>,
901 indices: IntTensor<B>,
902 ) -> MaxPool2dBackward<B>;
903
904 fn interpolate(
910 x: FloatTensor<B>,
911 output_size: [usize; 2],
912 options: InterpolateOptions,
913 ) -> FloatTensor<B>;
914
915 fn interpolate_backward(
917 x: FloatTensor<B>,
918 grad: FloatTensor<B>,
919 output_size: [usize; 2],
920 options: InterpolateOptions,
921 ) -> FloatTensor<B>;
922
923 fn attention(
941 query: FloatTensor<B>,
942 key: FloatTensor<B>,
943 value: FloatTensor<B>,
944 mask: Option<BoolTensor<B>>,
945 ) -> FloatTensor<B> {
946 attention::naive_attention::<B>(query, key, value, mask)
947 }
948}
949
950#[cfg(test)]
951mod tests {
952 use super::*;
953
954 #[test]
955 #[should_panic = "stride must be non-zero"]
956 fn conv_options_stride_zero() {
957 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
958 }
959
960 #[test]
961 #[should_panic = "dilation must be non-zero"]
962 fn conv_options_dilation_zero() {
963 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
964 }
965
966 #[test]
967 #[should_panic = "groups must be non-zero"]
968 fn conv_options_groups_zero() {
969 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
970 }
971
972 #[test]
973 #[should_panic = "stride must be non-zero"]
974 fn conv_transpose_options_stride_zero() {
975 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
976 }
977
978 #[test]
979 #[should_panic = "dilation must be non-zero"]
980 fn conv_transpose_options_dilation_zero() {
981 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
982 }
983
984 #[test]
985 #[should_panic = "groups must be non-zero"]
986 fn conv_transpose_options_groups_zero() {
987 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
988 }
989
990 #[test]
991 #[should_panic = "stride must be non-zero"]
992 fn deform_conv_options_stride_zero() {
993 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
994 }
995
996 #[test]
997 #[should_panic = "dilation must be non-zero"]
998 fn deform_conv_options_dilation_zero() {
999 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1000 }
1001
1002 #[test]
1003 #[should_panic = "weight groups must be non-zero"]
1004 fn deform_conv_options_weights_groups_zero() {
1005 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1006 }
1007
1008 #[test]
1009 #[should_panic = "offset groups must be non-zero"]
1010 fn deform_conv_options_offset_groups_zero() {
1011 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1012 }
1013
1014 #[test]
1015 #[should_panic = "stride must be non-zero"]
1016 fn unfold_options_stride_zero() {
1017 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1018 }
1019
1020 #[test]
1021 #[should_panic = "dilation must be non-zero"]
1022 fn unfold_options_dilation_zero() {
1023 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1024 }
1025}