Skip to main content

burn_backend/backend/ops/modules/
base.rs

1use super::{conv, ctc, linear, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
4use crate::{Backend, ElementConversion, TensorMetadata};
5use burn_std::Shape;
6use core::num::NonZeroUsize;
7
8/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
9#[derive(new)]
10pub struct Conv2dBackward<B: Backend> {
11    /// Gradient.
12    pub x_grad: FloatTensor<B>,
13
14    /// Weights gradient.
15    pub weights_grad: FloatTensor<B>,
16
17    /// Bias gradient.
18    pub bias_grad: Option<FloatTensor<B>>,
19}
20
21/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
22#[derive(new)]
23pub struct DeformConv2dBackward<B: Backend> {
24    /// Gradient.
25    pub x_grad: FloatTensor<B>,
26
27    /// Offset gradient.
28    pub offset_grad: FloatTensor<B>,
29
30    /// Weights gradient.
31    pub weight_grad: FloatTensor<B>,
32
33    /// Mask gradient.
34    pub mask_grad: Option<FloatTensor<B>>,
35
36    /// Bias gradient.
37    pub bias_grad: Option<FloatTensor<B>>,
38}
39
40/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
41#[derive(new)]
42pub struct Conv3dBackward<B: Backend> {
43    /// Gradient.
44    pub x_grad: FloatTensor<B>,
45
46    /// Weights gradient.
47    pub weights_grad: FloatTensor<B>,
48
49    /// Bias gradient.
50    pub bias_grad: Option<FloatTensor<B>>,
51}
52
53/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
54#[derive(new)]
55pub struct MaxPool1dBackward<B: Backend> {
56    /// Gradient.
57    pub x_grad: FloatTensor<B>,
58}
59
60/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
61#[derive(new)]
62pub struct MaxPool1dWithIndices<B: Backend> {
63    /// The output tensor.
64    pub output: FloatTensor<B>,
65
66    /// The indices tensor.
67    pub indices: IntTensor<B>,
68}
69
70/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
71#[derive(new)]
72pub struct MaxPool2dBackward<B: Backend> {
73    /// Gradient.
74    pub x_grad: FloatTensor<B>,
75}
76
77/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
78#[derive(new)]
79pub struct MaxPool2dWithIndices<B: Backend> {
80    /// The output tensor.
81    pub output: FloatTensor<B>,
82
83    /// The indices tensor.
84    pub indices: IntTensor<B>,
85}
86
87/// Check that the parameter value is non-zero.
88// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
89pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
90    NonZeroUsize::new(value).expect(msg);
91    value
92}
93
94/// Convolution options.
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96pub struct ConvOptions<const N: usize> {
97    /// Stride (non-zero).
98    pub stride: [usize; N],
99
100    /// Padding.
101    pub padding: [usize; N],
102
103    /// Dilation (non-zero).
104    pub dilation: [usize; N],
105
106    /// Groups (non-zero).
107    pub groups: usize,
108}
109
110impl<const N: usize> ConvOptions<N> {
111    /// Constructs a new `ConvOptions`.
112    pub fn new(
113        stride: [usize; N],
114        padding: [usize; N],
115        dilation: [usize; N],
116        groups: usize,
117    ) -> Self {
118        Self {
119            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
120            padding,
121            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
122            groups: check_nonzero(groups, "groups must be non-zero"),
123        }
124    }
125}
126
127/// Convolution options with support for asymmetric padding.
128///
129/// Wraps [`ConvOptions`] (which represents symmetric padding for the backend op)
130/// and adds optional asymmetric padding. When asymmetric padding is specified,
131/// the functional convolution layer applies an explicit pad operation before
132/// dispatching to the backend.
133///
134/// Implements `From<ConvOptions<N>>` for backward compatibility.
135#[derive(Debug, Clone)]
136pub struct PaddedConvOptions<const N: usize> {
137    /// The underlying convolution options for the backend.
138    pub options: ConvOptions<N>,
139    /// Padding at the end of each dimension (e.g., bottom/right for 2D).
140    /// If `None`, padding is symmetric (same as `options.padding`).
141    /// If `Some`, specifies different end-padding per dimension.
142    pub padding_end: Option<[usize; N]>,
143}
144
145impl<const N: usize> PaddedConvOptions<N> {
146    /// Creates options with asymmetric padding.
147    ///
148    /// `padding_start` is stored in `ConvOptions::padding`.
149    /// `padding_end` specifies the end padding per dimension.
150    pub fn asymmetric(
151        stride: [usize; N],
152        padding_start: [usize; N],
153        padding_end: [usize; N],
154        dilation: [usize; N],
155        groups: usize,
156    ) -> Self {
157        let options = ConvOptions::new(stride, padding_start, dilation, groups);
158        if padding_start == padding_end {
159            Self {
160                options,
161                padding_end: None,
162            }
163        } else {
164            Self {
165                options,
166                padding_end: Some(padding_end),
167            }
168        }
169    }
170
171    /// Returns true if padding is asymmetric.
172    pub fn is_asymmetric(&self) -> bool {
173        self.padding_end.is_some()
174    }
175}
176
177impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
178    fn from(options: ConvOptions<N>) -> Self {
179        Self {
180            options,
181            padding_end: None,
182        }
183    }
184}
185
186/// Convolution options.
187#[derive(Debug, Clone, Hash, PartialEq, Eq)]
188pub struct DeformConvOptions<const N: usize> {
189    /// Stride (non-zero).
190    pub stride: [usize; N],
191
192    /// Padding.
193    pub padding: [usize; N],
194
195    /// Dilation (non-zero).
196    pub dilation: [usize; N],
197
198    /// Weight Groups (non-zero).
199    pub weight_groups: usize,
200
201    /// Offset Groups (non-zero).
202    pub offset_groups: usize,
203}
204
205impl<const N: usize> DeformConvOptions<N> {
206    /// Constructs a new `DeformConvOptions`.
207    pub fn new(
208        stride: [usize; N],
209        padding: [usize; N],
210        dilation: [usize; N],
211        weight_groups: usize,
212        offset_groups: usize,
213    ) -> Self {
214        Self {
215            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
216            padding,
217            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
218            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
219            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
220        }
221    }
222}
223
224/// Transposed convolution options.
225#[derive(Debug, Clone, Hash, PartialEq, Eq)]
226pub struct ConvTransposeOptions<const N: usize> {
227    /// Stride (non-zero).
228    pub stride: [usize; N],
229
230    /// Padding.
231    pub padding: [usize; N],
232
233    /// Padding out.
234    pub padding_out: [usize; N],
235
236    /// Dilation (non-zero).
237    pub dilation: [usize; N],
238
239    /// Groups (non-zero).
240    pub groups: usize,
241}
242
243impl<const N: usize> ConvTransposeOptions<N> {
244    /// Constructs a new `ConvTransposeOptions`.
245    pub fn new(
246        stride: [usize; N],
247        padding: [usize; N],
248        padding_out: [usize; N],
249        dilation: [usize; N],
250        groups: usize,
251    ) -> Self {
252        Self {
253            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
254            padding,
255            padding_out,
256            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
257            groups: check_nonzero(groups, "groups must be non-zero"),
258        }
259    }
260}
261
262/// Unfold operation options.
263#[derive(Debug, Clone)]
264pub struct UnfoldOptions {
265    /// The number of positions to slide over the input tensor in each dimension.
266    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
267    pub stride: [usize; 2],
268
269    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
270    pub padding: [usize; 2],
271
272    /// The spacing between the blocks (patches) in the original input tensor.
273    pub dilation: [usize; 2],
274}
275
276impl UnfoldOptions {
277    /// Constructs a new `UnfoldOptions`.
278    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
279        Self {
280            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
281            padding,
282            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
283        }
284    }
285}
286
287/// Algorithm used for upsampling.
288#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
289pub enum InterpolateMode {
290    /// Nearest-neighbor interpolation.
291    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
292    Nearest,
293
294    /// Bilinear interpolation.
295    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
296    Bilinear,
297
298    /// Bicubic interpolation.
299    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
300    Bicubic,
301
302    /// Lanczos3 interpolation (6-tap sinc-based filter).
303    /// <https://en.wikipedia.org/wiki/Lanczos_resampling>
304    Lanczos3,
305}
306
307/// Interpolation options.
308#[derive(Debug, Clone)]
309pub struct InterpolateOptions {
310    /// Algorithm used for upsampling.
311    pub mode: InterpolateMode,
312    /// If `true`, the input and output tensors are aligned by their corner pixels.
313    /// If `false`, half-pixel coordinate mapping is used instead.
314    pub align_corners: bool,
315}
316
317impl InterpolateOptions {
318    /// Create new interpolate options with the given mode.
319    /// Defaults to `align_corners = true`.
320    pub fn new(mode: InterpolateMode) -> Self {
321        Self {
322            mode,
323            align_corners: true,
324        }
325    }
326
327    /// Set align_corners.
328    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
329        self.align_corners = align_corners;
330        self
331    }
332}
333
334/// Padding mode for grid sampling when coordinates are out of bounds.
335///
336/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
338pub enum GridSamplePaddingMode {
339    /// Fill with zeros for out-of-bounds coordinates.
340    #[default]
341    Zeros,
342    /// Clamp coordinates to the border (use nearest edge value).
343    Border,
344    /// Reflect coordinates at the boundary.
345    Reflection,
346}
347
348/// Options for grid sampling operations.
349#[derive(Debug, Clone)]
350pub struct GridSampleOptions {
351    /// Interpolation mode (bilinear, nearest, or bicubic).
352    pub mode: InterpolateMode,
353    /// Padding mode for out-of-bounds coordinates.
354    pub padding_mode: GridSamplePaddingMode,
355    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.
356    /// If `false`, they correspond to the corner points of the corner pixels
357    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).
358    pub align_corners: bool,
359}
360
361impl Default for GridSampleOptions {
362    fn default() -> Self {
363        Self {
364            mode: InterpolateMode::Bilinear,
365            padding_mode: GridSamplePaddingMode::Zeros,
366            align_corners: false,
367        }
368    }
369}
370
371impl From<InterpolateMode> for GridSampleOptions {
372    fn from(value: InterpolateMode) -> Self {
373        GridSampleOptions::new(value)
374    }
375}
376
377impl GridSampleOptions {
378    /// Create new grid sample options with the given interpolation mode.
379    ///
380    /// Uses default values for padding_mode (Zeros) and align_corners (false).
381    pub fn new(mode: InterpolateMode) -> Self {
382        Self {
383            mode,
384            ..Default::default()
385        }
386    }
387
388    /// Set the padding mode.
389    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
390        self.padding_mode = padding_mode;
391        self
392    }
393
394    /// Set align_corners.
395    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
396        self.align_corners = align_corners;
397        self
398    }
399}
400
401/// Padding mode for tensor pad operations.
402///
403/// Defines how values are filled when padding a tensor beyond its original boundaries.
404/// Padding can be applied to any dimension of a tensor.
405///
406/// # Modes
407///
408/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)
409/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)
410/// - [`Edge`](PadMode::Edge): Replicate boundary values
411#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
412pub enum PadMode {
413    /// Fill padded regions with a constant value.
414    ///
415    /// # Example
416    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:
417    /// Result: `[0, 0, 1, 2, 3]`
418    Constant(f32),
419
420    /// Reflect values at the boundary, excluding the edge value.
421    ///
422    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).
423    ///
424    /// # Example
425    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
426    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)
427    Reflect,
428
429    /// Replicate the edge values.
430    ///
431    /// # Example
432    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
433    /// Result: `[1, 1, 1, 2, 3, 4]`
434    Edge,
435}
436
437impl Default for PadMode {
438    fn default() -> Self {
439        PadMode::Constant(0.0)
440    }
441}
442
443impl<E: ElementConversion> From<E> for PadMode {
444    fn from(value: E) -> Self {
445        PadMode::Constant(value.elem())
446    }
447}
448
449/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
450#[derive(new)]
451pub struct InterpolateBackward<B: Backend> {
452    /// Gradient.
453    pub x_grad: FloatTensor<B>,
454}
455
456/// Options for [attention](ModuleOps::attention).
457#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]
458pub struct AttentionModuleOptions {
459    /// Custom scale factor applied to QK^T. When `None`, defaults to `1/sqrt(head_dim)`.
460    pub scale: Option<f64>,
461
462    /// Soft capping applied before softmax: `softcap * tanh(scores / softcap)`.
463    /// Used by Gemma-2 and similar models. Must be positive when set.
464    pub softcap: Option<f64>,
465
466    /// When `true`, applies causal (autoregressive) masking so that each query position
467    /// can only attend to key positions at or before it. This is more efficient than
468    /// passing an explicit lower-triangular bool mask because backends can use optimized
469    /// kernel paths (e.g. flash attention with causal mode).
470    pub is_causal: bool,
471}
472
473/// Module operations trait.
474pub trait ModuleOps<B: Backend> {
475    /// Embedding operation.
476    ///
477    /// # Arguments
478    ///
479    /// * `weights` - The embedding weights.
480    /// * `indices` - The indices tensor.
481    ///
482    /// # Returns
483    ///
484    /// The output tensor.
485    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
486        let [batch_size, seq_length] = indices.shape().dims();
487        let [_, d_model] = weights.shape().dims();
488
489        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
490        let output = B::float_select(weights, 0, indices);
491
492        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
493    }
494
495    /// Embedding backward operation.
496    ///
497    /// # Arguments
498    ///
499    /// * `weights` - The embedding weights.
500    /// * `output_grad` - The output gradient.
501    /// * `indices` - The indices tensor.
502    ///
503    /// # Returns
504    ///
505    /// The gradient.
506    fn embedding_backward(
507        weights: FloatTensor<B>,
508        output_grad: FloatTensor<B>,
509        indices: IntTensor<B>,
510    ) -> FloatTensor<B> {
511        let [batch_size, seq_length] = indices.shape().dims();
512        let [n_embeddings, d_model] = weights.shape().dims();
513        let device = B::float_device(&weights);
514        let dtype = output_grad.dtype();
515
516        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
517        let output_grad =
518            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
519        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
520
521        B::float_select_add(grad, 0, indices, output_grad)
522    }
523
524    /// Linear transformation.
525    ///
526    /// # Shapes
527    ///
528    /// x:      `[..., d_input]`,
529    /// weight: `[d_input, d_output]`,
530    /// bias:   `[d_output]`,
531    fn linear(
532        x: FloatTensor<B>,
533        weight: FloatTensor<B>,
534        bias: Option<FloatTensor<B>>,
535    ) -> FloatTensor<B> {
536        linear::linear::<B>(x, weight, bias)
537    }
538    /// Backward pass for [linear](ModuleOps::linear), returning the gradient for `x`.
539    fn linear_x_backward(weight: FloatTensor<B>, output_grad: FloatTensor<B>) -> FloatTensor<B> {
540        linear::linear_x_backward::<B>(weight, output_grad)
541    }
542    /// Backward pass for [linear](ModuleOps::linear), returning the gradient for `weight`.
543    fn linear_weight_backward(x: FloatTensor<B>, output_grad: FloatTensor<B>) -> FloatTensor<B> {
544        linear::linear_weight_backward::<B>(x, output_grad)
545    }
546    /// Backward pass for [linear](ModuleOps::linear), returning the gradient for `bias`.
547    fn linear_bias_backward(output_grad: FloatTensor<B>) -> FloatTensor<B> {
548        linear::linear_bias_backward::<B>(output_grad)
549    }
550
551    /// One dimensional convolution.
552    ///
553    /// # Shapes
554    ///
555    /// x:      `[batch_size, channels_in, length]`,
556    /// weight: `[channels_out, channels_in, kernel_size]`,
557    /// bias:   `[channels_out]`,
558    fn conv1d(
559        x: FloatTensor<B>,
560        weight: FloatTensor<B>,
561        bias: Option<FloatTensor<B>>,
562        options: ConvOptions<1>,
563    ) -> FloatTensor<B> {
564        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
565    }
566    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
567    fn conv1d_x_backward(
568        x: FloatTensor<B>,
569        weight: FloatTensor<B>,
570        output_grad: FloatTensor<B>,
571        options: ConvOptions<1>,
572    ) -> FloatTensor<B> {
573        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
574    }
575    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
576    fn conv1d_weight_backward(
577        x: FloatTensor<B>,
578        weight: FloatTensor<B>,
579        output_grad: FloatTensor<B>,
580        options: ConvOptions<1>,
581    ) -> FloatTensor<B> {
582        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
583    }
584    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
585    fn conv1d_bias_backward(
586        x: FloatTensor<B>,
587        bias: FloatTensor<B>,
588        output_grad: FloatTensor<B>,
589    ) -> FloatTensor<B> {
590        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
591    }
592    /// Two dimensional convolution.
593    ///
594    /// # Shapes
595    ///
596    /// x:      `[batch_size, channels_in, height, width]`,
597    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
598    /// bias:   `[channels_out]`,
599    fn conv2d(
600        x: FloatTensor<B>,
601        weight: FloatTensor<B>,
602        bias: Option<FloatTensor<B>>,
603        options: ConvOptions<2>,
604    ) -> FloatTensor<B>;
605    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
606    fn conv2d_x_backward(
607        x: FloatTensor<B>,
608        weight: FloatTensor<B>,
609        output_grad: FloatTensor<B>,
610        options: ConvOptions<2>,
611    ) -> FloatTensor<B> {
612        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
613    }
614    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
615    fn conv2d_weight_backward(
616        x: FloatTensor<B>,
617        weight: FloatTensor<B>,
618        output_grad: FloatTensor<B>,
619        options: ConvOptions<2>,
620    ) -> FloatTensor<B> {
621        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
622    }
623    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
624    fn conv2d_bias_backward(
625        x: FloatTensor<B>,
626        bias: FloatTensor<B>,
627        output_grad: FloatTensor<B>,
628    ) -> FloatTensor<B> {
629        conv::conv2d_bias_backward::<B>(x, bias, output_grad)
630    }
631
632    /// Two dimensional deformable convolution.
633    ///
634    /// # Shapes
635    ///
636    /// x:      `[batch_size, channels_in, height, width]`,
637    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
638    /// bias:   `[channels_out]`,
639    fn deform_conv2d(
640        x: FloatTensor<B>,
641        offset: FloatTensor<B>,
642        weight: FloatTensor<B>,
643        mask: Option<FloatTensor<B>>,
644        bias: Option<FloatTensor<B>>,
645        options: DeformConvOptions<2>,
646    ) -> FloatTensor<B>;
647    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
648    fn deform_conv2d_backward(
649        x: FloatTensor<B>,
650        offset: FloatTensor<B>,
651        weight: FloatTensor<B>,
652        mask: Option<FloatTensor<B>>,
653        bias: Option<FloatTensor<B>>,
654        output_grad: FloatTensor<B>,
655        options: DeformConvOptions<2>,
656    ) -> DeformConv2dBackward<B>;
657
658    /// Three dimensional convolution.
659    ///
660    /// # Shapes
661    ///
662    /// x:      `[batch_size, channels_in, depth, height, width]`,
663    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
664    /// bias:   `[channels_out]`,
665    fn conv3d(
666        x: FloatTensor<B>,
667        weight: FloatTensor<B>,
668        bias: Option<FloatTensor<B>>,
669        options: ConvOptions<3>,
670    ) -> FloatTensor<B>;
671    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
672    fn conv3d_x_backward(
673        x: FloatTensor<B>,
674        weight: FloatTensor<B>,
675        output_grad: FloatTensor<B>,
676        options: ConvOptions<3>,
677    ) -> FloatTensor<B> {
678        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
679    }
680    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
681    fn conv3d_weight_backward(
682        x: FloatTensor<B>,
683        weight: FloatTensor<B>,
684        output_grad: FloatTensor<B>,
685        options: ConvOptions<3>,
686    ) -> FloatTensor<B> {
687        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
688    }
689    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
690    fn conv3d_bias_backward(
691        x: FloatTensor<B>,
692        bias: FloatTensor<B>,
693        output_grad: FloatTensor<B>,
694    ) -> FloatTensor<B> {
695        conv::conv3d_bias_backward::<B>(x, bias, output_grad)
696    }
697    /// One dimensional transposed convolution.
698    ///
699    /// # Shapes
700    ///
701    /// x:      `[batch_size, channels_in, length]`,
702    /// weight: `[channels_in, channels_out, length]`,
703    /// bias:   `[channels_out]`,
704    fn conv_transpose1d(
705        x: FloatTensor<B>,
706        weight: FloatTensor<B>,
707        bias: Option<FloatTensor<B>>,
708        options: ConvTransposeOptions<1>,
709    ) -> FloatTensor<B> {
710        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
711    }
712    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
713    fn conv_transpose1d_x_backward(
714        weight: FloatTensor<B>,
715        output_grad: FloatTensor<B>,
716        options: ConvTransposeOptions<1>,
717    ) -> FloatTensor<B> {
718        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
719    }
720    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
721    fn conv_transpose1d_weight_backward(
722        x: FloatTensor<B>,
723        weight: FloatTensor<B>,
724        output_grad: FloatTensor<B>,
725        options: ConvTransposeOptions<1>,
726    ) -> FloatTensor<B> {
727        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
728    }
729    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
730    fn conv_transpose1d_bias_backward(
731        x: FloatTensor<B>,
732        bias: FloatTensor<B>,
733        output_grad: FloatTensor<B>,
734    ) -> FloatTensor<B> {
735        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
736    }
737
738    /// Two dimensional transposed convolution.
739    ///
740    /// # Shapes
741    ///
742    /// x:      `[batch_size, channels_in, height, width]`,
743    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
744    /// bias:   `[channels_out]`,
745    fn conv_transpose2d(
746        x: FloatTensor<B>,
747        weight: FloatTensor<B>,
748        bias: Option<FloatTensor<B>>,
749        options: ConvTransposeOptions<2>,
750    ) -> FloatTensor<B>;
751    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
752    fn conv_transpose2d_x_backward(
753        weight: FloatTensor<B>,
754        output_grad: FloatTensor<B>,
755        options: ConvTransposeOptions<2>,
756    ) -> FloatTensor<B> {
757        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
758    }
759    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
760    fn conv_transpose2d_weight_backward(
761        x: FloatTensor<B>,
762        weight: FloatTensor<B>,
763        output_grad: FloatTensor<B>,
764        options: ConvTransposeOptions<2>,
765    ) -> FloatTensor<B> {
766        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
767    }
768    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
769    fn conv_transpose2d_bias_backward(
770        x: FloatTensor<B>,
771        bias: FloatTensor<B>,
772        output_grad: FloatTensor<B>,
773    ) -> FloatTensor<B> {
774        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
775    }
776
777    /// Three dimensional transposed convolution.
778    ///
779    /// # Shapes
780    ///
781    /// x:      `[batch_size, channels_in, height, width]`,
782    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
783    /// bias:   `[channels_out]`,
784    fn conv_transpose3d(
785        x: FloatTensor<B>,
786        weight: FloatTensor<B>,
787        bias: Option<FloatTensor<B>>,
788        options: ConvTransposeOptions<3>,
789    ) -> FloatTensor<B>;
790    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
791    fn conv_transpose3d_x_backward(
792        weight: FloatTensor<B>,
793        output_grad: FloatTensor<B>,
794        options: ConvTransposeOptions<3>,
795    ) -> FloatTensor<B> {
796        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
797    }
798    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
799    fn conv_transpose3d_weight_backward(
800        x: FloatTensor<B>,
801        weight: FloatTensor<B>,
802        output_grad: FloatTensor<B>,
803        options: ConvTransposeOptions<3>,
804    ) -> FloatTensor<B> {
805        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
806    }
807    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
808    fn conv_transpose3d_bias_backward(
809        x: FloatTensor<B>,
810        bias: FloatTensor<B>,
811        output_grad: FloatTensor<B>,
812    ) -> FloatTensor<B> {
813        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
814    }
815
816    /// Four-dimensional unfolding.
817    ///
818    /// # Shapes
819    ///
820    /// * x:      ``[batch_size, channels_in, height, width]``,
821    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
822    fn unfold4d(
823        x: FloatTensor<B>,
824        kernel_size: [usize; 2],
825        options: UnfoldOptions,
826    ) -> FloatTensor<B> {
827        if options.padding == [0, 0] && options.dilation == [1, 1] {
828            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
829            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
830
831            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
832
833            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
834            let shape = blocks.shape();
835
836            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
837
838            B::float_reshape(
839                blocks,
840                [
841                    shape[0],
842                    shape[1] * shape[2] * shape[3],
843                    shape[4] * shape[5],
844                ]
845                .into(),
846            )
847        } else {
848            unfold4d_using_conv2d::<B>(x, kernel_size, options)
849        }
850    }
851
852    /// One dimensional avg pooling.
853    ///
854    /// # Shapes
855    ///
856    /// x: [batch_size, channels, length],
857    fn avg_pool1d(
858        x: FloatTensor<B>,
859        kernel_size: usize,
860        stride: usize,
861        padding: usize,
862        count_include_pad: bool,
863        ceil_mode: bool,
864    ) -> FloatTensor<B> {
865        pool::avg_pool1d_from_2d::<B>(
866            x,
867            kernel_size,
868            stride,
869            padding,
870            count_include_pad,
871            ceil_mode,
872        )
873    }
874    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
875    fn avg_pool1d_backward(
876        x: FloatTensor<B>,
877        grad: FloatTensor<B>,
878        kernel_size: usize,
879        stride: usize,
880        padding: usize,
881        count_include_pad: bool,
882        ceil_mode: bool,
883    ) -> FloatTensor<B> {
884        pool::avg_pool1d_backward_from_2d::<B>(
885            x,
886            grad,
887            kernel_size,
888            stride,
889            padding,
890            count_include_pad,
891            ceil_mode,
892        )
893    }
894    /// Two dimensional avg pooling.
895    ///
896    /// # Shapes
897    ///
898    /// x: [batch_size, channels, height, width],
899    fn avg_pool2d(
900        x: FloatTensor<B>,
901        kernel_size: [usize; 2],
902        stride: [usize; 2],
903        padding: [usize; 2],
904        count_include_pad: bool,
905        ceil_mode: bool,
906    ) -> FloatTensor<B>;
907    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
908    fn avg_pool2d_backward(
909        x: FloatTensor<B>,
910        grad: FloatTensor<B>,
911        kernel_size: [usize; 2],
912        stride: [usize; 2],
913        padding: [usize; 2],
914        count_include_pad: bool,
915        ceil_mode: bool,
916    ) -> FloatTensor<B>;
917    /// Two dimensional adaptive avg pooling.
918    ///
919    /// # Shapes
920    ///
921    /// x: [batch_size, channels, height, width],
922    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
923    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
924    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
925    /// One dimensional adaptive avg pooling.
926    ///
927    /// # Shapes
928    ///
929    /// x: [batch_size, channels, length],
930    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
931        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
932    }
933    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
934    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
935        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
936    }
937    /// One dimensional max pooling.
938    ///
939    /// # Shapes
940    ///
941    /// x: [batch_size, channels, length],
942    fn max_pool1d(
943        x: FloatTensor<B>,
944        kernel_size: usize,
945        stride: usize,
946        padding: usize,
947        dilation: usize,
948        ceil_mode: bool,
949    ) -> FloatTensor<B> {
950        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
951    }
952
953    /// One dimensional max pooling with indices.
954    ///
955    /// # Shapes
956    ///
957    /// x: [batch_size, channels, height, width],
958    fn max_pool1d_with_indices(
959        x: FloatTensor<B>,
960        kernel_size: usize,
961        stride: usize,
962        padding: usize,
963        dilation: usize,
964        ceil_mode: bool,
965    ) -> MaxPool1dWithIndices<B> {
966        pool::max_pool1d_with_indices_from_2d::<B>(
967            x,
968            kernel_size,
969            stride,
970            padding,
971            dilation,
972            ceil_mode,
973        )
974    }
975    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
976    #[allow(clippy::too_many_arguments)]
977    fn max_pool1d_with_indices_backward(
978        x: FloatTensor<B>,
979        kernel_size: usize,
980        stride: usize,
981        padding: usize,
982        dilation: usize,
983        ceil_mode: bool,
984        output_grad: FloatTensor<B>,
985        indices: IntTensor<B>,
986    ) -> MaxPool1dBackward<B> {
987        pool::max_pool1d_with_indices_backward_from_2d::<B>(
988            x,
989            kernel_size,
990            stride,
991            padding,
992            dilation,
993            ceil_mode,
994            output_grad,
995            indices,
996        )
997    }
998
999    /// Two dimensional max pooling.
1000    ///
1001    /// # Shapes
1002    ///
1003    /// x: [batch_size, channels, height, width],
1004    fn max_pool2d(
1005        x: FloatTensor<B>,
1006        kernel_size: [usize; 2],
1007        stride: [usize; 2],
1008        padding: [usize; 2],
1009        dilation: [usize; 2],
1010        ceil_mode: bool,
1011    ) -> FloatTensor<B>;
1012
1013    /// Two dimensional max pooling with indices.
1014    ///
1015    /// # Shapes
1016    ///
1017    /// x: [batch_size, channels, height, width],
1018    fn max_pool2d_with_indices(
1019        x: FloatTensor<B>,
1020        kernel_size: [usize; 2],
1021        stride: [usize; 2],
1022        padding: [usize; 2],
1023        dilation: [usize; 2],
1024        ceil_mode: bool,
1025    ) -> MaxPool2dWithIndices<B>;
1026    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
1027    #[allow(clippy::too_many_arguments)]
1028    fn max_pool2d_with_indices_backward(
1029        x: FloatTensor<B>,
1030        kernel_size: [usize; 2],
1031        stride: [usize; 2],
1032        padding: [usize; 2],
1033        dilation: [usize; 2],
1034        ceil_mode: bool,
1035        output_grad: FloatTensor<B>,
1036        indices: IntTensor<B>,
1037    ) -> MaxPool2dBackward<B>;
1038
1039    /// Down/up samples the input.
1040    ///
1041    /// # Shapes
1042    ///
1043    /// x: `[batch_size, channels, height, width]`,
1044    fn interpolate(
1045        x: FloatTensor<B>,
1046        output_size: [usize; 2],
1047        options: InterpolateOptions,
1048    ) -> FloatTensor<B>;
1049
1050    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
1051    fn interpolate_backward(
1052        x: FloatTensor<B>,
1053        grad: FloatTensor<B>,
1054        output_size: [usize; 2],
1055        options: InterpolateOptions,
1056    ) -> FloatTensor<B>;
1057
1058    /// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
1059    /// where scale defaults to 1/sqrt(head_dim). Optionally applies masking,
1060    /// additive bias, causal masking, and softcap to the attention scores.
1061    ///
1062    /// # Arguments
1063    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
1064    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
1065    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
1066    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
1067    ///   where `true` indicates positions to mask (i.e. set to -inf before softmax).
1068    /// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
1069    ///   added to the attention scores before softmax (e.g. ALiBi, relative position biases).
1070    /// - `options`: Additional attention options (custom scale, softcap, causal masking).
1071    ///
1072    /// # Returns
1073    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
1074    /// representing the attended context per head.
1075    ///
1076    /// # Note
1077    /// This implementation does not support dropout and is intended for inference or
1078    /// use cases where dropout is not needed.
1079    fn attention(
1080        query: FloatTensor<B>,
1081        key: FloatTensor<B>,
1082        value: FloatTensor<B>,
1083        mask: Option<BoolTensor<B>>,
1084        attn_bias: Option<FloatTensor<B>>,
1085        options: AttentionModuleOptions,
1086    ) -> FloatTensor<B>;
1087
1088    /// Applies Layer Normalization over the last dimension of the input tensor.
1089    ///
1090    /// Computes `(x - mean) / sqrt(var + epsilon) * gamma + beta`, where `mean` and
1091    /// (biased) `var` are reduced over the last axis.
1092    ///
1093    /// # Arguments
1094    ///
1095    /// * `tensor` - Input tensor of shape `[..., d_model]`.
1096    /// * `gamma` - Scale tensor of shape `[d_model]`.
1097    /// * `beta` - Optional bias tensor of shape `[d_model]`.
1098    /// * `epsilon` - Numerical stability term added to the variance before the square root.
1099    ///
1100    /// # Returns
1101    ///
1102    /// A tensor with the same shape as `tensor`.
1103    fn layer_norm(
1104        tensor: FloatTensor<B>,
1105        gamma: FloatTensor<B>,
1106        beta: Option<FloatTensor<B>>,
1107        epsilon: f64,
1108    ) -> FloatTensor<B> {
1109        let shape = tensor.shape();
1110        let rank = shape.num_dims();
1111        let last_dim = rank - 1;
1112        let d_model = shape[last_dim];
1113
1114        let mean = B::float_mean_dim(tensor.clone(), last_dim);
1115        let centered = B::float_sub(tensor, mean);
1116        let var = B::float_mean_dim(B::float_mul(centered.clone(), centered.clone()), last_dim);
1117        let denom = B::float_sqrt(B::float_add_scalar(var, epsilon.into()));
1118        let normalized = B::float_div(centered, denom);
1119
1120        let broadcast_dims: alloc::vec::Vec<usize> = (0..rank)
1121            .map(|i| if i == last_dim { d_model } else { 1 })
1122            .collect();
1123        let gamma_b = B::float_reshape(gamma, Shape::from(broadcast_dims.clone()));
1124        let scaled = B::float_mul(normalized, gamma_b);
1125
1126        match beta {
1127            Some(beta) => {
1128                let beta_b = B::float_reshape(beta, Shape::from(broadcast_dims));
1129                B::float_add(scaled, beta_b)
1130            }
1131            None => scaled,
1132        }
1133    }
1134
1135    /// Computes the Connectionist Temporal Classification (CTC) loss.
1136    ///
1137    /// Sums over all valid alignments between the input and target sequences
1138    /// using the forward (alpha) algorithm.
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `log_probs` - Log-probabilities of shape `[T, N, C]`
1143    /// * `targets` - Target label indices of shape `[N, S]`
1144    /// * `input_lengths` - Actual input sequence lengths per batch element `[N]`
1145    /// * `target_lengths` - Actual target lengths per batch element `[N]`
1146    /// * `blank` - Index of the blank label
1147    ///
1148    /// # Returns
1149    ///
1150    /// Per-sample loss of shape `[N]`
1151    fn ctc_loss(
1152        log_probs: FloatTensor<B>,
1153        targets: IntTensor<B>,
1154        input_lengths: IntTensor<B>,
1155        target_lengths: IntTensor<B>,
1156        blank: usize,
1157    ) -> FloatTensor<B> {
1158        ctc::ctc_loss_default::<B>(log_probs, targets, input_lengths, target_lengths, blank)
1159    }
1160
1161    /// Returns `true` if this backend implements [ctc_loss_backward](ModuleOps::ctc_loss_backward)
1162    /// natively.
1163    ///
1164    /// Autodiff queries this flag to decide between two paths:
1165    /// - `true`: use the backend's [ctc_loss](ModuleOps::ctc_loss) and
1166    ///   [ctc_loss_backward](ModuleOps::ctc_loss_backward) directly.
1167    /// - `false`: call [ctc::ctc_loss_default] for the forward pass; autodiff
1168    ///   then differentiates through the decomposed tensor ops.
1169    ///
1170    /// Backends that override `ctc_loss_backward` must also override this to
1171    /// return `true`.
1172    fn has_ctc_loss_backward() -> bool {
1173        false
1174    }
1175
1176    /// Backward pass for [ctc_loss](ModuleOps::ctc_loss): gradient w.r.t. `log_probs`.
1177    ///
1178    /// Only called when [has_ctc_loss_backward](ModuleOps::has_ctc_loss_backward)
1179    /// returns `true`. Backends without a native implementation should leave
1180    /// both methods at their defaults; the gradient is computed automatically by
1181    /// autodiff against the decomposed [ctc::ctc_loss_default] forward.
1182    ///
1183    /// # Arguments
1184    ///
1185    /// * `log_probs` - Log-probabilities of shape `[T, N, C]`
1186    /// * `targets` - Target label indices of shape `[N, S]`
1187    /// * `input_lengths` - Actual input sequence lengths per batch element `[N]`
1188    /// * `target_lengths` - Actual target lengths per batch element `[N]`
1189    /// * `grad_loss` - Upstream gradient w.r.t. the per-sample loss `[N]`
1190    /// * `blank` - Index of the blank label
1191    ///
1192    /// # Returns
1193    ///
1194    /// Gradient w.r.t. `log_probs` of shape `[T, N, C]`
1195    fn ctc_loss_backward(
1196        _log_probs: FloatTensor<B>,
1197        _targets: IntTensor<B>,
1198        _input_lengths: IntTensor<B>,
1199        _target_lengths: IntTensor<B>,
1200        _grad_loss: FloatTensor<B>,
1201        _blank: usize,
1202    ) -> FloatTensor<B> {
1203        unreachable!(
1204            "ctc_loss_backward called on a backend whose has_ctc_loss_backward() returns false"
1205        )
1206    }
1207
1208    /// Real-valued FFT with optional size parameter.
1209    ///
1210    /// When `n` is `None`, the signal must be a power of two along `dim`, and the output has
1211    /// `signal_len / 2 + 1` frequency bins.
1212    ///
1213    /// When `n` is `Some(size)`, `size` must also be a power of two. The signal is truncated
1214    /// or zero-padded to `size` and the output has `size / 2 + 1` frequency bins. Non-power-
1215    /// of-two sizes are currently rejected at the public API boundary; true arbitrary-`n` DFT
1216    /// support (Bluestein's algorithm) is tracked as a follow-up.
1217    ///
1218    /// Returns two tensors: the real part and the imaginary part.
1219    fn rfft(
1220        signal: FloatTensor<B>,
1221        dim: usize,
1222        n: Option<usize>,
1223    ) -> (FloatTensor<B>, FloatTensor<B>);
1224
1225    /// Inverse real-valued FFT with optional output size.
1226    ///
1227    /// When `n` is `None`, the reconstructed signal length `2 * (spectrum_size - 1)` must be
1228    /// a power of two.
1229    ///
1230    /// When `n` is `Some(size)`, `size` must also be a power of two. Output has exactly
1231    /// `size` samples.
1232    fn irfft(
1233        spectrum_re: FloatTensor<B>,
1234        spectrum_im: FloatTensor<B>,
1235        dim: usize,
1236        n: Option<usize>,
1237    ) -> FloatTensor<B>;
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242    use super::*;
1243
1244    #[test]
1245    #[should_panic = "stride must be non-zero"]
1246    fn conv_options_stride_zero() {
1247        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1248    }
1249
1250    #[test]
1251    #[should_panic = "dilation must be non-zero"]
1252    fn conv_options_dilation_zero() {
1253        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1254    }
1255
1256    #[test]
1257    #[should_panic = "groups must be non-zero"]
1258    fn conv_options_groups_zero() {
1259        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1260    }
1261
1262    #[test]
1263    #[should_panic = "stride must be non-zero"]
1264    fn conv_transpose_options_stride_zero() {
1265        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1266    }
1267
1268    #[test]
1269    #[should_panic = "dilation must be non-zero"]
1270    fn conv_transpose_options_dilation_zero() {
1271        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1272    }
1273
1274    #[test]
1275    #[should_panic = "groups must be non-zero"]
1276    fn conv_transpose_options_groups_zero() {
1277        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1278    }
1279
1280    #[test]
1281    #[should_panic = "stride must be non-zero"]
1282    fn deform_conv_options_stride_zero() {
1283        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1284    }
1285
1286    #[test]
1287    #[should_panic = "dilation must be non-zero"]
1288    fn deform_conv_options_dilation_zero() {
1289        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1290    }
1291
1292    #[test]
1293    #[should_panic = "weight groups must be non-zero"]
1294    fn deform_conv_options_weights_groups_zero() {
1295        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1296    }
1297
1298    #[test]
1299    #[should_panic = "offset groups must be non-zero"]
1300    fn deform_conv_options_offset_groups_zero() {
1301        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1302    }
1303
1304    #[test]
1305    #[should_panic = "stride must be non-zero"]
1306    fn unfold_options_stride_zero() {
1307        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1308    }
1309
1310    #[test]
1311    #[should_panic = "dilation must be non-zero"]
1312    fn unfold_options_dilation_zero() {
1313        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1314    }
1315}