Skip to main content

burn_backend/backend/ops/modules/
conv.rs

1#![allow(clippy::single_range_in_vec_init)]
2use super::{ConvOptions, ConvTransposeOptions};
3use crate::{Backend, TensorMetadata, tensor::FloatTensor};
4use burn_std::{MetadataError, Shape, Slice};
5
6use alloc::{vec, vec::Vec};
7#[cfg(not(feature = "std"))]
8#[allow(unused_imports)]
9use num_traits::Float as _;
10
11/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a pooling operation.
12pub fn calculate_pool_output_shape<const N: usize>(
13    in_shape: &Shape,
14    kernel_size: &[usize; N],
15    stride: &[usize; N],
16    padding: &[usize; N],
17    dilation: &[usize; N],
18    ceil_mode: bool,
19) -> Result<Shape, MetadataError> {
20    if in_shape.rank() != N + 2 {
21        return Err(MetadataError::RankMismatch {
22            left: in_shape.rank(),
23            right: N + 2,
24        });
25    }
26
27    let mut out_shape = in_shape.clone();
28    // Spatial dims
29    for (i, size_i) in out_shape[2..].iter_mut().enumerate() {
30        *size_i = calculate_pool_output_size(
31            kernel_size[i],
32            stride[i],
33            padding[i],
34            dilation[i],
35            *size_i,
36            ceil_mode,
37        );
38    }
39
40    Ok(out_shape)
41}
42
43/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a convolution.
44pub fn calculate_conv_output_shape<const N: usize>(
45    in_shape: &Shape,
46    weight_shape: &Shape,
47    stride: &[usize; N],
48    padding: &[usize; N],
49    dilation: &[usize; N],
50) -> Result<Shape, MetadataError> {
51    if weight_shape.rank() != N + 2 {
52        return Err(MetadataError::RankMismatch {
53            left: weight_shape.rank(),
54            right: N + 2,
55        });
56    }
57
58    if in_shape.rank() != N + 2 {
59        return Err(MetadataError::RankMismatch {
60            left: in_shape.rank(),
61            right: N + 2,
62        });
63    }
64
65    let kernel_size = &weight_shape[2..];
66
67    let mut out_shape = in_shape.clone();
68    // Spatial dims
69    for (i, size_i) in out_shape[2..].iter_mut().enumerate() {
70        *size_i =
71            calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_i);
72    }
73    // Output channels
74    out_shape[1] = weight_shape[0];
75
76    Ok(out_shape)
77}
78
79/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a transposed convolution.
80pub fn calculate_conv_transpose_output_shape<const N: usize>(
81    in_shape: &Shape,
82    weight_shape: &Shape,
83    stride: &[usize; N],
84    padding: &[usize; N],
85    padding_out: &[usize; N],
86    dilation: &[usize; N],
87    groups: usize,
88) -> Result<Shape, MetadataError> {
89    if weight_shape.rank() != N + 2 {
90        return Err(MetadataError::RankMismatch {
91            left: weight_shape.rank(),
92            right: N + 2,
93        });
94    }
95
96    if in_shape.rank() != N + 2 {
97        return Err(MetadataError::RankMismatch {
98            left: in_shape.rank(),
99            right: N + 2,
100        });
101    }
102
103    let kernel_size = &weight_shape[2..];
104
105    let mut out_shape = in_shape.clone();
106    // Spatial dims
107    for (i, size_i) in out_shape[2..].iter_mut().enumerate() {
108        *size_i = calculate_conv_transpose_output_size(
109            kernel_size[i],
110            stride[i],
111            padding[i],
112            padding_out[i],
113            dilation[i],
114            *size_i,
115        );
116    }
117    // Output channels
118    out_shape[1] = weight_shape[1] * groups;
119
120    Ok(out_shape)
121}
122
123/// Calculate the expected padding size required when applying a convolution.
124pub fn calculate_conv_padding(
125    kernel_size: usize,
126    stride: usize,
127    size_in: usize,
128    size_out: usize,
129) -> usize {
130    let kernel_size = kernel_size as f32;
131    let stride = stride as f32;
132    let size_in = size_in as f32;
133    let size_out = size_out as f32;
134
135    let padding = stride * (size_out - 1.) - size_in + kernel_size;
136    let padding = (padding / 2.).ceil();
137
138    padding as usize
139}
140
141/// Calculate the expected output size when doing a convolution operation.
142pub fn calculate_conv_output_size(
143    kernel_size: usize,
144    stride: usize,
145    padding: usize,
146    dilation: usize,
147    size_in: usize,
148) -> usize {
149    (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
150}
151
152/// Calculate the expected output sizes when doing a convolution operation.
153pub fn calculate_conv_output_sizes(
154    kernel_size: &[usize],
155    stride: &[usize],
156    padding: &[usize],
157    dilation: &[usize],
158    size_in: &[usize],
159) -> Vec<usize> {
160    size_in
161        .iter()
162        .enumerate()
163        .map(|(i, size_in)| {
164            calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_in)
165        })
166        .collect()
167}
168
169/// Calculate the expected output size when doing a pooling operation.
170///
171/// # Arguments
172///
173/// * `kernel_size` - Size of the pooling kernel
174/// * `stride` - Stride of the pooling operation
175/// * `padding` - Padding applied to input
176/// * `dilation` - Dilation of the pooling kernel
177/// * `size_in` - Input size (height or width)
178/// * `ceil_mode` - If true, use ceiling instead of floor for output size calculation.
179///   This allows the last pooling window to go out-of-bounds if needed.
180pub fn calculate_pool_output_size(
181    kernel_size: usize,
182    stride: usize,
183    padding: usize,
184    dilation: usize,
185    size_in: usize,
186    ceil_mode: bool,
187) -> usize {
188    let numerator = size_in + 2 * padding - dilation * (kernel_size - 1) - 1;
189    if ceil_mode {
190        // Ceiling division: (a + b - 1) / b
191        numerator.div_ceil(stride) + 1
192    } else {
193        // Floor division (default)
194        numerator / stride + 1
195    }
196}
197
198/// Calculate the expected output size when doing a transposed convolution operation.
199pub fn calculate_conv_transpose_output_size(
200    kernel_size: usize,
201    stride: usize,
202    padding: usize,
203    padding_out: usize,
204    dilation: usize,
205    size_in: usize,
206) -> usize {
207    (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding
208}
209
210/// Calculate the original input size that was used for a transposed convolution.
211/// This is used during the backward pass to recover the correct gradient shape.
212fn calculate_conv_transpose_input_size(
213    kernel_size: usize,
214    stride: usize,
215    padding: usize,
216    padding_out: usize,
217    dilation: usize,
218    size_out: usize,
219) -> usize {
220    // We solve the forward formula for size_in:
221    // size_out = (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding
222    (size_out + 2 * padding - dilation * (kernel_size - 1) - padding_out - 1) / stride + 1
223}
224
225/// Calculate the original input sizes that were used for a transposed convolution.
226fn calculate_conv_transpose_input_sizes<const D: usize>(
227    kernel_size: [usize; D],
228    stride: [usize; D],
229    padding: [usize; D],
230    padding_out: [usize; D],
231    dilation: [usize; D],
232    size_out: [usize; D],
233) -> [usize; D] {
234    let mut res = [0; D];
235    for i in 0..D {
236        res[i] = calculate_conv_transpose_input_size(
237            kernel_size[i],
238            stride[i],
239            padding[i],
240            padding_out[i],
241            dilation[i],
242            size_out[i],
243        );
244    }
245    res
246}
247
248/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`.
249pub(crate) fn conv1d_x_backward<B: Backend>(
250    x: FloatTensor<B>,
251    weight: FloatTensor<B>,
252    output_grad: FloatTensor<B>,
253    options: ConvOptions<1>,
254) -> FloatTensor<B> {
255    let weight_shape = weight.shape();
256
257    let [_batch_size, _, length_in] = x.shape().dims();
258    let [_batch_size, _channels_out, length_out] = output_grad.shape().dims();
259    let [_, _, kernel_size] = weight_shape.dims();
260
261    let padding_out = calculate_padding_out(
262        kernel_size,
263        options.stride[0],
264        options.padding[0],
265        options.dilation[0],
266        length_in,
267        length_out,
268    );
269
270    B::conv_transpose1d(
271        output_grad,
272        weight,
273        None,
274        ConvTransposeOptions::new(
275            options.stride,
276            options.padding,
277            [padding_out],
278            options.dilation,
279            options.groups,
280        ),
281    )
282}
283
284/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`.
285pub(crate) fn conv1d_weight_backward<B: Backend>(
286    x: FloatTensor<B>,
287    weight: FloatTensor<B>,
288    output_grad: FloatTensor<B>,
289    options: ConvOptions<1>,
290) -> FloatTensor<B> {
291    let weight_dtype = weight.dtype();
292    let weight_shape = weight.shape();
293    let weight_device = B::float_device(&weight);
294
295    match options.groups == 1 {
296        true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
297        false => conv1d_weight_grad_groups::<B>(
298            x,
299            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
300            output_grad,
301            options,
302        ),
303    }
304}
305
306/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`.
307pub(crate) fn conv1d_bias_backward<B: Backend>(
308    x: FloatTensor<B>,
309    bias: FloatTensor<B>,
310    output_grad: FloatTensor<B>,
311) -> FloatTensor<B> {
312    let [batch_size, _, _length_in] = x.shape().dims();
313    let [_batch_size, channels_out, length_out] = output_grad.shape().dims();
314
315    let grad = B::float_swap_dims(output_grad, 0, 1);
316    let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
317    let grad = B::float_sum_dim(grad, 1);
318
319    B::float_reshape(grad, bias.shape())
320}
321
322/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`.
323pub(crate) fn conv2d_x_backward<B: Backend>(
324    x: FloatTensor<B>,
325    weight: FloatTensor<B>,
326    output_grad: FloatTensor<B>,
327    options: ConvOptions<2>,
328) -> FloatTensor<B> {
329    let weight_shape = weight.shape();
330
331    let [_batch_size, _channels_in, height_in, width_in] = x.shape().dims();
332    let [_, _, height_out, width_out] = output_grad.shape().dims();
333    let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims();
334
335    let padding_1_out = calculate_padding_out(
336        kernel_size_1,
337        options.stride[0],
338        options.padding[0],
339        options.dilation[0],
340        height_in,
341        height_out,
342    );
343    let padding_2_out = calculate_padding_out(
344        kernel_size_2,
345        options.stride[1],
346        options.padding[1],
347        options.dilation[1],
348        width_in,
349        width_out,
350    );
351
352    B::conv_transpose2d(
353        output_grad,
354        weight,
355        None,
356        ConvTransposeOptions::new(
357            options.stride,
358            options.padding,
359            [padding_1_out, padding_2_out],
360            options.dilation,
361            options.groups,
362        ),
363    )
364}
365
366/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`.
367pub(crate) fn conv2d_weight_backward<B: Backend>(
368    x: FloatTensor<B>,
369    weight: FloatTensor<B>,
370    output_grad: FloatTensor<B>,
371    options: ConvOptions<2>,
372) -> FloatTensor<B> {
373    let weight_dtype = weight.dtype();
374    let weight_shape = weight.shape();
375    let weight_device = B::float_device(&weight);
376
377    match options.groups == 1 {
378        true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
379        false => conv2d_weight_grad_groups::<B>(
380            x,
381            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
382            output_grad,
383            options,
384        ),
385    }
386}
387
388/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`.
389pub(crate) fn conv2d_bias_backward<B: Backend>(
390    x: FloatTensor<B>,
391    bias: FloatTensor<B>,
392    output_grad: FloatTensor<B>,
393) -> FloatTensor<B> {
394    let [batch_size, _, _, _] = x.shape().dims();
395    let [_, channels_out, height_out, width_out] = output_grad.shape().dims();
396
397    let grad = B::float_swap_dims(output_grad, 0, 1);
398    let grad = B::float_reshape(
399        grad,
400        Shape::new([channels_out, batch_size * height_out * width_out]),
401    );
402    let grad = B::float_sum_dim(grad, 1);
403
404    B::float_reshape(grad, bias.shape())
405}
406
407/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`.
408pub(crate) fn conv3d_x_backward<B: Backend>(
409    x: FloatTensor<B>,
410    weight: FloatTensor<B>,
411    output_grad: FloatTensor<B>,
412    options: ConvOptions<3>,
413) -> FloatTensor<B> {
414    let weight_shape = weight.shape();
415
416    let [_batch_size, _channels_in, depth_in, height_in, width_in] = x.shape().dims();
417    let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims();
418    let [
419        _channels_out,
420        _,
421        kernel_size_1,
422        kernel_size_2,
423        kernel_size_3,
424    ] = weight_shape.dims();
425
426    let padding_1_out = calculate_padding_out(
427        kernel_size_1,
428        options.stride[0],
429        options.padding[0],
430        options.dilation[0],
431        depth_in,
432        depth_out,
433    );
434    let padding_2_out = calculate_padding_out(
435        kernel_size_2,
436        options.stride[1],
437        options.padding[1],
438        options.dilation[1],
439        height_in,
440        height_out,
441    );
442    let padding_3_out = calculate_padding_out(
443        kernel_size_3,
444        options.stride[2],
445        options.padding[2],
446        options.dilation[2],
447        width_in,
448        width_out,
449    );
450
451    B::conv_transpose3d(
452        output_grad,
453        weight,
454        None,
455        ConvTransposeOptions::new(
456            options.stride,
457            options.padding,
458            [padding_1_out, padding_2_out, padding_3_out],
459            options.dilation,
460            options.groups,
461        ),
462    )
463}
464
465/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`.
466pub(crate) fn conv3d_weight_backward<B: Backend>(
467    x: FloatTensor<B>,
468    weight: FloatTensor<B>,
469    output_grad: FloatTensor<B>,
470    options: ConvOptions<3>,
471) -> FloatTensor<B> {
472    let weight_dtype = weight.dtype();
473    let weight_shape = weight.shape();
474    let weight_device = B::float_device(&weight);
475
476    match options.groups == 1 {
477        true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
478        false => conv3d_weight_grad_groups::<B>(
479            x,
480            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
481            output_grad,
482            options,
483        ),
484    }
485}
486
487/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`.
488pub(crate) fn conv3d_bias_backward<B: Backend>(
489    x: FloatTensor<B>,
490    bias: FloatTensor<B>,
491    output_grad: FloatTensor<B>,
492) -> FloatTensor<B> {
493    let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = x.shape().dims();
494    let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims();
495
496    let grad = B::float_swap_dims(output_grad, 0, 1);
497    let grad = B::float_reshape(
498        grad,
499        Shape::new([
500            channels_out,
501            batch_size * depth_out * height_out * width_out,
502        ]),
503    );
504    let grad = B::float_sum_dim(grad, 1);
505
506    B::float_reshape(grad, bias.shape())
507}
508
509/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`.
510pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
511    weight: FloatTensor<B>,
512    output_grad: FloatTensor<B>,
513    options: ConvTransposeOptions<1>,
514) -> FloatTensor<B> {
515    let [batch_size, _c_out, out_length] = output_grad.shape().dims();
516    let [c_in, _c_out_groups, kernel_size] = weight.shape().dims();
517
518    let grad = B::conv1d(
519        output_grad,
520        weight,
521        None,
522        ConvOptions::new(
523            options.stride,
524            options.padding,
525            options.dilation,
526            options.groups,
527        ),
528    );
529
530    if options.padding_out[0] == 0 {
531        return grad;
532    }
533
534    let exp_length = calculate_conv_transpose_input_size(
535        kernel_size,
536        options.stride[0],
537        options.padding[0],
538        options.padding_out[0],
539        options.dilation[0],
540        out_length,
541    );
542
543    B::float_slice(
544        grad,
545        &[
546            Slice::from(0..batch_size),
547            Slice::from(0..c_in),
548            Slice::from(0..exp_length),
549        ],
550    )
551}
552
553/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.
554pub(crate) fn conv_transpose1d_weight_backward<B: Backend>(
555    x: FloatTensor<B>,
556    weight: FloatTensor<B>,
557    output_grad: FloatTensor<B>,
558    options: ConvTransposeOptions<1>,
559) -> FloatTensor<B> {
560    let weight_dtype = weight.dtype();
561    let weight_shape = weight.shape();
562    let weight_device = B::float_device(&weight);
563
564    match options.groups == 1 {
565        true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
566        false => conv_transpose1d_weight_grad_groups::<B>(
567            x,
568            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
569            output_grad,
570            options,
571        ),
572    }
573}
574
575/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`.
576pub(crate) fn conv_transpose1d_bias_backward<B: Backend>(
577    x: FloatTensor<B>,
578    bias: FloatTensor<B>,
579    output_grad: FloatTensor<B>,
580) -> FloatTensor<B> {
581    let [batch_size, _channels_in, _] = x.shape().dims();
582    let [_, channels_out, length_out] = output_grad.shape().dims();
583
584    let grad = B::float_swap_dims(output_grad, 0, 1);
585    let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));
586    let grad = B::float_sum_dim(grad, 1);
587
588    B::float_reshape(grad, bias.shape())
589}
590
591/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`.
592pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
593    weight: FloatTensor<B>,
594    output_grad: FloatTensor<B>,
595    options: ConvTransposeOptions<2>,
596) -> FloatTensor<B> {
597    let [batch_size, _c_out, out_h, out_w] = output_grad.shape().dims();
598    let [c_in, _c_out_groups, k_h, k_w] = weight.shape().dims();
599
600    let grad = B::conv2d(
601        output_grad,
602        weight,
603        None,
604        ConvOptions::new(
605            options.stride,
606            options.padding,
607            options.dilation,
608            options.groups,
609        ),
610    );
611
612    if options.padding_out[0] == 0 && options.padding_out[1] == 0 {
613        return grad;
614    }
615
616    let [exp_h, exp_w] = calculate_conv_transpose_input_sizes(
617        [k_h, k_w],
618        options.stride,
619        options.padding,
620        options.padding_out,
621        options.dilation,
622        [out_h, out_w],
623    );
624
625    B::float_slice(
626        grad,
627        &[
628            Slice::from(0..batch_size),
629            Slice::from(0..c_in),
630            Slice::from(0..exp_h),
631            Slice::from(0..exp_w),
632        ],
633    )
634}
635
636/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.
637pub(crate) fn conv_transpose2d_weight_backward<B: Backend>(
638    x: FloatTensor<B>,
639    weight: FloatTensor<B>,
640    output_grad: FloatTensor<B>,
641    options: ConvTransposeOptions<2>,
642) -> FloatTensor<B> {
643    let weight_dtype = weight.dtype();
644    let weight_shape = weight.shape();
645    let weight_device = B::float_device(&weight);
646
647    match options.groups == 1 {
648        true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
649        false => conv_transpose2d_weight_grad_groups::<B>(
650            x,
651            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
652            output_grad,
653            options,
654        ),
655    }
656}
657
658/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`.
659pub(crate) fn conv_transpose2d_bias_backward<B: Backend>(
660    x: FloatTensor<B>,
661    bias: FloatTensor<B>,
662    output_grad: FloatTensor<B>,
663) -> FloatTensor<B> {
664    let [batch_size, _channels_in, _, _] = x.shape().dims();
665    let [_, channels_out, height_out, width_out] = output_grad.shape().dims();
666
667    let grad = B::float_swap_dims(output_grad, 0, 1);
668    let grad = B::float_reshape(
669        grad,
670        Shape::new([channels_out, batch_size * height_out * width_out]),
671    );
672    let grad = B::float_sum_dim(grad, 1);
673
674    B::float_reshape(grad, bias.shape())
675}
676
677/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`.
678pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
679    weight: FloatTensor<B>,
680    output_grad: FloatTensor<B>,
681    options: ConvTransposeOptions<3>,
682) -> FloatTensor<B> {
683    let [batch_size, _c_out, out_d, out_h, out_w] = output_grad.shape().dims();
684    let [c_in, _c_out_groups, k_d, k_h, k_w] = weight.shape().dims();
685
686    let grad = B::conv3d(
687        output_grad,
688        weight,
689        None,
690        ConvOptions::new(
691            options.stride,
692            options.padding,
693            options.dilation,
694            options.groups,
695        ),
696    );
697
698    if options.padding_out[0] == 0 && options.padding_out[1] == 0 && options.padding_out[2] == 0 {
699        return grad;
700    }
701
702    let [exp_d, exp_h, exp_w] = calculate_conv_transpose_input_sizes(
703        [k_d, k_h, k_w],
704        options.stride,
705        options.padding,
706        options.padding_out,
707        options.dilation,
708        [out_d, out_h, out_w],
709    );
710
711    B::float_slice(
712        grad,
713        &[
714            Slice::from(0..batch_size),
715            Slice::from(0..c_in),
716            Slice::from(0..exp_d),
717            Slice::from(0..exp_h),
718            Slice::from(0..exp_w),
719        ],
720    )
721}
722
723/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.
724pub(crate) fn conv_transpose3d_weight_backward<B: Backend>(
725    x: FloatTensor<B>,
726    weight: FloatTensor<B>,
727    output_grad: FloatTensor<B>,
728    options: ConvTransposeOptions<3>,
729) -> FloatTensor<B> {
730    let weight_dtype = weight.dtype();
731    let weight_shape = weight.shape();
732    let weight_device = B::float_device(&weight);
733
734    match options.groups == 1 {
735        true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),
736        false => conv_transpose3d_weight_grad_groups::<B>(
737            x,
738            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),
739            output_grad,
740            options,
741        ),
742    }
743}
744
745/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`.
746pub(crate) fn conv_transpose3d_bias_backward<B: Backend>(
747    x: FloatTensor<B>,
748    bias: FloatTensor<B>,
749    output_grad: FloatTensor<B>,
750) -> FloatTensor<B> {
751    let [batch_size, _channels_in, _, _, _] = x.shape().dims();
752    let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims();
753
754    let grad = B::float_swap_dims(output_grad, 0, 1);
755    let grad = B::float_reshape(
756        grad,
757        Shape::new([
758            channels_out,
759            batch_size * depth_out * height_out * width_out,
760        ]),
761    );
762    let grad = B::float_sum_dim(grad, 1);
763
764    B::float_reshape(grad, bias.shape())
765}
766
767/// Execute a 1D convolution using a 2D convolution.
768pub(crate) fn conv1d_from_conv2d<B: Backend>(
769    x: FloatTensor<B>,
770    weight: FloatTensor<B>,
771    bias: Option<FloatTensor<B>>,
772    options: ConvOptions<1>,
773) -> FloatTensor<B> {
774    let [channels_out, _channels_in, kernel_size] = weight.shape().dims();
775    let [batch_size, channels_in, length_in] = x.shape().dims();
776
777    let weight = B::float_reshape(
778        weight,
779        Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]),
780    );
781    let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
782
783    let tensor = B::conv2d(
784        x,
785        weight,
786        bias,
787        ConvOptions::new(
788            [options.stride[0], 1],
789            [options.padding[0], 0],
790            [options.dilation[0], 1],
791            options.groups,
792        ),
793    );
794    let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();
795    B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
796}
797
798/// Execute a 1D transposed convolution using a 2D transposed convolution.
799pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
800    x: FloatTensor<B>,
801    weight: FloatTensor<B>,
802    bias: Option<FloatTensor<B>>,
803    options: ConvTransposeOptions<1>,
804) -> FloatTensor<B> {
805    let [channels_in, channels_out, kernel_size] = weight.shape().dims();
806    let [batch_size, _channels_in, length_in] = x.shape().dims();
807
808    let weight = B::float_reshape(
809        weight,
810        Shape::new([channels_in, channels_out, kernel_size, 1]),
811    );
812    let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
813
814    let tensor = B::conv_transpose2d(
815        x,
816        weight,
817        bias,
818        ConvTransposeOptions::new(
819            [options.stride[0], 1],
820            [options.padding[0], 0],
821            [options.padding_out[0], 0],
822            [options.dilation[0], 1],
823            options.groups,
824        ),
825    );
826    let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();
827    B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
828}
829
830fn conv1d_weight_grad_no_groups<B: Backend>(
831    x: FloatTensor<B>,
832    output_grad: FloatTensor<B>,
833    weight_shape: Shape,
834    options: ConvOptions<1>,
835) -> FloatTensor<B> {
836    let x_swapped = B::float_swap_dims(x, 0, 1);
837    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
838    let weight_grad_swapped = B::conv1d(
839        x_swapped,
840        output_grad_swapped,
841        None,
842        ConvOptions::new(options.dilation, options.padding, options.stride, 1),
843    );
844    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
845
846    if weight_grad.shape() != weight_shape {
847        let slices = vec![
848            Slice::from(0..weight_shape[0]),
849            Slice::from(0..weight_shape[1]),
850            Slice::from(0..weight_shape[2]),
851        ];
852        weight_grad = B::float_slice(weight_grad, &slices);
853    }
854    weight_grad
855}
856
857fn conv2d_weight_grad_no_groups<B: Backend>(
858    x: FloatTensor<B>,
859    output_grad: FloatTensor<B>,
860    weight_shape: Shape,
861    options: ConvOptions<2>,
862) -> FloatTensor<B> {
863    let x_swapped = B::float_swap_dims(x, 0, 1);
864    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
865    let weight_grad_swapped = B::conv2d(
866        x_swapped,
867        output_grad_swapped,
868        None,
869        ConvOptions::new(options.dilation, options.padding, options.stride, 1),
870    );
871    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
872
873    if weight_grad.shape() != weight_shape {
874        let slices = vec![
875            Slice::from(0..weight_shape[0]),
876            Slice::from(0..weight_shape[1]),
877            Slice::from(0..weight_shape[2]),
878            Slice::from(0..weight_shape[3]),
879        ];
880        weight_grad = B::float_slice(weight_grad, &slices);
881    }
882    weight_grad
883}
884
885fn conv3d_weight_grad_no_groups<B: Backend>(
886    x: FloatTensor<B>,
887    output_grad: FloatTensor<B>,
888    weight_shape: Shape,
889    options: ConvOptions<3>,
890) -> FloatTensor<B> {
891    let x_swapped = B::float_swap_dims(x, 0, 1);
892    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
893    let weight_grad_swapped = B::conv3d(
894        x_swapped,
895        output_grad_swapped,
896        None,
897        ConvOptions::new(options.dilation, options.padding, options.stride, 1),
898    );
899    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
900
901    if weight_grad.shape() != weight_shape {
902        let slices = vec![
903            Slice::from(0..weight_shape[0]),
904            Slice::from(0..weight_shape[1]),
905            Slice::from(0..weight_shape[2]),
906            Slice::from(0..weight_shape[3]),
907            Slice::from(0..weight_shape[4]),
908        ];
909        weight_grad = B::float_slice(weight_grad, &slices);
910    }
911    weight_grad
912}
913
914fn conv1d_weight_grad_groups<B: Backend>(
915    x: FloatTensor<B>,
916    mut weight_grad: FloatTensor<B>,
917    output_grad: FloatTensor<B>,
918    options: ConvOptions<1>,
919) -> FloatTensor<B> {
920    let [channels_out, increment_ci, kernel_size] = weight_grad.shape().dims();
921    let increment_co = channels_out / options.groups;
922
923    let x_swapped = B::float_swap_dims(x, 0, 1);
924    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
925
926    for g in 0..options.groups {
927        let start_idx_ci = g * increment_ci;
928        let end_idx_ci = (g + 1) * increment_ci;
929        let start_idx_co = g * increment_co;
930        let end_idx_co = (g + 1) * increment_co;
931
932        let x_slice = vec![Slice::new(
933            start_idx_ci as isize,
934            Some(end_idx_ci as isize),
935            1,
936        )];
937        let x = B::float_slice(x_swapped.clone(), &x_slice);
938        let grad_slice = vec![Slice::new(
939            start_idx_co as isize,
940            Some(end_idx_co as isize),
941            1,
942        )];
943        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
944        let mut weight_grad_tmp = B::conv1d(
945            x,
946            grad,
947            None,
948            ConvOptions::new(options.dilation, options.padding, options.stride, 1),
949        );
950        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
951        weight_grad = B::float_slice_assign(
952            weight_grad,
953            &[
954                Slice::from(start_idx_co..end_idx_co),
955                Slice::from(0..increment_ci),
956                Slice::from(0..kernel_size),
957            ],
958            weight_grad_tmp,
959        );
960    }
961
962    weight_grad
963}
964
965fn conv2d_weight_grad_groups<B: Backend>(
966    x: FloatTensor<B>,
967    mut weight_grad: FloatTensor<B>,
968    output_grad: FloatTensor<B>,
969    options: ConvOptions<2>,
970) -> FloatTensor<B> {
971    let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = weight_grad.shape().dims();
972    let increment_co = channels_out / options.groups;
973
974    let x_swapped = B::float_swap_dims(x, 0, 1);
975    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
976
977    for g in 0..options.groups {
978        let start_idx_ci = g * increment_ci;
979        let end_idx_ci = (g + 1) * increment_ci;
980        let start_idx_co = g * increment_co;
981        let end_idx_co = (g + 1) * increment_co;
982
983        let x_slice = vec![Slice::new(
984            start_idx_ci as isize,
985            Some(end_idx_ci as isize),
986            1,
987        )];
988        let x = B::float_slice(x_swapped.clone(), &x_slice);
989        let grad_slice = vec![Slice::new(
990            start_idx_co as isize,
991            Some(end_idx_co as isize),
992            1,
993        )];
994        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
995        let mut weight_grad_tmp = B::conv2d(
996            x,
997            grad,
998            None,
999            ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1000        );
1001        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1002        let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims();
1003
1004        if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {
1005            let slices = vec![
1006                Slice::from(0..increment_co),
1007                Slice::from(0..increment_ci),
1008                Slice::from(0..kernel_size_1),
1009                Slice::from(0..kernel_size_2),
1010            ];
1011            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1012        }
1013
1014        weight_grad = B::float_slice_assign(
1015            weight_grad,
1016            &[
1017                Slice::from(start_idx_co..end_idx_co),
1018                Slice::from(0..increment_ci),
1019                Slice::from(0..kernel_size_1),
1020                Slice::from(0..kernel_size_2),
1021            ],
1022            weight_grad_tmp,
1023        );
1024    }
1025
1026    weight_grad
1027}
1028
1029fn conv3d_weight_grad_groups<B: Backend>(
1030    x: FloatTensor<B>,
1031    mut weight_grad: FloatTensor<B>,
1032    output_grad: FloatTensor<B>,
1033    options: ConvOptions<3>,
1034) -> FloatTensor<B> {
1035    let [
1036        channels_out,
1037        increment_ci,
1038        kernel_size_1,
1039        kernel_size_2,
1040        kernel_size_3,
1041    ] = weight_grad.shape().dims();
1042    let increment_co = channels_out / options.groups;
1043
1044    let x_swapped = B::float_swap_dims(x, 0, 1);
1045    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1046
1047    for g in 0..options.groups {
1048        let start_idx_ci = g * increment_ci;
1049        let end_idx_ci = (g + 1) * increment_ci;
1050        let start_idx_co = g * increment_co;
1051        let end_idx_co = (g + 1) * increment_co;
1052
1053        let x_slice = vec![Slice::new(
1054            start_idx_ci as isize,
1055            Some(end_idx_ci as isize),
1056            1,
1057        )];
1058        let x = B::float_slice(x_swapped.clone(), &x_slice);
1059        let grad_slice = vec![Slice::new(
1060            start_idx_co as isize,
1061            Some(end_idx_co as isize),
1062            1,
1063        )];
1064        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1065        let mut weight_grad_tmp = B::conv3d(
1066            x,
1067            grad,
1068            None,
1069            ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1070        );
1071        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1072        let [
1073            _,
1074            _,
1075            kernel_size_1_tmp,
1076            kernel_size_2_tmp,
1077            kernel_size_3_tmp,
1078        ] = weight_grad_tmp.shape().dims();
1079
1080        if kernel_size_1_tmp != kernel_size_1
1081            || kernel_size_2_tmp != kernel_size_2
1082            || kernel_size_3_tmp != kernel_size_3
1083        {
1084            let slices = vec![
1085                Slice::from(0..increment_co),
1086                Slice::from(0..increment_ci),
1087                Slice::from(0..kernel_size_1),
1088                Slice::from(0..kernel_size_2),
1089                Slice::from(0..kernel_size_3),
1090            ];
1091            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1092        }
1093
1094        weight_grad = B::float_slice_assign(
1095            weight_grad,
1096            &[
1097                Slice::from(start_idx_co..end_idx_co),
1098                Slice::from(0..increment_ci),
1099                Slice::from(0..kernel_size_1),
1100                Slice::from(0..kernel_size_2),
1101                Slice::from(0..kernel_size_3),
1102            ],
1103            weight_grad_tmp,
1104        );
1105    }
1106
1107    weight_grad
1108}
1109
1110fn conv_transpose1d_weight_grad_no_groups<B: Backend>(
1111    x: FloatTensor<B>,
1112    output_grad: FloatTensor<B>,
1113    weight_shape: Shape,
1114    options: ConvTransposeOptions<1>,
1115) -> FloatTensor<B> {
1116    let x_swapped = B::float_swap_dims(x, 0, 1);
1117    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1118    let weight_grad_swapped = B::conv1d(
1119        output_grad_swapped,
1120        x_swapped,
1121        None,
1122        ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1123    );
1124    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
1125
1126    let grad_shape = weight_grad.shape();
1127    if grad_shape != weight_shape {
1128        let slices = vec![
1129            Slice::from(0..weight_shape[0]),
1130            Slice::from(0..weight_shape[1]),
1131            Slice::from(0..weight_shape[2]),
1132        ];
1133        weight_grad = B::float_slice(weight_grad, &slices);
1134    }
1135    weight_grad
1136}
1137
1138fn conv_transpose2d_weight_grad_no_groups<B: Backend>(
1139    x: FloatTensor<B>,
1140    output_grad: FloatTensor<B>,
1141    weight_shape: Shape,
1142    options: ConvTransposeOptions<2>,
1143) -> FloatTensor<B> {
1144    let x_swapped = B::float_swap_dims(x, 0, 1);
1145    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1146    let weight_grad_swapped = B::conv2d(
1147        output_grad_swapped,
1148        x_swapped,
1149        None,
1150        ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1151    );
1152    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
1153
1154    let grad_shape = weight_grad.shape();
1155    if grad_shape != weight_shape {
1156        let slices = vec![
1157            Slice::from(0..weight_shape[0]),
1158            Slice::from(0..weight_shape[1]),
1159            Slice::from(0..weight_shape[2]),
1160            Slice::from(0..weight_shape[3]),
1161        ];
1162        weight_grad = B::float_slice(weight_grad, &slices);
1163    }
1164    weight_grad
1165}
1166
1167fn conv_transpose3d_weight_grad_no_groups<B: Backend>(
1168    x: FloatTensor<B>,
1169    output_grad: FloatTensor<B>,
1170    weight_shape: Shape,
1171    options: ConvTransposeOptions<3>,
1172) -> FloatTensor<B> {
1173    let x_swapped = B::float_swap_dims(x, 0, 1);
1174    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1175    let weight_grad_swapped = B::conv3d(
1176        output_grad_swapped,
1177        x_swapped,
1178        None,
1179        ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1180    );
1181    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);
1182
1183    let grad_shape = weight_grad.shape();
1184    if grad_shape != weight_shape {
1185        let slices = vec![
1186            Slice::from(0..weight_shape[0]),
1187            Slice::from(0..weight_shape[1]),
1188            Slice::from(0..weight_shape[2]),
1189            Slice::from(0..weight_shape[3]),
1190            Slice::from(0..weight_shape[4]),
1191        ];
1192        weight_grad = B::float_slice(weight_grad, &slices);
1193    }
1194    weight_grad
1195}
1196
1197fn conv_transpose1d_weight_grad_groups<B: Backend>(
1198    x: FloatTensor<B>,
1199    mut weight_grad: FloatTensor<B>,
1200    output_grad: FloatTensor<B>,
1201    options: ConvTransposeOptions<1>,
1202) -> FloatTensor<B> {
1203    let [channels_in, increment_co, kernel_size] = weight_grad.shape().dims();
1204    let increment_ci = channels_in / options.groups;
1205
1206    let x_swapped = B::float_swap_dims(x, 0, 1);
1207    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1208
1209    for g in 0..options.groups {
1210        let start_idx_ci = g * increment_ci;
1211        let end_idx_ci = (g + 1) * increment_ci;
1212        let start_idx_co = g * increment_co;
1213        let end_idx_co = (g + 1) * increment_co;
1214
1215        let x_slice = vec![Slice::new(
1216            start_idx_ci as isize,
1217            Some(end_idx_ci as isize),
1218            1,
1219        )];
1220        let x = B::float_slice(x_swapped.clone(), &x_slice);
1221        let grad_slice = vec![Slice::new(
1222            start_idx_co as isize,
1223            Some(end_idx_co as isize),
1224            1,
1225        )];
1226        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1227        let mut weight_grad_tmp = B::conv1d(
1228            grad,
1229            x,
1230            None,
1231            ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1232        );
1233        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1234        let [_, _, kernel_size_tmp] = weight_grad_tmp.shape().dims();
1235
1236        if kernel_size_tmp != kernel_size {
1237            let slices = vec![
1238                Slice::from(0..increment_ci),
1239                Slice::from(0..increment_co),
1240                Slice::from(0..kernel_size),
1241            ];
1242            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1243        }
1244
1245        weight_grad = B::float_slice_assign(
1246            weight_grad,
1247            &[
1248                Slice::from(start_idx_ci..end_idx_ci),
1249                Slice::from(0..increment_co),
1250                Slice::from(0..kernel_size),
1251            ],
1252            weight_grad_tmp,
1253        );
1254    }
1255
1256    weight_grad
1257}
1258
1259fn conv_transpose2d_weight_grad_groups<B: Backend>(
1260    x: FloatTensor<B>,
1261    mut weight_grad: FloatTensor<B>,
1262    output_grad: FloatTensor<B>,
1263    options: ConvTransposeOptions<2>,
1264) -> FloatTensor<B> {
1265    let [channels_in, increment_co, kernel_size_1, kernel_size_2] = weight_grad.shape().dims();
1266    let increment_ci = channels_in / options.groups;
1267
1268    let x_swapped = B::float_swap_dims(x, 0, 1);
1269    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1270
1271    for g in 0..options.groups {
1272        let start_idx_ci = g * increment_ci;
1273        let end_idx_ci = (g + 1) * increment_ci;
1274        let start_idx_co = g * increment_co;
1275        let end_idx_co = (g + 1) * increment_co;
1276
1277        let x_slice = vec![Slice::new(
1278            start_idx_ci as isize,
1279            Some(end_idx_ci as isize),
1280            1,
1281        )];
1282        let x = B::float_slice(x_swapped.clone(), &x_slice);
1283        let grad_slice = vec![Slice::new(
1284            start_idx_co as isize,
1285            Some(end_idx_co as isize),
1286            1,
1287        )];
1288        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1289        let mut weight_grad_tmp = B::conv2d(
1290            grad,
1291            x,
1292            None,
1293            ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1294        );
1295        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1296        let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims();
1297
1298        if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {
1299            let slices = vec![
1300                Slice::from(0..increment_ci),
1301                Slice::from(0..increment_co),
1302                Slice::from(0..kernel_size_1),
1303                Slice::from(0..kernel_size_2),
1304            ];
1305            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1306        }
1307
1308        weight_grad = B::float_slice_assign(
1309            weight_grad,
1310            &[
1311                Slice::from(start_idx_ci..end_idx_ci),
1312                Slice::from(0..increment_co),
1313                Slice::from(0..kernel_size_1),
1314                Slice::from(0..kernel_size_2),
1315            ],
1316            weight_grad_tmp,
1317        );
1318    }
1319
1320    weight_grad
1321}
1322
1323fn conv_transpose3d_weight_grad_groups<B: Backend>(
1324    x: FloatTensor<B>,
1325    mut weight_grad: FloatTensor<B>,
1326    output_grad: FloatTensor<B>,
1327    options: ConvTransposeOptions<3>,
1328) -> FloatTensor<B> {
1329    let [
1330        channels_in,
1331        increment_co,
1332        kernel_size_1,
1333        kernel_size_2,
1334        kernel_size_3,
1335    ] = weight_grad.shape().dims();
1336    let increment_ci = channels_in / options.groups;
1337
1338    let x_swapped = B::float_swap_dims(x, 0, 1);
1339    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);
1340
1341    for g in 0..options.groups {
1342        let start_idx_ci = g * increment_ci;
1343        let end_idx_ci = (g + 1) * increment_ci;
1344        let start_idx_co = g * increment_co;
1345        let end_idx_co = (g + 1) * increment_co;
1346
1347        let x_slice = vec![Slice::new(
1348            start_idx_ci as isize,
1349            Some(end_idx_ci as isize),
1350            1,
1351        )];
1352        let x = B::float_slice(x_swapped.clone(), &x_slice);
1353        let grad_slice = vec![Slice::new(
1354            start_idx_co as isize,
1355            Some(end_idx_co as isize),
1356            1,
1357        )];
1358        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);
1359        let mut weight_grad_tmp = B::conv3d(
1360            grad,
1361            x,
1362            None,
1363            ConvOptions::new(options.dilation, options.padding, options.stride, 1),
1364        );
1365        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);
1366        let [
1367            _,
1368            _,
1369            kernel_size_1_tmp,
1370            kernel_size_2_tmp,
1371            kernel_size_3_tmp,
1372        ] = weight_grad_tmp.shape().dims();
1373
1374        if kernel_size_1_tmp != kernel_size_1
1375            || kernel_size_2_tmp != kernel_size_2
1376            || kernel_size_3_tmp != kernel_size_3
1377        {
1378            let slices = vec![
1379                Slice::from(0..increment_ci),
1380                Slice::from(0..increment_co),
1381                Slice::from(0..kernel_size_1),
1382                Slice::from(0..kernel_size_2),
1383                Slice::from(0..kernel_size_3),
1384            ];
1385            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);
1386        }
1387        weight_grad = B::float_slice_assign(
1388            weight_grad,
1389            &[
1390                Slice::from(start_idx_ci..end_idx_ci),
1391                Slice::from(0..increment_co),
1392                Slice::from(0..kernel_size_1),
1393                Slice::from(0..kernel_size_2),
1394                Slice::from(0..kernel_size_3),
1395            ],
1396            weight_grad_tmp,
1397        );
1398    }
1399
1400    weight_grad
1401}
1402
1403/// Compute the `padding_out` for a transpose conv that exactly recovers the
1404/// original `size_in` from `size_out`, accounting for any input elements the
1405/// forward conv dropped. Shared by `conv{1,2,3}d_x_backward` and the CubeCL
1406/// dgrad fallback so the two paths can't drift.
1407pub fn calculate_padding_out(
1408    kernel_size: usize,
1409    stride: usize,
1410    padding: usize,
1411    dilation: usize,
1412    size_in: usize,
1413    size_out: usize,
1414) -> usize {
1415    if stride <= 1 {
1416        return 0;
1417    }
1418
1419    // Invert the transpose conv output formula to recover the exact number of
1420    // input elements that a forward conv would drop for this (size_in, size_out).
1421    //
1422    // Forward: size_out = floor((size_in + 2*padding - dilated_kernel) / stride) + 1
1423    // Transpose: trans_out = (size_out - 1)*stride + dilated_kernel + padding_out - 2*padding
1424    // Setting trans_out == size_in and solving for padding_out:
1425    let dilated_kernel = dilation * (kernel_size - 1) + 1;
1426    let base = (size_out as i64 - 1) * stride as i64 + dilated_kernel as i64 - 2 * padding as i64;
1427    i64::max(0, size_in as i64 - base) as usize
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432    use super::*;
1433
1434    #[test]
1435    fn test_calculate_output_size_1() {
1436        let kernel_size = 3;
1437        let stride = 1;
1438        let padding = 1;
1439        let size_in = 3;
1440        let dilation = 1;
1441
1442        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1443
1444        assert_eq!(size_out, 3);
1445    }
1446
1447    #[test]
1448    fn test_calculate_output_size_2() {
1449        let kernel_size = 5;
1450        let stride = 2;
1451        let padding = 3;
1452        let size_in = 27;
1453        let dilation = 1;
1454
1455        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1456
1457        assert_eq!(size_out, 15);
1458    }
1459
1460    #[test]
1461    fn test_calculate_output_size_3() {
1462        let kernel_size = 5;
1463        let stride = 2;
1464        let padding = 3;
1465        let size_in = 27;
1466        let dilation = 2;
1467
1468        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1469
1470        assert_eq!(size_out, 13);
1471    }
1472
1473    #[test]
1474    fn test_calculate_same_padding_1() {
1475        let kernel_size = 3;
1476        let stride = 1;
1477        let size_in = 3;
1478        let dilation = 1;
1479
1480        let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
1481        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1482
1483        assert_eq!(size_in, size_out, "Expected size");
1484    }
1485
1486    #[test]
1487    fn test_calculate_same_padding_2() {
1488        let kernel_size = 3;
1489        let stride = 2;
1490        let size_in = 7;
1491        let dilation = 1;
1492
1493        let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);
1494        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1495
1496        assert_eq!(size_in, size_out, "Expected size");
1497    }
1498
1499    #[test]
1500    fn test_calculate_output_padding_1() {
1501        let kernel_size = 3;
1502        let stride = 2;
1503        let size_in = 7;
1504        let size_out = 10;
1505        let dilation = 1;
1506
1507        let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out);
1508        let size_out_expected =
1509            calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);
1510
1511        assert_eq!(size_out, size_out_expected, "Expected size");
1512    }
1513
1514    #[test]
1515    fn test_expect_conv2d_output_shape() {
1516        // in channels: 3
1517        // out channels: 8
1518        // size in: [27, 3]
1519        // kernel size: [5, 3]
1520        let stride = [2, 1];
1521        let padding = [3, 1];
1522        let dilation = [2, 1];
1523        let shape = calculate_conv_output_shape(
1524            &Shape::new([12, 3, 27, 3]),
1525            &Shape::new([8, 3, 5, 3]),
1526            &stride,
1527            &padding,
1528            &dilation,
1529        )
1530        .unwrap();
1531        assert_eq!(shape, Shape::new([12, 8, 13, 3]))
1532    }
1533}