Skip to main content

burn_backend/backend/ops/modules/
base.rs

1use super::{conv, pool};
2use crate::ops::unfold::unfold4d_using_conv2d;
3use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
4use crate::{Backend, ElementConversion, TensorMetadata};
5use burn_std::Shape;
6use core::num::NonZeroUsize;
7
8/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
9#[derive(new)]
10pub struct Conv2dBackward<B: Backend> {
11    /// Gradient.
12    pub x_grad: FloatTensor<B>,
13
14    /// Weights gradient.
15    pub weights_grad: FloatTensor<B>,
16
17    /// Bias gradient.
18    pub bias_grad: Option<FloatTensor<B>>,
19}
20
21/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
22#[derive(new)]
23pub struct DeformConv2dBackward<B: Backend> {
24    /// Gradient.
25    pub x_grad: FloatTensor<B>,
26
27    /// Offset gradient.
28    pub offset_grad: FloatTensor<B>,
29
30    /// Weights gradient.
31    pub weight_grad: FloatTensor<B>,
32
33    /// Mask gradient.
34    pub mask_grad: Option<FloatTensor<B>>,
35
36    /// Bias gradient.
37    pub bias_grad: Option<FloatTensor<B>>,
38}
39
40/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
41#[derive(new)]
42pub struct Conv3dBackward<B: Backend> {
43    /// Gradient.
44    pub x_grad: FloatTensor<B>,
45
46    /// Weights gradient.
47    pub weights_grad: FloatTensor<B>,
48
49    /// Bias gradient.
50    pub bias_grad: Option<FloatTensor<B>>,
51}
52
53/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
54#[derive(new)]
55pub struct MaxPool1dBackward<B: Backend> {
56    /// Gradient.
57    pub x_grad: FloatTensor<B>,
58}
59
60/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).
61#[derive(new)]
62pub struct MaxPool1dWithIndices<B: Backend> {
63    /// The output tensor.
64    pub output: FloatTensor<B>,
65
66    /// The indices tensor.
67    pub indices: IntTensor<B>,
68}
69
70/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
71#[derive(new)]
72pub struct MaxPool2dBackward<B: Backend> {
73    /// Gradient.
74    pub x_grad: FloatTensor<B>,
75}
76
77/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
78#[derive(new)]
79pub struct MaxPool2dWithIndices<B: Backend> {
80    /// The output tensor.
81    pub output: FloatTensor<B>,
82
83    /// The indices tensor.
84    pub indices: IntTensor<B>,
85}
86
87/// Check that the parameter value is non-zero.
88// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.
89pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
90    NonZeroUsize::new(value).expect(msg);
91    value
92}
93
94/// Convolution options.
95#[derive(Debug, Clone, Hash, PartialEq, Eq)]
96pub struct ConvOptions<const N: usize> {
97    /// Stride (non-zero).
98    pub stride: [usize; N],
99
100    /// Padding.
101    pub padding: [usize; N],
102
103    /// Dilation (non-zero).
104    pub dilation: [usize; N],
105
106    /// Groups (non-zero).
107    pub groups: usize,
108}
109
110impl<const N: usize> ConvOptions<N> {
111    /// Constructs a new `ConvOptions`.
112    pub fn new(
113        stride: [usize; N],
114        padding: [usize; N],
115        dilation: [usize; N],
116        groups: usize,
117    ) -> Self {
118        Self {
119            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
120            padding,
121            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
122            groups: check_nonzero(groups, "groups must be non-zero"),
123        }
124    }
125}
126
127/// Convolution options with support for asymmetric padding.
128///
129/// Wraps [`ConvOptions`] (which represents symmetric padding for the backend op)
130/// and adds optional asymmetric padding. When asymmetric padding is specified,
131/// the functional convolution layer applies an explicit pad operation before
132/// dispatching to the backend.
133///
134/// Implements `From<ConvOptions<N>>` for backward compatibility.
135#[derive(Debug, Clone)]
136pub struct PaddedConvOptions<const N: usize> {
137    /// The underlying convolution options for the backend.
138    pub options: ConvOptions<N>,
139    /// Padding at the end of each dimension (e.g., bottom/right for 2D).
140    /// If `None`, padding is symmetric (same as `options.padding`).
141    /// If `Some`, specifies different end-padding per dimension.
142    pub padding_end: Option<[usize; N]>,
143}
144
145impl<const N: usize> PaddedConvOptions<N> {
146    /// Creates options with asymmetric padding.
147    ///
148    /// `padding_start` is stored in `ConvOptions::padding`.
149    /// `padding_end` specifies the end padding per dimension.
150    pub fn asymmetric(
151        stride: [usize; N],
152        padding_start: [usize; N],
153        padding_end: [usize; N],
154        dilation: [usize; N],
155        groups: usize,
156    ) -> Self {
157        let options = ConvOptions::new(stride, padding_start, dilation, groups);
158        if padding_start == padding_end {
159            Self {
160                options,
161                padding_end: None,
162            }
163        } else {
164            Self {
165                options,
166                padding_end: Some(padding_end),
167            }
168        }
169    }
170
171    /// Returns true if padding is asymmetric.
172    pub fn is_asymmetric(&self) -> bool {
173        self.padding_end.is_some()
174    }
175}
176
177impl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {
178    fn from(options: ConvOptions<N>) -> Self {
179        Self {
180            options,
181            padding_end: None,
182        }
183    }
184}
185
186/// Convolution options.
187#[derive(Debug, Clone, Hash, PartialEq, Eq)]
188pub struct DeformConvOptions<const N: usize> {
189    /// Stride (non-zero).
190    pub stride: [usize; N],
191
192    /// Padding.
193    pub padding: [usize; N],
194
195    /// Dilation (non-zero).
196    pub dilation: [usize; N],
197
198    /// Weight Groups (non-zero).
199    pub weight_groups: usize,
200
201    /// Offset Groups (non-zero).
202    pub offset_groups: usize,
203}
204
205impl<const N: usize> DeformConvOptions<N> {
206    /// Constructs a new `DeformConvOptions`.
207    pub fn new(
208        stride: [usize; N],
209        padding: [usize; N],
210        dilation: [usize; N],
211        weight_groups: usize,
212        offset_groups: usize,
213    ) -> Self {
214        Self {
215            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
216            padding,
217            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
218            weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
219            offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
220        }
221    }
222}
223
224/// Transposed convolution options.
225#[derive(Debug, Clone, Hash, PartialEq, Eq)]
226pub struct ConvTransposeOptions<const N: usize> {
227    /// Stride (non-zero).
228    pub stride: [usize; N],
229
230    /// Padding.
231    pub padding: [usize; N],
232
233    /// Padding out.
234    pub padding_out: [usize; N],
235
236    /// Dilation (non-zero).
237    pub dilation: [usize; N],
238
239    /// Groups (non-zero).
240    pub groups: usize,
241}
242
243impl<const N: usize> ConvTransposeOptions<N> {
244    /// Constructs a new `ConvTransposeOptions`.
245    pub fn new(
246        stride: [usize; N],
247        padding: [usize; N],
248        padding_out: [usize; N],
249        dilation: [usize; N],
250        groups: usize,
251    ) -> Self {
252        Self {
253            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
254            padding,
255            padding_out,
256            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
257            groups: check_nonzero(groups, "groups must be non-zero"),
258        }
259    }
260}
261
262/// Unfold operation options.
263#[derive(Debug, Clone)]
264pub struct UnfoldOptions {
265    /// The number of positions to slide over the input tensor in each dimension.
266    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.
267    pub stride: [usize; 2],
268
269    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.
270    pub padding: [usize; 2],
271
272    /// The spacing between the blocks (patches) in the original input tensor.
273    pub dilation: [usize; 2],
274}
275
276impl UnfoldOptions {
277    /// Constructs a new `UnfoldOptions`.
278    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
279        Self {
280            stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
281            padding,
282            dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
283        }
284    }
285}
286
287/// Algorithm used for upsampling.
288#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
289pub enum InterpolateMode {
290    /// Nearest-neighbor interpolation.
291    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
292    Nearest,
293
294    /// Bilinear interpolation.
295    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
296    Bilinear,
297
298    /// Bicubic interpolation.
299    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>
300    Bicubic,
301
302    /// Lanczos3 interpolation (6-tap sinc-based filter).
303    /// <https://en.wikipedia.org/wiki/Lanczos_resampling>
304    Lanczos3,
305}
306
307/// Interpolation options.
308#[derive(Debug, Clone)]
309pub struct InterpolateOptions {
310    /// Algorithm used for upsampling.
311    pub mode: InterpolateMode,
312    /// If `true`, the input and output tensors are aligned by their corner pixels.
313    /// If `false`, half-pixel coordinate mapping is used instead.
314    pub align_corners: bool,
315}
316
317impl InterpolateOptions {
318    /// Create new interpolate options with the given mode.
319    /// Defaults to `align_corners = true`.
320    pub fn new(mode: InterpolateMode) -> Self {
321        Self {
322            mode,
323            align_corners: true,
324        }
325    }
326
327    /// Set align_corners.
328    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
329        self.align_corners = align_corners;
330        self
331    }
332}
333
334/// Padding mode for grid sampling when coordinates are out of bounds.
335///
336/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
338pub enum GridSamplePaddingMode {
339    /// Fill with zeros for out-of-bounds coordinates.
340    #[default]
341    Zeros,
342    /// Clamp coordinates to the border (use nearest edge value).
343    Border,
344    /// Reflect coordinates at the boundary.
345    Reflection,
346}
347
348/// Options for grid sampling operations.
349#[derive(Debug, Clone)]
350pub struct GridSampleOptions {
351    /// Interpolation mode (bilinear, nearest, or bicubic).
352    pub mode: InterpolateMode,
353    /// Padding mode for out-of-bounds coordinates.
354    pub padding_mode: GridSamplePaddingMode,
355    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.
356    /// If `false`, they correspond to the corner points of the corner pixels
357    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).
358    pub align_corners: bool,
359}
360
361impl Default for GridSampleOptions {
362    fn default() -> Self {
363        Self {
364            mode: InterpolateMode::Bilinear,
365            padding_mode: GridSamplePaddingMode::Zeros,
366            align_corners: false,
367        }
368    }
369}
370
371impl From<InterpolateMode> for GridSampleOptions {
372    fn from(value: InterpolateMode) -> Self {
373        GridSampleOptions::new(value)
374    }
375}
376
377impl GridSampleOptions {
378    /// Create new grid sample options with the given interpolation mode.
379    ///
380    /// Uses default values for padding_mode (Zeros) and align_corners (false).
381    pub fn new(mode: InterpolateMode) -> Self {
382        Self {
383            mode,
384            ..Default::default()
385        }
386    }
387
388    /// Set the padding mode.
389    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
390        self.padding_mode = padding_mode;
391        self
392    }
393
394    /// Set align_corners.
395    pub fn with_align_corners(mut self, align_corners: bool) -> Self {
396        self.align_corners = align_corners;
397        self
398    }
399}
400
401/// Padding mode for tensor pad operations.
402///
403/// Defines how values are filled when padding a tensor beyond its original boundaries.
404/// Padding can be applied to any dimension of a tensor.
405///
406/// # Modes
407///
408/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)
409/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)
410/// - [`Edge`](PadMode::Edge): Replicate boundary values
411#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
412pub enum PadMode {
413    /// Fill padded regions with a constant value.
414    ///
415    /// # Example
416    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:
417    /// Result: `[0, 0, 1, 2, 3]`
418    Constant(f32),
419
420    /// Reflect values at the boundary, excluding the edge value.
421    ///
422    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).
423    ///
424    /// # Example
425    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
426    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)
427    Reflect,
428
429    /// Replicate the edge values.
430    ///
431    /// # Example
432    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:
433    /// Result: `[1, 1, 1, 2, 3, 4]`
434    Edge,
435}
436
437impl Default for PadMode {
438    fn default() -> Self {
439        PadMode::Constant(0.0)
440    }
441}
442
443impl<E: ElementConversion> From<E> for PadMode {
444    fn from(value: E) -> Self {
445        PadMode::Constant(value.elem())
446    }
447}
448
449/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).
450#[derive(new)]
451pub struct InterpolateBackward<B: Backend> {
452    /// Gradient.
453    pub x_grad: FloatTensor<B>,
454}
455
456/// Options for [attention](ModuleOps::attention).
457#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]
458pub struct AttentionModuleOptions {
459    /// Custom scale factor applied to QK^T. When `None`, defaults to `1/sqrt(head_dim)`.
460    pub scale: Option<f64>,
461
462    /// Soft capping applied before softmax: `softcap * tanh(scores / softcap)`.
463    /// Used by Gemma-2 and similar models. Must be positive when set.
464    pub softcap: Option<f64>,
465
466    /// When `true`, applies causal (autoregressive) masking so that each query position
467    /// can only attend to key positions at or before it. This is more efficient than
468    /// passing an explicit lower-triangular bool mask because backends can use optimized
469    /// kernel paths (e.g. flash attention with causal mode).
470    pub is_causal: bool,
471}
472
473/// Module operations trait.
474pub trait ModuleOps<B: Backend> {
475    /// Embedding operation.
476    ///
477    /// # Arguments
478    ///
479    /// * `weights` - The embedding weights.
480    /// * `indices` - The indices tensor.
481    ///
482    /// # Returns
483    ///
484    /// The output tensor.
485    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
486        let [batch_size, seq_length] = indices.shape().dims();
487        let [_, d_model] = weights.shape().dims();
488
489        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
490        let output = B::float_select(weights, 0, indices);
491
492        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
493    }
494
495    /// Embedding backward operation.
496    ///
497    /// # Arguments
498    ///
499    /// * `weights` - The embedding weights.
500    /// * `output_grad` - The output gradient.
501    /// * `indices` - The indices tensor.
502    ///
503    /// # Returns
504    ///
505    /// The gradient.
506    fn embedding_backward(
507        weights: FloatTensor<B>,
508        output_grad: FloatTensor<B>,
509        indices: IntTensor<B>,
510    ) -> FloatTensor<B> {
511        let [batch_size, seq_length] = indices.shape().dims();
512        let [n_embeddings, d_model] = weights.shape().dims();
513        let device = B::float_device(&weights);
514        let dtype = output_grad.dtype();
515
516        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
517        let output_grad =
518            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
519        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
520
521        B::float_select_add(grad, 0, indices, output_grad)
522    }
523    /// One dimensional convolution.
524    ///
525    /// # Shapes
526    ///
527    /// x:      `[batch_size, channels_in, length]`,
528    /// weight: `[channels_out, channels_in, kernel_size]`,
529    /// bias:   `[channels_out]`,
530    fn conv1d(
531        x: FloatTensor<B>,
532        weight: FloatTensor<B>,
533        bias: Option<FloatTensor<B>>,
534        options: ConvOptions<1>,
535    ) -> FloatTensor<B> {
536        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
537    }
538    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.
539    fn conv1d_x_backward(
540        x: FloatTensor<B>,
541        weight: FloatTensor<B>,
542        output_grad: FloatTensor<B>,
543        options: ConvOptions<1>,
544    ) -> FloatTensor<B> {
545        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
546    }
547    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.
548    fn conv1d_weight_backward(
549        x: FloatTensor<B>,
550        weight: FloatTensor<B>,
551        output_grad: FloatTensor<B>,
552        options: ConvOptions<1>,
553    ) -> FloatTensor<B> {
554        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
555    }
556    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.
557    fn conv1d_bias_backward(
558        x: FloatTensor<B>,
559        bias: FloatTensor<B>,
560        output_grad: FloatTensor<B>,
561    ) -> FloatTensor<B> {
562        conv::conv1d_bias_backward::<B>(x, bias, output_grad)
563    }
564    /// Two dimensional convolution.
565    ///
566    /// # Shapes
567    ///
568    /// x:      `[batch_size, channels_in, height, width]`,
569    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
570    /// bias:   `[channels_out]`,
571    fn conv2d(
572        x: FloatTensor<B>,
573        weight: FloatTensor<B>,
574        bias: Option<FloatTensor<B>>,
575        options: ConvOptions<2>,
576    ) -> FloatTensor<B>;
577    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.
578    fn conv2d_x_backward(
579        x: FloatTensor<B>,
580        weight: FloatTensor<B>,
581        output_grad: FloatTensor<B>,
582        options: ConvOptions<2>,
583    ) -> FloatTensor<B> {
584        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
585    }
586    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.
587    fn conv2d_weight_backward(
588        x: FloatTensor<B>,
589        weight: FloatTensor<B>,
590        output_grad: FloatTensor<B>,
591        options: ConvOptions<2>,
592    ) -> FloatTensor<B> {
593        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
594    }
595    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.
596    fn conv2d_bias_backward(
597        x: FloatTensor<B>,
598        bias: FloatTensor<B>,
599        output_grad: FloatTensor<B>,
600    ) -> FloatTensor<B> {
601        conv::conv2d_bias_backward::<B>(x, bias, output_grad)
602    }
603
604    /// Two dimensional deformable convolution.
605    ///
606    /// # Shapes
607    ///
608    /// x:      `[batch_size, channels_in, height, width]`,
609    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
610    /// bias:   `[channels_out]`,
611    fn deform_conv2d(
612        x: FloatTensor<B>,
613        offset: FloatTensor<B>,
614        weight: FloatTensor<B>,
615        mask: Option<FloatTensor<B>>,
616        bias: Option<FloatTensor<B>>,
617        options: DeformConvOptions<2>,
618    ) -> FloatTensor<B>;
619    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
620    fn deform_conv2d_backward(
621        x: FloatTensor<B>,
622        offset: FloatTensor<B>,
623        weight: FloatTensor<B>,
624        mask: Option<FloatTensor<B>>,
625        bias: Option<FloatTensor<B>>,
626        output_grad: FloatTensor<B>,
627        options: DeformConvOptions<2>,
628    ) -> DeformConv2dBackward<B>;
629
630    /// Three dimensional convolution.
631    ///
632    /// # Shapes
633    ///
634    /// x:      `[batch_size, channels_in, depth, height, width]`,
635    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,
636    /// bias:   `[channels_out]`,
637    fn conv3d(
638        x: FloatTensor<B>,
639        weight: FloatTensor<B>,
640        bias: Option<FloatTensor<B>>,
641        options: ConvOptions<3>,
642    ) -> FloatTensor<B>;
643    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.
644    fn conv3d_x_backward(
645        x: FloatTensor<B>,
646        weight: FloatTensor<B>,
647        output_grad: FloatTensor<B>,
648        options: ConvOptions<3>,
649    ) -> FloatTensor<B> {
650        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
651    }
652    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.
653    fn conv3d_weight_backward(
654        x: FloatTensor<B>,
655        weight: FloatTensor<B>,
656        output_grad: FloatTensor<B>,
657        options: ConvOptions<3>,
658    ) -> FloatTensor<B> {
659        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
660    }
661    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.
662    fn conv3d_bias_backward(
663        x: FloatTensor<B>,
664        bias: FloatTensor<B>,
665        output_grad: FloatTensor<B>,
666    ) -> FloatTensor<B> {
667        conv::conv3d_bias_backward::<B>(x, bias, output_grad)
668    }
669    /// One dimensional transposed convolution.
670    ///
671    /// # Shapes
672    ///
673    /// x:      `[batch_size, channels_in, length]`,
674    /// weight: `[channels_in, channels_out, length]`,
675    /// bias:   `[channels_out]`,
676    fn conv_transpose1d(
677        x: FloatTensor<B>,
678        weight: FloatTensor<B>,
679        bias: Option<FloatTensor<B>>,
680        options: ConvTransposeOptions<1>,
681    ) -> FloatTensor<B> {
682        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
683    }
684    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.
685    fn conv_transpose1d_x_backward(
686        weight: FloatTensor<B>,
687        output_grad: FloatTensor<B>,
688        options: ConvTransposeOptions<1>,
689    ) -> FloatTensor<B> {
690        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
691    }
692    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.
693    fn conv_transpose1d_weight_backward(
694        x: FloatTensor<B>,
695        weight: FloatTensor<B>,
696        output_grad: FloatTensor<B>,
697        options: ConvTransposeOptions<1>,
698    ) -> FloatTensor<B> {
699        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
700    }
701    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.
702    fn conv_transpose1d_bias_backward(
703        x: FloatTensor<B>,
704        bias: FloatTensor<B>,
705        output_grad: FloatTensor<B>,
706    ) -> FloatTensor<B> {
707        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
708    }
709
710    /// Two dimensional transposed convolution.
711    ///
712    /// # Shapes
713    ///
714    /// x:      `[batch_size, channels_in, height, width]`,
715    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
716    /// bias:   `[channels_out]`,
717    fn conv_transpose2d(
718        x: FloatTensor<B>,
719        weight: FloatTensor<B>,
720        bias: Option<FloatTensor<B>>,
721        options: ConvTransposeOptions<2>,
722    ) -> FloatTensor<B>;
723    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.
724    fn conv_transpose2d_x_backward(
725        weight: FloatTensor<B>,
726        output_grad: FloatTensor<B>,
727        options: ConvTransposeOptions<2>,
728    ) -> FloatTensor<B> {
729        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
730    }
731    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.
732    fn conv_transpose2d_weight_backward(
733        x: FloatTensor<B>,
734        weight: FloatTensor<B>,
735        output_grad: FloatTensor<B>,
736        options: ConvTransposeOptions<2>,
737    ) -> FloatTensor<B> {
738        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
739    }
740    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.
741    fn conv_transpose2d_bias_backward(
742        x: FloatTensor<B>,
743        bias: FloatTensor<B>,
744        output_grad: FloatTensor<B>,
745    ) -> FloatTensor<B> {
746        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
747    }
748
749    /// Three dimensional transposed convolution.
750    ///
751    /// # Shapes
752    ///
753    /// x:      `[batch_size, channels_in, height, width]`,
754    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,
755    /// bias:   `[channels_out]`,
756    fn conv_transpose3d(
757        x: FloatTensor<B>,
758        weight: FloatTensor<B>,
759        bias: Option<FloatTensor<B>>,
760        options: ConvTransposeOptions<3>,
761    ) -> FloatTensor<B>;
762    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.
763    fn conv_transpose3d_x_backward(
764        weight: FloatTensor<B>,
765        output_grad: FloatTensor<B>,
766        options: ConvTransposeOptions<3>,
767    ) -> FloatTensor<B> {
768        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
769    }
770    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.
771    fn conv_transpose3d_weight_backward(
772        x: FloatTensor<B>,
773        weight: FloatTensor<B>,
774        output_grad: FloatTensor<B>,
775        options: ConvTransposeOptions<3>,
776    ) -> FloatTensor<B> {
777        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
778    }
779    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.
780    fn conv_transpose3d_bias_backward(
781        x: FloatTensor<B>,
782        bias: FloatTensor<B>,
783        output_grad: FloatTensor<B>,
784    ) -> FloatTensor<B> {
785        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
786    }
787
788    /// Four-dimensional unfolding.
789    ///
790    /// # Shapes
791    ///
792    /// * x:      ``[batch_size, channels_in, height, width]``,
793    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
794    fn unfold4d(
795        x: FloatTensor<B>,
796        kernel_size: [usize; 2],
797        options: UnfoldOptions,
798    ) -> FloatTensor<B> {
799        if options.padding == [0, 0] && options.dilation == [1, 1] {
800            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
801            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
802
803            // batch, channels, h_blocks, w_blocks, h_kern, w_kern
804
805            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
806            let shape = blocks.shape();
807
808            // batch, channels, h_kern, w_kern, h_blocks, w_blocks
809
810            B::float_reshape(
811                blocks,
812                [
813                    shape[0],
814                    shape[1] * shape[2] * shape[3],
815                    shape[4] * shape[5],
816                ]
817                .into(),
818            )
819        } else {
820            unfold4d_using_conv2d::<B>(x, kernel_size, options)
821        }
822    }
823
824    /// One dimensional avg pooling.
825    ///
826    /// # Shapes
827    ///
828    /// x: [batch_size, channels, length],
829    fn avg_pool1d(
830        x: FloatTensor<B>,
831        kernel_size: usize,
832        stride: usize,
833        padding: usize,
834        count_include_pad: bool,
835        ceil_mode: bool,
836    ) -> FloatTensor<B> {
837        pool::avg_pool1d_from_2d::<B>(
838            x,
839            kernel_size,
840            stride,
841            padding,
842            count_include_pad,
843            ceil_mode,
844        )
845    }
846    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.
847    fn avg_pool1d_backward(
848        x: FloatTensor<B>,
849        grad: FloatTensor<B>,
850        kernel_size: usize,
851        stride: usize,
852        padding: usize,
853        count_include_pad: bool,
854        ceil_mode: bool,
855    ) -> FloatTensor<B> {
856        pool::avg_pool1d_backward_from_2d::<B>(
857            x,
858            grad,
859            kernel_size,
860            stride,
861            padding,
862            count_include_pad,
863            ceil_mode,
864        )
865    }
866    /// Two dimensional avg pooling.
867    ///
868    /// # Shapes
869    ///
870    /// x: [batch_size, channels, height, width],
871    fn avg_pool2d(
872        x: FloatTensor<B>,
873        kernel_size: [usize; 2],
874        stride: [usize; 2],
875        padding: [usize; 2],
876        count_include_pad: bool,
877        ceil_mode: bool,
878    ) -> FloatTensor<B>;
879    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.
880    fn avg_pool2d_backward(
881        x: FloatTensor<B>,
882        grad: FloatTensor<B>,
883        kernel_size: [usize; 2],
884        stride: [usize; 2],
885        padding: [usize; 2],
886        count_include_pad: bool,
887        ceil_mode: bool,
888    ) -> FloatTensor<B>;
889    /// Two dimensional adaptive avg pooling.
890    ///
891    /// # Shapes
892    ///
893    /// x: [batch_size, channels, height, width],
894    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
895    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.
896    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
897    /// One dimensional adaptive avg pooling.
898    ///
899    /// # Shapes
900    ///
901    /// x: [batch_size, channels, length],
902    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
903        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
904    }
905    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.
906    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
907        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
908    }
909    /// One dimensional max pooling.
910    ///
911    /// # Shapes
912    ///
913    /// x: [batch_size, channels, length],
914    fn max_pool1d(
915        x: FloatTensor<B>,
916        kernel_size: usize,
917        stride: usize,
918        padding: usize,
919        dilation: usize,
920        ceil_mode: bool,
921    ) -> FloatTensor<B> {
922        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
923    }
924
925    /// One dimensional max pooling with indices.
926    ///
927    /// # Shapes
928    ///
929    /// x: [batch_size, channels, height, width],
930    fn max_pool1d_with_indices(
931        x: FloatTensor<B>,
932        kernel_size: usize,
933        stride: usize,
934        padding: usize,
935        dilation: usize,
936        ceil_mode: bool,
937    ) -> MaxPool1dWithIndices<B> {
938        pool::max_pool1d_with_indices_from_2d::<B>(
939            x,
940            kernel_size,
941            stride,
942            padding,
943            dilation,
944            ceil_mode,
945        )
946    }
947    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.
948    #[allow(clippy::too_many_arguments)]
949    fn max_pool1d_with_indices_backward(
950        x: FloatTensor<B>,
951        kernel_size: usize,
952        stride: usize,
953        padding: usize,
954        dilation: usize,
955        ceil_mode: bool,
956        output_grad: FloatTensor<B>,
957        indices: IntTensor<B>,
958    ) -> MaxPool1dBackward<B> {
959        pool::max_pool1d_with_indices_backward_from_2d::<B>(
960            x,
961            kernel_size,
962            stride,
963            padding,
964            dilation,
965            ceil_mode,
966            output_grad,
967            indices,
968        )
969    }
970
971    /// Two dimensional max pooling.
972    ///
973    /// # Shapes
974    ///
975    /// x: [batch_size, channels, height, width],
976    fn max_pool2d(
977        x: FloatTensor<B>,
978        kernel_size: [usize; 2],
979        stride: [usize; 2],
980        padding: [usize; 2],
981        dilation: [usize; 2],
982        ceil_mode: bool,
983    ) -> FloatTensor<B>;
984
985    /// Two dimensional max pooling with indices.
986    ///
987    /// # Shapes
988    ///
989    /// x: [batch_size, channels, height, width],
990    fn max_pool2d_with_indices(
991        x: FloatTensor<B>,
992        kernel_size: [usize; 2],
993        stride: [usize; 2],
994        padding: [usize; 2],
995        dilation: [usize; 2],
996        ceil_mode: bool,
997    ) -> MaxPool2dWithIndices<B>;
998    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
999    #[allow(clippy::too_many_arguments)]
1000    fn max_pool2d_with_indices_backward(
1001        x: FloatTensor<B>,
1002        kernel_size: [usize; 2],
1003        stride: [usize; 2],
1004        padding: [usize; 2],
1005        dilation: [usize; 2],
1006        ceil_mode: bool,
1007        output_grad: FloatTensor<B>,
1008        indices: IntTensor<B>,
1009    ) -> MaxPool2dBackward<B>;
1010
1011    /// Down/up samples the input.
1012    ///
1013    /// # Shapes
1014    ///
1015    /// x: `[batch_size, channels, height, width]`,
1016    fn interpolate(
1017        x: FloatTensor<B>,
1018        output_size: [usize; 2],
1019        options: InterpolateOptions,
1020    ) -> FloatTensor<B>;
1021
1022    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.
1023    fn interpolate_backward(
1024        x: FloatTensor<B>,
1025        grad: FloatTensor<B>,
1026        output_size: [usize; 2],
1027        options: InterpolateOptions,
1028    ) -> FloatTensor<B>;
1029
1030    /// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
1031    /// where scale defaults to 1/sqrt(head_dim). Optionally applies masking,
1032    /// additive bias, causal masking, and softcap to the attention scores.
1033    ///
1034    /// # Arguments
1035    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
1036    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
1037    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
1038    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
1039    ///   where `true` indicates positions to mask (i.e. set to -inf before softmax).
1040    /// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
1041    ///   added to the attention scores before softmax (e.g. ALiBi, relative position biases).
1042    /// - `options`: Additional attention options (custom scale, softcap, causal masking).
1043    ///
1044    /// # Returns
1045    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
1046    /// representing the attended context per head.
1047    ///
1048    /// # Note
1049    /// This implementation does not support dropout and is intended for inference or
1050    /// use cases where dropout is not needed.
1051    fn attention(
1052        query: FloatTensor<B>,
1053        key: FloatTensor<B>,
1054        value: FloatTensor<B>,
1055        mask: Option<BoolTensor<B>>,
1056        attn_bias: Option<FloatTensor<B>>,
1057        options: AttentionModuleOptions,
1058    ) -> FloatTensor<B>;
1059
1060    /// Real-valued fast Fourier transform.
1061    ///
1062    /// Computes the discrete Fourier transform of a real-valued input along the given dimension.
1063    /// The transform is applied independently for each slice along `dim`, returning the non-redundant
1064    /// frequency components as separate real and imaginary tensors.
1065    /// #Returns
1066    /// two tensors, the first is the real part, the second is the imaginary
1067    fn rfft(signal: FloatTensor<B>, dim: usize) -> (FloatTensor<B>, FloatTensor<B>);
1068
1069    /// Inverse real-valued fast Fourier transform.
1070    ///
1071    /// Computes the inverse discrete Fourier transform from a frequency-domain
1072    /// representation given as separate real and imaginary components.
1073    /// The transform is applied independently for each slice along `dim`.
1074    fn irfft(
1075        spectrum_re: FloatTensor<B>,
1076        spectrum_im: FloatTensor<B>,
1077        dim: usize,
1078    ) -> FloatTensor<B>;
1079}
1080
1081#[cfg(test)]
1082mod tests {
1083    use super::*;
1084
1085    #[test]
1086    #[should_panic = "stride must be non-zero"]
1087    fn conv_options_stride_zero() {
1088        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
1089    }
1090
1091    #[test]
1092    #[should_panic = "dilation must be non-zero"]
1093    fn conv_options_dilation_zero() {
1094        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
1095    }
1096
1097    #[test]
1098    #[should_panic = "groups must be non-zero"]
1099    fn conv_options_groups_zero() {
1100        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
1101    }
1102
1103    #[test]
1104    #[should_panic = "stride must be non-zero"]
1105    fn conv_transpose_options_stride_zero() {
1106        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
1107    }
1108
1109    #[test]
1110    #[should_panic = "dilation must be non-zero"]
1111    fn conv_transpose_options_dilation_zero() {
1112        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
1113    }
1114
1115    #[test]
1116    #[should_panic = "groups must be non-zero"]
1117    fn conv_transpose_options_groups_zero() {
1118        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
1119    }
1120
1121    #[test]
1122    #[should_panic = "stride must be non-zero"]
1123    fn deform_conv_options_stride_zero() {
1124        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
1125    }
1126
1127    #[test]
1128    #[should_panic = "dilation must be non-zero"]
1129    fn deform_conv_options_dilation_zero() {
1130        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
1131    }
1132
1133    #[test]
1134    #[should_panic = "weight groups must be non-zero"]
1135    fn deform_conv_options_weights_groups_zero() {
1136        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
1137    }
1138
1139    #[test]
1140    #[should_panic = "offset groups must be non-zero"]
1141    fn deform_conv_options_offset_groups_zero() {
1142        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
1143    }
1144
1145    #[test]
1146    #[should_panic = "stride must be non-zero"]
1147    fn unfold_options_stride_zero() {
1148        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
1149    }
1150
1151    #[test]
1152    #[should_panic = "dilation must be non-zero"]
1153    fn unfold_options_dilation_zero() {
1154        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
1155    }
1156}