burn_backend/backend/ops/modules/
base.rs

1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, ElementConversion, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
10#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12    /// Gradient.
13    pub x_grad: FloatTensor<B>,
14
15    /// Weights gradient.
16    pub weights_grad: FloatTensor<B>,
17
18    /// Bias gradient.
19    pub bias_grad: Option<FloatTensor<B>>,
20}
21
22/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
23#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25    /// Gradient.
26    pub x_grad: FloatTensor<B>,
27
28    /// Offset gradient.
29    pub offset_grad: FloatTensor<B>,
30
31    /// Weights gradient.
32    pub weight_grad: FloatTensor<B>,
33
34    /// Mask gradient.
35    pub mask_grad: Option<FloatTensor<B>>,
36
37    /// Bias gradient.
38    pub bias_grad: Option<FloatTensor<B>>,
39}
40
41/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
42#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44    /// Gradient.
45    pub x_grad: FloatTensor<B>,
46
47    /// Weights gradient.
48    pub weights_grad: FloatTensor<B>,
49
50    /// Bias gradient.
51    pub bias_grad: Option<FloatTensor<B>>,
52}
53
54/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
55#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57    /// Gradient.
58    pub x_grad: FloatTensor<B>,
59}
60
61/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
62#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64    /// The output tensor.
65    pub output: FloatTensor<B>,
66
67    /// The indices tensor.
68    pub indices: IntTensor<B>,
69}
70
71/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
72#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74    /// Gradient.
75    pub x_grad: FloatTensor<B>,
76}
77
78/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
79#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81    /// The output tensor.
82    pub output: FloatTensor<B>,
83
84    /// The indices tensor.
85    pub indices: IntTensor<B>,
86}
87
88/// Check that the parameter value is non-zero.
89// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
90pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91    NonZeroUsize::new(value).expect(msg);
92    value
93}
94
95/// Convolution options.
96#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98    /// Stride (non-zero).
99    pub stride: [usize; N],
100
101    /// Padding.
102    pub padding: [usize; N],
103
104    /// Dilation (non-zero).
105    pub dilation: [usize; N],
106
107    /// Groups (non-zero).
108    pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112    /// Constructs a new `ConvOptions`.
113    pub fn new(
114        stride: [usize; N],
115        padding: [usize; N],
116        dilation: [usize; N],
117        groups: usize,
118    ) -> Self {
119        Self {
120            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121            padding,
122            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123            groups: check_nonzero(groups, "groups must be non-zero"),
124        }
125    }
126}
127
128/// Convolution options.
129#[derive(Debug, Clone, Hash, PartialEq, Eq)]
130pub struct DeformConvOptions<const N: usize> {
131    /// Stride (non-zero).
132    pub stride: [usize; N],
133
134    /// Padding.
135    pub padding: [usize; N],
136
137    /// Dilation (non-zero).
138    pub dilation: [usize; N],
139
140    /// Weight Groups (non-zero).
141    pub weight_groups: usize,
142
143    /// Offset Groups (non-zero).
144    pub offset_groups: usize,
145}
146
147impl<const N: usize> DeformConvOptions<N> {
148    /// Constructs a new `DeformConvOptions`.
149    pub fn new(
150        stride: [usize; N],
151        padding: [usize; N],
152        dilation: [usize; N],
153        weight_groups: usize,
154        offset_groups: usize,
155    ) -> Self {
156        Self {
157            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
158            padding,
159            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
160            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
161            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
162        }
163    }
164}
165
166/// Transposed convolution options.
167#[derive(Debug, Clone, Hash, PartialEq, Eq)]
168pub struct ConvTransposeOptions<const N: usize> {
169    /// Stride (non-zero).
170    pub stride: [usize; N],
171
172    /// Padding.
173    pub padding: [usize; N],
174
175    /// Padding out.
176    pub padding_out: [usize; N],
177
178    /// Dilation (non-zero).
179    pub dilation: [usize; N],
180
181    /// Groups (non-zero).
182    pub groups: usize,
183}
184
185impl<const N: usize> ConvTransposeOptions<N> {
186    /// Constructs a new `ConvTransposeOptions`.
187    pub fn new(
188        stride: [usize; N],
189        padding: [usize; N],
190        padding_out: [usize; N],
191        dilation: [usize; N],
192        groups: usize,
193    ) -> Self {
194        Self {
195            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
196            padding,
197            padding_out,
198            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
199            groups: check_nonzero(groups, "groups must be non-zero"),
200        }
201    }
202}
203
204/// Unfold operation options.
205#[derive(Debug, Clone)]
206pub struct UnfoldOptions {
207    /// The number of positions to slide over the input tensor in each dimension.
208    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
209    pub stride: [usize; 2],
210
211    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
212    pub padding: [usize; 2],
213
214    /// The spacing between the blocks (patches) in the original input tensor.
215    pub dilation: [usize; 2],
216}
217
218impl UnfoldOptions {
219    /// Constructs a new `UnfoldOptions`.
220    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
221        Self {
222            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
223            padding,
224            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
225        }
226    }
227}
228
229/// Algorithm used for upsampling.
230#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
231pub enum InterpolateMode {
232    /// Nearest-neighbor interpolation.
233    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
234    Nearest,
235
236    /// Bilinear interpolation.
237    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
238    Bilinear,
239
240    /// Bicubic interpolation.
241    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
242    Bicubic,
243}
244
245/// Interpolation options.
246#[derive(new, Debug, Clone)]
247pub struct InterpolateOptions {
248    /// Algorithm used for upsampling.
249    pub mode: InterpolateMode,
250}
251
252/// Padding mode for grid sampling when coordinates are out of bounds.
253///
254/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
256pub enum GridSamplePaddingMode {
257    /// Fill with zeros for out-of-bounds coordinates.
258    #[default]
259    Zeros,
260    /// Clamp coordinates to the border (use nearest edge value).
261    Border,
262    /// Reflect coordinates at the boundary.
263    Reflection,
264}
265
266/// Options for grid sampling operations.
267#[derive(Debug, Clone)]
268pub struct GridSampleOptions {
269    /// Interpolation mode (bilinear, nearest, or bicubic).
270    pub mode: InterpolateMode,
271    /// Padding mode for out-of-bounds coordinates.
272    pub padding_mode: GridSamplePaddingMode,
273    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.
274    /// If `false`, they correspond to the corner points of the corner pixels
275    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).
276    pub align_corners: bool,
277}
278
279impl Default for GridSampleOptions {
280    fn default() -> Self {
281        Self {
282            mode: InterpolateMode::Bilinear,
283            padding_mode: GridSamplePaddingMode::Zeros,
284            align_corners: false,
285        }
286    }
287}
288
289impl From<InterpolateMode> for GridSampleOptions {
290    fn from(value: InterpolateMode) -> Self {
291        GridSampleOptions::new(value)
292    }
293}
294
295impl GridSampleOptions {
296    /// Create new grid sample options with the given interpolation mode.
297    ///
298    /// Uses default values for padding_mode (Zeros) and align_corners (false).
299    pub fn new(mode: InterpolateMode) -> Self {
300        Self {
301            mode,
302            ..Default::default()
303        }
304    }
305
306    /// Set the padding mode.
307    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
308        self.padding_mode = padding_mode;
309        self
310    }
311
312    /// Set align_corners.
313    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
314        self.align_corners = align_corners;
315        self
316    }
317}
318
319/// Padding mode for tensor pad operations.
320///
321/// Defines how values are filled when padding a tensor beyond its original boundaries.
322///
323/// **Note**: Currently, padding is only supported on the last two dimensions of a tensor
324/// (typically height and width for image data in NCHW format).
325///
326/// # Modes
327///
328/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)
329/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)
330/// - [`Edge`](PadMode::Edge): Replicate boundary values
331#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
332pub enum PadMode {
333    /// Fill padded regions with a constant value.
334    ///
335    /// # Example
336    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:
337    /// Result: `[0, 0, 1, 2, 3]`
338    Constant(f32),
339
340    /// Reflect values at the boundary, excluding the edge value.
341    ///
342    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).
343    ///
344    /// # Example
345    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
346    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)
347    Reflect,
348
349    /// Replicate the edge values.
350    ///
351    /// # Example
352    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
353    /// Result: `[1, 1, 1, 2, 3, 4]`
354    Edge,
355}
356
357impl Default for PadMode {
358    fn default() -> Self {
359        PadMode::Constant(0.0)
360    }
361}
362
363impl<E: ElementConversion> From<E> for PadMode {
364    fn from(value: E) -> Self {
365        PadMode::Constant(value.elem())
366    }
367}
368
369/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
370#[derive(new)]
371pub struct InterpolateBackward<B: Backend> {
372    /// Gradient.
373    pub x_grad: FloatTensor<B>,
374}
375
376/// Module operations trait.
377pub trait ModuleOps<B: Backend> {
378    /// Embedding operation.
379    ///
380    /// # Arguments
381    ///
382    /// * `weights` - The embedding weights.
383    /// * `indices` - The indices tensor.
384    ///
385    /// # Returns
386    ///
387    /// The output tensor.
388    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
389        let [batch_size, seq_length] = indices.shape().dims();
390        let [_, d_model] = weights.shape().dims();
391
392        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
393        let output = B::float_select(weights, 0, indices);
394
395        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
396    }
397
398    /// Embedding backward operation.
399    ///
400    /// # Arguments
401    ///
402    /// * `weights` - The embedding weights.
403    /// * `output_grad` - The output gradient.
404    /// * `indices` - The indices tensor.
405    ///
406    /// # Returns
407    ///
408    /// The gradient.
409    fn embedding_backward(
410        weights: FloatTensor<B>,
411        output_grad: FloatTensor<B>,
412        indices: IntTensor<B>,
413    ) -> FloatTensor<B> {
414        let [batch_size, seq_length] = indices.shape().dims();
415        let [n_embeddings, d_model] = weights.shape().dims();
416        let device = B::float_device(&weights);
417        let dtype = output_grad.dtype();
418
419        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
420        let output_grad =
421            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
422        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
423
424        B::float_select_add(grad, 0, indices, output_grad)
425    }
426    /// One dimensional convolution.
427    ///
428    /// # Shapes
429    ///
430    /// x:      `[batch_size, channels_in, length]`,
431    /// weight: `[channels_out, channels_in, kernel_size]`,
432    /// bias:   `[channels_out]`,
433    fn conv1d(
434        x: FloatTensor<B>,
435        weight: FloatTensor<B>,
436        bias: Option<FloatTensor<B>>,
437        options: ConvOptions<1>,
438    ) -> FloatTensor<B> {
439        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
440    }
441    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
442    fn conv1d_x_backward(
443        x: FloatTensor<B>,
444        weight: FloatTensor<B>,
445        output_grad: FloatTensor<B>,
446        options: ConvOptions<1>,
447    ) -> FloatTensor<B> {
448        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
449    }
450    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
451    fn conv1d_weight_backward(
452        x: FloatTensor<B>,
453        weight: FloatTensor<B>,
454        output_grad: FloatTensor<B>,
455        options: ConvOptions<1>,
456    ) -> FloatTensor<B> {
457        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
458    }
459    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
460    fn conv1d_bias_backward(
461        x: FloatTensor<B>,
462        bias: FloatTensor<B>,
463        output_grad: FloatTensor<B>,
464    ) -> FloatTensor<B> {
465        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
466    }
467    /// Two dimensional convolution.
468    ///
469    /// # Shapes
470    ///
471    /// x:      `[batch_size, channels_in, height, width]`,
472    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
473    /// bias:   `[channels_out]`,
474    fn conv2d(
475        x: FloatTensor<B>,
476        weight: FloatTensor<B>,
477        bias: Option<FloatTensor<B>>,
478        options: ConvOptions<2>,
479    ) -> FloatTensor<B>;
480    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
481    fn conv2d_x_backward(
482        x: FloatTensor<B>,
483        weight: FloatTensor<B>,
484        output_grad: FloatTensor<B>,
485        options: ConvOptions<2>,
486    ) -> FloatTensor<B> {
487        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
488    }
489    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
490    fn conv2d_weight_backward(
491        x: FloatTensor<B>,
492        weight: FloatTensor<B>,
493        output_grad: FloatTensor<B>,
494        options: ConvOptions<2>,
495    ) -> FloatTensor<B> {
496        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
497    }
498    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
499    fn conv2d_bias_backward(
500        x: FloatTensor<B>,
501        weight: FloatTensor<B>,
502        bias: FloatTensor<B>,
503        output_grad: FloatTensor<B>,
504    ) -> FloatTensor<B> {
505        conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
506    }
507
508    /// Two dimensional deformable convolution.
509    ///
510    /// # Shapes
511    ///
512    /// x:      `[batch_size, channels_in, height, width]`,
513    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
514    /// bias:   `[channels_out]`,
515    fn deform_conv2d(
516        x: FloatTensor<B>,
517        offset: FloatTensor<B>,
518        weight: FloatTensor<B>,
519        mask: Option<FloatTensor<B>>,
520        bias: Option<FloatTensor<B>>,
521        options: DeformConvOptions<2>,
522    ) -> FloatTensor<B>;
523    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
524    fn deform_conv2d_backward(
525        x: FloatTensor<B>,
526        offset: FloatTensor<B>,
527        weight: FloatTensor<B>,
528        mask: Option<FloatTensor<B>>,
529        bias: Option<FloatTensor<B>>,
530        output_grad: FloatTensor<B>,
531        options: DeformConvOptions<2>,
532    ) -> DeformConv2dBackward<B>;
533
534    /// Three dimensional convolution.
535    ///
536    /// # Shapes
537    ///
538    /// x:      `[batch_size, channels_in, depth, height, width]`,
539    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
540    /// bias:   `[channels_out]`,
541    fn conv3d(
542        x: FloatTensor<B>,
543        weight: FloatTensor<B>,
544        bias: Option<FloatTensor<B>>,
545        options: ConvOptions<3>,
546    ) -> FloatTensor<B>;
547    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
548    fn conv3d_x_backward(
549        x: FloatTensor<B>,
550        weight: FloatTensor<B>,
551        output_grad: FloatTensor<B>,
552        options: ConvOptions<3>,
553    ) -> FloatTensor<B> {
554        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
555    }
556    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
557    fn conv3d_weight_backward(
558        x: FloatTensor<B>,
559        weight: FloatTensor<B>,
560        output_grad: FloatTensor<B>,
561        options: ConvOptions<3>,
562    ) -> FloatTensor<B> {
563        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
564    }
565    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
566    fn conv3d_bias_backward(
567        x: FloatTensor<B>,
568        weight: FloatTensor<B>,
569        bias: FloatTensor<B>,
570        output_grad: FloatTensor<B>,
571    ) -> FloatTensor<B> {
572        conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
573    }
574    /// One dimensional transposed convolution.
575    ///
576    /// # Shapes
577    ///
578    /// x:      `[batch_size, channels_in, length]`,
579    /// weight: `[channels_in, channels_out, length]`,
580    /// bias:   `[channels_out]`,
581    fn conv_transpose1d(
582        x: FloatTensor<B>,
583        weight: FloatTensor<B>,
584        bias: Option<FloatTensor<B>>,
585        options: ConvTransposeOptions<1>,
586    ) -> FloatTensor<B> {
587        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
588    }
589    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
590    fn conv_transpose1d_x_backward(
591        weight: FloatTensor<B>,
592        output_grad: FloatTensor<B>,
593        options: ConvTransposeOptions<1>,
594    ) -> FloatTensor<B> {
595        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
596    }
597    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
598    fn conv_transpose1d_weight_backward(
599        x: FloatTensor<B>,
600        weight: FloatTensor<B>,
601        output_grad: FloatTensor<B>,
602        options: ConvTransposeOptions<1>,
603    ) -> FloatTensor<B> {
604        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
605    }
606    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
607    fn conv_transpose1d_bias_backward(
608        x: FloatTensor<B>,
609        bias: FloatTensor<B>,
610        output_grad: FloatTensor<B>,
611    ) -> FloatTensor<B> {
612        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
613    }
614
615    /// Two dimensional transposed convolution.
616    ///
617    /// # Shapes
618    ///
619    /// x:      `[batch_size, channels_in, height, width]`,
620    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
621    /// bias:   `[channels_out]`,
622    fn conv_transpose2d(
623        x: FloatTensor<B>,
624        weight: FloatTensor<B>,
625        bias: Option<FloatTensor<B>>,
626        options: ConvTransposeOptions<2>,
627    ) -> FloatTensor<B>;
628    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
629    fn conv_transpose2d_x_backward(
630        weight: FloatTensor<B>,
631        output_grad: FloatTensor<B>,
632        options: ConvTransposeOptions<2>,
633    ) -> FloatTensor<B> {
634        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
635    }
636    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
637    fn conv_transpose2d_weight_backward(
638        x: FloatTensor<B>,
639        weight: FloatTensor<B>,
640        output_grad: FloatTensor<B>,
641        options: ConvTransposeOptions<2>,
642    ) -> FloatTensor<B> {
643        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
644    }
645    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
646    fn conv_transpose2d_bias_backward(
647        x: FloatTensor<B>,
648        bias: FloatTensor<B>,
649        output_grad: FloatTensor<B>,
650    ) -> FloatTensor<B> {
651        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
652    }
653
654    /// Three dimensional transposed convolution.
655    ///
656    /// # Shapes
657    ///
658    /// x:      `[batch_size, channels_in, height, width]`,
659    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
660    /// bias:   `[channels_out]`,
661    fn conv_transpose3d(
662        x: FloatTensor<B>,
663        weight: FloatTensor<B>,
664        bias: Option<FloatTensor<B>>,
665        options: ConvTransposeOptions<3>,
666    ) -> FloatTensor<B>;
667    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
668    fn conv_transpose3d_x_backward(
669        weight: FloatTensor<B>,
670        output_grad: FloatTensor<B>,
671        options: ConvTransposeOptions<3>,
672    ) -> FloatTensor<B> {
673        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
674    }
675    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
676    fn conv_transpose3d_weight_backward(
677        x: FloatTensor<B>,
678        weight: FloatTensor<B>,
679        output_grad: FloatTensor<B>,
680        options: ConvTransposeOptions<3>,
681    ) -> FloatTensor<B> {
682        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
683    }
684    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
685    fn conv_transpose3d_bias_backward(
686        x: FloatTensor<B>,
687        bias: FloatTensor<B>,
688        output_grad: FloatTensor<B>,
689    ) -> FloatTensor<B> {
690        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
691    }
692
693    /// Four-dimensional unfolding.
694    ///
695    /// # Shapes
696    ///
697    /// * x:      ``[batch_size, channels_in, height, width]``,
698    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
699    fn unfold4d(
700        x: FloatTensor<B>,
701        kernel_size: [usize; 2],
702        options: UnfoldOptions,
703    ) -> FloatTensor<B> {
704        if options.padding == [0, 0] && options.dilation == [1, 1] {
705            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
706            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
707
708            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
709
710            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
711            let shape = &blocks.shape().dims;
712
713            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
714
715            B::float_reshape(
716                blocks,
717                [
718                    shape[0],
719                    shape[1] * shape[2] * shape[3],
720                    shape[4] * shape[5],
721                ]
722                .into(),
723            )
724        } else {
725            unfold4d_using_conv2d::<B>(x, kernel_size, options)
726        }
727    }
728
729    /// One dimensional avg pooling.
730    ///
731    /// # Shapes
732    ///
733    /// x: [batch_size, channels, length],
734    fn avg_pool1d(
735        x: FloatTensor<B>,
736        kernel_size: usize,
737        stride: usize,
738        padding: usize,
739        count_include_pad: bool,
740        ceil_mode: bool,
741    ) -> FloatTensor<B> {
742        pool::avg_pool1d_from_2d::<B>(
743            x,
744            kernel_size,
745            stride,
746            padding,
747            count_include_pad,
748            ceil_mode,
749        )
750    }
751    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
752    fn avg_pool1d_backward(
753        x: FloatTensor<B>,
754        grad: FloatTensor<B>,
755        kernel_size: usize,
756        stride: usize,
757        padding: usize,
758        count_include_pad: bool,
759        ceil_mode: bool,
760    ) -> FloatTensor<B> {
761        pool::avg_pool1d_backward_from_2d::<B>(
762            x,
763            grad,
764            kernel_size,
765            stride,
766            padding,
767            count_include_pad,
768            ceil_mode,
769        )
770    }
771    /// Two dimensional avg pooling.
772    ///
773    /// # Shapes
774    ///
775    /// x: [batch_size, channels, height, width],
776    fn avg_pool2d(
777        x: FloatTensor<B>,
778        kernel_size: [usize; 2],
779        stride: [usize; 2],
780        padding: [usize; 2],
781        count_include_pad: bool,
782        ceil_mode: bool,
783    ) -> FloatTensor<B>;
784    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
785    fn avg_pool2d_backward(
786        x: FloatTensor<B>,
787        grad: FloatTensor<B>,
788        kernel_size: [usize; 2],
789        stride: [usize; 2],
790        padding: [usize; 2],
791        count_include_pad: bool,
792        ceil_mode: bool,
793    ) -> FloatTensor<B>;
794    /// Two dimensional adaptive avg pooling.
795    ///
796    /// # Shapes
797    ///
798    /// x: [batch_size, channels, height, width],
799    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
800    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
801    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
802    /// One dimensional adaptive avg pooling.
803    ///
804    /// # Shapes
805    ///
806    /// x: [batch_size, channels, length],
807    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
808        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
809    }
810    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
811    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
812        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
813    }
814    /// One dimensional max pooling.
815    ///
816    /// # Shapes
817    ///
818    /// x: [batch_size, channels, length],
819    fn max_pool1d(
820        x: FloatTensor<B>,
821        kernel_size: usize,
822        stride: usize,
823        padding: usize,
824        dilation: usize,
825        ceil_mode: bool,
826    ) -> FloatTensor<B> {
827        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
828    }
829
830    /// One dimensional max pooling with indices.
831    ///
832    /// # Shapes
833    ///
834    /// x: [batch_size, channels, height, width],
835    fn max_pool1d_with_indices(
836        x: FloatTensor<B>,
837        kernel_size: usize,
838        stride: usize,
839        padding: usize,
840        dilation: usize,
841        ceil_mode: bool,
842    ) -> MaxPool1dWithIndices<B> {
843        pool::max_pool1d_with_indices_from_2d::<B>(
844            x,
845            kernel_size,
846            stride,
847            padding,
848            dilation,
849            ceil_mode,
850        )
851    }
852    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
853    #[allow(clippy::too_many_arguments)]
854    fn max_pool1d_with_indices_backward(
855        x: FloatTensor<B>,
856        kernel_size: usize,
857        stride: usize,
858        padding: usize,
859        dilation: usize,
860        ceil_mode: bool,
861        output_grad: FloatTensor<B>,
862        indices: IntTensor<B>,
863    ) -> MaxPool1dBackward<B> {
864        pool::max_pool1d_with_indices_backward_from_2d::<B>(
865            x,
866            kernel_size,
867            stride,
868            padding,
869            dilation,
870            ceil_mode,
871            output_grad,
872            indices,
873        )
874    }
875
876    /// Two dimensional max pooling.
877    ///
878    /// # Shapes
879    ///
880    /// x: [batch_size, channels, height, width],
881    fn max_pool2d(
882        x: FloatTensor<B>,
883        kernel_size: [usize; 2],
884        stride: [usize; 2],
885        padding: [usize; 2],
886        dilation: [usize; 2],
887        ceil_mode: bool,
888    ) -> FloatTensor<B>;
889
890    /// Two dimensional max pooling with indices.
891    ///
892    /// # Shapes
893    ///
894    /// x: [batch_size, channels, height, width],
895    fn max_pool2d_with_indices(
896        x: FloatTensor<B>,
897        kernel_size: [usize; 2],
898        stride: [usize; 2],
899        padding: [usize; 2],
900        dilation: [usize; 2],
901        ceil_mode: bool,
902    ) -> MaxPool2dWithIndices<B>;
903    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
904    #[allow(clippy::too_many_arguments)]
905    fn max_pool2d_with_indices_backward(
906        x: FloatTensor<B>,
907        kernel_size: [usize; 2],
908        stride: [usize; 2],
909        padding: [usize; 2],
910        dilation: [usize; 2],
911        ceil_mode: bool,
912        output_grad: FloatTensor<B>,
913        indices: IntTensor<B>,
914    ) -> MaxPool2dBackward<B>;
915
916    /// Down/up samples the input.
917    ///
918    /// # Shapes
919    ///
920    /// x: `[batch_size, channels, height, width]`,
921    fn interpolate(
922        x: FloatTensor<B>,
923        output_size: [usize; 2],
924        options: InterpolateOptions,
925    ) -> FloatTensor<B>;
926
927    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
928    fn interpolate_backward(
929        x: FloatTensor<B>,
930        grad: FloatTensor<B>,
931        output_size: [usize; 2],
932        options: InterpolateOptions,
933    ) -> FloatTensor<B>;
934
935    /// Computes scaled dot-product attention: softmax(QKᵗ / √d) · V,
936    /// optionally applying a mask to the attention scores.
937    ///
938    /// # Arguments
939    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q,  head_dim]`
940    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
941    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
942    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
943    ///   here `true` indicates positions to mask (i.e. set to -∞ before softmax).
944    ///
945    /// # Returns
946    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
947    /// representing the attended context per head.
948    ///
949    /// # Note
950    /// This implementation does not support dropout and is intended for inference or
951    /// use cases where dropout is not needed.
952    fn attention(
953        query: FloatTensor<B>,
954        key: FloatTensor<B>,
955        value: FloatTensor<B>,
956        mask: Option<BoolTensor<B>>,
957    ) -> FloatTensor<B> {
958        attention::naive_attention::<B>(query, key, value, mask)
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965
966    #[test]
967    #[should_panic = "stride must be non-zero"]
968    fn conv_options_stride_zero() {
969        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
970    }
971
972    #[test]
973    #[should_panic = "dilation must be non-zero"]
974    fn conv_options_dilation_zero() {
975        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
976    }
977
978    #[test]
979    #[should_panic = "groups must be non-zero"]
980    fn conv_options_groups_zero() {
981        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
982    }
983
984    #[test]
985    #[should_panic = "stride must be non-zero"]
986    fn conv_transpose_options_stride_zero() {
987        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
988    }
989
990    #[test]
991    #[should_panic = "dilation must be non-zero"]
992    fn conv_transpose_options_dilation_zero() {
993        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
994    }
995
996    #[test]
997    #[should_panic = "groups must be non-zero"]
998    fn conv_transpose_options_groups_zero() {
999        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1000    }
1001
1002    #[test]
1003    #[should_panic = "stride must be non-zero"]
1004    fn deform_conv_options_stride_zero() {
1005        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1006    }
1007
1008    #[test]
1009    #[should_panic = "dilation must be non-zero"]
1010    fn deform_conv_options_dilation_zero() {
1011        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1012    }
1013
1014    #[test]
1015    #[should_panic = "weight groups must be non-zero"]
1016    fn deform_conv_options_weights_groups_zero() {
1017        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1018    }
1019
1020    #[test]
1021    #[should_panic = "offset groups must be non-zero"]
1022    fn deform_conv_options_offset_groups_zero() {
1023        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1024    }
1025
1026    #[test]
1027    #[should_panic = "stride must be non-zero"]
1028    fn unfold_options_stride_zero() {
1029        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1030    }
1031
1032    #[test]
1033    #[should_panic = "dilation must be non-zero"]
1034    fn unfold_options_dilation_zero() {
1035        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1036    }
1037}