1use burn_backend::ops::ConvOptions;
2use burn_std::Shape;
3use cubek::convolution::{AcceleratedTileKind, components::ConvSetupError};
4
5#[cfg(feature = "autotune")]
6use crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune};
7use crate::{
8 CubeRuntime,
9 kernel::conv::{
10 backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
11 backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
12 forward::implicit_gemm::conv_gemm_simple_sync,
13 },
14 ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw},
15 tensor::CubeTensor,
16};
17
18use super::conv_direct;
19#[cfg(feature = "autotune")]
20use super::forward::conv_autotune;
21
22pub enum ConvStrategy {
24 Direct,
26 #[cfg(feature = "autotune")]
27 Autotune,
29 ImplicitGemm,
32}
33
34impl Default for ConvStrategy {
35 fn default() -> Self {
36 #[cfg(feature = "autotune")]
38 return ConvStrategy::Autotune;
39
40 #[cfg(not(feature = "autotune"))]
42 ConvStrategy::Direct
43 }
44}
45
46pub fn conv_forward<R: CubeRuntime, const N: usize>(
54 input: CubeTensor<R>,
55 weight: CubeTensor<R>,
56 bias: Option<CubeTensor<R>>,
57 options: ConvOptions<N>,
58 strategy: ConvStrategy,
59) -> Result<CubeTensor<R>, ConvSetupError> {
60 let input = permute_nchw_to_nhwc(input);
61 let weight = permute_nchw_to_nhwc(weight);
62
63 let out = conv_forward_nhwc(input, weight, bias, options, strategy)?;
64
65 Ok(permute_nhwc_to_nchw(out))
66}
67
68pub fn conv_forward_nhwc<R: CubeRuntime, const N: usize>(
76 input: CubeTensor<R>,
77 weight: CubeTensor<R>,
78 bias: Option<CubeTensor<R>>,
79 options: ConvOptions<N>,
80 strategy: ConvStrategy,
81) -> Result<CubeTensor<R>, ConvSetupError> {
82 match strategy {
83 ConvStrategy::Direct => conv_direct::<R, N>(input, weight, bias, options),
84 #[cfg(feature = "autotune")]
85 ConvStrategy::Autotune => Ok(conv_autotune::<R, N>(input, weight, bias, options)),
86 ConvStrategy::ImplicitGemm => {
87 if options.groups != 1 {
88 conv_direct::<R, N>(input, weight, bias, options)
89 } else {
90 conv_gemm_simple_sync::<R, N>(
91 input,
92 weight,
93 bias,
94 options,
95 AcceleratedTileKind::Cmma,
96 )
97 }
98 }
99 }
100}
101
102pub fn conv_weight_backward<R: CubeRuntime, const N: usize>(
110 input: CubeTensor<R>,
111 out_grad: CubeTensor<R>,
112 weight_shape: Shape,
113 options: ConvOptions<N>,
114 strategy: ConvStrategy,
115) -> Result<CubeTensor<R>, ConvSetupError> {
116 let input = permute_nchw_to_nhwc(input);
117 let out_grad = permute_nchw_to_nhwc(out_grad);
118 let weight_shape = permute_nchw_to_nhwc_shape(weight_shape);
119
120 let weight_grad = match strategy {
121 ConvStrategy::Direct => {
122 conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
123 }
124 #[cfg(feature = "autotune")]
125 ConvStrategy::Autotune => Ok(wgrad_autotune::<R, N>(
126 input,
127 out_grad,
128 weight_shape,
129 options,
130 )),
131 ConvStrategy::ImplicitGemm => {
132 if options.groups != 1 {
133 conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
134 } else {
135 wgrad_gemm_simple_sync::<R, N>(
136 input,
137 out_grad,
138 weight_shape,
139 options,
140 AcceleratedTileKind::Cmma,
141 )
142 }
143 }
144 }?;
145
146 Ok(permute_nhwc_to_nchw(weight_grad))
147}
148
149pub fn conv_data_backward<R: CubeRuntime, const N: usize>(
157 out_grad: CubeTensor<R>,
158 weights: CubeTensor<R>,
159 in_shape: Shape,
160 options: ConvOptions<N>,
161 strategy: ConvStrategy,
162) -> Result<CubeTensor<R>, ConvSetupError> {
163 let out_grad = permute_nchw_to_nhwc(out_grad);
164 let weights = permute_nchw_to_nhwc(weights);
165 let in_shape = permute_nchw_to_nhwc_shape(in_shape);
166
167 let weight_grad = match strategy {
168 ConvStrategy::Direct => {
169 conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
170 }
171 #[cfg(feature = "autotune")]
172 ConvStrategy::Autotune => dgrad_autotune::<R, N>(out_grad, weights, in_shape, options),
173 ConvStrategy::ImplicitGemm => {
174 if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
175 conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
176 } else {
177 dgrad_gemm_simple_sync::<R, N>(
178 out_grad,
179 weights,
180 in_shape,
181 options,
182 AcceleratedTileKind::Cmma,
183 )?
184 }
185 }
186 };
187
188 Ok(permute_nhwc_to_nchw(weight_grad))
189}