burn_backend/backend/ops/modules/
base.rs

1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
10#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12    /// Gradient.
13    pub x_grad: FloatTensor<B>,
14
15    /// Weights gradient.
16    pub weights_grad: FloatTensor<B>,
17
18    /// Bias gradient.
19    pub bias_grad: Option<FloatTensor<B>>,
20}
21
22/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
23#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25    /// Gradient.
26    pub x_grad: FloatTensor<B>,
27
28    /// Offset gradient.
29    pub offset_grad: FloatTensor<B>,
30
31    /// Weights gradient.
32    pub weight_grad: FloatTensor<B>,
33
34    /// Mask gradient.
35    pub mask_grad: Option<FloatTensor<B>>,
36
37    /// Bias gradient.
38    pub bias_grad: Option<FloatTensor<B>>,
39}
40
41/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
42#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44    /// Gradient.
45    pub x_grad: FloatTensor<B>,
46
47    /// Weights gradient.
48    pub weights_grad: FloatTensor<B>,
49
50    /// Bias gradient.
51    pub bias_grad: Option<FloatTensor<B>>,
52}
53
54/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
55#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57    /// Gradient.
58    pub x_grad: FloatTensor<B>,
59}
60
61/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
62#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64    /// The output tensor.
65    pub output: FloatTensor<B>,
66
67    /// The indices tensor.
68    pub indices: IntTensor<B>,
69}
70
71/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
72#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74    /// Gradient.
75    pub x_grad: FloatTensor<B>,
76}
77
78/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
79#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81    /// The output tensor.
82    pub output: FloatTensor<B>,
83
84    /// The indices tensor.
85    pub indices: IntTensor<B>,
86}
87
88/// Check that the parameter value is non-zero.
89// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
90pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91    NonZeroUsize::new(value).expect(msg);
92    value
93}
94
95/// Convolution options.
96#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98    /// Stride (non-zero).
99    pub stride: [usize; N],
100
101    /// Padding.
102    pub padding: [usize; N],
103
104    /// Dilation (non-zero).
105    pub dilation: [usize; N],
106
107    /// Groups (non-zero).
108    pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112    /// Constructs a new `ConvOptions`.
113    pub fn new(
114        stride: [usize; N],
115        padding: [usize; N],
116        dilation: [usize; N],
117        groups: usize,
118    ) -> Self {
119        Self {
120            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121            padding,
122            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123            groups: check_nonzero(groups, "groups must be non-zero"),
124        }
125    }
126}
127
128/// Convolution options.
129#[derive(Debug, Clone, Hash, PartialEq, Eq)]
130pub struct DeformConvOptions<const N: usize> {
131    /// Stride (non-zero).
132    pub stride: [usize; N],
133
134    /// Padding.
135    pub padding: [usize; N],
136
137    /// Dilation (non-zero).
138    pub dilation: [usize; N],
139
140    /// Weight Groups (non-zero).
141    pub weight_groups: usize,
142
143    /// Offset Groups (non-zero).
144    pub offset_groups: usize,
145}
146
147impl<const N: usize> DeformConvOptions<N> {
148    /// Constructs a new `DeformConvOptions`.
149    pub fn new(
150        stride: [usize; N],
151        padding: [usize; N],
152        dilation: [usize; N],
153        weight_groups: usize,
154        offset_groups: usize,
155    ) -> Self {
156        Self {
157            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
158            padding,
159            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
160            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
161            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
162        }
163    }
164}
165
166/// Transposed convolution options.
167#[derive(Debug, Clone, Hash, PartialEq, Eq)]
168pub struct ConvTransposeOptions<const N: usize> {
169    /// Stride (non-zero).
170    pub stride: [usize; N],
171
172    /// Padding.
173    pub padding: [usize; N],
174
175    /// Padding out.
176    pub padding_out: [usize; N],
177
178    /// Dilation (non-zero).
179    pub dilation: [usize; N],
180
181    /// Groups (non-zero).
182    pub groups: usize,
183}
184
185impl<const N: usize> ConvTransposeOptions<N> {
186    /// Constructs a new `ConvTransposeOptions`.
187    pub fn new(
188        stride: [usize; N],
189        padding: [usize; N],
190        padding_out: [usize; N],
191        dilation: [usize; N],
192        groups: usize,
193    ) -> Self {
194        Self {
195            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
196            padding,
197            padding_out,
198            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
199            groups: check_nonzero(groups, "groups must be non-zero"),
200        }
201    }
202}
203
204/// Unfold operation options.
205#[derive(Debug, Clone)]
206pub struct UnfoldOptions {
207    /// The number of positions to slide over the input tensor in each dimension.
208    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
209    pub stride: [usize; 2],
210
211    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
212    pub padding: [usize; 2],
213
214    /// The spacing between the blocks (patches) in the original input tensor.
215    pub dilation: [usize; 2],
216}
217
218impl UnfoldOptions {
219    /// Constructs a new `UnfoldOptions`.
220    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
221        Self {
222            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
223            padding,
224            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
225        }
226    }
227}
228
229/// Algorithm used for upsampling.
230#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
231pub enum InterpolateMode {
232    /// Nearest-neighbor interpolation.
233    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
234    Nearest,
235
236    /// Bilinear interpolation.
237    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
238    Bilinear,
239
240    /// Bicubic interpolation.
241    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
242    Bicubic,
243}
244
245/// Interpolation options.
246#[derive(new, Debug, Clone)]
247pub struct InterpolateOptions {
248    /// Algorithm used for upsampling.
249    pub mode: InterpolateMode,
250}
251
252/// Padding mode for grid sampling when coordinates are out of bounds.
253///
254/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
256pub enum GridSamplePaddingMode {
257    /// Fill with zeros for out-of-bounds coordinates.
258    #[default]
259    Zeros,
260    /// Clamp coordinates to the border (use nearest edge value).
261    Border,
262    /// Reflect coordinates at the boundary.
263    Reflection,
264}
265
266/// Options for grid sampling operations.
267#[derive(Debug, Clone)]
268pub struct GridSampleOptions {
269    /// Interpolation mode (bilinear, nearest, or bicubic).
270    pub mode: InterpolateMode,
271    /// Padding mode for out-of-bounds coordinates.
272    pub padding_mode: GridSamplePaddingMode,
273    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.
274    /// If `false`, they correspond to the corner points of the corner pixels
275    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).
276    pub align_corners: bool,
277}
278
279impl Default for GridSampleOptions {
280    fn default() -> Self {
281        Self {
282            mode: InterpolateMode::Bilinear,
283            padding_mode: GridSamplePaddingMode::Zeros,
284            align_corners: false,
285        }
286    }
287}
288
289impl GridSampleOptions {
290    /// Create new grid sample options with the given interpolation mode.
291    ///
292    /// Uses default values for padding_mode (Zeros) and align_corners (false).
293    pub fn new(mode: InterpolateMode) -> Self {
294        Self {
295            mode,
296            ..Default::default()
297        }
298    }
299
300    /// Set the padding mode.
301    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
302        self.padding_mode = padding_mode;
303        self
304    }
305
306    /// Set align_corners.
307    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
308        self.align_corners = align_corners;
309        self
310    }
311}
312
313/// Padding mode for tensor pad operations.
314///
315/// Defines how values are filled when padding a tensor beyond its original boundaries.
316///
317/// **Note**: Currently, padding is only supported on the last two dimensions of a tensor
318/// (typically height and width for image data in NCHW format).
319///
320/// # Modes
321///
322/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)
323/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)
324/// - [`Edge`](PadMode::Edge): Replicate boundary values
325#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
326pub enum PadMode {
327    /// Fill padded regions with a constant value.
328    ///
329    /// # Example
330    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:
331    /// Result: `[0, 0, 1, 2, 3]`
332    Constant(f32),
333
334    /// Reflect values at the boundary, excluding the edge value.
335    ///
336    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).
337    ///
338    /// # Example
339    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
340    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)
341    Reflect,
342
343    /// Replicate the edge values.
344    ///
345    /// # Example
346    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
347    /// Result: `[1, 1, 1, 2, 3, 4]`
348    Edge,
349}
350
351impl Default for PadMode {
352    fn default() -> Self {
353        PadMode::Constant(0.0)
354    }
355}
356
357/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
358#[derive(new)]
359pub struct InterpolateBackward<B: Backend> {
360    /// Gradient.
361    pub x_grad: FloatTensor<B>,
362}
363
364/// Module operations trait.
365pub trait ModuleOps<B: Backend> {
366    /// Embedding operation.
367    ///
368    /// # Arguments
369    ///
370    /// * `weights` - The embedding weights.
371    /// * `indices` - The indices tensor.
372    ///
373    /// # Returns
374    ///
375    /// The output tensor.
376    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
377        let [batch_size, seq_length] = indices.shape().dims();
378        let [_, d_model] = weights.shape().dims();
379
380        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
381        let output = B::float_select(weights, 0, indices);
382
383        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
384    }
385
386    /// Embedding backward operation.
387    ///
388    /// # Arguments
389    ///
390    /// * `weights` - The embedding weights.
391    /// * `output_grad` - The output gradient.
392    /// * `indices` - The indices tensor.
393    ///
394    /// # Returns
395    ///
396    /// The gradient.
397    fn embedding_backward(
398        weights: FloatTensor<B>,
399        output_grad: FloatTensor<B>,
400        indices: IntTensor<B>,
401    ) -> FloatTensor<B> {
402        let [batch_size, seq_length] = indices.shape().dims();
403        let [n_embeddings, d_model] = weights.shape().dims();
404        let device = B::float_device(&weights);
405        let dtype = output_grad.dtype();
406
407        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
408        let output_grad =
409            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
410        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
411
412        B::float_select_add(grad, 0, indices, output_grad)
413    }
414    /// One dimensional convolution.
415    ///
416    /// # Shapes
417    ///
418    /// x:      `[batch_size, channels_in, length]`,
419    /// weight: `[channels_out, channels_in, kernel_size]`,
420    /// bias:   `[channels_out]`,
421    fn conv1d(
422        x: FloatTensor<B>,
423        weight: FloatTensor<B>,
424        bias: Option<FloatTensor<B>>,
425        options: ConvOptions<1>,
426    ) -> FloatTensor<B> {
427        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
428    }
429    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
430    fn conv1d_x_backward(
431        x: FloatTensor<B>,
432        weight: FloatTensor<B>,
433        output_grad: FloatTensor<B>,
434        options: ConvOptions<1>,
435    ) -> FloatTensor<B> {
436        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
437    }
438    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
439    fn conv1d_weight_backward(
440        x: FloatTensor<B>,
441        weight: FloatTensor<B>,
442        output_grad: FloatTensor<B>,
443        options: ConvOptions<1>,
444    ) -> FloatTensor<B> {
445        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
446    }
447    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
448    fn conv1d_bias_backward(
449        x: FloatTensor<B>,
450        bias: FloatTensor<B>,
451        output_grad: FloatTensor<B>,
452    ) -> FloatTensor<B> {
453        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
454    }
455    /// Two dimensional convolution.
456    ///
457    /// # Shapes
458    ///
459    /// x:      `[batch_size, channels_in, height, width]`,
460    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
461    /// bias:   `[channels_out]`,
462    fn conv2d(
463        x: FloatTensor<B>,
464        weight: FloatTensor<B>,
465        bias: Option<FloatTensor<B>>,
466        options: ConvOptions<2>,
467    ) -> FloatTensor<B>;
468    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
469    fn conv2d_x_backward(
470        x: FloatTensor<B>,
471        weight: FloatTensor<B>,
472        output_grad: FloatTensor<B>,
473        options: ConvOptions<2>,
474    ) -> FloatTensor<B> {
475        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
476    }
477    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
478    fn conv2d_weight_backward(
479        x: FloatTensor<B>,
480        weight: FloatTensor<B>,
481        output_grad: FloatTensor<B>,
482        options: ConvOptions<2>,
483    ) -> FloatTensor<B> {
484        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
485    }
486    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
487    fn conv2d_bias_backward(
488        x: FloatTensor<B>,
489        weight: FloatTensor<B>,
490        bias: FloatTensor<B>,
491        output_grad: FloatTensor<B>,
492    ) -> FloatTensor<B> {
493        conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
494    }
495
496    /// Two dimensional deformable convolution.
497    ///
498    /// # Shapes
499    ///
500    /// x:      `[batch_size, channels_in, height, width]`,
501    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
502    /// bias:   `[channels_out]`,
503    fn deform_conv2d(
504        x: FloatTensor<B>,
505        offset: FloatTensor<B>,
506        weight: FloatTensor<B>,
507        mask: Option<FloatTensor<B>>,
508        bias: Option<FloatTensor<B>>,
509        options: DeformConvOptions<2>,
510    ) -> FloatTensor<B>;
511    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
512    fn deform_conv2d_backward(
513        x: FloatTensor<B>,
514        offset: FloatTensor<B>,
515        weight: FloatTensor<B>,
516        mask: Option<FloatTensor<B>>,
517        bias: Option<FloatTensor<B>>,
518        output_grad: FloatTensor<B>,
519        options: DeformConvOptions<2>,
520    ) -> DeformConv2dBackward<B>;
521
522    /// Three dimensional convolution.
523    ///
524    /// # Shapes
525    ///
526    /// x:      `[batch_size, channels_in, depth, height, width]`,
527    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
528    /// bias:   `[channels_out]`,
529    fn conv3d(
530        x: FloatTensor<B>,
531        weight: FloatTensor<B>,
532        bias: Option<FloatTensor<B>>,
533        options: ConvOptions<3>,
534    ) -> FloatTensor<B>;
535    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
536    fn conv3d_x_backward(
537        x: FloatTensor<B>,
538        weight: FloatTensor<B>,
539        output_grad: FloatTensor<B>,
540        options: ConvOptions<3>,
541    ) -> FloatTensor<B> {
542        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
543    }
544    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
545    fn conv3d_weight_backward(
546        x: FloatTensor<B>,
547        weight: FloatTensor<B>,
548        output_grad: FloatTensor<B>,
549        options: ConvOptions<3>,
550    ) -> FloatTensor<B> {
551        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
552    }
553    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
554    fn conv3d_bias_backward(
555        x: FloatTensor<B>,
556        weight: FloatTensor<B>,
557        bias: FloatTensor<B>,
558        output_grad: FloatTensor<B>,
559    ) -> FloatTensor<B> {
560        conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
561    }
562    /// One dimensional transposed convolution.
563    ///
564    /// # Shapes
565    ///
566    /// x:      `[batch_size, channels_in, length]`,
567    /// weight: `[channels_in, channels_out, length]`,
568    /// bias:   `[channels_out]`,
569    fn conv_transpose1d(
570        x: FloatTensor<B>,
571        weight: FloatTensor<B>,
572        bias: Option<FloatTensor<B>>,
573        options: ConvTransposeOptions<1>,
574    ) -> FloatTensor<B> {
575        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
576    }
577    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
578    fn conv_transpose1d_x_backward(
579        weight: FloatTensor<B>,
580        output_grad: FloatTensor<B>,
581        options: ConvTransposeOptions<1>,
582    ) -> FloatTensor<B> {
583        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
584    }
585    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
586    fn conv_transpose1d_weight_backward(
587        x: FloatTensor<B>,
588        weight: FloatTensor<B>,
589        output_grad: FloatTensor<B>,
590        options: ConvTransposeOptions<1>,
591    ) -> FloatTensor<B> {
592        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
593    }
594    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
595    fn conv_transpose1d_bias_backward(
596        x: FloatTensor<B>,
597        bias: FloatTensor<B>,
598        output_grad: FloatTensor<B>,
599    ) -> FloatTensor<B> {
600        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
601    }
602
603    /// Two dimensional transposed convolution.
604    ///
605    /// # Shapes
606    ///
607    /// x:      `[batch_size, channels_in, height, width]`,
608    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
609    /// bias:   `[channels_out]`,
610    fn conv_transpose2d(
611        x: FloatTensor<B>,
612        weight: FloatTensor<B>,
613        bias: Option<FloatTensor<B>>,
614        options: ConvTransposeOptions<2>,
615    ) -> FloatTensor<B>;
616    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
617    fn conv_transpose2d_x_backward(
618        weight: FloatTensor<B>,
619        output_grad: FloatTensor<B>,
620        options: ConvTransposeOptions<2>,
621    ) -> FloatTensor<B> {
622        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
623    }
624    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
625    fn conv_transpose2d_weight_backward(
626        x: FloatTensor<B>,
627        weight: FloatTensor<B>,
628        output_grad: FloatTensor<B>,
629        options: ConvTransposeOptions<2>,
630    ) -> FloatTensor<B> {
631        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
632    }
633    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
634    fn conv_transpose2d_bias_backward(
635        x: FloatTensor<B>,
636        bias: FloatTensor<B>,
637        output_grad: FloatTensor<B>,
638    ) -> FloatTensor<B> {
639        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
640    }
641
642    /// Three dimensional transposed convolution.
643    ///
644    /// # Shapes
645    ///
646    /// x:      `[batch_size, channels_in, height, width]`,
647    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
648    /// bias:   `[channels_out]`,
649    fn conv_transpose3d(
650        x: FloatTensor<B>,
651        weight: FloatTensor<B>,
652        bias: Option<FloatTensor<B>>,
653        options: ConvTransposeOptions<3>,
654    ) -> FloatTensor<B>;
655    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
656    fn conv_transpose3d_x_backward(
657        weight: FloatTensor<B>,
658        output_grad: FloatTensor<B>,
659        options: ConvTransposeOptions<3>,
660    ) -> FloatTensor<B> {
661        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
662    }
663    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
664    fn conv_transpose3d_weight_backward(
665        x: FloatTensor<B>,
666        weight: FloatTensor<B>,
667        output_grad: FloatTensor<B>,
668        options: ConvTransposeOptions<3>,
669    ) -> FloatTensor<B> {
670        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
671    }
672    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
673    fn conv_transpose3d_bias_backward(
674        x: FloatTensor<B>,
675        bias: FloatTensor<B>,
676        output_grad: FloatTensor<B>,
677    ) -> FloatTensor<B> {
678        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
679    }
680
681    /// Four-dimensional unfolding.
682    ///
683    /// # Shapes
684    ///
685    /// * x:      ``[batch_size, channels_in, height, width]``,
686    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
687    fn unfold4d(
688        x: FloatTensor<B>,
689        kernel_size: [usize; 2],
690        options: UnfoldOptions,
691    ) -> FloatTensor<B> {
692        if options.padding == [0, 0] && options.dilation == [1, 1] {
693            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
694            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
695
696            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
697
698            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
699            let shape = &blocks.shape().dims;
700
701            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
702
703            B::float_reshape(
704                blocks,
705                [
706                    shape[0],
707                    shape[1] * shape[2] * shape[3],
708                    shape[4] * shape[5],
709                ]
710                .into(),
711            )
712        } else {
713            unfold4d_using_conv2d::<B>(x, kernel_size, options)
714        }
715    }
716
717    /// One dimensional avg pooling.
718    ///
719    /// # Shapes
720    ///
721    /// x: [batch_size, channels, length],
722    fn avg_pool1d(
723        x: FloatTensor<B>,
724        kernel_size: usize,
725        stride: usize,
726        padding: usize,
727        count_include_pad: bool,
728    ) -> FloatTensor<B> {
729        pool::avg_pool1d_from_2d::<B>(x, kernel_size, stride, padding, count_include_pad)
730    }
731    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
732    fn avg_pool1d_backward(
733        x: FloatTensor<B>,
734        grad: FloatTensor<B>,
735        kernel_size: usize,
736        stride: usize,
737        padding: usize,
738        count_include_pad: bool,
739    ) -> FloatTensor<B> {
740        pool::avg_pool1d_backward_from_2d::<B>(
741            x,
742            grad,
743            kernel_size,
744            stride,
745            padding,
746            count_include_pad,
747        )
748    }
749    /// Two dimensional avg pooling.
750    ///
751    /// # Shapes
752    ///
753    /// x: [batch_size, channels, height, width],
754    fn avg_pool2d(
755        x: FloatTensor<B>,
756        kernel_size: [usize; 2],
757        stride: [usize; 2],
758        padding: [usize; 2],
759        count_include_pad: bool,
760    ) -> FloatTensor<B>;
761    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
762    fn avg_pool2d_backward(
763        x: FloatTensor<B>,
764        grad: FloatTensor<B>,
765        kernel_size: [usize; 2],
766        stride: [usize; 2],
767        padding: [usize; 2],
768        count_include_pad: bool,
769    ) -> FloatTensor<B>;
770    /// Two dimensional adaptive avg pooling.
771    ///
772    /// # Shapes
773    ///
774    /// x: [batch_size, channels, height, width],
775    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
776    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
777    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
778    /// One dimensional adaptive avg pooling.
779    ///
780    /// # Shapes
781    ///
782    /// x: [batch_size, channels, length],
783    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
784        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
785    }
786    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
787    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
788        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
789    }
790    /// One dimensional max pooling.
791    ///
792    /// # Shapes
793    ///
794    /// x: [batch_size, channels, length],
795    fn max_pool1d(
796        x: FloatTensor<B>,
797        kernel_size: usize,
798        stride: usize,
799        padding: usize,
800        dilation: usize,
801    ) -> FloatTensor<B> {
802        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation)
803    }
804
805    /// One dimensional max pooling with indices.
806    ///
807    /// # Shapes
808    ///
809    /// x: [batch_size, channels, height, width],
810    fn max_pool1d_with_indices(
811        x: FloatTensor<B>,
812        kernel_size: usize,
813        stride: usize,
814        padding: usize,
815        dilation: usize,
816    ) -> MaxPool1dWithIndices<B> {
817        pool::max_pool1d_with_indices_from_2d::<B>(x, kernel_size, stride, padding, dilation)
818    }
819    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
820    fn max_pool1d_with_indices_backward(
821        x: FloatTensor<B>,
822        kernel_size: usize,
823        stride: usize,
824        padding: usize,
825        dilation: usize,
826        output_grad: FloatTensor<B>,
827        indices: IntTensor<B>,
828    ) -> MaxPool1dBackward<B> {
829        pool::max_pool1d_with_indices_backward_from_2d::<B>(
830            x,
831            kernel_size,
832            stride,
833            padding,
834            dilation,
835            output_grad,
836            indices,
837        )
838    }
839
840    /// Two dimensional max pooling.
841    ///
842    /// # Shapes
843    ///
844    /// x: [batch_size, channels, height, width],
845    fn max_pool2d(
846        x: FloatTensor<B>,
847        kernel_size: [usize; 2],
848        stride: [usize; 2],
849        padding: [usize; 2],
850        dilation: [usize; 2],
851    ) -> FloatTensor<B>;
852
853    /// Two dimensional max pooling with indices.
854    ///
855    /// # Shapes
856    ///
857    /// x: [batch_size, channels, height, width],
858    fn max_pool2d_with_indices(
859        x: FloatTensor<B>,
860        kernel_size: [usize; 2],
861        stride: [usize; 2],
862        padding: [usize; 2],
863        dilation: [usize; 2],
864    ) -> MaxPool2dWithIndices<B>;
865    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
866    fn max_pool2d_with_indices_backward(
867        x: FloatTensor<B>,
868        kernel_size: [usize; 2],
869        stride: [usize; 2],
870        padding: [usize; 2],
871        dilation: [usize; 2],
872        output_grad: FloatTensor<B>,
873        indices: IntTensor<B>,
874    ) -> MaxPool2dBackward<B>;
875
876    /// Down/up samples the input.
877    ///
878    /// # Shapes
879    ///
880    /// x: `[batch_size, channels, height, width]`,
881    fn interpolate(
882        x: FloatTensor<B>,
883        output_size: [usize; 2],
884        options: InterpolateOptions,
885    ) -> FloatTensor<B>;
886
887    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
888    fn interpolate_backward(
889        x: FloatTensor<B>,
890        grad: FloatTensor<B>,
891        output_size: [usize; 2],
892        options: InterpolateOptions,
893    ) -> FloatTensor<B>;
894
895    /// Computes scaled dot-product attention: softmax(QKᵗ / √d) · V,
896    /// optionally applying a mask to the attention scores.
897    ///
898    /// # Arguments
899    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q,  head_dim]`
900    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
901    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
902    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
903    ///   here `true` indicates positions to mask (i.e. set to -∞ before softmax).
904    ///
905    /// # Returns
906    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
907    /// representing the attended context per head.
908    ///
909    /// # Note
910    /// This implementation does not support dropout and is intended for inference or
911    /// use cases where dropout is not needed.
912    fn attention(
913        query: FloatTensor<B>,
914        key: FloatTensor<B>,
915        value: FloatTensor<B>,
916        mask: Option<BoolTensor<B>>,
917    ) -> FloatTensor<B> {
918        attention::naive_attention::<B>(query, key, value, mask)
919    }
920}
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925
926    #[test]
927    #[should_panic = "stride must be non-zero"]
928    fn conv_options_stride_zero() {
929        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
930    }
931
932    #[test]
933    #[should_panic = "dilation must be non-zero"]
934    fn conv_options_dilation_zero() {
935        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
936    }
937
938    #[test]
939    #[should_panic = "groups must be non-zero"]
940    fn conv_options_groups_zero() {
941        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
942    }
943
944    #[test]
945    #[should_panic = "stride must be non-zero"]
946    fn conv_transpose_options_stride_zero() {
947        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
948    }
949
950    #[test]
951    #[should_panic = "dilation must be non-zero"]
952    fn conv_transpose_options_dilation_zero() {
953        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
954    }
955
956    #[test]
957    #[should_panic = "groups must be non-zero"]
958    fn conv_transpose_options_groups_zero() {
959        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
960    }
961
962    #[test]
963    #[should_panic = "stride must be non-zero"]
964    fn deform_conv_options_stride_zero() {
965        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
966    }
967
968    #[test]
969    #[should_panic = "dilation must be non-zero"]
970    fn deform_conv_options_dilation_zero() {
971        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
972    }
973
974    #[test]
975    #[should_panic = "weight groups must be non-zero"]
976    fn deform_conv_options_weights_groups_zero() {
977        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
978    }
979
980    #[test]
981    #[should_panic = "offset groups must be non-zero"]
982    fn deform_conv_options_offset_groups_zero() {
983        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
984    }
985
986    #[test]
987    #[should_panic = "stride must be non-zero"]
988    fn unfold_options_stride_zero() {
989        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
990    }
991
992    #[test]
993    #[should_panic = "dilation must be non-zero"]
994    fn unfold_options_dilation_zero() {
995        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
996    }
997}