burn_tensor/tensor/ops/modules/
base.rs

1use core::num::NonZeroUsize;
2
3use super::{conv, pool, unfold::unfold4d_using_conv2d};
4use crate::{
5    backend::Backend,
6    ops::{FloatTensor, IntTensor},
7    Shape, TensorMetadata,
8};
9
10/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
11#[derive(new)]
12pub struct Conv2dBackward<B: Backend> {
13    /// Gradient.
14    pub x_grad: FloatTensor<B>,
15
16    /// Weights gradient.
17    pub weights_grad: FloatTensor<B>,
18
19    /// Bias gradient.
20    pub bias_grad: Option<FloatTensor<B>>,
21}
22
23/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
24#[derive(new)]
25pub struct DeformConv2dBackward<B: Backend> {
26    /// Gradient.
27    pub x_grad: FloatTensor<B>,
28
29    /// Offset gradient.
30    pub offset_grad: FloatTensor<B>,
31
32    /// Weights gradient.
33    pub weight_grad: FloatTensor<B>,
34
35    /// Mask gradient.
36    pub mask_grad: Option<FloatTensor<B>>,
37
38    /// Bias gradient.
39    pub bias_grad: Option<FloatTensor<B>>,
40}
41
42/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
43#[derive(new)]
44pub struct Conv3dBackward<B: Backend> {
45    /// Gradient.
46    pub x_grad: FloatTensor<B>,
47
48    /// Weights gradient.
49    pub weights_grad: FloatTensor<B>,
50
51    /// Bias gradient.
52    pub bias_grad: Option<FloatTensor<B>>,
53}
54
55/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
56#[derive(new)]
57pub struct MaxPool1dBackward<B: Backend> {
58    /// Gradient.
59    pub x_grad: FloatTensor<B>,
60}
61
62/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
63#[derive(new)]
64pub struct MaxPool1dWithIndices<B: Backend> {
65    /// The output tensor.
66    pub output: FloatTensor<B>,
67
68    /// The indices tensor.
69    pub indices: IntTensor<B>,
70}
71
72/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
73#[derive(new)]
74pub struct MaxPool2dBackward<B: Backend> {
75    /// Gradient.
76    pub x_grad: FloatTensor<B>,
77}
78
79/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
80#[derive(new)]
81pub struct MaxPool2dWithIndices<B: Backend> {
82    /// The output tensor.
83    pub output: FloatTensor<B>,
84
85    /// The indices tensor.
86    pub indices: IntTensor<B>,
87}
88
89/// Check that the parameter value is non-zero.
90// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
91pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
92    NonZeroUsize::new(value).expect(msg);
93    value
94}
95
96/// Convolution options.
97#[derive(Debug, Clone, Hash, PartialEq, Eq)]
98pub struct ConvOptions<const N: usize> {
99    /// Stride (non-zero).
100    pub stride: [usize; N],
101
102    /// Padding.
103    pub padding: [usize; N],
104
105    /// Dilation (non-zero).
106    pub dilation: [usize; N],
107
108    /// Groups (non-zero).
109    pub groups: usize,
110}
111
112impl<const N: usize> ConvOptions<N> {
113    /// Constructs a new `ConvOptions`.
114    pub fn new(
115        stride: [usize; N],
116        padding: [usize; N],
117        dilation: [usize; N],
118        groups: usize,
119    ) -> Self {
120        Self {
121            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
122            padding,
123            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
124            groups: check_nonzero(groups, "groups must be non-zero"),
125        }
126    }
127}
128
129/// Convolution options.
130#[derive(Debug, Clone, Hash, PartialEq, Eq)]
131pub struct DeformConvOptions<const N: usize> {
132    /// Stride (non-zero).
133    pub stride: [usize; N],
134
135    /// Padding.
136    pub padding: [usize; N],
137
138    /// Dilation (non-zero).
139    pub dilation: [usize; N],
140
141    /// Weight Groups (non-zero).
142    pub weight_groups: usize,
143
144    /// Offset Groups (non-zero).
145    pub offset_groups: usize,
146}
147
148impl<const N: usize> DeformConvOptions<N> {
149    /// Constructs a new `DeformConvOptions`.
150    pub fn new(
151        stride: [usize; N],
152        padding: [usize; N],
153        dilation: [usize; N],
154        weight_groups: usize,
155        offset_groups: usize,
156    ) -> Self {
157        Self {
158            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
159            padding,
160            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
161            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
162            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
163        }
164    }
165}
166
167/// Transposed convolution options.
168#[derive(Debug, Clone, Hash, PartialEq, Eq)]
169pub struct ConvTransposeOptions<const N: usize> {
170    /// Stride (non-zero).
171    pub stride: [usize; N],
172
173    /// Padding.
174    pub padding: [usize; N],
175
176    /// Padding out.
177    pub padding_out: [usize; N],
178
179    /// Dilation (non-zero).
180    pub dilation: [usize; N],
181
182    /// Groups (non-zero).
183    pub groups: usize,
184}
185
186impl<const N: usize> ConvTransposeOptions<N> {
187    /// Constructs a new `ConvTransposeOptions`.
188    pub fn new(
189        stride: [usize; N],
190        padding: [usize; N],
191        padding_out: [usize; N],
192        dilation: [usize; N],
193        groups: usize,
194    ) -> Self {
195        Self {
196            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
197            padding,
198            padding_out,
199            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
200            groups: check_nonzero(groups, "groups must be non-zero"),
201        }
202    }
203}
204
205/// Unfold operation options.
206#[derive(Debug, Clone)]
207pub struct UnfoldOptions {
208    /// The number of positions to slide over the input tensor in each dimension.
209    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
210    pub stride: [usize; 2],
211
212    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
213    pub padding: [usize; 2],
214
215    /// The spacing between the blocks (patches) in the original input tensor.
216    pub dilation: [usize; 2],
217}
218
219impl UnfoldOptions {
220    /// Constructs a new `UnfoldOptions`.
221    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
222        Self {
223            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
224            padding,
225            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
226        }
227    }
228}
229
230/// Algorithm used for upsampling.
231#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
232pub enum InterpolateMode {
233    /// Nearest-neighbor interpolation.
234    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
235    Nearest,
236
237    /// Bilinear interpolation.
238    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
239    Bilinear,
240
241    /// Bicubic interpolation.
242    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
243    Bicubic,
244}
245
246/// Interpolation options.
247#[derive(new, Debug, Clone)]
248pub struct InterpolateOptions {
249    /// Algorithm used for upsampling.
250    pub mode: InterpolateMode,
251}
252
253/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
254#[derive(new)]
255pub struct InterpolateBackward<B: Backend> {
256    /// Gradient.
257    pub x_grad: FloatTensor<B>,
258}
259
260/// Module operations trait.
261pub trait ModuleOps<B: Backend> {
262    /// Embedding operation.
263    ///
264    /// # Arguments
265    ///
266    /// * `weights` - The embedding weights.
267    /// * `indices` - The indices tensor.
268    ///
269    /// # Returns
270    ///
271    /// The output tensor.
272    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
273        let [batch_size, seq_length] = indices.shape().dims();
274        let [_, d_model] = weights.shape().dims();
275
276        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
277        let output = B::float_select(weights, 0, indices);
278
279        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
280    }
281
282    /// Embedding backward operation.
283    ///
284    /// # Arguments
285    ///
286    /// * `weights` - The embedding weights.
287    /// * `output_grad` - The output gradient.
288    /// * `indices` - The indices tensor.
289    ///
290    /// # Returns
291    ///
292    /// The gradient.
293    fn embedding_backward(
294        weights: FloatTensor<B>,
295        output_grad: FloatTensor<B>,
296        indices: IntTensor<B>,
297    ) -> FloatTensor<B> {
298        let [batch_size, seq_length] = indices.shape().dims();
299        let [n_embeddings, d_model] = weights.shape().dims();
300        let device = B::float_device(&weights);
301
302        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
303        let output_grad =
304            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
305        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device);
306
307        B::float_select_assign(grad, 0, indices, output_grad)
308    }
309    /// One dimensional convolution.
310    ///
311    /// # Shapes
312    ///
313    /// x:      `[batch_size, channels_in, length]`,
314    /// weight: `[channels_out, channels_in, kernel_size]`,
315    /// bias:   `[channels_out]`,
316    fn conv1d(
317        x: FloatTensor<B>,
318        weight: FloatTensor<B>,
319        bias: Option<FloatTensor<B>>,
320        options: ConvOptions<1>,
321    ) -> FloatTensor<B> {
322        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
323    }
324    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
325    fn conv1d_x_backward(
326        x: FloatTensor<B>,
327        weight: FloatTensor<B>,
328        output_grad: FloatTensor<B>,
329        options: ConvOptions<1>,
330    ) -> FloatTensor<B> {
331        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
332    }
333    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
334    fn conv1d_weight_backward(
335        x: FloatTensor<B>,
336        weight: FloatTensor<B>,
337        output_grad: FloatTensor<B>,
338        options: ConvOptions<1>,
339    ) -> FloatTensor<B> {
340        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
341    }
342    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
343    fn conv1d_bias_backward(
344        x: FloatTensor<B>,
345        bias: FloatTensor<B>,
346        output_grad: FloatTensor<B>,
347    ) -> FloatTensor<B> {
348        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
349    }
350    /// Two dimensional convolution.
351    ///
352    /// # Shapes
353    ///
354    /// x:      `[batch_size, channels_in, height, width]`,
355    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
356    /// bias:   `[channels_out]`,
357    fn conv2d(
358        x: FloatTensor<B>,
359        weight: FloatTensor<B>,
360        bias: Option<FloatTensor<B>>,
361        options: ConvOptions<2>,
362    ) -> FloatTensor<B>;
363    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
364    fn conv2d_x_backward(
365        x: FloatTensor<B>,
366        weight: FloatTensor<B>,
367        output_grad: FloatTensor<B>,
368        options: ConvOptions<2>,
369    ) -> FloatTensor<B> {
370        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
371    }
372    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
373    fn conv2d_weight_backward(
374        x: FloatTensor<B>,
375        weight: FloatTensor<B>,
376        output_grad: FloatTensor<B>,
377        options: ConvOptions<2>,
378    ) -> FloatTensor<B> {
379        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
380    }
381    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
382    fn conv2d_bias_backward(
383        x: FloatTensor<B>,
384        weight: FloatTensor<B>,
385        bias: FloatTensor<B>,
386        output_grad: FloatTensor<B>,
387    ) -> FloatTensor<B> {
388        conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
389    }
390
391    /// Two dimensional deformable convolution.
392    ///
393    /// # Shapes
394    ///
395    /// x:      `[batch_size, channels_in, height, width]`,
396    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
397    /// bias:   `[channels_out]`,
398    fn deform_conv2d(
399        x: FloatTensor<B>,
400        offset: FloatTensor<B>,
401        weight: FloatTensor<B>,
402        mask: Option<FloatTensor<B>>,
403        bias: Option<FloatTensor<B>>,
404        options: DeformConvOptions<2>,
405    ) -> FloatTensor<B>;
406    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
407    fn deform_conv2d_backward(
408        x: FloatTensor<B>,
409        offset: FloatTensor<B>,
410        weight: FloatTensor<B>,
411        mask: Option<FloatTensor<B>>,
412        bias: Option<FloatTensor<B>>,
413        output_grad: FloatTensor<B>,
414        options: DeformConvOptions<2>,
415    ) -> DeformConv2dBackward<B>;
416
417    /// Three dimensional convolution.
418    ///
419    /// # Shapes
420    ///
421    /// x:      `[batch_size, channels_in, depth, height, width]`,
422    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
423    /// bias:   `[channels_out]`,
424    fn conv3d(
425        x: FloatTensor<B>,
426        weight: FloatTensor<B>,
427        bias: Option<FloatTensor<B>>,
428        options: ConvOptions<3>,
429    ) -> FloatTensor<B>;
430    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
431    fn conv3d_x_backward(
432        x: FloatTensor<B>,
433        weight: FloatTensor<B>,
434        output_grad: FloatTensor<B>,
435        options: ConvOptions<3>,
436    ) -> FloatTensor<B> {
437        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
438    }
439    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
440    fn conv3d_weight_backward(
441        x: FloatTensor<B>,
442        weight: FloatTensor<B>,
443        output_grad: FloatTensor<B>,
444        options: ConvOptions<3>,
445    ) -> FloatTensor<B> {
446        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
447    }
448    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
449    fn conv3d_bias_backward(
450        x: FloatTensor<B>,
451        weight: FloatTensor<B>,
452        bias: FloatTensor<B>,
453        output_grad: FloatTensor<B>,
454    ) -> FloatTensor<B> {
455        conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
456    }
457    /// One dimensional transposed convolution.
458    ///
459    /// # Shapes
460    ///
461    /// x:      `[batch_size, channels_in, length]`,
462    /// weight: `[channels_in, channels_out, length]`,
463    /// bias:   `[channels_out]`,
464    fn conv_transpose1d(
465        x: FloatTensor<B>,
466        weight: FloatTensor<B>,
467        bias: Option<FloatTensor<B>>,
468        options: ConvTransposeOptions<1>,
469    ) -> FloatTensor<B> {
470        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
471    }
472    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
473    fn conv_transpose1d_x_backward(
474        weight: FloatTensor<B>,
475        output_grad: FloatTensor<B>,
476        options: ConvTransposeOptions<1>,
477    ) -> FloatTensor<B> {
478        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
479    }
480    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
481    fn conv_transpose1d_weight_backward(
482        x: FloatTensor<B>,
483        weight: FloatTensor<B>,
484        output_grad: FloatTensor<B>,
485        options: ConvTransposeOptions<1>,
486    ) -> FloatTensor<B> {
487        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
488    }
489    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
490    fn conv_transpose1d_bias_backward(
491        x: FloatTensor<B>,
492        bias: FloatTensor<B>,
493        output_grad: FloatTensor<B>,
494    ) -> FloatTensor<B> {
495        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
496    }
497
498    /// Two dimensional transposed convolution.
499    ///
500    /// # Shapes
501    ///
502    /// x:      `[batch_size, channels_in, height, width]`,
503    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
504    /// bias:   `[channels_out]`,
505    fn conv_transpose2d(
506        x: FloatTensor<B>,
507        weight: FloatTensor<B>,
508        bias: Option<FloatTensor<B>>,
509        options: ConvTransposeOptions<2>,
510    ) -> FloatTensor<B>;
511    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
512    fn conv_transpose2d_x_backward(
513        weight: FloatTensor<B>,
514        output_grad: FloatTensor<B>,
515        options: ConvTransposeOptions<2>,
516    ) -> FloatTensor<B> {
517        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
518    }
519    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
520    fn conv_transpose2d_weight_backward(
521        x: FloatTensor<B>,
522        weight: FloatTensor<B>,
523        output_grad: FloatTensor<B>,
524        options: ConvTransposeOptions<2>,
525    ) -> FloatTensor<B> {
526        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
527    }
528    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
529    fn conv_transpose2d_bias_backward(
530        x: FloatTensor<B>,
531        bias: FloatTensor<B>,
532        output_grad: FloatTensor<B>,
533    ) -> FloatTensor<B> {
534        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
535    }
536
537    /// Three dimensional transposed convolution.
538    ///
539    /// # Shapes
540    ///
541    /// x:      `[batch_size, channels_in, height, width]`,
542    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
543    /// bias:   `[channels_out]`,
544    fn conv_transpose3d(
545        x: FloatTensor<B>,
546        weight: FloatTensor<B>,
547        bias: Option<FloatTensor<B>>,
548        options: ConvTransposeOptions<3>,
549    ) -> FloatTensor<B>;
550    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
551    fn conv_transpose3d_x_backward(
552        weight: FloatTensor<B>,
553        output_grad: FloatTensor<B>,
554        options: ConvTransposeOptions<3>,
555    ) -> FloatTensor<B> {
556        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
557    }
558    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
559    fn conv_transpose3d_weight_backward(
560        x: FloatTensor<B>,
561        weight: FloatTensor<B>,
562        output_grad: FloatTensor<B>,
563        options: ConvTransposeOptions<3>,
564    ) -> FloatTensor<B> {
565        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
566    }
567    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
568    fn conv_transpose3d_bias_backward(
569        x: FloatTensor<B>,
570        bias: FloatTensor<B>,
571        output_grad: FloatTensor<B>,
572    ) -> FloatTensor<B> {
573        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
574    }
575
576    /// Four-dimensional unfolding.
577    ///
578    /// # Shapes
579    ///
580    /// x:      `[batch_size, channels_in, height, width]`,
581    /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`,
582    fn unfold4d(
583        x: FloatTensor<B>,
584        kernel_size: [usize; 2],
585        options: UnfoldOptions,
586    ) -> FloatTensor<B> {
587        unfold4d_using_conv2d::<B>(x, kernel_size, options)
588    }
589
590    /// One dimensional avg pooling.
591    ///
592    /// # Shapes
593    ///
594    /// x: [batch_size, channels, length],
595    fn avg_pool1d(
596        x: FloatTensor<B>,
597        kernel_size: usize,
598        stride: usize,
599        padding: usize,
600        count_include_pad: bool,
601    ) -> FloatTensor<B> {
602        pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad)
603    }
604    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
605    fn avg_pool1d_backward(
606        x: FloatTensor<B>,
607        grad: FloatTensor<B>,
608        kernel_size: usize,
609        stride: usize,
610        padding: usize,
611        count_include_pad: bool,
612    ) -> FloatTensor<B> {
613        pool::avg_pool1d_backward_from_2d::<B>(
614            x,
615            grad,
616            kernel_size,
617            stride,
618            padding,
619            count_include_pad,
620        )
621    }
622    /// Two dimensional avg pooling.
623    ///
624    /// # Shapes
625    ///
626    /// x: [batch_size, channels, height, width],
627    fn avg_pool2d(
628        x: FloatTensor<B>,
629        kernel_size: [usize; 2],
630        stride: [usize; 2],
631        padding: [usize; 2],
632        count_include_pad: bool,
633    ) -> FloatTensor<B>;
634    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
635    fn avg_pool2d_backward(
636        x: FloatTensor<B>,
637        grad: FloatTensor<B>,
638        kernel_size: [usize; 2],
639        stride: [usize; 2],
640        padding: [usize; 2],
641        count_include_pad: bool,
642    ) -> FloatTensor<B>;
643    /// Two dimensional adaptive avg pooling.
644    ///
645    /// # Shapes
646    ///
647    /// x: [batch_size, channels, height, width],
648    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
649    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
650    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
651    /// One dimensional adaptive avg pooling.
652    ///
653    /// # Shapes
654    ///
655    /// x: [batch_size, channels, length],
656    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
657        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
658    }
659    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
660    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
661        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
662    }
663    /// One dimensional max pooling.
664    ///
665    /// # Shapes
666    ///
667    /// x: [batch_size, channels, length],
668    fn max_pool1d(
669        x: FloatTensor<B>,
670        kernel_size: usize,
671        stride: usize,
672        padding: usize,
673        dilation: usize,
674    ) -> FloatTensor<B> {
675        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
676    }
677
678    /// One dimensional max pooling with indices.
679    ///
680    /// # Shapes
681    ///
682    /// x: [batch_size, channels, height, width],
683    fn max_pool1d_with_indices(
684        x: FloatTensor<B>,
685        kernel_size: usize,
686        stride: usize,
687        padding: usize,
688        dilation: usize,
689    ) -> MaxPool1dWithIndices<B> {
690        pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
691    }
692    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
693    fn max_pool1d_with_indices_backward(
694        x: FloatTensor<B>,
695        kernel_size: usize,
696        stride: usize,
697        padding: usize,
698        dilation: usize,
699        output_grad: FloatTensor<B>,
700        indices: IntTensor<B>,
701    ) -> MaxPool1dBackward<B> {
702        pool::max_pool1d_with_indices_backward_from_2d::<B>(
703            x,
704            kernel_size,
705            stride,
706            padding,
707            dilation,
708            output_grad,
709            indices,
710        )
711    }
712
713    /// Two dimensional max pooling.
714    ///
715    /// # Shapes
716    ///
717    /// x: [batch_size, channels, height, width],
718    fn max_pool2d(
719        x: FloatTensor<B>,
720        kernel_size: [usize; 2],
721        stride: [usize; 2],
722        padding: [usize; 2],
723        dilation: [usize; 2],
724    ) -> FloatTensor<B>;
725
726    /// Two dimensional max pooling with indices.
727    ///
728    /// # Shapes
729    ///
730    /// x: [batch_size, channels, height, width],
731    fn max_pool2d_with_indices(
732        x: FloatTensor<B>,
733        kernel_size: [usize; 2],
734        stride: [usize; 2],
735        padding: [usize; 2],
736        dilation: [usize; 2],
737    ) -> MaxPool2dWithIndices<B>;
738    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
739    fn max_pool2d_with_indices_backward(
740        x: FloatTensor<B>,
741        kernel_size: [usize; 2],
742        stride: [usize; 2],
743        padding: [usize; 2],
744        dilation: [usize; 2],
745        output_grad: FloatTensor<B>,
746        indices: IntTensor<B>,
747    ) -> MaxPool2dBackward<B>;
748
749    /// Down/up samples the input.
750    ///
751    /// # Shapes
752    ///
753    /// x: `[batch_size, channels, height, width]`,
754    fn interpolate(
755        x: FloatTensor<B>,
756        output_size: [usize; 2],
757        options: InterpolateOptions,
758    ) -> FloatTensor<B>;
759
760    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
761    fn interpolate_backward(
762        x: FloatTensor<B>,
763        grad: FloatTensor<B>,
764        output_size: [usize; 2],
765        options: InterpolateOptions,
766    ) -> FloatTensor<B>;
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    #[test]
774    #[should_panic = "stride must be non-zero"]
775    fn conv_options_stride_zero() {
776        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
777    }
778
779    #[test]
780    #[should_panic = "dilation must be non-zero"]
781    fn conv_options_dilation_zero() {
782        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
783    }
784
785    #[test]
786    #[should_panic = "groups must be non-zero"]
787    fn conv_options_groups_zero() {
788        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
789    }
790
791    #[test]
792    #[should_panic = "stride must be non-zero"]
793    fn conv_transpose_options_stride_zero() {
794        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
795    }
796
797    #[test]
798    #[should_panic = "dilation must be non-zero"]
799    fn conv_transpose_options_dilation_zero() {
800        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
801    }
802
803    #[test]
804    #[should_panic = "groups must be non-zero"]
805    fn conv_transpose_options_groups_zero() {
806        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
807    }
808
809    #[test]
810    #[should_panic = "stride must be non-zero"]
811    fn deform_conv_options_stride_zero() {
812        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
813    }
814
815    #[test]
816    #[should_panic = "dilation must be non-zero"]
817    fn deform_conv_options_dilation_zero() {
818        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
819    }
820
821    #[test]
822    #[should_panic = "weight groups must be non-zero"]
823    fn deform_conv_options_weights_groups_zero() {
824        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
825    }
826
827    #[test]
828    #[should_panic = "offset groups must be non-zero"]
829    fn deform_conv_options_offset_groups_zero() {
830        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
831    }
832
833    #[test]
834    #[should_panic = "stride must be non-zero"]
835    fn unfold_options_stride_zero() {
836        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
837    }
838
839    #[test]
840    #[should_panic = "dilation must be non-zero"]
841    fn unfold_options_dilation_zero() {
842        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
843    }
844}