1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, ElementConversion, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12 pub x_grad: FloatTensor<B>,
14
15 pub weights_grad: FloatTensor<B>,
17
18 pub bias_grad: Option<FloatTensor<B>>,
20}
21
22#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25 pub x_grad: FloatTensor<B>,
27
28 pub offset_grad: FloatTensor<B>,
30
31 pub weight_grad: FloatTensor<B>,
33
34 pub mask_grad: Option<FloatTensor<B>>,
36
37 pub bias_grad: Option<FloatTensor<B>>,
39}
40
41#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44 pub x_grad: FloatTensor<B>,
46
47 pub weights_grad: FloatTensor<B>,
49
50 pub bias_grad: Option<FloatTensor<B>>,
52}
53
54#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57 pub x_grad: FloatTensor<B>,
59}
60
61#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64 pub output: FloatTensor<B>,
66
67 pub indices: IntTensor<B>,
69}
70
71#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74 pub x_grad: FloatTensor<B>,
76}
77
78#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81 pub output: FloatTensor<B>,
83
84 pub indices: IntTensor<B>,
86}
87
88pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91 NonZeroUsize::new(value).expect(msg);
92 value
93}
94
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98 pub stride: [usize; N],
100
101 pub padding: [usize; N],
103
104 pub dilation: [usize; N],
106
107 pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112 pub fn new(
114 stride: [usize; N],
115 padding: [usize; N],
116 dilation: [usize; N],
117 groups: usize,
118 ) -> Self {
119 Self {
120 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121 padding,
122 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123 groups: check_nonzero(groups, "groups must be non-zero"),
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
137pub struct PaddedConvOptions<const N: usize> {
138 pub options: ConvOptions<N>,
140 pub padding_end: Option<[usize; N]>,
144}
145
146impl<const N: usize> PaddedConvOptions<N> {
147 pub fn asymmetric(
152 stride: [usize; N],
153 padding_start: [usize; N],
154 padding_end: [usize; N],
155 dilation: [usize; N],
156 groups: usize,
157 ) -> Self {
158 let options = ConvOptions::new(stride, padding_start, dilation, groups);
159 if padding_start == padding_end {
160 Self {
161 options,
162 padding_end: None,
163 }
164 } else {
165 Self {
166 options,
167 padding_end: Some(padding_end),
168 }
169 }
170 }
171
172 pub fn is_asymmetric(&self) -> bool {
174 self.padding_end.is_some()
175 }
176}
177
178impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
179 fn from(options: ConvOptions<N>) -> Self {
180 Self {
181 options,
182 padding_end: None,
183 }
184 }
185}
186
187#[derive(Debug, Clone, Hash, PartialEq, Eq)]
189pub struct DeformConvOptions<const N: usize> {
190 pub stride: [usize; N],
192
193 pub padding: [usize; N],
195
196 pub dilation: [usize; N],
198
199 pub weight_groups: usize,
201
202 pub offset_groups: usize,
204}
205
206impl<const N: usize> DeformConvOptions<N> {
207 pub fn new(
209 stride: [usize; N],
210 padding: [usize; N],
211 dilation: [usize; N],
212 weight_groups: usize,
213 offset_groups: usize,
214 ) -> Self {
215 Self {
216 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
217 padding,
218 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
219 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
220 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
221 }
222 }
223}
224
225#[derive(Debug, Clone, Hash, PartialEq, Eq)]
227pub struct ConvTransposeOptions<const N: usize> {
228 pub stride: [usize; N],
230
231 pub padding: [usize; N],
233
234 pub padding_out: [usize; N],
236
237 pub dilation: [usize; N],
239
240 pub groups: usize,
242}
243
244impl<const N: usize> ConvTransposeOptions<N> {
245 pub fn new(
247 stride: [usize; N],
248 padding: [usize; N],
249 padding_out: [usize; N],
250 dilation: [usize; N],
251 groups: usize,
252 ) -> Self {
253 Self {
254 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
255 padding,
256 padding_out,
257 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
258 groups: check_nonzero(groups, "groups must be non-zero"),
259 }
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct UnfoldOptions {
266 pub stride: [usize; 2],
269
270 pub padding: [usize; 2],
272
273 pub dilation: [usize; 2],
275}
276
277impl UnfoldOptions {
278 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
280 Self {
281 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
282 padding,
283 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
284 }
285 }
286}
287
288#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
290pub enum InterpolateMode {
291 Nearest,
294
295 Bilinear,
298
299 Bicubic,
302}
303
304#[derive(new, Debug, Clone)]
306pub struct InterpolateOptions {
307 pub mode: InterpolateMode,
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
315pub enum GridSamplePaddingMode {
316 #[default]
318 Zeros,
319 Border,
321 Reflection,
323}
324
325#[derive(Debug, Clone)]
327pub struct GridSampleOptions {
328 pub mode: InterpolateMode,
330 pub padding_mode: GridSamplePaddingMode,
332 pub align_corners: bool,
336}
337
338impl Default for GridSampleOptions {
339 fn default() -> Self {
340 Self {
341 mode: InterpolateMode::Bilinear,
342 padding_mode: GridSamplePaddingMode::Zeros,
343 align_corners: false,
344 }
345 }
346}
347
348impl From<InterpolateMode> for GridSampleOptions {
349 fn from(value: InterpolateMode) -> Self {
350 GridSampleOptions::new(value)
351 }
352}
353
354impl GridSampleOptions {
355 pub fn new(mode: InterpolateMode) -> Self {
359 Self {
360 mode,
361 ..Default::default()
362 }
363 }
364
365 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
367 self.padding_mode = padding_mode;
368 self
369 }
370
371 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
373 self.align_corners = align_corners;
374 self
375 }
376}
377
378#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
391pub enum PadMode {
392 Constant(f32),
398
399 Reflect,
407
408 Edge,
414}
415
416impl Default for PadMode {
417 fn default() -> Self {
418 PadMode::Constant(0.0)
419 }
420}
421
422impl<E: ElementConversion> From<E> for PadMode {
423 fn from(value: E) -> Self {
424 PadMode::Constant(value.elem())
425 }
426}
427
428#[derive(new)]
430pub struct InterpolateBackward<B: Backend> {
431 pub x_grad: FloatTensor<B>,
433}
434
435pub trait ModuleOps<B: Backend> {
437 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
448 let [batch_size, seq_length] = indices.shape().dims();
449 let [_, d_model] = weights.shape().dims();
450
451 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
452 let output = B::float_select(weights, 0, indices);
453
454 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
455 }
456
457 fn embedding_backward(
469 weights: FloatTensor<B>,
470 output_grad: FloatTensor<B>,
471 indices: IntTensor<B>,
472 ) -> FloatTensor<B> {
473 let [batch_size, seq_length] = indices.shape().dims();
474 let [n_embeddings, d_model] = weights.shape().dims();
475 let device = B::float_device(&weights);
476 let dtype = output_grad.dtype();
477
478 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
479 let output_grad =
480 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
481 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
482
483 B::float_select_add(grad, 0, indices, output_grad)
484 }
485 fn conv1d(
493 x: FloatTensor<B>,
494 weight: FloatTensor<B>,
495 bias: Option<FloatTensor<B>>,
496 options: ConvOptions<1>,
497 ) -> FloatTensor<B> {
498 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
499 }
500 fn conv1d_x_backward(
502 x: FloatTensor<B>,
503 weight: FloatTensor<B>,
504 output_grad: FloatTensor<B>,
505 options: ConvOptions<1>,
506 ) -> FloatTensor<B> {
507 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
508 }
509 fn conv1d_weight_backward(
511 x: FloatTensor<B>,
512 weight: FloatTensor<B>,
513 output_grad: FloatTensor<B>,
514 options: ConvOptions<1>,
515 ) -> FloatTensor<B> {
516 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
517 }
518 fn conv1d_bias_backward(
520 x: FloatTensor<B>,
521 bias: FloatTensor<B>,
522 output_grad: FloatTensor<B>,
523 ) -> FloatTensor<B> {
524 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
525 }
526 fn conv2d(
534 x: FloatTensor<B>,
535 weight: FloatTensor<B>,
536 bias: Option<FloatTensor<B>>,
537 options: ConvOptions<2>,
538 ) -> FloatTensor<B>;
539 fn conv2d_x_backward(
541 x: FloatTensor<B>,
542 weight: FloatTensor<B>,
543 output_grad: FloatTensor<B>,
544 options: ConvOptions<2>,
545 ) -> FloatTensor<B> {
546 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
547 }
548 fn conv2d_weight_backward(
550 x: FloatTensor<B>,
551 weight: FloatTensor<B>,
552 output_grad: FloatTensor<B>,
553 options: ConvOptions<2>,
554 ) -> FloatTensor<B> {
555 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
556 }
557 fn conv2d_bias_backward(
559 x: FloatTensor<B>,
560 bias: FloatTensor<B>,
561 output_grad: FloatTensor<B>,
562 ) -> FloatTensor<B> {
563 conv::conv2d_bias_backward::<B>(x, bias, output_grad)
564 }
565
566 fn deform_conv2d(
574 x: FloatTensor<B>,
575 offset: FloatTensor<B>,
576 weight: FloatTensor<B>,
577 mask: Option<FloatTensor<B>>,
578 bias: Option<FloatTensor<B>>,
579 options: DeformConvOptions<2>,
580 ) -> FloatTensor<B>;
581 fn deform_conv2d_backward(
583 x: FloatTensor<B>,
584 offset: FloatTensor<B>,
585 weight: FloatTensor<B>,
586 mask: Option<FloatTensor<B>>,
587 bias: Option<FloatTensor<B>>,
588 output_grad: FloatTensor<B>,
589 options: DeformConvOptions<2>,
590 ) -> DeformConv2dBackward<B>;
591
592 fn conv3d(
600 x: FloatTensor<B>,
601 weight: FloatTensor<B>,
602 bias: Option<FloatTensor<B>>,
603 options: ConvOptions<3>,
604 ) -> FloatTensor<B>;
605 fn conv3d_x_backward(
607 x: FloatTensor<B>,
608 weight: FloatTensor<B>,
609 output_grad: FloatTensor<B>,
610 options: ConvOptions<3>,
611 ) -> FloatTensor<B> {
612 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
613 }
614 fn conv3d_weight_backward(
616 x: FloatTensor<B>,
617 weight: FloatTensor<B>,
618 output_grad: FloatTensor<B>,
619 options: ConvOptions<3>,
620 ) -> FloatTensor<B> {
621 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
622 }
623 fn conv3d_bias_backward(
625 x: FloatTensor<B>,
626 bias: FloatTensor<B>,
627 output_grad: FloatTensor<B>,
628 ) -> FloatTensor<B> {
629 conv::conv3d_bias_backward::<B>(x, bias, output_grad)
630 }
631 fn conv_transpose1d(
639 x: FloatTensor<B>,
640 weight: FloatTensor<B>,
641 bias: Option<FloatTensor<B>>,
642 options: ConvTransposeOptions<1>,
643 ) -> FloatTensor<B> {
644 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
645 }
646 fn conv_transpose1d_x_backward(
648 weight: FloatTensor<B>,
649 output_grad: FloatTensor<B>,
650 options: ConvTransposeOptions<1>,
651 ) -> FloatTensor<B> {
652 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
653 }
654 fn conv_transpose1d_weight_backward(
656 x: FloatTensor<B>,
657 weight: FloatTensor<B>,
658 output_grad: FloatTensor<B>,
659 options: ConvTransposeOptions<1>,
660 ) -> FloatTensor<B> {
661 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
662 }
663 fn conv_transpose1d_bias_backward(
665 x: FloatTensor<B>,
666 bias: FloatTensor<B>,
667 output_grad: FloatTensor<B>,
668 ) -> FloatTensor<B> {
669 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
670 }
671
672 fn conv_transpose2d(
680 x: FloatTensor<B>,
681 weight: FloatTensor<B>,
682 bias: Option<FloatTensor<B>>,
683 options: ConvTransposeOptions<2>,
684 ) -> FloatTensor<B>;
685 fn conv_transpose2d_x_backward(
687 weight: FloatTensor<B>,
688 output_grad: FloatTensor<B>,
689 options: ConvTransposeOptions<2>,
690 ) -> FloatTensor<B> {
691 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
692 }
693 fn conv_transpose2d_weight_backward(
695 x: FloatTensor<B>,
696 weight: FloatTensor<B>,
697 output_grad: FloatTensor<B>,
698 options: ConvTransposeOptions<2>,
699 ) -> FloatTensor<B> {
700 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
701 }
702 fn conv_transpose2d_bias_backward(
704 x: FloatTensor<B>,
705 bias: FloatTensor<B>,
706 output_grad: FloatTensor<B>,
707 ) -> FloatTensor<B> {
708 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
709 }
710
711 fn conv_transpose3d(
719 x: FloatTensor<B>,
720 weight: FloatTensor<B>,
721 bias: Option<FloatTensor<B>>,
722 options: ConvTransposeOptions<3>,
723 ) -> FloatTensor<B>;
724 fn conv_transpose3d_x_backward(
726 weight: FloatTensor<B>,
727 output_grad: FloatTensor<B>,
728 options: ConvTransposeOptions<3>,
729 ) -> FloatTensor<B> {
730 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
731 }
732 fn conv_transpose3d_weight_backward(
734 x: FloatTensor<B>,
735 weight: FloatTensor<B>,
736 output_grad: FloatTensor<B>,
737 options: ConvTransposeOptions<3>,
738 ) -> FloatTensor<B> {
739 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
740 }
741 fn conv_transpose3d_bias_backward(
743 x: FloatTensor<B>,
744 bias: FloatTensor<B>,
745 output_grad: FloatTensor<B>,
746 ) -> FloatTensor<B> {
747 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
748 }
749
750 fn unfold4d(
757 x: FloatTensor<B>,
758 kernel_size: [usize; 2],
759 options: UnfoldOptions,
760 ) -> FloatTensor<B> {
761 if options.padding == [0, 0] && options.dilation == [1, 1] {
762 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
763 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
764
765 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
768 let shape = &blocks.shape().dims;
769
770 B::float_reshape(
773 blocks,
774 [
775 shape[0],
776 shape[1] * shape[2] * shape[3],
777 shape[4] * shape[5],
778 ]
779 .into(),
780 )
781 } else {
782 unfold4d_using_conv2d::<B>(x, kernel_size, options)
783 }
784 }
785
786 fn avg_pool1d(
792 x: FloatTensor<B>,
793 kernel_size: usize,
794 stride: usize,
795 padding: usize,
796 count_include_pad: bool,
797 ceil_mode: bool,
798 ) -> FloatTensor<B> {
799 pool::avg_pool1d_from_2d::<B>(
800 x,
801 kernel_size,
802 stride,
803 padding,
804 count_include_pad,
805 ceil_mode,
806 )
807 }
808 fn avg_pool1d_backward(
810 x: FloatTensor<B>,
811 grad: FloatTensor<B>,
812 kernel_size: usize,
813 stride: usize,
814 padding: usize,
815 count_include_pad: bool,
816 ceil_mode: bool,
817 ) -> FloatTensor<B> {
818 pool::avg_pool1d_backward_from_2d::<B>(
819 x,
820 grad,
821 kernel_size,
822 stride,
823 padding,
824 count_include_pad,
825 ceil_mode,
826 )
827 }
828 fn avg_pool2d(
834 x: FloatTensor<B>,
835 kernel_size: [usize; 2],
836 stride: [usize; 2],
837 padding: [usize; 2],
838 count_include_pad: bool,
839 ceil_mode: bool,
840 ) -> FloatTensor<B>;
841 fn avg_pool2d_backward(
843 x: FloatTensor<B>,
844 grad: FloatTensor<B>,
845 kernel_size: [usize; 2],
846 stride: [usize; 2],
847 padding: [usize; 2],
848 count_include_pad: bool,
849 ceil_mode: bool,
850 ) -> FloatTensor<B>;
851 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
857 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
859 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
865 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
866 }
867 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
869 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
870 }
871 fn max_pool1d(
877 x: FloatTensor<B>,
878 kernel_size: usize,
879 stride: usize,
880 padding: usize,
881 dilation: usize,
882 ceil_mode: bool,
883 ) -> FloatTensor<B> {
884 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
885 }
886
887 fn max_pool1d_with_indices(
893 x: FloatTensor<B>,
894 kernel_size: usize,
895 stride: usize,
896 padding: usize,
897 dilation: usize,
898 ceil_mode: bool,
899 ) -> MaxPool1dWithIndices<B> {
900 pool::max_pool1d_with_indices_from_2d::<B>(
901 x,
902 kernel_size,
903 stride,
904 padding,
905 dilation,
906 ceil_mode,
907 )
908 }
909 #[allow(clippy::too_many_arguments)]
911 fn max_pool1d_with_indices_backward(
912 x: FloatTensor<B>,
913 kernel_size: usize,
914 stride: usize,
915 padding: usize,
916 dilation: usize,
917 ceil_mode: bool,
918 output_grad: FloatTensor<B>,
919 indices: IntTensor<B>,
920 ) -> MaxPool1dBackward<B> {
921 pool::max_pool1d_with_indices_backward_from_2d::<B>(
922 x,
923 kernel_size,
924 stride,
925 padding,
926 dilation,
927 ceil_mode,
928 output_grad,
929 indices,
930 )
931 }
932
933 fn max_pool2d(
939 x: FloatTensor<B>,
940 kernel_size: [usize; 2],
941 stride: [usize; 2],
942 padding: [usize; 2],
943 dilation: [usize; 2],
944 ceil_mode: bool,
945 ) -> FloatTensor<B>;
946
947 fn max_pool2d_with_indices(
953 x: FloatTensor<B>,
954 kernel_size: [usize; 2],
955 stride: [usize; 2],
956 padding: [usize; 2],
957 dilation: [usize; 2],
958 ceil_mode: bool,
959 ) -> MaxPool2dWithIndices<B>;
960 #[allow(clippy::too_many_arguments)]
962 fn max_pool2d_with_indices_backward(
963 x: FloatTensor<B>,
964 kernel_size: [usize; 2],
965 stride: [usize; 2],
966 padding: [usize; 2],
967 dilation: [usize; 2],
968 ceil_mode: bool,
969 output_grad: FloatTensor<B>,
970 indices: IntTensor<B>,
971 ) -> MaxPool2dBackward<B>;
972
973 fn interpolate(
979 x: FloatTensor<B>,
980 output_size: [usize; 2],
981 options: InterpolateOptions,
982 ) -> FloatTensor<B>;
983
984 fn interpolate_backward(
986 x: FloatTensor<B>,
987 grad: FloatTensor<B>,
988 output_size: [usize; 2],
989 options: InterpolateOptions,
990 ) -> FloatTensor<B>;
991
992 fn attention(
1010 query: FloatTensor<B>,
1011 key: FloatTensor<B>,
1012 value: FloatTensor<B>,
1013 mask: Option<BoolTensor<B>>,
1014 ) -> FloatTensor<B> {
1015 attention::naive_attention::<B>(query, key, value, mask)
1016 }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use super::*;
1022
1023 #[test]
1024 #[should_panic = "stride must be non-zero"]
1025 fn conv_options_stride_zero() {
1026 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1027 }
1028
1029 #[test]
1030 #[should_panic = "dilation must be non-zero"]
1031 fn conv_options_dilation_zero() {
1032 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1033 }
1034
1035 #[test]
1036 #[should_panic = "groups must be non-zero"]
1037 fn conv_options_groups_zero() {
1038 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1039 }
1040
1041 #[test]
1042 #[should_panic = "stride must be non-zero"]
1043 fn conv_transpose_options_stride_zero() {
1044 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1045 }
1046
1047 #[test]
1048 #[should_panic = "dilation must be non-zero"]
1049 fn conv_transpose_options_dilation_zero() {
1050 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1051 }
1052
1053 #[test]
1054 #[should_panic = "groups must be non-zero"]
1055 fn conv_transpose_options_groups_zero() {
1056 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1057 }
1058
1059 #[test]
1060 #[should_panic = "stride must be non-zero"]
1061 fn deform_conv_options_stride_zero() {
1062 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1063 }
1064
1065 #[test]
1066 #[should_panic = "dilation must be non-zero"]
1067 fn deform_conv_options_dilation_zero() {
1068 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1069 }
1070
1071 #[test]
1072 #[should_panic = "weight groups must be non-zero"]
1073 fn deform_conv_options_weights_groups_zero() {
1074 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1075 }
1076
1077 #[test]
1078 #[should_panic = "offset groups must be non-zero"]
1079 fn deform_conv_options_offset_groups_zero() {
1080 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1081 }
1082
1083 #[test]
1084 #[should_panic = "stride must be non-zero"]
1085 fn unfold_options_stride_zero() {
1086 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1087 }
1088
1089 #[test]
1090 #[should_panic = "dilation must be non-zero"]
1091 fn unfold_options_dilation_zero() {
1092 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1093 }
1094}