burn_tensor/tensor/ops/modules/
base.rs

1use super::{conv, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::{
4    Shape, TensorMetadata,
5    backend::Backend,
6    ops::{FloatTensor, IntTensor},
7};
8use core::num::NonZeroUsize;
9
10/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
11#[derive(new)]
12pub struct Conv2dBackward<B: Backend> {
13    /// Gradient.
14    pub x_grad: FloatTensor<B>,
15
16    /// Weights gradient.
17    pub weights_grad: FloatTensor<B>,
18
19    /// Bias gradient.
20    pub bias_grad: Option<FloatTensor<B>>,
21}
22
23/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
24#[derive(new)]
25pub struct DeformConv2dBackward<B: Backend> {
26    /// Gradient.
27    pub x_grad: FloatTensor<B>,
28
29    /// Offset gradient.
30    pub offset_grad: FloatTensor<B>,
31
32    /// Weights gradient.
33    pub weight_grad: FloatTensor<B>,
34
35    /// Mask gradient.
36    pub mask_grad: Option<FloatTensor<B>>,
37
38    /// Bias gradient.
39    pub bias_grad: Option<FloatTensor<B>>,
40}
41
42/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
43#[derive(new)]
44pub struct Conv3dBackward<B: Backend> {
45    /// Gradient.
46    pub x_grad: FloatTensor<B>,
47
48    /// Weights gradient.
49    pub weights_grad: FloatTensor<B>,
50
51    /// Bias gradient.
52    pub bias_grad: Option<FloatTensor<B>>,
53}
54
55/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
56#[derive(new)]
57pub struct MaxPool1dBackward<B: Backend> {
58    /// Gradient.
59    pub x_grad: FloatTensor<B>,
60}
61
62/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
63#[derive(new)]
64pub struct MaxPool1dWithIndices<B: Backend> {
65    /// The output tensor.
66    pub output: FloatTensor<B>,
67
68    /// The indices tensor.
69    pub indices: IntTensor<B>,
70}
71
72/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
73#[derive(new)]
74pub struct MaxPool2dBackward<B: Backend> {
75    /// Gradient.
76    pub x_grad: FloatTensor<B>,
77}
78
79/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
80#[derive(new)]
81pub struct MaxPool2dWithIndices<B: Backend> {
82    /// The output tensor.
83    pub output: FloatTensor<B>,
84
85    /// The indices tensor.
86    pub indices: IntTensor<B>,
87}
88
89/// Check that the parameter value is non-zero.
90// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
91pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
92    NonZeroUsize::new(value).expect(msg);
93    value
94}
95
96/// Convolution options.
97#[derive(Debug, Clone, Hash, PartialEq, Eq)]
98pub struct ConvOptions<const N: usize> {
99    /// Stride (non-zero).
100    pub stride: [usize; N],
101
102    /// Padding.
103    pub padding: [usize; N],
104
105    /// Dilation (non-zero).
106    pub dilation: [usize; N],
107
108    /// Groups (non-zero).
109    pub groups: usize,
110}
111
112impl<const N: usize> ConvOptions<N> {
113    /// Constructs a new `ConvOptions`.
114    pub fn new(
115        stride: [usize; N],
116        padding: [usize; N],
117        dilation: [usize; N],
118        groups: usize,
119    ) -> Self {
120        Self {
121            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
122            padding,
123            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
124            groups: check_nonzero(groups, "groups must be non-zero"),
125        }
126    }
127}
128
129/// Convolution options.
130#[derive(Debug, Clone, Hash, PartialEq, Eq)]
131pub struct DeformConvOptions<const N: usize> {
132    /// Stride (non-zero).
133    pub stride: [usize; N],
134
135    /// Padding.
136    pub padding: [usize; N],
137
138    /// Dilation (non-zero).
139    pub dilation: [usize; N],
140
141    /// Weight Groups (non-zero).
142    pub weight_groups: usize,
143
144    /// Offset Groups (non-zero).
145    pub offset_groups: usize,
146}
147
148impl<const N: usize> DeformConvOptions<N> {
149    /// Constructs a new `DeformConvOptions`.
150    pub fn new(
151        stride: [usize; N],
152        padding: [usize; N],
153        dilation: [usize; N],
154        weight_groups: usize,
155        offset_groups: usize,
156    ) -> Self {
157        Self {
158            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
159            padding,
160            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
161            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
162            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
163        }
164    }
165}
166
167/// Transposed convolution options.
168#[derive(Debug, Clone, Hash, PartialEq, Eq)]
169pub struct ConvTransposeOptions<const N: usize> {
170    /// Stride (non-zero).
171    pub stride: [usize; N],
172
173    /// Padding.
174    pub padding: [usize; N],
175
176    /// Padding out.
177    pub padding_out: [usize; N],
178
179    /// Dilation (non-zero).
180    pub dilation: [usize; N],
181
182    /// Groups (non-zero).
183    pub groups: usize,
184}
185
186impl<const N: usize> ConvTransposeOptions<N> {
187    /// Constructs a new `ConvTransposeOptions`.
188    pub fn new(
189        stride: [usize; N],
190        padding: [usize; N],
191        padding_out: [usize; N],
192        dilation: [usize; N],
193        groups: usize,
194    ) -> Self {
195        Self {
196            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
197            padding,
198            padding_out,
199            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
200            groups: check_nonzero(groups, "groups must be non-zero"),
201        }
202    }
203}
204
205/// Unfold operation options.
206#[derive(Debug, Clone)]
207pub struct UnfoldOptions {
208    /// The number of positions to slide over the input tensor in each dimension.
209    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
210    pub stride: [usize; 2],
211
212    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
213    pub padding: [usize; 2],
214
215    /// The spacing between the blocks (patches) in the original input tensor.
216    pub dilation: [usize; 2],
217}
218
219impl UnfoldOptions {
220    /// Constructs a new `UnfoldOptions`.
221    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
222        Self {
223            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
224            padding,
225            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
226        }
227    }
228}
229
230/// Algorithm used for upsampling.
231#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
232pub enum InterpolateMode {
233    /// Nearest-neighbor interpolation.
234    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
235    Nearest,
236
237    /// Bilinear interpolation.
238    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
239    Bilinear,
240
241    /// Bicubic interpolation.
242    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
243    Bicubic,
244}
245
246/// Interpolation options.
247#[derive(new, Debug, Clone)]
248pub struct InterpolateOptions {
249    /// Algorithm used for upsampling.
250    pub mode: InterpolateMode,
251}
252
253/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
254#[derive(new)]
255pub struct InterpolateBackward<B: Backend> {
256    /// Gradient.
257    pub x_grad: FloatTensor<B>,
258}
259
260/// Module operations trait.
261pub trait ModuleOps<B: Backend> {
262    /// Embedding operation.
263    ///
264    /// # Arguments
265    ///
266    /// * `weights` - The embedding weights.
267    /// * `indices` - The indices tensor.
268    ///
269    /// # Returns
270    ///
271    /// The output tensor.
272    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
273        let [batch_size, seq_length] = indices.shape().dims();
274        let [_, d_model] = weights.shape().dims();
275
276        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
277        let output = B::float_select(weights, 0, indices);
278
279        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
280    }
281
282    /// Embedding backward operation.
283    ///
284    /// # Arguments
285    ///
286    /// * `weights` - The embedding weights.
287    /// * `output_grad` - The output gradient.
288    /// * `indices` - The indices tensor.
289    ///
290    /// # Returns
291    ///
292    /// The gradient.
293    fn embedding_backward(
294        weights: FloatTensor<B>,
295        output_grad: FloatTensor<B>,
296        indices: IntTensor<B>,
297    ) -> FloatTensor<B> {
298        let [batch_size, seq_length] = indices.shape().dims();
299        let [n_embeddings, d_model] = weights.shape().dims();
300        let device = B::float_device(&weights);
301        let dtype = output_grad.dtype();
302
303        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
304        let output_grad =
305            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
306        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
307
308        B::float_select_assign(grad, 0, indices, output_grad)
309    }
310    /// One dimensional convolution.
311    ///
312    /// # Shapes
313    ///
314    /// x:      `[batch_size, channels_in, length]`,
315    /// weight: `[channels_out, channels_in, kernel_size]`,
316    /// bias:   `[channels_out]`,
317    fn conv1d(
318        x: FloatTensor<B>,
319        weight: FloatTensor<B>,
320        bias: Option<FloatTensor<B>>,
321        options: ConvOptions<1>,
322    ) -> FloatTensor<B> {
323        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
324    }
325    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
326    fn conv1d_x_backward(
327        x: FloatTensor<B>,
328        weight: FloatTensor<B>,
329        output_grad: FloatTensor<B>,
330        options: ConvOptions<1>,
331    ) -> FloatTensor<B> {
332        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
333    }
334    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
335    fn conv1d_weight_backward(
336        x: FloatTensor<B>,
337        weight: FloatTensor<B>,
338        output_grad: FloatTensor<B>,
339        options: ConvOptions<1>,
340    ) -> FloatTensor<B> {
341        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
342    }
343    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
344    fn conv1d_bias_backward(
345        x: FloatTensor<B>,
346        bias: FloatTensor<B>,
347        output_grad: FloatTensor<B>,
348    ) -> FloatTensor<B> {
349        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
350    }
351    /// Two dimensional convolution.
352    ///
353    /// # Shapes
354    ///
355    /// x:      `[batch_size, channels_in, height, width]`,
356    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
357    /// bias:   `[channels_out]`,
358    fn conv2d(
359        x: FloatTensor<B>,
360        weight: FloatTensor<B>,
361        bias: Option<FloatTensor<B>>,
362        options: ConvOptions<2>,
363    ) -> FloatTensor<B>;
364    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
365    fn conv2d_x_backward(
366        x: FloatTensor<B>,
367        weight: FloatTensor<B>,
368        output_grad: FloatTensor<B>,
369        options: ConvOptions<2>,
370    ) -> FloatTensor<B> {
371        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
372    }
373    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
374    fn conv2d_weight_backward(
375        x: FloatTensor<B>,
376        weight: FloatTensor<B>,
377        output_grad: FloatTensor<B>,
378        options: ConvOptions<2>,
379    ) -> FloatTensor<B> {
380        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
381    }
382    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
383    fn conv2d_bias_backward(
384        x: FloatTensor<B>,
385        weight: FloatTensor<B>,
386        bias: FloatTensor<B>,
387        output_grad: FloatTensor<B>,
388    ) -> FloatTensor<B> {
389        conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
390    }
391
392    /// Two dimensional deformable convolution.
393    ///
394    /// # Shapes
395    ///
396    /// x:      `[batch_size, channels_in, height, width]`,
397    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
398    /// bias:   `[channels_out]`,
399    fn deform_conv2d(
400        x: FloatTensor<B>,
401        offset: FloatTensor<B>,
402        weight: FloatTensor<B>,
403        mask: Option<FloatTensor<B>>,
404        bias: Option<FloatTensor<B>>,
405        options: DeformConvOptions<2>,
406    ) -> FloatTensor<B>;
407    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
408    fn deform_conv2d_backward(
409        x: FloatTensor<B>,
410        offset: FloatTensor<B>,
411        weight: FloatTensor<B>,
412        mask: Option<FloatTensor<B>>,
413        bias: Option<FloatTensor<B>>,
414        output_grad: FloatTensor<B>,
415        options: DeformConvOptions<2>,
416    ) -> DeformConv2dBackward<B>;
417
418    /// Three dimensional convolution.
419    ///
420    /// # Shapes
421    ///
422    /// x:      `[batch_size, channels_in, depth, height, width]`,
423    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
424    /// bias:   `[channels_out]`,
425    fn conv3d(
426        x: FloatTensor<B>,
427        weight: FloatTensor<B>,
428        bias: Option<FloatTensor<B>>,
429        options: ConvOptions<3>,
430    ) -> FloatTensor<B>;
431    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
432    fn conv3d_x_backward(
433        x: FloatTensor<B>,
434        weight: FloatTensor<B>,
435        output_grad: FloatTensor<B>,
436        options: ConvOptions<3>,
437    ) -> FloatTensor<B> {
438        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
439    }
440    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
441    fn conv3d_weight_backward(
442        x: FloatTensor<B>,
443        weight: FloatTensor<B>,
444        output_grad: FloatTensor<B>,
445        options: ConvOptions<3>,
446    ) -> FloatTensor<B> {
447        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
448    }
449    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
450    fn conv3d_bias_backward(
451        x: FloatTensor<B>,
452        weight: FloatTensor<B>,
453        bias: FloatTensor<B>,
454        output_grad: FloatTensor<B>,
455    ) -> FloatTensor<B> {
456        conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
457    }
458    /// One dimensional transposed convolution.
459    ///
460    /// # Shapes
461    ///
462    /// x:      `[batch_size, channels_in, length]`,
463    /// weight: `[channels_in, channels_out, length]`,
464    /// bias:   `[channels_out]`,
465    fn conv_transpose1d(
466        x: FloatTensor<B>,
467        weight: FloatTensor<B>,
468        bias: Option<FloatTensor<B>>,
469        options: ConvTransposeOptions<1>,
470    ) -> FloatTensor<B> {
471        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
472    }
473    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
474    fn conv_transpose1d_x_backward(
475        weight: FloatTensor<B>,
476        output_grad: FloatTensor<B>,
477        options: ConvTransposeOptions<1>,
478    ) -> FloatTensor<B> {
479        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
480    }
481    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
482    fn conv_transpose1d_weight_backward(
483        x: FloatTensor<B>,
484        weight: FloatTensor<B>,
485        output_grad: FloatTensor<B>,
486        options: ConvTransposeOptions<1>,
487    ) -> FloatTensor<B> {
488        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
489    }
490    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
491    fn conv_transpose1d_bias_backward(
492        x: FloatTensor<B>,
493        bias: FloatTensor<B>,
494        output_grad: FloatTensor<B>,
495    ) -> FloatTensor<B> {
496        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
497    }
498
499    /// Two dimensional transposed convolution.
500    ///
501    /// # Shapes
502    ///
503    /// x:      `[batch_size, channels_in, height, width]`,
504    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
505    /// bias:   `[channels_out]`,
506    fn conv_transpose2d(
507        x: FloatTensor<B>,
508        weight: FloatTensor<B>,
509        bias: Option<FloatTensor<B>>,
510        options: ConvTransposeOptions<2>,
511    ) -> FloatTensor<B>;
512    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
513    fn conv_transpose2d_x_backward(
514        weight: FloatTensor<B>,
515        output_grad: FloatTensor<B>,
516        options: ConvTransposeOptions<2>,
517    ) -> FloatTensor<B> {
518        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
519    }
520    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
521    fn conv_transpose2d_weight_backward(
522        x: FloatTensor<B>,
523        weight: FloatTensor<B>,
524        output_grad: FloatTensor<B>,
525        options: ConvTransposeOptions<2>,
526    ) -> FloatTensor<B> {
527        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
528    }
529    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
530    fn conv_transpose2d_bias_backward(
531        x: FloatTensor<B>,
532        bias: FloatTensor<B>,
533        output_grad: FloatTensor<B>,
534    ) -> FloatTensor<B> {
535        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
536    }
537
538    /// Three dimensional transposed convolution.
539    ///
540    /// # Shapes
541    ///
542    /// x:      `[batch_size, channels_in, height, width]`,
543    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
544    /// bias:   `[channels_out]`,
545    fn conv_transpose3d(
546        x: FloatTensor<B>,
547        weight: FloatTensor<B>,
548        bias: Option<FloatTensor<B>>,
549        options: ConvTransposeOptions<3>,
550    ) -> FloatTensor<B>;
551    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
552    fn conv_transpose3d_x_backward(
553        weight: FloatTensor<B>,
554        output_grad: FloatTensor<B>,
555        options: ConvTransposeOptions<3>,
556    ) -> FloatTensor<B> {
557        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
558    }
559    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
560    fn conv_transpose3d_weight_backward(
561        x: FloatTensor<B>,
562        weight: FloatTensor<B>,
563        output_grad: FloatTensor<B>,
564        options: ConvTransposeOptions<3>,
565    ) -> FloatTensor<B> {
566        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
567    }
568    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
569    fn conv_transpose3d_bias_backward(
570        x: FloatTensor<B>,
571        bias: FloatTensor<B>,
572        output_grad: FloatTensor<B>,
573    ) -> FloatTensor<B> {
574        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
575    }
576
577    /// Four-dimensional unfolding.
578    ///
579    /// # Shapes
580    ///
581    /// * x:      ``[batch_size, channels_in, height, width]``,
582    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
583    fn unfold4d(
584        x: FloatTensor<B>,
585        kernel_size: [usize; 2],
586        options: UnfoldOptions,
587    ) -> FloatTensor<B> {
588        if options.padding == [0, 0] && options.dilation == [1, 1] {
589            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
590            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
591
592            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
593
594            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
595            let shape = &blocks.shape().dims;
596
597            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
598
599            B::float_reshape(
600                blocks,
601                [
602                    shape[0],
603                    shape[1] * shape[2] * shape[3],
604                    shape[4] * shape[5],
605                ]
606                .into(),
607            )
608        } else {
609            unfold4d_using_conv2d::<B>(x, kernel_size, options)
610        }
611    }
612
613    /// One dimensional avg pooling.
614    ///
615    /// # Shapes
616    ///
617    /// x: [batch_size, channels, length],
618    fn avg_pool1d(
619        x: FloatTensor<B>,
620        kernel_size: usize,
621        stride: usize,
622        padding: usize,
623        count_include_pad: bool,
624    ) -> FloatTensor<B> {
625        pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad)
626    }
627    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
628    fn avg_pool1d_backward(
629        x: FloatTensor<B>,
630        grad: FloatTensor<B>,
631        kernel_size: usize,
632        stride: usize,
633        padding: usize,
634        count_include_pad: bool,
635    ) -> FloatTensor<B> {
636        pool::avg_pool1d_backward_from_2d::<B>(
637            x,
638            grad,
639            kernel_size,
640            stride,
641            padding,
642            count_include_pad,
643        )
644    }
645    /// Two dimensional avg pooling.
646    ///
647    /// # Shapes
648    ///
649    /// x: [batch_size, channels, height, width],
650    fn avg_pool2d(
651        x: FloatTensor<B>,
652        kernel_size: [usize; 2],
653        stride: [usize; 2],
654        padding: [usize; 2],
655        count_include_pad: bool,
656    ) -> FloatTensor<B>;
657    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
658    fn avg_pool2d_backward(
659        x: FloatTensor<B>,
660        grad: FloatTensor<B>,
661        kernel_size: [usize; 2],
662        stride: [usize; 2],
663        padding: [usize; 2],
664        count_include_pad: bool,
665    ) -> FloatTensor<B>;
666    /// Two dimensional adaptive avg pooling.
667    ///
668    /// # Shapes
669    ///
670    /// x: [batch_size, channels, height, width],
671    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
672    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
673    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
674    /// One dimensional adaptive avg pooling.
675    ///
676    /// # Shapes
677    ///
678    /// x: [batch_size, channels, length],
679    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
680        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
681    }
682    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
683    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
684        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
685    }
686    /// One dimensional max pooling.
687    ///
688    /// # Shapes
689    ///
690    /// x: [batch_size, channels, length],
691    fn max_pool1d(
692        x: FloatTensor<B>,
693        kernel_size: usize,
694        stride: usize,
695        padding: usize,
696        dilation: usize,
697    ) -> FloatTensor<B> {
698        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
699    }
700
701    /// One dimensional max pooling with indices.
702    ///
703    /// # Shapes
704    ///
705    /// x: [batch_size, channels, height, width],
706    fn max_pool1d_with_indices(
707        x: FloatTensor<B>,
708        kernel_size: usize,
709        stride: usize,
710        padding: usize,
711        dilation: usize,
712    ) -> MaxPool1dWithIndices<B> {
713        pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
714    }
715    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
716    fn max_pool1d_with_indices_backward(
717        x: FloatTensor<B>,
718        kernel_size: usize,
719        stride: usize,
720        padding: usize,
721        dilation: usize,
722        output_grad: FloatTensor<B>,
723        indices: IntTensor<B>,
724    ) -> MaxPool1dBackward<B> {
725        pool::max_pool1d_with_indices_backward_from_2d::<B>(
726            x,
727            kernel_size,
728            stride,
729            padding,
730            dilation,
731            output_grad,
732            indices,
733        )
734    }
735
736    /// Two dimensional max pooling.
737    ///
738    /// # Shapes
739    ///
740    /// x: [batch_size, channels, height, width],
741    fn max_pool2d(
742        x: FloatTensor<B>,
743        kernel_size: [usize; 2],
744        stride: [usize; 2],
745        padding: [usize; 2],
746        dilation: [usize; 2],
747    ) -> FloatTensor<B>;
748
749    /// Two dimensional max pooling with indices.
750    ///
751    /// # Shapes
752    ///
753    /// x: [batch_size, channels, height, width],
754    fn max_pool2d_with_indices(
755        x: FloatTensor<B>,
756        kernel_size: [usize; 2],
757        stride: [usize; 2],
758        padding: [usize; 2],
759        dilation: [usize; 2],
760    ) -> MaxPool2dWithIndices<B>;
761    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
762    fn max_pool2d_with_indices_backward(
763        x: FloatTensor<B>,
764        kernel_size: [usize; 2],
765        stride: [usize; 2],
766        padding: [usize; 2],
767        dilation: [usize; 2],
768        output_grad: FloatTensor<B>,
769        indices: IntTensor<B>,
770    ) -> MaxPool2dBackward<B>;
771
772    /// Down/up samples the input.
773    ///
774    /// # Shapes
775    ///
776    /// x: `[batch_size, channels, height, width]`,
777    fn interpolate(
778        x: FloatTensor<B>,
779        output_size: [usize; 2],
780        options: InterpolateOptions,
781    ) -> FloatTensor<B>;
782
783    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
784    fn interpolate_backward(
785        x: FloatTensor<B>,
786        grad: FloatTensor<B>,
787        output_size: [usize; 2],
788        options: InterpolateOptions,
789    ) -> FloatTensor<B>;
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795
796    #[test]
797    #[should_panic = "stride must be non-zero"]
798    fn conv_options_stride_zero() {
799        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
800    }
801
802    #[test]
803    #[should_panic = "dilation must be non-zero"]
804    fn conv_options_dilation_zero() {
805        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
806    }
807
808    #[test]
809    #[should_panic = "groups must be non-zero"]
810    fn conv_options_groups_zero() {
811        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
812    }
813
814    #[test]
815    #[should_panic = "stride must be non-zero"]
816    fn conv_transpose_options_stride_zero() {
817        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
818    }
819
820    #[test]
821    #[should_panic = "dilation must be non-zero"]
822    fn conv_transpose_options_dilation_zero() {
823        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
824    }
825
826    #[test]
827    #[should_panic = "groups must be non-zero"]
828    fn conv_transpose_options_groups_zero() {
829        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
830    }
831
832    #[test]
833    #[should_panic = "stride must be non-zero"]
834    fn deform_conv_options_stride_zero() {
835        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
836    }
837
838    #[test]
839    #[should_panic = "dilation must be non-zero"]
840    fn deform_conv_options_dilation_zero() {
841        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
842    }
843
844    #[test]
845    #[should_panic = "weight groups must be non-zero"]
846    fn deform_conv_options_weights_groups_zero() {
847        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
848    }
849
850    #[test]
851    #[should_panic = "offset groups must be non-zero"]
852    fn deform_conv_options_offset_groups_zero() {
853        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
854    }
855
856    #[test]
857    #[should_panic = "stride must be non-zero"]
858    fn unfold_options_stride_zero() {
859        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
860    }
861
862    #[test]
863    #[should_panic = "dilation must be non-zero"]
864    fn unfold_options_dilation_zero() {
865        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
866    }
867}