burn_cubecl/kernel/conv/
base.rs

1use burn_backend::ops::ConvOptions;
2use burn_std::Shape;
3use cubek::convolution::{AcceleratedTileKind, components::ConvSetupError};
4
5#[cfg(feature = "autotune")]
6use crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune};
7use crate::{
8    CubeRuntime,
9    kernel::conv::{
10        backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
11        backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
12        forward::implicit_gemm::conv_gemm_simple_sync,
13    },
14    ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw},
15    tensor::CubeTensor,
16};
17
18use super::conv_direct;
19#[cfg(feature = "autotune")]
20use super::forward::conv_autotune;
21
22/// The strategy to be used when launching a convolution kernel.
23pub enum ConvStrategy {
24    /// A simple direct convolution.
25    Direct,
26    #[cfg(feature = "autotune")]
27    /// Using autotune to choose the best kernel based on runtime information.
28    Autotune,
29    /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
30    /// has constraints on tensor shape.
31    ImplicitGemm,
32}
33
34impl Default for ConvStrategy {
35    fn default() -> Self {
36        // if autotune is enabled, default to autotune
37        #[cfg(feature = "autotune")]
38        return ConvStrategy::Autotune;
39
40        // if autotune is disabled, default to the more memory-conservative algorithm
41        #[cfg(not(feature = "autotune"))]
42        ConvStrategy::Direct
43    }
44}
45
46/// Performs an N-dimensional convolution with the given strategy
47///
48/// * `input` - The input feature map
49/// * `weight` - The weights (filter) applied to each kernel
50/// * `bias` - The bias added to each channel
51/// * `options` - The options to use for the convolution
52/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
53pub fn conv_forward<R: CubeRuntime, const N: usize>(
54    input: CubeTensor<R>,
55    weight: CubeTensor<R>,
56    bias: Option<CubeTensor<R>>,
57    options: ConvOptions<N>,
58    strategy: ConvStrategy,
59) -> Result<CubeTensor<R>, ConvSetupError> {
60    let input = permute_nchw_to_nhwc(input);
61    let weight = permute_nchw_to_nhwc(weight);
62
63    let out = conv_forward_nhwc(input, weight, bias, options, strategy)?;
64
65    Ok(permute_nhwc_to_nchw(out))
66}
67
68/// Performs an N-dimensional convolution with the given strategy on NHWC inputs/outputs
69///
70/// * `input` - The input feature map
71/// * `weight` - The weights (filter) applied to each kernel
72/// * `bias` - The bias added to each channel
73/// * `options` - The options to use for the convolution
74/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
75pub fn conv_forward_nhwc<R: CubeRuntime, const N: usize>(
76    input: CubeTensor<R>,
77    weight: CubeTensor<R>,
78    bias: Option<CubeTensor<R>>,
79    options: ConvOptions<N>,
80    strategy: ConvStrategy,
81) -> Result<CubeTensor<R>, ConvSetupError> {
82    match strategy {
83        ConvStrategy::Direct => conv_direct::<R, N>(input, weight, bias, options),
84        #[cfg(feature = "autotune")]
85        ConvStrategy::Autotune => Ok(conv_autotune::<R, N>(input, weight, bias, options)),
86        ConvStrategy::ImplicitGemm => {
87            if options.groups != 1 {
88                conv_direct::<R, N>(input, weight, bias, options)
89            } else {
90                conv_gemm_simple_sync::<R, N>(
91                    input,
92                    weight,
93                    bias,
94                    options,
95                    AcceleratedTileKind::Cmma,
96                )
97            }
98        }
99    }
100}
101
102/// Performs an N-dimensional convolution backwards pass with regard to weight, with the given strategy
103///
104/// * `input` - The input feature map
105/// * `out_grad` - The output gradients
106/// * `weight_shape` - The shape of the weights/weight gradients
107/// * `options` - The options used for the convolution
108/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
109pub fn conv_weight_backward<R: CubeRuntime, const N: usize>(
110    input: CubeTensor<R>,
111    out_grad: CubeTensor<R>,
112    weight_shape: Shape,
113    options: ConvOptions<N>,
114    strategy: ConvStrategy,
115) -> Result<CubeTensor<R>, ConvSetupError> {
116    let input = permute_nchw_to_nhwc(input);
117    let out_grad = permute_nchw_to_nhwc(out_grad);
118    let weight_shape = permute_nchw_to_nhwc_shape(weight_shape);
119
120    let weight_grad = match strategy {
121        ConvStrategy::Direct => {
122            conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
123        }
124        #[cfg(feature = "autotune")]
125        ConvStrategy::Autotune => Ok(wgrad_autotune::<R, N>(
126            input,
127            out_grad,
128            weight_shape,
129            options,
130        )),
131        ConvStrategy::ImplicitGemm => {
132            if options.groups != 1 {
133                conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
134            } else {
135                wgrad_gemm_simple_sync::<R, N>(
136                    input,
137                    out_grad,
138                    weight_shape,
139                    options,
140                    AcceleratedTileKind::Cmma,
141                )
142            }
143        }
144    }?;
145
146    Ok(permute_nhwc_to_nchw(weight_grad))
147}
148
149/// Performs an N-dimensional convolution backwards data pass with the given strategy
150///
151/// * `input` - The input feature map
152/// * `weight` - The weights (filter) applied to each kernel
153/// * `in_shape` - The shape of the input to the layer
154/// * `options` - The options to use for the convolution
155/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
156pub fn conv_data_backward<R: CubeRuntime, const N: usize>(
157    out_grad: CubeTensor<R>,
158    weights: CubeTensor<R>,
159    in_shape: Shape,
160    options: ConvOptions<N>,
161    strategy: ConvStrategy,
162) -> Result<CubeTensor<R>, ConvSetupError> {
163    let out_grad = permute_nchw_to_nhwc(out_grad);
164    let weights = permute_nchw_to_nhwc(weights);
165    let in_shape = permute_nchw_to_nhwc_shape(in_shape);
166
167    let weight_grad = match strategy {
168        ConvStrategy::Direct => {
169            conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
170        }
171        #[cfg(feature = "autotune")]
172        ConvStrategy::Autotune => dgrad_autotune::<R, N>(out_grad, weights, in_shape, options),
173        ConvStrategy::ImplicitGemm => {
174            if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
175                conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
176            } else {
177                dgrad_gemm_simple_sync::<R, N>(
178                    out_grad,
179                    weights,
180                    in_shape,
181                    options,
182                    AcceleratedTileKind::Cmma,
183                )?
184            }
185        }
186    };
187
188    Ok(permute_nhwc_to_nchw(weight_grad))
189}