1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12 pub x_grad: FloatTensor<B>,
14
15 pub weights_grad: FloatTensor<B>,
17
18 pub bias_grad: Option<FloatTensor<B>>,
20}
21
22#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25 pub x_grad: FloatTensor<B>,
27
28 pub offset_grad: FloatTensor<B>,
30
31 pub weight_grad: FloatTensor<B>,
33
34 pub mask_grad: Option<FloatTensor<B>>,
36
37 pub bias_grad: Option<FloatTensor<B>>,
39}
40
41#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44 pub x_grad: FloatTensor<B>,
46
47 pub weights_grad: FloatTensor<B>,
49
50 pub bias_grad: Option<FloatTensor<B>>,
52}
53
54#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57 pub x_grad: FloatTensor<B>,
59}
60
61#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64 pub output: FloatTensor<B>,
66
67 pub indices: IntTensor<B>,
69}
70
71#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74 pub x_grad: FloatTensor<B>,
76}
77
78#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81 pub output: FloatTensor<B>,
83
84 pub indices: IntTensor<B>,
86}
87
88pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91 NonZeroUsize::new(value).expect(msg);
92 value
93}
94
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98 pub stride: [usize; N],
100
101 pub padding: [usize; N],
103
104 pub dilation: [usize; N],
106
107 pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112 pub fn new(
114 stride: [usize; N],
115 padding: [usize; N],
116 dilation: [usize; N],
117 groups: usize,
118 ) -> Self {
119 Self {
120 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121 padding,
122 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123 groups: check_nonzero(groups, "groups must be non-zero"),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Hash, PartialEq, Eq)]
130pub struct DeformConvOptions<const N: usize> {
131 pub stride: [usize; N],
133
134 pub padding: [usize; N],
136
137 pub dilation: [usize; N],
139
140 pub weight_groups: usize,
142
143 pub offset_groups: usize,
145}
146
147impl<const N: usize> DeformConvOptions<N> {
148 pub fn new(
150 stride: [usize; N],
151 padding: [usize; N],
152 dilation: [usize; N],
153 weight_groups: usize,
154 offset_groups: usize,
155 ) -> Self {
156 Self {
157 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
158 padding,
159 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
160 weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
161 offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
162 }
163 }
164}
165
166#[derive(Debug, Clone, Hash, PartialEq, Eq)]
168pub struct ConvTransposeOptions<const N: usize> {
169 pub stride: [usize; N],
171
172 pub padding: [usize; N],
174
175 pub padding_out: [usize; N],
177
178 pub dilation: [usize; N],
180
181 pub groups: usize,
183}
184
185impl<const N: usize> ConvTransposeOptions<N> {
186 pub fn new(
188 stride: [usize; N],
189 padding: [usize; N],
190 padding_out: [usize; N],
191 dilation: [usize; N],
192 groups: usize,
193 ) -> Self {
194 Self {
195 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
196 padding,
197 padding_out,
198 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
199 groups: check_nonzero(groups, "groups must be non-zero"),
200 }
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct UnfoldOptions {
207 pub stride: [usize; 2],
210
211 pub padding: [usize; 2],
213
214 pub dilation: [usize; 2],
216}
217
218impl UnfoldOptions {
219 pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
221 Self {
222 stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
223 padding,
224 dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
225 }
226 }
227}
228
229#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
231pub enum InterpolateMode {
232 Nearest,
235
236 Bilinear,
239
240 Bicubic,
243}
244
245#[derive(new, Debug, Clone)]
247pub struct InterpolateOptions {
248 pub mode: InterpolateMode,
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
256pub enum GridSamplePaddingMode {
257 #[default]
259 Zeros,
260 Border,
262 Reflection,
264}
265
266#[derive(Debug, Clone)]
268pub struct GridSampleOptions {
269 pub mode: InterpolateMode,
271 pub padding_mode: GridSamplePaddingMode,
273 pub align_corners: bool,
277}
278
279impl Default for GridSampleOptions {
280 fn default() -> Self {
281 Self {
282 mode: InterpolateMode::Bilinear,
283 padding_mode: GridSamplePaddingMode::Zeros,
284 align_corners: false,
285 }
286 }
287}
288
289impl GridSampleOptions {
290 pub fn new(mode: InterpolateMode) -> Self {
294 Self {
295 mode,
296 ..Default::default()
297 }
298 }
299
300 pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
302 self.padding_mode = padding_mode;
303 self
304 }
305
306 pub fn with_align_corners(mut self, align_corners: bool) -> Self {
308 self.align_corners = align_corners;
309 self
310 }
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
326pub enum PadMode {
327 Constant(f32),
333
334 Reflect,
342
343 Edge,
349}
350
351impl Default for PadMode {
352 fn default() -> Self {
353 PadMode::Constant(0.0)
354 }
355}
356
357#[derive(new)]
359pub struct InterpolateBackward<B: Backend> {
360 pub x_grad: FloatTensor<B>,
362}
363
364pub trait ModuleOps<B: Backend> {
366 fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
377 let [batch_size, seq_length] = indices.shape().dims();
378 let [_, d_model] = weights.shape().dims();
379
380 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
381 let output = B::float_select(weights, 0, indices);
382
383 B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
384 }
385
386 fn embedding_backward(
398 weights: FloatTensor<B>,
399 output_grad: FloatTensor<B>,
400 indices: IntTensor<B>,
401 ) -> FloatTensor<B> {
402 let [batch_size, seq_length] = indices.shape().dims();
403 let [n_embeddings, d_model] = weights.shape().dims();
404 let device = B::float_device(&weights);
405 let dtype = output_grad.dtype();
406
407 let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
408 let output_grad =
409 B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
410 let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
411
412 B::float_select_add(grad, 0, indices, output_grad)
413 }
414 fn conv1d(
422 x: FloatTensor<B>,
423 weight: FloatTensor<B>,
424 bias: Option<FloatTensor<B>>,
425 options: ConvOptions<1>,
426 ) -> FloatTensor<B> {
427 conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
428 }
429 fn conv1d_x_backward(
431 x: FloatTensor<B>,
432 weight: FloatTensor<B>,
433 output_grad: FloatTensor<B>,
434 options: ConvOptions<1>,
435 ) -> FloatTensor<B> {
436 conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
437 }
438 fn conv1d_weight_backward(
440 x: FloatTensor<B>,
441 weight: FloatTensor<B>,
442 output_grad: FloatTensor<B>,
443 options: ConvOptions<1>,
444 ) -> FloatTensor<B> {
445 conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
446 }
447 fn conv1d_bias_backward(
449 x: FloatTensor<B>,
450 bias: FloatTensor<B>,
451 output_grad: FloatTensor<B>,
452 ) -> FloatTensor<B> {
453 conv::conv1d_bias_backward::<B>(x, bias, output_grad)
454 }
455 fn conv2d(
463 x: FloatTensor<B>,
464 weight: FloatTensor<B>,
465 bias: Option<FloatTensor<B>>,
466 options: ConvOptions<2>,
467 ) -> FloatTensor<B>;
468 fn conv2d_x_backward(
470 x: FloatTensor<B>,
471 weight: FloatTensor<B>,
472 output_grad: FloatTensor<B>,
473 options: ConvOptions<2>,
474 ) -> FloatTensor<B> {
475 conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
476 }
477 fn conv2d_weight_backward(
479 x: FloatTensor<B>,
480 weight: FloatTensor<B>,
481 output_grad: FloatTensor<B>,
482 options: ConvOptions<2>,
483 ) -> FloatTensor<B> {
484 conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
485 }
486 fn conv2d_bias_backward(
488 x: FloatTensor<B>,
489 weight: FloatTensor<B>,
490 bias: FloatTensor<B>,
491 output_grad: FloatTensor<B>,
492 ) -> FloatTensor<B> {
493 conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
494 }
495
496 fn deform_conv2d(
504 x: FloatTensor<B>,
505 offset: FloatTensor<B>,
506 weight: FloatTensor<B>,
507 mask: Option<FloatTensor<B>>,
508 bias: Option<FloatTensor<B>>,
509 options: DeformConvOptions<2>,
510 ) -> FloatTensor<B>;
511 fn deform_conv2d_backward(
513 x: FloatTensor<B>,
514 offset: FloatTensor<B>,
515 weight: FloatTensor<B>,
516 mask: Option<FloatTensor<B>>,
517 bias: Option<FloatTensor<B>>,
518 output_grad: FloatTensor<B>,
519 options: DeformConvOptions<2>,
520 ) -> DeformConv2dBackward<B>;
521
522 fn conv3d(
530 x: FloatTensor<B>,
531 weight: FloatTensor<B>,
532 bias: Option<FloatTensor<B>>,
533 options: ConvOptions<3>,
534 ) -> FloatTensor<B>;
535 fn conv3d_x_backward(
537 x: FloatTensor<B>,
538 weight: FloatTensor<B>,
539 output_grad: FloatTensor<B>,
540 options: ConvOptions<3>,
541 ) -> FloatTensor<B> {
542 conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
543 }
544 fn conv3d_weight_backward(
546 x: FloatTensor<B>,
547 weight: FloatTensor<B>,
548 output_grad: FloatTensor<B>,
549 options: ConvOptions<3>,
550 ) -> FloatTensor<B> {
551 conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
552 }
553 fn conv3d_bias_backward(
555 x: FloatTensor<B>,
556 weight: FloatTensor<B>,
557 bias: FloatTensor<B>,
558 output_grad: FloatTensor<B>,
559 ) -> FloatTensor<B> {
560 conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
561 }
562 fn conv_transpose1d(
570 x: FloatTensor<B>,
571 weight: FloatTensor<B>,
572 bias: Option<FloatTensor<B>>,
573 options: ConvTransposeOptions<1>,
574 ) -> FloatTensor<B> {
575 conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
576 }
577 fn conv_transpose1d_x_backward(
579 weight: FloatTensor<B>,
580 output_grad: FloatTensor<B>,
581 options: ConvTransposeOptions<1>,
582 ) -> FloatTensor<B> {
583 conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
584 }
585 fn conv_transpose1d_weight_backward(
587 x: FloatTensor<B>,
588 weight: FloatTensor<B>,
589 output_grad: FloatTensor<B>,
590 options: ConvTransposeOptions<1>,
591 ) -> FloatTensor<B> {
592 conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
593 }
594 fn conv_transpose1d_bias_backward(
596 x: FloatTensor<B>,
597 bias: FloatTensor<B>,
598 output_grad: FloatTensor<B>,
599 ) -> FloatTensor<B> {
600 conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
601 }
602
603 fn conv_transpose2d(
611 x: FloatTensor<B>,
612 weight: FloatTensor<B>,
613 bias: Option<FloatTensor<B>>,
614 options: ConvTransposeOptions<2>,
615 ) -> FloatTensor<B>;
616 fn conv_transpose2d_x_backward(
618 weight: FloatTensor<B>,
619 output_grad: FloatTensor<B>,
620 options: ConvTransposeOptions<2>,
621 ) -> FloatTensor<B> {
622 conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
623 }
624 fn conv_transpose2d_weight_backward(
626 x: FloatTensor<B>,
627 weight: FloatTensor<B>,
628 output_grad: FloatTensor<B>,
629 options: ConvTransposeOptions<2>,
630 ) -> FloatTensor<B> {
631 conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
632 }
633 fn conv_transpose2d_bias_backward(
635 x: FloatTensor<B>,
636 bias: FloatTensor<B>,
637 output_grad: FloatTensor<B>,
638 ) -> FloatTensor<B> {
639 conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
640 }
641
642 fn conv_transpose3d(
650 x: FloatTensor<B>,
651 weight: FloatTensor<B>,
652 bias: Option<FloatTensor<B>>,
653 options: ConvTransposeOptions<3>,
654 ) -> FloatTensor<B>;
655 fn conv_transpose3d_x_backward(
657 weight: FloatTensor<B>,
658 output_grad: FloatTensor<B>,
659 options: ConvTransposeOptions<3>,
660 ) -> FloatTensor<B> {
661 conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
662 }
663 fn conv_transpose3d_weight_backward(
665 x: FloatTensor<B>,
666 weight: FloatTensor<B>,
667 output_grad: FloatTensor<B>,
668 options: ConvTransposeOptions<3>,
669 ) -> FloatTensor<B> {
670 conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
671 }
672 fn conv_transpose3d_bias_backward(
674 x: FloatTensor<B>,
675 bias: FloatTensor<B>,
676 output_grad: FloatTensor<B>,
677 ) -> FloatTensor<B> {
678 conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
679 }
680
681 fn unfold4d(
688 x: FloatTensor<B>,
689 kernel_size: [usize; 2],
690 options: UnfoldOptions,
691 ) -> FloatTensor<B> {
692 if options.padding == [0, 0] && options.dilation == [1, 1] {
693 let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
694 let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
695
696 let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
699 let shape = &blocks.shape().dims;
700
701 B::float_reshape(
704 blocks,
705 [
706 shape[0],
707 shape[1] * shape[2] * shape[3],
708 shape[4] * shape[5],
709 ]
710 .into(),
711 )
712 } else {
713 unfold4d_using_conv2d::<B>(x, kernel_size, options)
714 }
715 }
716
717 fn avg_pool1d(
723 x: FloatTensor<B>,
724 kernel_size: usize,
725 stride: usize,
726 padding: usize,
727 count_include_pad: bool,
728 ) -> FloatTensor<B> {
729 pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad)
730 }
731 fn avg_pool1d_backward(
733 x: FloatTensor<B>,
734 grad: FloatTensor<B>,
735 kernel_size: usize,
736 stride: usize,
737 padding: usize,
738 count_include_pad: bool,
739 ) -> FloatTensor<B> {
740 pool::avg_pool1d_backward_from_2d::<B>(
741 x,
742 grad,
743 kernel_size,
744 stride,
745 padding,
746 count_include_pad,
747 )
748 }
749 fn avg_pool2d(
755 x: FloatTensor<B>,
756 kernel_size: [usize; 2],
757 stride: [usize; 2],
758 padding: [usize; 2],
759 count_include_pad: bool,
760 ) -> FloatTensor<B>;
761 fn avg_pool2d_backward(
763 x: FloatTensor<B>,
764 grad: FloatTensor<B>,
765 kernel_size: [usize; 2],
766 stride: [usize; 2],
767 padding: [usize; 2],
768 count_include_pad: bool,
769 ) -> FloatTensor<B>;
770 fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
776 fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
778 fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
784 pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
785 }
786 fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
788 pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
789 }
790 fn max_pool1d(
796 x: FloatTensor<B>,
797 kernel_size: usize,
798 stride: usize,
799 padding: usize,
800 dilation: usize,
801 ) -> FloatTensor<B> {
802 pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
803 }
804
805 fn max_pool1d_with_indices(
811 x: FloatTensor<B>,
812 kernel_size: usize,
813 stride: usize,
814 padding: usize,
815 dilation: usize,
816 ) -> MaxPool1dWithIndices<B> {
817 pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
818 }
819 fn max_pool1d_with_indices_backward(
821 x: FloatTensor<B>,
822 kernel_size: usize,
823 stride: usize,
824 padding: usize,
825 dilation: usize,
826 output_grad: FloatTensor<B>,
827 indices: IntTensor<B>,
828 ) -> MaxPool1dBackward<B> {
829 pool::max_pool1d_with_indices_backward_from_2d::<B>(
830 x,
831 kernel_size,
832 stride,
833 padding,
834 dilation,
835 output_grad,
836 indices,
837 )
838 }
839
840 fn max_pool2d(
846 x: FloatTensor<B>,
847 kernel_size: [usize; 2],
848 stride: [usize; 2],
849 padding: [usize; 2],
850 dilation: [usize; 2],
851 ) -> FloatTensor<B>;
852
853 fn max_pool2d_with_indices(
859 x: FloatTensor<B>,
860 kernel_size: [usize; 2],
861 stride: [usize; 2],
862 padding: [usize; 2],
863 dilation: [usize; 2],
864 ) -> MaxPool2dWithIndices<B>;
865 fn max_pool2d_with_indices_backward(
867 x: FloatTensor<B>,
868 kernel_size: [usize; 2],
869 stride: [usize; 2],
870 padding: [usize; 2],
871 dilation: [usize; 2],
872 output_grad: FloatTensor<B>,
873 indices: IntTensor<B>,
874 ) -> MaxPool2dBackward<B>;
875
876 fn interpolate(
882 x: FloatTensor<B>,
883 output_size: [usize; 2],
884 options: InterpolateOptions,
885 ) -> FloatTensor<B>;
886
887 fn interpolate_backward(
889 x: FloatTensor<B>,
890 grad: FloatTensor<B>,
891 output_size: [usize; 2],
892 options: InterpolateOptions,
893 ) -> FloatTensor<B>;
894
895 fn attention(
913 query: FloatTensor<B>,
914 key: FloatTensor<B>,
915 value: FloatTensor<B>,
916 mask: Option<BoolTensor<B>>,
917 ) -> FloatTensor<B> {
918 attention::naive_attention::<B>(query, key, value, mask)
919 }
920}
921
922#[cfg(test)]
923mod tests {
924 use super::*;
925
926 #[test]
927 #[should_panic = "stride must be non-zero"]
928 fn conv_options_stride_zero() {
929 let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
930 }
931
932 #[test]
933 #[should_panic = "dilation must be non-zero"]
934 fn conv_options_dilation_zero() {
935 let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
936 }
937
938 #[test]
939 #[should_panic = "groups must be non-zero"]
940 fn conv_options_groups_zero() {
941 let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
942 }
943
944 #[test]
945 #[should_panic = "stride must be non-zero"]
946 fn conv_transpose_options_stride_zero() {
947 let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
948 }
949
950 #[test]
951 #[should_panic = "dilation must be non-zero"]
952 fn conv_transpose_options_dilation_zero() {
953 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
954 }
955
956 #[test]
957 #[should_panic = "groups must be non-zero"]
958 fn conv_transpose_options_groups_zero() {
959 let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
960 }
961
962 #[test]
963 #[should_panic = "stride must be non-zero"]
964 fn deform_conv_options_stride_zero() {
965 let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
966 }
967
968 #[test]
969 #[should_panic = "dilation must be non-zero"]
970 fn deform_conv_options_dilation_zero() {
971 let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
972 }
973
974 #[test]
975 #[should_panic = "weight groups must be non-zero"]
976 fn deform_conv_options_weights_groups_zero() {
977 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
978 }
979
980 #[test]
981 #[should_panic = "offset groups must be non-zero"]
982 fn deform_conv_options_offset_groups_zero() {
983 let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
984 }
985
986 #[test]
987 #[should_panic = "stride must be non-zero"]
988 fn unfold_options_stride_zero() {
989 let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
990 }
991
992 #[test]
993 #[should_panic = "dilation must be non-zero"]
994 fn unfold_options_dilation_zero() {
995 let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
996 }
997}