Skip to main content

burn_backend/backend/ops/modules/
base.rs

1use super::{conv, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
4use crate::{Backend, ElementConversion, TensorMetadata};
5use burn_std::Shape;
6use core::num::NonZeroUsize;
7
8/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
9#[derive(new)]
10pub struct Conv2dBackward<B: Backend> {
11    /// Gradient.
12    pub x_grad: FloatTensor<B>,
13
14    /// Weights gradient.
15    pub weights_grad: FloatTensor<B>,
16
17    /// Bias gradient.
18    pub bias_grad: Option<FloatTensor<B>>,
19}
20
21/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
22#[derive(new)]
23pub struct DeformConv2dBackward<B: Backend> {
24    /// Gradient.
25    pub x_grad: FloatTensor<B>,
26
27    /// Offset gradient.
28    pub offset_grad: FloatTensor<B>,
29
30    /// Weights gradient.
31    pub weight_grad: FloatTensor<B>,
32
33    /// Mask gradient.
34    pub mask_grad: Option<FloatTensor<B>>,
35
36    /// Bias gradient.
37    pub bias_grad: Option<FloatTensor<B>>,
38}
39
40/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
41#[derive(new)]
42pub struct Conv3dBackward<B: Backend> {
43    /// Gradient.
44    pub x_grad: FloatTensor<B>,
45
46    /// Weights gradient.
47    pub weights_grad: FloatTensor<B>,
48
49    /// Bias gradient.
50    pub bias_grad: Option<FloatTensor<B>>,
51}
52
53/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
54#[derive(new)]
55pub struct MaxPool1dBackward<B: Backend> {
56    /// Gradient.
57    pub x_grad: FloatTensor<B>,
58}
59
60/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
61#[derive(new)]
62pub struct MaxPool1dWithIndices<B: Backend> {
63    /// The output tensor.
64    pub output: FloatTensor<B>,
65
66    /// The indices tensor.
67    pub indices: IntTensor<B>,
68}
69
70/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
71#[derive(new)]
72pub struct MaxPool2dBackward<B: Backend> {
73    /// Gradient.
74    pub x_grad: FloatTensor<B>,
75}
76
77/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
78#[derive(new)]
79pub struct MaxPool2dWithIndices<B: Backend> {
80    /// The output tensor.
81    pub output: FloatTensor<B>,
82
83    /// The indices tensor.
84    pub indices: IntTensor<B>,
85}
86
87/// Check that the parameter value is non-zero.
88// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
89pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
90    NonZeroUsize::new(value).expect(msg);
91    value
92}
93
94/// Convolution options.
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96pub struct ConvOptions<const N: usize> {
97    /// Stride (non-zero).
98    pub stride: [usize; N],
99
100    /// Padding.
101    pub padding: [usize; N],
102
103    /// Dilation (non-zero).
104    pub dilation: [usize; N],
105
106    /// Groups (non-zero).
107    pub groups: usize,
108}
109
110impl<const N: usize> ConvOptions<N> {
111    /// Constructs a new `ConvOptions`.
112    pub fn new(
113        stride: [usize; N],
114        padding: [usize; N],
115        dilation: [usize; N],
116        groups: usize,
117    ) -> Self {
118        Self {
119            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
120            padding,
121            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
122            groups: check_nonzero(groups, "groups must be non-zero"),
123        }
124    }
125}
126
127/// Convolution options with support for asymmetric padding.
128///
129/// Wraps [`ConvOptions`] (which represents symmetric padding for the backend op)
130/// and adds optional asymmetric padding. When asymmetric padding is specified,
131/// the functional convolution layer applies an explicit pad operation before
132/// dispatching to the backend.
133///
134/// Implements `From<ConvOptions<N>>` for backward compatibility.
135#[derive(Debug, Clone)]
136pub struct PaddedConvOptions<const N: usize> {
137    /// The underlying convolution options for the backend.
138    pub options: ConvOptions<N>,
139    /// Padding at the end of each dimension (e.g., bottom/right for 2D).
140    /// If `None`, padding is symmetric (same as `options.padding`).
141    /// If `Some`, specifies different end-padding per dimension.
142    pub padding_end: Option<[usize; N]>,
143}
144
145impl<const N: usize> PaddedConvOptions<N> {
146    /// Creates options with asymmetric padding.
147    ///
148    /// `padding_start` is stored in `ConvOptions::padding`.
149    /// `padding_end` specifies the end padding per dimension.
150    pub fn asymmetric(
151        stride: [usize; N],
152        padding_start: [usize; N],
153        padding_end: [usize; N],
154        dilation: [usize; N],
155        groups: usize,
156    ) -> Self {
157        let options = ConvOptions::new(stride, padding_start, dilation, groups);
158        if padding_start == padding_end {
159            Self {
160                options,
161                padding_end: None,
162            }
163        } else {
164            Self {
165                options,
166                padding_end: Some(padding_end),
167            }
168        }
169    }
170
171    /// Returns true if padding is asymmetric.
172    pub fn is_asymmetric(&self) -> bool {
173        self.padding_end.is_some()
174    }
175}
176
177impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
178    fn from(options: ConvOptions<N>) -> Self {
179        Self {
180            options,
181            padding_end: None,
182        }
183    }
184}
185
186/// Convolution options.
187#[derive(Debug, Clone, Hash, PartialEq, Eq)]
188pub struct DeformConvOptions<const N: usize> {
189    /// Stride (non-zero).
190    pub stride: [usize; N],
191
192    /// Padding.
193    pub padding: [usize; N],
194
195    /// Dilation (non-zero).
196    pub dilation: [usize; N],
197
198    /// Weight Groups (non-zero).
199    pub weight_groups: usize,
200
201    /// Offset Groups (non-zero).
202    pub offset_groups: usize,
203}
204
205impl<const N: usize> DeformConvOptions<N> {
206    /// Constructs a new `DeformConvOptions`.
207    pub fn new(
208        stride: [usize; N],
209        padding: [usize; N],
210        dilation: [usize; N],
211        weight_groups: usize,
212        offset_groups: usize,
213    ) -> Self {
214        Self {
215            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
216            padding,
217            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
218            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
219            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
220        }
221    }
222}
223
224/// Transposed convolution options.
225#[derive(Debug, Clone, Hash, PartialEq, Eq)]
226pub struct ConvTransposeOptions<const N: usize> {
227    /// Stride (non-zero).
228    pub stride: [usize; N],
229
230    /// Padding.
231    pub padding: [usize; N],
232
233    /// Padding out.
234    pub padding_out: [usize; N],
235
236    /// Dilation (non-zero).
237    pub dilation: [usize; N],
238
239    /// Groups (non-zero).
240    pub groups: usize,
241}
242
243impl<const N: usize> ConvTransposeOptions<N> {
244    /// Constructs a new `ConvTransposeOptions`.
245    pub fn new(
246        stride: [usize; N],
247        padding: [usize; N],
248        padding_out: [usize; N],
249        dilation: [usize; N],
250        groups: usize,
251    ) -> Self {
252        Self {
253            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
254            padding,
255            padding_out,
256            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
257            groups: check_nonzero(groups, "groups must be non-zero"),
258        }
259    }
260}
261
262/// Unfold operation options.
263#[derive(Debug, Clone)]
264pub struct UnfoldOptions {
265    /// The number of positions to slide over the input tensor in each dimension.
266    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
267    pub stride: [usize; 2],
268
269    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
270    pub padding: [usize; 2],
271
272    /// The spacing between the blocks (patches) in the original input tensor.
273    pub dilation: [usize; 2],
274}
275
276impl UnfoldOptions {
277    /// Constructs a new `UnfoldOptions`.
278    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
279        Self {
280            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
281            padding,
282            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
283        }
284    }
285}
286
287/// Algorithm used for upsampling.
288#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
289pub enum InterpolateMode {
290    /// Nearest-neighbor interpolation.
291    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
292    Nearest,
293
294    /// Bilinear interpolation.
295    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
296    Bilinear,
297
298    /// Bicubic interpolation.
299    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
300    Bicubic,
301}
302
303/// Interpolation options.
304#[derive(Debug, Clone)]
305pub struct InterpolateOptions {
306    /// Algorithm used for upsampling.
307    pub mode: InterpolateMode,
308    /// If `true`, the input and output tensors are aligned by their corner pixels.
309    /// If `false`, half-pixel coordinate mapping is used instead.
310    pub align_corners: bool,
311}
312
313impl InterpolateOptions {
314    /// Create new interpolate options with the given mode.
315    /// Defaults to `align_corners = true`.
316    pub fn new(mode: InterpolateMode) -> Self {
317        Self {
318            mode,
319            align_corners: true,
320        }
321    }
322
323    /// Set align_corners.
324    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
325        self.align_corners = align_corners;
326        self
327    }
328}
329
330/// Padding mode for grid sampling when coordinates are out of bounds.
331///
332/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.
333#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
334pub enum GridSamplePaddingMode {
335    /// Fill with zeros for out-of-bounds coordinates.
336    #[default]
337    Zeros,
338    /// Clamp coordinates to the border (use nearest edge value).
339    Border,
340    /// Reflect coordinates at the boundary.
341    Reflection,
342}
343
344/// Options for grid sampling operations.
345#[derive(Debug, Clone)]
346pub struct GridSampleOptions {
347    /// Interpolation mode (bilinear, nearest, or bicubic).
348    pub mode: InterpolateMode,
349    /// Padding mode for out-of-bounds coordinates.
350    pub padding_mode: GridSamplePaddingMode,
351    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.
352    /// If `false`, they correspond to the corner points of the corner pixels
353    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).
354    pub align_corners: bool,
355}
356
357impl Default for GridSampleOptions {
358    fn default() -> Self {
359        Self {
360            mode: InterpolateMode::Bilinear,
361            padding_mode: GridSamplePaddingMode::Zeros,
362            align_corners: false,
363        }
364    }
365}
366
367impl From<InterpolateMode> for GridSampleOptions {
368    fn from(value: InterpolateMode) -> Self {
369        GridSampleOptions::new(value)
370    }
371}
372
373impl GridSampleOptions {
374    /// Create new grid sample options with the given interpolation mode.
375    ///
376    /// Uses default values for padding_mode (Zeros) and align_corners (false).
377    pub fn new(mode: InterpolateMode) -> Self {
378        Self {
379            mode,
380            ..Default::default()
381        }
382    }
383
384    /// Set the padding mode.
385    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
386        self.padding_mode = padding_mode;
387        self
388    }
389
390    /// Set align_corners.
391    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
392        self.align_corners = align_corners;
393        self
394    }
395}
396
397/// Padding mode for tensor pad operations.
398///
399/// Defines how values are filled when padding a tensor beyond its original boundaries.
400/// Padding can be applied to any dimension of a tensor.
401///
402/// # Modes
403///
404/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)
405/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)
406/// - [`Edge`](PadMode::Edge): Replicate boundary values
407#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
408pub enum PadMode {
409    /// Fill padded regions with a constant value.
410    ///
411    /// # Example
412    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:
413    /// Result: `[0, 0, 1, 2, 3]`
414    Constant(f32),
415
416    /// Reflect values at the boundary, excluding the edge value.
417    ///
418    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).
419    ///
420    /// # Example
421    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
422    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)
423    Reflect,
424
425    /// Replicate the edge values.
426    ///
427    /// # Example
428    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
429    /// Result: `[1, 1, 1, 2, 3, 4]`
430    Edge,
431}
432
433impl Default for PadMode {
434    fn default() -> Self {
435        PadMode::Constant(0.0)
436    }
437}
438
439impl<E: ElementConversion> From<E> for PadMode {
440    fn from(value: E) -> Self {
441        PadMode::Constant(value.elem())
442    }
443}
444
445/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
446#[derive(new)]
447pub struct InterpolateBackward<B: Backend> {
448    /// Gradient.
449    pub x_grad: FloatTensor<B>,
450}
451
452/// Options for [attention](ModuleOps::attention).
453#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]
454pub struct AttentionModuleOptions {
455    /// Custom scale factor applied to QK^T. When `None`, defaults to `1/sqrt(head_dim)`.
456    pub scale: Option<f64>,
457
458    /// Soft capping applied before softmax: `softcap * tanh(scores / softcap)`.
459    /// Used by Gemma-2 and similar models. Must be positive when set.
460    pub softcap: Option<f64>,
461
462    /// When `true`, applies causal (autoregressive) masking so that each query position
463    /// can only attend to key positions at or before it. This is more efficient than
464    /// passing an explicit lower-triangular bool mask because backends can use optimized
465    /// kernel paths (e.g. flash attention with causal mode).
466    pub is_causal: bool,
467}
468
469/// Module operations trait.
470pub trait ModuleOps<B: Backend> {
471    /// Embedding operation.
472    ///
473    /// # Arguments
474    ///
475    /// * `weights` - The embedding weights.
476    /// * `indices` - The indices tensor.
477    ///
478    /// # Returns
479    ///
480    /// The output tensor.
481    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
482        let [batch_size, seq_length] = indices.shape().dims();
483        let [_, d_model] = weights.shape().dims();
484
485        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
486        let output = B::float_select(weights, 0, indices);
487
488        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
489    }
490
491    /// Embedding backward operation.
492    ///
493    /// # Arguments
494    ///
495    /// * `weights` - The embedding weights.
496    /// * `output_grad` - The output gradient.
497    /// * `indices` - The indices tensor.
498    ///
499    /// # Returns
500    ///
501    /// The gradient.
502    fn embedding_backward(
503        weights: FloatTensor<B>,
504        output_grad: FloatTensor<B>,
505        indices: IntTensor<B>,
506    ) -> FloatTensor<B> {
507        let [batch_size, seq_length] = indices.shape().dims();
508        let [n_embeddings, d_model] = weights.shape().dims();
509        let device = B::float_device(&weights);
510        let dtype = output_grad.dtype();
511
512        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
513        let output_grad =
514            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
515        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
516
517        B::float_select_add(grad, 0, indices, output_grad)
518    }
519    /// One dimensional convolution.
520    ///
521    /// # Shapes
522    ///
523    /// x:      `[batch_size, channels_in, length]`,
524    /// weight: `[channels_out, channels_in, kernel_size]`,
525    /// bias:   `[channels_out]`,
526    fn conv1d(
527        x: FloatTensor<B>,
528        weight: FloatTensor<B>,
529        bias: Option<FloatTensor<B>>,
530        options: ConvOptions<1>,
531    ) -> FloatTensor<B> {
532        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
533    }
534    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
535    fn conv1d_x_backward(
536        x: FloatTensor<B>,
537        weight: FloatTensor<B>,
538        output_grad: FloatTensor<B>,
539        options: ConvOptions<1>,
540    ) -> FloatTensor<B> {
541        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
542    }
543    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
544    fn conv1d_weight_backward(
545        x: FloatTensor<B>,
546        weight: FloatTensor<B>,
547        output_grad: FloatTensor<B>,
548        options: ConvOptions<1>,
549    ) -> FloatTensor<B> {
550        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
551    }
552    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
553    fn conv1d_bias_backward(
554        x: FloatTensor<B>,
555        bias: FloatTensor<B>,
556        output_grad: FloatTensor<B>,
557    ) -> FloatTensor<B> {
558        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
559    }
560    /// Two dimensional convolution.
561    ///
562    /// # Shapes
563    ///
564    /// x:      `[batch_size, channels_in, height, width]`,
565    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
566    /// bias:   `[channels_out]`,
567    fn conv2d(
568        x: FloatTensor<B>,
569        weight: FloatTensor<B>,
570        bias: Option<FloatTensor<B>>,
571        options: ConvOptions<2>,
572    ) -> FloatTensor<B>;
573    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
574    fn conv2d_x_backward(
575        x: FloatTensor<B>,
576        weight: FloatTensor<B>,
577        output_grad: FloatTensor<B>,
578        options: ConvOptions<2>,
579    ) -> FloatTensor<B> {
580        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
581    }
582    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
583    fn conv2d_weight_backward(
584        x: FloatTensor<B>,
585        weight: FloatTensor<B>,
586        output_grad: FloatTensor<B>,
587        options: ConvOptions<2>,
588    ) -> FloatTensor<B> {
589        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
590    }
591    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
592    fn conv2d_bias_backward(
593        x: FloatTensor<B>,
594        bias: FloatTensor<B>,
595        output_grad: FloatTensor<B>,
596    ) -> FloatTensor<B> {
597        conv::conv2d_bias_backward::<B>(x, bias, output_grad)
598    }
599
600    /// Two dimensional deformable convolution.
601    ///
602    /// # Shapes
603    ///
604    /// x:      `[batch_size, channels_in, height, width]`,
605    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
606    /// bias:   `[channels_out]`,
607    fn deform_conv2d(
608        x: FloatTensor<B>,
609        offset: FloatTensor<B>,
610        weight: FloatTensor<B>,
611        mask: Option<FloatTensor<B>>,
612        bias: Option<FloatTensor<B>>,
613        options: DeformConvOptions<2>,
614    ) -> FloatTensor<B>;
615    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
616    fn deform_conv2d_backward(
617        x: FloatTensor<B>,
618        offset: FloatTensor<B>,
619        weight: FloatTensor<B>,
620        mask: Option<FloatTensor<B>>,
621        bias: Option<FloatTensor<B>>,
622        output_grad: FloatTensor<B>,
623        options: DeformConvOptions<2>,
624    ) -> DeformConv2dBackward<B>;
625
626    /// Three dimensional convolution.
627    ///
628    /// # Shapes
629    ///
630    /// x:      `[batch_size, channels_in, depth, height, width]`,
631    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
632    /// bias:   `[channels_out]`,
633    fn conv3d(
634        x: FloatTensor<B>,
635        weight: FloatTensor<B>,
636        bias: Option<FloatTensor<B>>,
637        options: ConvOptions<3>,
638    ) -> FloatTensor<B>;
639    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
640    fn conv3d_x_backward(
641        x: FloatTensor<B>,
642        weight: FloatTensor<B>,
643        output_grad: FloatTensor<B>,
644        options: ConvOptions<3>,
645    ) -> FloatTensor<B> {
646        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
647    }
648    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
649    fn conv3d_weight_backward(
650        x: FloatTensor<B>,
651        weight: FloatTensor<B>,
652        output_grad: FloatTensor<B>,
653        options: ConvOptions<3>,
654    ) -> FloatTensor<B> {
655        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
656    }
657    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
658    fn conv3d_bias_backward(
659        x: FloatTensor<B>,
660        bias: FloatTensor<B>,
661        output_grad: FloatTensor<B>,
662    ) -> FloatTensor<B> {
663        conv::conv3d_bias_backward::<B>(x, bias, output_grad)
664    }
665    /// One dimensional transposed convolution.
666    ///
667    /// # Shapes
668    ///
669    /// x:      `[batch_size, channels_in, length]`,
670    /// weight: `[channels_in, channels_out, length]`,
671    /// bias:   `[channels_out]`,
672    fn conv_transpose1d(
673        x: FloatTensor<B>,
674        weight: FloatTensor<B>,
675        bias: Option<FloatTensor<B>>,
676        options: ConvTransposeOptions<1>,
677    ) -> FloatTensor<B> {
678        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
679    }
680    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
681    fn conv_transpose1d_x_backward(
682        weight: FloatTensor<B>,
683        output_grad: FloatTensor<B>,
684        options: ConvTransposeOptions<1>,
685    ) -> FloatTensor<B> {
686        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
687    }
688    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
689    fn conv_transpose1d_weight_backward(
690        x: FloatTensor<B>,
691        weight: FloatTensor<B>,
692        output_grad: FloatTensor<B>,
693        options: ConvTransposeOptions<1>,
694    ) -> FloatTensor<B> {
695        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
696    }
697    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
698    fn conv_transpose1d_bias_backward(
699        x: FloatTensor<B>,
700        bias: FloatTensor<B>,
701        output_grad: FloatTensor<B>,
702    ) -> FloatTensor<B> {
703        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
704    }
705
706    /// Two dimensional transposed convolution.
707    ///
708    /// # Shapes
709    ///
710    /// x:      `[batch_size, channels_in, height, width]`,
711    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
712    /// bias:   `[channels_out]`,
713    fn conv_transpose2d(
714        x: FloatTensor<B>,
715        weight: FloatTensor<B>,
716        bias: Option<FloatTensor<B>>,
717        options: ConvTransposeOptions<2>,
718    ) -> FloatTensor<B>;
719    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
720    fn conv_transpose2d_x_backward(
721        weight: FloatTensor<B>,
722        output_grad: FloatTensor<B>,
723        options: ConvTransposeOptions<2>,
724    ) -> FloatTensor<B> {
725        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
726    }
727    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
728    fn conv_transpose2d_weight_backward(
729        x: FloatTensor<B>,
730        weight: FloatTensor<B>,
731        output_grad: FloatTensor<B>,
732        options: ConvTransposeOptions<2>,
733    ) -> FloatTensor<B> {
734        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
735    }
736    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
737    fn conv_transpose2d_bias_backward(
738        x: FloatTensor<B>,
739        bias: FloatTensor<B>,
740        output_grad: FloatTensor<B>,
741    ) -> FloatTensor<B> {
742        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
743    }
744
745    /// Three dimensional transposed convolution.
746    ///
747    /// # Shapes
748    ///
749    /// x:      `[batch_size, channels_in, height, width]`,
750    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
751    /// bias:   `[channels_out]`,
752    fn conv_transpose3d(
753        x: FloatTensor<B>,
754        weight: FloatTensor<B>,
755        bias: Option<FloatTensor<B>>,
756        options: ConvTransposeOptions<3>,
757    ) -> FloatTensor<B>;
758    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
759    fn conv_transpose3d_x_backward(
760        weight: FloatTensor<B>,
761        output_grad: FloatTensor<B>,
762        options: ConvTransposeOptions<3>,
763    ) -> FloatTensor<B> {
764        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
765    }
766    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
767    fn conv_transpose3d_weight_backward(
768        x: FloatTensor<B>,
769        weight: FloatTensor<B>,
770        output_grad: FloatTensor<B>,
771        options: ConvTransposeOptions<3>,
772    ) -> FloatTensor<B> {
773        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
774    }
775    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
776    fn conv_transpose3d_bias_backward(
777        x: FloatTensor<B>,
778        bias: FloatTensor<B>,
779        output_grad: FloatTensor<B>,
780    ) -> FloatTensor<B> {
781        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
782    }
783
784    /// Four-dimensional unfolding.
785    ///
786    /// # Shapes
787    ///
788    /// * x:      ``[batch_size, channels_in, height, width]``,
789    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
790    fn unfold4d(
791        x: FloatTensor<B>,
792        kernel_size: [usize; 2],
793        options: UnfoldOptions,
794    ) -> FloatTensor<B> {
795        if options.padding == [0, 0] && options.dilation == [1, 1] {
796            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
797            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
798
799            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
800
801            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
802            let shape = blocks.shape();
803
804            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
805
806            B::float_reshape(
807                blocks,
808                [
809                    shape[0],
810                    shape[1] * shape[2] * shape[3],
811                    shape[4] * shape[5],
812                ]
813                .into(),
814            )
815        } else {
816            unfold4d_using_conv2d::<B>(x, kernel_size, options)
817        }
818    }
819
820    /// One dimensional avg pooling.
821    ///
822    /// # Shapes
823    ///
824    /// x: [batch_size, channels, length],
825    fn avg_pool1d(
826        x: FloatTensor<B>,
827        kernel_size: usize,
828        stride: usize,
829        padding: usize,
830        count_include_pad: bool,
831        ceil_mode: bool,
832    ) -> FloatTensor<B> {
833        pool::avg_pool1d_from_2d::<B>(
834            x,
835            kernel_size,
836            stride,
837            padding,
838            count_include_pad,
839            ceil_mode,
840        )
841    }
842    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
843    fn avg_pool1d_backward(
844        x: FloatTensor<B>,
845        grad: FloatTensor<B>,
846        kernel_size: usize,
847        stride: usize,
848        padding: usize,
849        count_include_pad: bool,
850        ceil_mode: bool,
851    ) -> FloatTensor<B> {
852        pool::avg_pool1d_backward_from_2d::<B>(
853            x,
854            grad,
855            kernel_size,
856            stride,
857            padding,
858            count_include_pad,
859            ceil_mode,
860        )
861    }
862    /// Two dimensional avg pooling.
863    ///
864    /// # Shapes
865    ///
866    /// x: [batch_size, channels, height, width],
867    fn avg_pool2d(
868        x: FloatTensor<B>,
869        kernel_size: [usize; 2],
870        stride: [usize; 2],
871        padding: [usize; 2],
872        count_include_pad: bool,
873        ceil_mode: bool,
874    ) -> FloatTensor<B>;
875    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
876    fn avg_pool2d_backward(
877        x: FloatTensor<B>,
878        grad: FloatTensor<B>,
879        kernel_size: [usize; 2],
880        stride: [usize; 2],
881        padding: [usize; 2],
882        count_include_pad: bool,
883        ceil_mode: bool,
884    ) -> FloatTensor<B>;
885    /// Two dimensional adaptive avg pooling.
886    ///
887    /// # Shapes
888    ///
889    /// x: [batch_size, channels, height, width],
890    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
891    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
892    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
893    /// One dimensional adaptive avg pooling.
894    ///
895    /// # Shapes
896    ///
897    /// x: [batch_size, channels, length],
898    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
899        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
900    }
901    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
902    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
903        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
904    }
905    /// One dimensional max pooling.
906    ///
907    /// # Shapes
908    ///
909    /// x: [batch_size, channels, length],
910    fn max_pool1d(
911        x: FloatTensor<B>,
912        kernel_size: usize,
913        stride: usize,
914        padding: usize,
915        dilation: usize,
916        ceil_mode: bool,
917    ) -> FloatTensor<B> {
918        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
919    }
920
921    /// One dimensional max pooling with indices.
922    ///
923    /// # Shapes
924    ///
925    /// x: [batch_size, channels, height, width],
926    fn max_pool1d_with_indices(
927        x: FloatTensor<B>,
928        kernel_size: usize,
929        stride: usize,
930        padding: usize,
931        dilation: usize,
932        ceil_mode: bool,
933    ) -> MaxPool1dWithIndices<B> {
934        pool::max_pool1d_with_indices_from_2d::<B>(
935            x,
936            kernel_size,
937            stride,
938            padding,
939            dilation,
940            ceil_mode,
941        )
942    }
943    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
944    #[allow(clippy::too_many_arguments)]
945    fn max_pool1d_with_indices_backward(
946        x: FloatTensor<B>,
947        kernel_size: usize,
948        stride: usize,
949        padding: usize,
950        dilation: usize,
951        ceil_mode: bool,
952        output_grad: FloatTensor<B>,
953        indices: IntTensor<B>,
954    ) -> MaxPool1dBackward<B> {
955        pool::max_pool1d_with_indices_backward_from_2d::<B>(
956            x,
957            kernel_size,
958            stride,
959            padding,
960            dilation,
961            ceil_mode,
962            output_grad,
963            indices,
964        )
965    }
966
967    /// Two dimensional max pooling.
968    ///
969    /// # Shapes
970    ///
971    /// x: [batch_size, channels, height, width],
972    fn max_pool2d(
973        x: FloatTensor<B>,
974        kernel_size: [usize; 2],
975        stride: [usize; 2],
976        padding: [usize; 2],
977        dilation: [usize; 2],
978        ceil_mode: bool,
979    ) -> FloatTensor<B>;
980
981    /// Two dimensional max pooling with indices.
982    ///
983    /// # Shapes
984    ///
985    /// x: [batch_size, channels, height, width],
986    fn max_pool2d_with_indices(
987        x: FloatTensor<B>,
988        kernel_size: [usize; 2],
989        stride: [usize; 2],
990        padding: [usize; 2],
991        dilation: [usize; 2],
992        ceil_mode: bool,
993    ) -> MaxPool2dWithIndices<B>;
994    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
995    #[allow(clippy::too_many_arguments)]
996    fn max_pool2d_with_indices_backward(
997        x: FloatTensor<B>,
998        kernel_size: [usize; 2],
999        stride: [usize; 2],
1000        padding: [usize; 2],
1001        dilation: [usize; 2],
1002        ceil_mode: bool,
1003        output_grad: FloatTensor<B>,
1004        indices: IntTensor<B>,
1005    ) -> MaxPool2dBackward<B>;
1006
1007    /// Down/up samples the input.
1008    ///
1009    /// # Shapes
1010    ///
1011    /// x: `[batch_size, channels, height, width]`,
1012    fn interpolate(
1013        x: FloatTensor<B>,
1014        output_size: [usize; 2],
1015        options: InterpolateOptions,
1016    ) -> FloatTensor<B>;
1017
1018    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
1019    fn interpolate_backward(
1020        x: FloatTensor<B>,
1021        grad: FloatTensor<B>,
1022        output_size: [usize; 2],
1023        options: InterpolateOptions,
1024    ) -> FloatTensor<B>;
1025
1026    /// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
1027    /// where scale defaults to 1/sqrt(head_dim). Optionally applies masking,
1028    /// additive bias, causal masking, and softcap to the attention scores.
1029    ///
1030    /// # Arguments
1031    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
1032    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
1033    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
1034    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
1035    ///   where `true` indicates positions to mask (i.e. set to -inf before softmax).
1036    /// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
1037    ///   added to the attention scores before softmax (e.g. ALiBi, relative position biases).
1038    /// - `options`: Additional attention options (custom scale, softcap, causal masking).
1039    ///
1040    /// # Returns
1041    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
1042    /// representing the attended context per head.
1043    ///
1044    /// # Note
1045    /// This implementation does not support dropout and is intended for inference or
1046    /// use cases where dropout is not needed.
1047    fn attention(
1048        query: FloatTensor<B>,
1049        key: FloatTensor<B>,
1050        value: FloatTensor<B>,
1051        mask: Option<BoolTensor<B>>,
1052        attn_bias: Option<FloatTensor<B>>,
1053        options: AttentionModuleOptions,
1054    ) -> FloatTensor<B>;
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060
1061    #[test]
1062    #[should_panic = "stride must be non-zero"]
1063    fn conv_options_stride_zero() {
1064        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1065    }
1066
1067    #[test]
1068    #[should_panic = "dilation must be non-zero"]
1069    fn conv_options_dilation_zero() {
1070        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1071    }
1072
1073    #[test]
1074    #[should_panic = "groups must be non-zero"]
1075    fn conv_options_groups_zero() {
1076        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1077    }
1078
1079    #[test]
1080    #[should_panic = "stride must be non-zero"]
1081    fn conv_transpose_options_stride_zero() {
1082        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1083    }
1084
1085    #[test]
1086    #[should_panic = "dilation must be non-zero"]
1087    fn conv_transpose_options_dilation_zero() {
1088        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1089    }
1090
1091    #[test]
1092    #[should_panic = "groups must be non-zero"]
1093    fn conv_transpose_options_groups_zero() {
1094        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1095    }
1096
1097    #[test]
1098    #[should_panic = "stride must be non-zero"]
1099    fn deform_conv_options_stride_zero() {
1100        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1101    }
1102
1103    #[test]
1104    #[should_panic = "dilation must be non-zero"]
1105    fn deform_conv_options_dilation_zero() {
1106        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1107    }
1108
1109    #[test]
1110    #[should_panic = "weight groups must be non-zero"]
1111    fn deform_conv_options_weights_groups_zero() {
1112        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1113    }
1114
1115    #[test]
1116    #[should_panic = "offset groups must be non-zero"]
1117    fn deform_conv_options_offset_groups_zero() {
1118        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1119    }
1120
1121    #[test]
1122    #[should_panic = "stride must be non-zero"]
1123    fn unfold_options_stride_zero() {
1124        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1125    }
1126
1127    #[test]
1128    #[should_panic = "dilation must be non-zero"]
1129    fn unfold_options_dilation_zero() {
1130        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1131    }
1132}