Skip to main content

burn_backend/backend/ops/modules/
base.rs

1use super::{conv, pool};
2use crate::ops::attention;
3use crate::ops::unfold::unfold4d_using_conv2d;
4use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
5use crate::{Backend, ElementConversion, TensorMetadata};
6use burn_std::Shape;
7use core::num::NonZeroUsize;
8
9/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
10#[derive(new)]
11pub struct Conv2dBackward<B: Backend> {
12    /// Gradient.
13    pub x_grad: FloatTensor<B>,
14
15    /// Weights gradient.
16    pub weights_grad: FloatTensor<B>,
17
18    /// Bias gradient.
19    pub bias_grad: Option<FloatTensor<B>>,
20}
21
22/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
23#[derive(new)]
24pub struct DeformConv2dBackward<B: Backend> {
25    /// Gradient.
26    pub x_grad: FloatTensor<B>,
27
28    /// Offset gradient.
29    pub offset_grad: FloatTensor<B>,
30
31    /// Weights gradient.
32    pub weight_grad: FloatTensor<B>,
33
34    /// Mask gradient.
35    pub mask_grad: Option<FloatTensor<B>>,
36
37    /// Bias gradient.
38    pub bias_grad: Option<FloatTensor<B>>,
39}
40
41/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
42#[derive(new)]
43pub struct Conv3dBackward<B: Backend> {
44    /// Gradient.
45    pub x_grad: FloatTensor<B>,
46
47    /// Weights gradient.
48    pub weights_grad: FloatTensor<B>,
49
50    /// Bias gradient.
51    pub bias_grad: Option<FloatTensor<B>>,
52}
53
54/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
55#[derive(new)]
56pub struct MaxPool1dBackward<B: Backend> {
57    /// Gradient.
58    pub x_grad: FloatTensor<B>,
59}
60
61/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
62#[derive(new)]
63pub struct MaxPool1dWithIndices<B: Backend> {
64    /// The output tensor.
65    pub output: FloatTensor<B>,
66
67    /// The indices tensor.
68    pub indices: IntTensor<B>,
69}
70
71/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
72#[derive(new)]
73pub struct MaxPool2dBackward<B: Backend> {
74    /// Gradient.
75    pub x_grad: FloatTensor<B>,
76}
77
78/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
79#[derive(new)]
80pub struct MaxPool2dWithIndices<B: Backend> {
81    /// The output tensor.
82    pub output: FloatTensor<B>,
83
84    /// The indices tensor.
85    pub indices: IntTensor<B>,
86}
87
88/// Check that the parameter value is non-zero.
89// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
90pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
91    NonZeroUsize::new(value).expect(msg);
92    value
93}
94
95/// Convolution options.
96#[derive(Debug, Clone, Hash, PartialEq, Eq)]
97pub struct ConvOptions<const N: usize> {
98    /// Stride (non-zero).
99    pub stride: [usize; N],
100
101    /// Padding.
102    pub padding: [usize; N],
103
104    /// Dilation (non-zero).
105    pub dilation: [usize; N],
106
107    /// Groups (non-zero).
108    pub groups: usize,
109}
110
111impl<const N: usize> ConvOptions<N> {
112    /// Constructs a new `ConvOptions`.
113    pub fn new(
114        stride: [usize; N],
115        padding: [usize; N],
116        dilation: [usize; N],
117        groups: usize,
118    ) -> Self {
119        Self {
120            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
121            padding,
122            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
123            groups: check_nonzero(groups, "groups must be non-zero"),
124        }
125    }
126}
127
128/// Convolution options with support for asymmetric padding.
129///
130/// Wraps [`ConvOptions`] (which represents symmetric padding for the backend op)
131/// and adds optional asymmetric padding. When asymmetric padding is specified,
132/// the functional convolution layer applies an explicit pad operation before
133/// dispatching to the backend.
134///
135/// Implements `From<ConvOptions<N>>` for backward compatibility.
136#[derive(Debug, Clone)]
137pub struct PaddedConvOptions<const N: usize> {
138    /// The underlying convolution options for the backend.
139    pub options: ConvOptions<N>,
140    /// Padding at the end of each dimension (e.g., bottom/right for 2D).
141    /// If `None`, padding is symmetric (same as `options.padding`).
142    /// If `Some`, specifies different end-padding per dimension.
143    pub padding_end: Option<[usize; N]>,
144}
145
146impl<const N: usize> PaddedConvOptions<N> {
147    /// Creates options with asymmetric padding.
148    ///
149    /// `padding_start` is stored in `ConvOptions::padding`.
150    /// `padding_end` specifies the end padding per dimension.
151    pub fn asymmetric(
152        stride: [usize; N],
153        padding_start: [usize; N],
154        padding_end: [usize; N],
155        dilation: [usize; N],
156        groups: usize,
157    ) -> Self {
158        let options = ConvOptions::new(stride, padding_start, dilation, groups);
159        if padding_start == padding_end {
160            Self {
161                options,
162                padding_end: None,
163            }
164        } else {
165            Self {
166                options,
167                padding_end: Some(padding_end),
168            }
169        }
170    }
171
172    /// Returns true if padding is asymmetric.
173    pub fn is_asymmetric(&self) -> bool {
174        self.padding_end.is_some()
175    }
176}
177
178impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
179    fn from(options: ConvOptions<N>) -> Self {
180        Self {
181            options,
182            padding_end: None,
183        }
184    }
185}
186
187/// Convolution options.
188#[derive(Debug, Clone, Hash, PartialEq, Eq)]
189pub struct DeformConvOptions<const N: usize> {
190    /// Stride (non-zero).
191    pub stride: [usize; N],
192
193    /// Padding.
194    pub padding: [usize; N],
195
196    /// Dilation (non-zero).
197    pub dilation: [usize; N],
198
199    /// Weight Groups (non-zero).
200    pub weight_groups: usize,
201
202    /// Offset Groups (non-zero).
203    pub offset_groups: usize,
204}
205
206impl<const N: usize> DeformConvOptions<N> {
207    /// Constructs a new `DeformConvOptions`.
208    pub fn new(
209        stride: [usize; N],
210        padding: [usize; N],
211        dilation: [usize; N],
212        weight_groups: usize,
213        offset_groups: usize,
214    ) -> Self {
215        Self {
216            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
217            padding,
218            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
219            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
220            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
221        }
222    }
223}
224
225/// Transposed convolution options.
226#[derive(Debug, Clone, Hash, PartialEq, Eq)]
227pub struct ConvTransposeOptions<const N: usize> {
228    /// Stride (non-zero).
229    pub stride: [usize; N],
230
231    /// Padding.
232    pub padding: [usize; N],
233
234    /// Padding out.
235    pub padding_out: [usize; N],
236
237    /// Dilation (non-zero).
238    pub dilation: [usize; N],
239
240    /// Groups (non-zero).
241    pub groups: usize,
242}
243
244impl<const N: usize> ConvTransposeOptions<N> {
245    /// Constructs a new `ConvTransposeOptions`.
246    pub fn new(
247        stride: [usize; N],
248        padding: [usize; N],
249        padding_out: [usize; N],
250        dilation: [usize; N],
251        groups: usize,
252    ) -> Self {
253        Self {
254            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
255            padding,
256            padding_out,
257            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
258            groups: check_nonzero(groups, "groups must be non-zero"),
259        }
260    }
261}
262
263/// Unfold operation options.
264#[derive(Debug, Clone)]
265pub struct UnfoldOptions {
266    /// The number of positions to slide over the input tensor in each dimension.
267    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
268    pub stride: [usize; 2],
269
270    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
271    pub padding: [usize; 2],
272
273    /// The spacing between the blocks (patches) in the original input tensor.
274    pub dilation: [usize; 2],
275}
276
277impl UnfoldOptions {
278    /// Constructs a new `UnfoldOptions`.
279    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
280        Self {
281            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
282            padding,
283            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
284        }
285    }
286}
287
288/// Algorithm used for upsampling.
289#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
290pub enum InterpolateMode {
291    /// Nearest-neighbor interpolation.
292    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
293    Nearest,
294
295    /// Bilinear interpolation.
296    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
297    Bilinear,
298
299    /// Bicubic interpolation.
300    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
301    Bicubic,
302}
303
304/// Interpolation options.
305#[derive(new, Debug, Clone)]
306pub struct InterpolateOptions {
307    /// Algorithm used for upsampling.
308    pub mode: InterpolateMode,
309}
310
311/// Padding mode for grid sampling when coordinates are out of bounds.
312///
313/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
315pub enum GridSamplePaddingMode {
316    /// Fill with zeros for out-of-bounds coordinates.
317    #[default]
318    Zeros,
319    /// Clamp coordinates to the border (use nearest edge value).
320    Border,
321    /// Reflect coordinates at the boundary.
322    Reflection,
323}
324
325/// Options for grid sampling operations.
326#[derive(Debug, Clone)]
327pub struct GridSampleOptions {
328    /// Interpolation mode (bilinear, nearest, or bicubic).
329    pub mode: InterpolateMode,
330    /// Padding mode for out-of-bounds coordinates.
331    pub padding_mode: GridSamplePaddingMode,
332    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.
333    /// If `false`, they correspond to the corner points of the corner pixels
334    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).
335    pub align_corners: bool,
336}
337
338impl Default for GridSampleOptions {
339    fn default() -> Self {
340        Self {
341            mode: InterpolateMode::Bilinear,
342            padding_mode: GridSamplePaddingMode::Zeros,
343            align_corners: false,
344        }
345    }
346}
347
348impl From<InterpolateMode> for GridSampleOptions {
349    fn from(value: InterpolateMode) -> Self {
350        GridSampleOptions::new(value)
351    }
352}
353
354impl GridSampleOptions {
355    /// Create new grid sample options with the given interpolation mode.
356    ///
357    /// Uses default values for padding_mode (Zeros) and align_corners (false).
358    pub fn new(mode: InterpolateMode) -> Self {
359        Self {
360            mode,
361            ..Default::default()
362        }
363    }
364
365    /// Set the padding mode.
366    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
367        self.padding_mode = padding_mode;
368        self
369    }
370
371    /// Set align_corners.
372    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
373        self.align_corners = align_corners;
374        self
375    }
376}
377
378/// Padding mode for tensor pad operations.
379///
380/// Defines how values are filled when padding a tensor beyond its original boundaries.
381///
382/// **Note**: Currently, padding is only supported on the last two dimensions of a tensor
383/// (typically height and width for image data in NCHW format).
384///
385/// # Modes
386///
387/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)
388/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)
389/// - [`Edge`](PadMode::Edge): Replicate boundary values
390#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
391pub enum PadMode {
392    /// Fill padded regions with a constant value.
393    ///
394    /// # Example
395    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:
396    /// Result: `[0, 0, 1, 2, 3]`
397    Constant(f32),
398
399    /// Reflect values at the boundary, excluding the edge value.
400    ///
401    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).
402    ///
403    /// # Example
404    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
405    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)
406    Reflect,
407
408    /// Replicate the edge values.
409    ///
410    /// # Example
411    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
412    /// Result: `[1, 1, 1, 2, 3, 4]`
413    Edge,
414}
415
416impl Default for PadMode {
417    fn default() -> Self {
418        PadMode::Constant(0.0)
419    }
420}
421
422impl<E: ElementConversion> From<E> for PadMode {
423    fn from(value: E) -> Self {
424        PadMode::Constant(value.elem())
425    }
426}
427
428/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
429#[derive(new)]
430pub struct InterpolateBackward<B: Backend> {
431    /// Gradient.
432    pub x_grad: FloatTensor<B>,
433}
434
435/// Module operations trait.
436pub trait ModuleOps<B: Backend> {
437    /// Embedding operation.
438    ///
439    /// # Arguments
440    ///
441    /// * `weights` - The embedding weights.
442    /// * `indices` - The indices tensor.
443    ///
444    /// # Returns
445    ///
446    /// The output tensor.
447    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
448        let [batch_size, seq_length] = indices.shape().dims();
449        let [_, d_model] = weights.shape().dims();
450
451        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
452        let output = B::float_select(weights, 0, indices);
453
454        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
455    }
456
457    /// Embedding backward operation.
458    ///
459    /// # Arguments
460    ///
461    /// * `weights` - The embedding weights.
462    /// * `output_grad` - The output gradient.
463    /// * `indices` - The indices tensor.
464    ///
465    /// # Returns
466    ///
467    /// The gradient.
468    fn embedding_backward(
469        weights: FloatTensor<B>,
470        output_grad: FloatTensor<B>,
471        indices: IntTensor<B>,
472    ) -> FloatTensor<B> {
473        let [batch_size, seq_length] = indices.shape().dims();
474        let [n_embeddings, d_model] = weights.shape().dims();
475        let device = B::float_device(&weights);
476        let dtype = output_grad.dtype();
477
478        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
479        let output_grad =
480            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
481        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
482
483        B::float_select_add(grad, 0, indices, output_grad)
484    }
485    /// One dimensional convolution.
486    ///
487    /// # Shapes
488    ///
489    /// x:      `[batch_size, channels_in, length]`,
490    /// weight: `[channels_out, channels_in, kernel_size]`,
491    /// bias:   `[channels_out]`,
492    fn conv1d(
493        x: FloatTensor<B>,
494        weight: FloatTensor<B>,
495        bias: Option<FloatTensor<B>>,
496        options: ConvOptions<1>,
497    ) -> FloatTensor<B> {
498        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
499    }
500    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
501    fn conv1d_x_backward(
502        x: FloatTensor<B>,
503        weight: FloatTensor<B>,
504        output_grad: FloatTensor<B>,
505        options: ConvOptions<1>,
506    ) -> FloatTensor<B> {
507        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
508    }
509    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
510    fn conv1d_weight_backward(
511        x: FloatTensor<B>,
512        weight: FloatTensor<B>,
513        output_grad: FloatTensor<B>,
514        options: ConvOptions<1>,
515    ) -> FloatTensor<B> {
516        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
517    }
518    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
519    fn conv1d_bias_backward(
520        x: FloatTensor<B>,
521        bias: FloatTensor<B>,
522        output_grad: FloatTensor<B>,
523    ) -> FloatTensor<B> {
524        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
525    }
526    /// Two dimensional convolution.
527    ///
528    /// # Shapes
529    ///
530    /// x:      `[batch_size, channels_in, height, width]`,
531    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
532    /// bias:   `[channels_out]`,
533    fn conv2d(
534        x: FloatTensor<B>,
535        weight: FloatTensor<B>,
536        bias: Option<FloatTensor<B>>,
537        options: ConvOptions<2>,
538    ) -> FloatTensor<B>;
539    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
540    fn conv2d_x_backward(
541        x: FloatTensor<B>,
542        weight: FloatTensor<B>,
543        output_grad: FloatTensor<B>,
544        options: ConvOptions<2>,
545    ) -> FloatTensor<B> {
546        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
547    }
548    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
549    fn conv2d_weight_backward(
550        x: FloatTensor<B>,
551        weight: FloatTensor<B>,
552        output_grad: FloatTensor<B>,
553        options: ConvOptions<2>,
554    ) -> FloatTensor<B> {
555        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
556    }
557    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
558    fn conv2d_bias_backward(
559        x: FloatTensor<B>,
560        bias: FloatTensor<B>,
561        output_grad: FloatTensor<B>,
562    ) -> FloatTensor<B> {
563        conv::conv2d_bias_backward::<B>(x, bias, output_grad)
564    }
565
566    /// Two dimensional deformable convolution.
567    ///
568    /// # Shapes
569    ///
570    /// x:      `[batch_size, channels_in, height, width]`,
571    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
572    /// bias:   `[channels_out]`,
573    fn deform_conv2d(
574        x: FloatTensor<B>,
575        offset: FloatTensor<B>,
576        weight: FloatTensor<B>,
577        mask: Option<FloatTensor<B>>,
578        bias: Option<FloatTensor<B>>,
579        options: DeformConvOptions<2>,
580    ) -> FloatTensor<B>;
581    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
582    fn deform_conv2d_backward(
583        x: FloatTensor<B>,
584        offset: FloatTensor<B>,
585        weight: FloatTensor<B>,
586        mask: Option<FloatTensor<B>>,
587        bias: Option<FloatTensor<B>>,
588        output_grad: FloatTensor<B>,
589        options: DeformConvOptions<2>,
590    ) -> DeformConv2dBackward<B>;
591
592    /// Three dimensional convolution.
593    ///
594    /// # Shapes
595    ///
596    /// x:      `[batch_size, channels_in, depth, height, width]`,
597    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
598    /// bias:   `[channels_out]`,
599    fn conv3d(
600        x: FloatTensor<B>,
601        weight: FloatTensor<B>,
602        bias: Option<FloatTensor<B>>,
603        options: ConvOptions<3>,
604    ) -> FloatTensor<B>;
605    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
606    fn conv3d_x_backward(
607        x: FloatTensor<B>,
608        weight: FloatTensor<B>,
609        output_grad: FloatTensor<B>,
610        options: ConvOptions<3>,
611    ) -> FloatTensor<B> {
612        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
613    }
614    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
615    fn conv3d_weight_backward(
616        x: FloatTensor<B>,
617        weight: FloatTensor<B>,
618        output_grad: FloatTensor<B>,
619        options: ConvOptions<3>,
620    ) -> FloatTensor<B> {
621        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
622    }
623    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
624    fn conv3d_bias_backward(
625        x: FloatTensor<B>,
626        bias: FloatTensor<B>,
627        output_grad: FloatTensor<B>,
628    ) -> FloatTensor<B> {
629        conv::conv3d_bias_backward::<B>(x, bias, output_grad)
630    }
631    /// One dimensional transposed convolution.
632    ///
633    /// # Shapes
634    ///
635    /// x:      `[batch_size, channels_in, length]`,
636    /// weight: `[channels_in, channels_out, length]`,
637    /// bias:   `[channels_out]`,
638    fn conv_transpose1d(
639        x: FloatTensor<B>,
640        weight: FloatTensor<B>,
641        bias: Option<FloatTensor<B>>,
642        options: ConvTransposeOptions<1>,
643    ) -> FloatTensor<B> {
644        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
645    }
646    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
647    fn conv_transpose1d_x_backward(
648        weight: FloatTensor<B>,
649        output_grad: FloatTensor<B>,
650        options: ConvTransposeOptions<1>,
651    ) -> FloatTensor<B> {
652        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
653    }
654    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
655    fn conv_transpose1d_weight_backward(
656        x: FloatTensor<B>,
657        weight: FloatTensor<B>,
658        output_grad: FloatTensor<B>,
659        options: ConvTransposeOptions<1>,
660    ) -> FloatTensor<B> {
661        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
662    }
663    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
664    fn conv_transpose1d_bias_backward(
665        x: FloatTensor<B>,
666        bias: FloatTensor<B>,
667        output_grad: FloatTensor<B>,
668    ) -> FloatTensor<B> {
669        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
670    }
671
672    /// Two dimensional transposed convolution.
673    ///
674    /// # Shapes
675    ///
676    /// x:      `[batch_size, channels_in, height, width]`,
677    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
678    /// bias:   `[channels_out]`,
679    fn conv_transpose2d(
680        x: FloatTensor<B>,
681        weight: FloatTensor<B>,
682        bias: Option<FloatTensor<B>>,
683        options: ConvTransposeOptions<2>,
684    ) -> FloatTensor<B>;
685    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
686    fn conv_transpose2d_x_backward(
687        weight: FloatTensor<B>,
688        output_grad: FloatTensor<B>,
689        options: ConvTransposeOptions<2>,
690    ) -> FloatTensor<B> {
691        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
692    }
693    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
694    fn conv_transpose2d_weight_backward(
695        x: FloatTensor<B>,
696        weight: FloatTensor<B>,
697        output_grad: FloatTensor<B>,
698        options: ConvTransposeOptions<2>,
699    ) -> FloatTensor<B> {
700        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
701    }
702    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
703    fn conv_transpose2d_bias_backward(
704        x: FloatTensor<B>,
705        bias: FloatTensor<B>,
706        output_grad: FloatTensor<B>,
707    ) -> FloatTensor<B> {
708        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
709    }
710
711    /// Three dimensional transposed convolution.
712    ///
713    /// # Shapes
714    ///
715    /// x:      `[batch_size, channels_in, height, width]`,
716    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
717    /// bias:   `[channels_out]`,
718    fn conv_transpose3d(
719        x: FloatTensor<B>,
720        weight: FloatTensor<B>,
721        bias: Option<FloatTensor<B>>,
722        options: ConvTransposeOptions<3>,
723    ) -> FloatTensor<B>;
724    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
725    fn conv_transpose3d_x_backward(
726        weight: FloatTensor<B>,
727        output_grad: FloatTensor<B>,
728        options: ConvTransposeOptions<3>,
729    ) -> FloatTensor<B> {
730        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
731    }
732    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
733    fn conv_transpose3d_weight_backward(
734        x: FloatTensor<B>,
735        weight: FloatTensor<B>,
736        output_grad: FloatTensor<B>,
737        options: ConvTransposeOptions<3>,
738    ) -> FloatTensor<B> {
739        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
740    }
741    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
742    fn conv_transpose3d_bias_backward(
743        x: FloatTensor<B>,
744        bias: FloatTensor<B>,
745        output_grad: FloatTensor<B>,
746    ) -> FloatTensor<B> {
747        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
748    }
749
750    /// Four-dimensional unfolding.
751    ///
752    /// # Shapes
753    ///
754    /// * x:      ``[batch_size, channels_in, height, width]``,
755    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
756    fn unfold4d(
757        x: FloatTensor<B>,
758        kernel_size: [usize; 2],
759        options: UnfoldOptions,
760    ) -> FloatTensor<B> {
761        if options.padding == [0, 0] && options.dilation == [1, 1] {
762            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
763            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
764
765            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
766
767            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
768            let shape = &blocks.shape().dims;
769
770            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
771
772            B::float_reshape(
773                blocks,
774                [
775                    shape[0],
776                    shape[1] * shape[2] * shape[3],
777                    shape[4] * shape[5],
778                ]
779                .into(),
780            )
781        } else {
782            unfold4d_using_conv2d::<B>(x, kernel_size, options)
783        }
784    }
785
786    /// One dimensional avg pooling.
787    ///
788    /// # Shapes
789    ///
790    /// x: [batch_size, channels, length],
791    fn avg_pool1d(
792        x: FloatTensor<B>,
793        kernel_size: usize,
794        stride: usize,
795        padding: usize,
796        count_include_pad: bool,
797        ceil_mode: bool,
798    ) -> FloatTensor<B> {
799        pool::avg_pool1d_from_2d::<B>(
800            x,
801            kernel_size,
802            stride,
803            padding,
804            count_include_pad,
805            ceil_mode,
806        )
807    }
808    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
809    fn avg_pool1d_backward(
810        x: FloatTensor<B>,
811        grad: FloatTensor<B>,
812        kernel_size: usize,
813        stride: usize,
814        padding: usize,
815        count_include_pad: bool,
816        ceil_mode: bool,
817    ) -> FloatTensor<B> {
818        pool::avg_pool1d_backward_from_2d::<B>(
819            x,
820            grad,
821            kernel_size,
822            stride,
823            padding,
824            count_include_pad,
825            ceil_mode,
826        )
827    }
828    /// Two dimensional avg pooling.
829    ///
830    /// # Shapes
831    ///
832    /// x: [batch_size, channels, height, width],
833    fn avg_pool2d(
834        x: FloatTensor<B>,
835        kernel_size: [usize; 2],
836        stride: [usize; 2],
837        padding: [usize; 2],
838        count_include_pad: bool,
839        ceil_mode: bool,
840    ) -> FloatTensor<B>;
841    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
842    fn avg_pool2d_backward(
843        x: FloatTensor<B>,
844        grad: FloatTensor<B>,
845        kernel_size: [usize; 2],
846        stride: [usize; 2],
847        padding: [usize; 2],
848        count_include_pad: bool,
849        ceil_mode: bool,
850    ) -> FloatTensor<B>;
851    /// Two dimensional adaptive avg pooling.
852    ///
853    /// # Shapes
854    ///
855    /// x: [batch_size, channels, height, width],
856    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
857    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
858    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
859    /// One dimensional adaptive avg pooling.
860    ///
861    /// # Shapes
862    ///
863    /// x: [batch_size, channels, length],
864    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
865        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
866    }
867    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
868    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
869        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
870    }
871    /// One dimensional max pooling.
872    ///
873    /// # Shapes
874    ///
875    /// x: [batch_size, channels, length],
876    fn max_pool1d(
877        x: FloatTensor<B>,
878        kernel_size: usize,
879        stride: usize,
880        padding: usize,
881        dilation: usize,
882        ceil_mode: bool,
883    ) -> FloatTensor<B> {
884        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
885    }
886
887    /// One dimensional max pooling with indices.
888    ///
889    /// # Shapes
890    ///
891    /// x: [batch_size, channels, height, width],
892    fn max_pool1d_with_indices(
893        x: FloatTensor<B>,
894        kernel_size: usize,
895        stride: usize,
896        padding: usize,
897        dilation: usize,
898        ceil_mode: bool,
899    ) -> MaxPool1dWithIndices<B> {
900        pool::max_pool1d_with_indices_from_2d::<B>(
901            x,
902            kernel_size,
903            stride,
904            padding,
905            dilation,
906            ceil_mode,
907        )
908    }
909    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
910    #[allow(clippy::too_many_arguments)]
911    fn max_pool1d_with_indices_backward(
912        x: FloatTensor<B>,
913        kernel_size: usize,
914        stride: usize,
915        padding: usize,
916        dilation: usize,
917        ceil_mode: bool,
918        output_grad: FloatTensor<B>,
919        indices: IntTensor<B>,
920    ) -> MaxPool1dBackward<B> {
921        pool::max_pool1d_with_indices_backward_from_2d::<B>(
922            x,
923            kernel_size,
924            stride,
925            padding,
926            dilation,
927            ceil_mode,
928            output_grad,
929            indices,
930        )
931    }
932
933    /// Two dimensional max pooling.
934    ///
935    /// # Shapes
936    ///
937    /// x: [batch_size, channels, height, width],
938    fn max_pool2d(
939        x: FloatTensor<B>,
940        kernel_size: [usize; 2],
941        stride: [usize; 2],
942        padding: [usize; 2],
943        dilation: [usize; 2],
944        ceil_mode: bool,
945    ) -> FloatTensor<B>;
946
947    /// Two dimensional max pooling with indices.
948    ///
949    /// # Shapes
950    ///
951    /// x: [batch_size, channels, height, width],
952    fn max_pool2d_with_indices(
953        x: FloatTensor<B>,
954        kernel_size: [usize; 2],
955        stride: [usize; 2],
956        padding: [usize; 2],
957        dilation: [usize; 2],
958        ceil_mode: bool,
959    ) -> MaxPool2dWithIndices<B>;
960    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
961    #[allow(clippy::too_many_arguments)]
962    fn max_pool2d_with_indices_backward(
963        x: FloatTensor<B>,
964        kernel_size: [usize; 2],
965        stride: [usize; 2],
966        padding: [usize; 2],
967        dilation: [usize; 2],
968        ceil_mode: bool,
969        output_grad: FloatTensor<B>,
970        indices: IntTensor<B>,
971    ) -> MaxPool2dBackward<B>;
972
973    /// Down/up samples the input.
974    ///
975    /// # Shapes
976    ///
977    /// x: `[batch_size, channels, height, width]`,
978    fn interpolate(
979        x: FloatTensor<B>,
980        output_size: [usize; 2],
981        options: InterpolateOptions,
982    ) -> FloatTensor<B>;
983
984    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
985    fn interpolate_backward(
986        x: FloatTensor<B>,
987        grad: FloatTensor<B>,
988        output_size: [usize; 2],
989        options: InterpolateOptions,
990    ) -> FloatTensor<B>;
991
992    /// Computes scaled dot-product attention: softmax(QKᵗ / √d) · V,
993    /// optionally applying a mask to the attention scores.
994    ///
995    /// # Arguments
996    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q,  head_dim]`
997    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
998    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
999    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
1000    ///   here `true` indicates positions to mask (i.e. set to -∞ before softmax).
1001    ///
1002    /// # Returns
1003    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
1004    /// representing the attended context per head.
1005    ///
1006    /// # Note
1007    /// This implementation does not support dropout and is intended for inference or
1008    /// use cases where dropout is not needed.
1009    fn attention(
1010        query: FloatTensor<B>,
1011        key: FloatTensor<B>,
1012        value: FloatTensor<B>,
1013        mask: Option<BoolTensor<B>>,
1014    ) -> FloatTensor<B> {
1015        attention::naive_attention::<B>(query, key, value, mask)
1016    }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022
1023    #[test]
1024    #[should_panic = "stride must be non-zero"]
1025    fn conv_options_stride_zero() {
1026        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1027    }
1028
1029    #[test]
1030    #[should_panic = "dilation must be non-zero"]
1031    fn conv_options_dilation_zero() {
1032        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1033    }
1034
1035    #[test]
1036    #[should_panic = "groups must be non-zero"]
1037    fn conv_options_groups_zero() {
1038        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1039    }
1040
1041    #[test]
1042    #[should_panic = "stride must be non-zero"]
1043    fn conv_transpose_options_stride_zero() {
1044        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1045    }
1046
1047    #[test]
1048    #[should_panic = "dilation must be non-zero"]
1049    fn conv_transpose_options_dilation_zero() {
1050        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1051    }
1052
1053    #[test]
1054    #[should_panic = "groups must be non-zero"]
1055    fn conv_transpose_options_groups_zero() {
1056        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1057    }
1058
1059    #[test]
1060    #[should_panic = "stride must be non-zero"]
1061    fn deform_conv_options_stride_zero() {
1062        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1063    }
1064
1065    #[test]
1066    #[should_panic = "dilation must be non-zero"]
1067    fn deform_conv_options_dilation_zero() {
1068        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1069    }
1070
1071    #[test]
1072    #[should_panic = "weight groups must be non-zero"]
1073    fn deform_conv_options_weights_groups_zero() {
1074        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1075    }
1076
1077    #[test]
1078    #[should_panic = "offset groups must be non-zero"]
1079    fn deform_conv_options_offset_groups_zero() {
1080        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1081    }
1082
1083    #[test]
1084    #[should_panic = "stride must be non-zero"]
1085    fn unfold_options_stride_zero() {
1086        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1087    }
1088
1089    #[test]
1090    #[should_panic = "dilation must be non-zero"]
1091    fn unfold_options_dilation_zero() {
1092        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1093    }
1094}