1use crate::{
2 AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3 backward_data::args::ConcreteArgs,
4 components::{ConvGemmConfig as _, ConvolutionOperation},
5 kernels::forward::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9 components::{ConvolutionProblem, Dimensionality},
10 kernels::forward::algorithm::Algorithm,
11};
12use cubecl::{
13 Runtime,
14 client::ComputeClient,
15 prelude::*,
16 std::{CubeOption, tensor::TensorHandle},
17};
18use cubek_matmul::{
19 components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
20 definition::{AvailableLineSizes, MatmulElems, MatmulSetupError, MatrixLayout},
21 launch::{MatmulInputHandle, MatmulInputHandleRef},
22};
23use derive_new::new;
24
25macro_rules! with_tile_kind {
26 ($kind: expr, $T: ident, $launch: expr) => {
27 match $kind {
28 AcceleratedTileKind::Cmma => {
29 type $T = CmmaMatmul<CubeOption<Strided>>;
30 ($launch)()
31 }
32 AcceleratedTileKind::Mma => {
33 type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
34 ($launch)()
35 }
36 }
37 };
38}
39
40#[allow(clippy::result_large_err, clippy::too_many_arguments)]
41pub fn launch<R: Runtime, const N_SPATIAL: usize>(
42 strategy: &Strategy,
43 client: &ComputeClient<R>,
44 out_grad: MatmulInputHandle<R>,
45 weights: MatmulInputHandle<R>,
46 in_grad: TensorHandle<R>,
47 args: ConvolutionArgs<N_SPATIAL>,
48 dtypes: MatmulElems,
49) -> Result<(), ConvSetupError> {
50 launch_ref(
51 strategy,
52 client,
53 &out_grad.as_ref(),
54 &weights.as_ref(),
55 &in_grad.as_ref(),
56 args,
57 dtypes,
58 )
59}
60
61#[allow(clippy::result_large_err, clippy::too_many_arguments)]
70pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
71 strategy: &Strategy,
72 client: &ComputeClient<R>,
73 out_grad: &MatmulInputHandleRef<'_, R>,
74 weights: &MatmulInputHandleRef<'_, R>,
75 in_grad: &TensorHandleRef<'_, R>,
76 args: ConvolutionArgs<N_SPATIAL>,
77 dtypes: MatmulElems,
78) -> Result<(), ConvSetupError> {
79 let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
80
81 match strategy {
82 Strategy::Simple {
83 read_strategy,
84 tile_kind,
85 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
86 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
87 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
88 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
89 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
90 ReadingStrategy::AsyncStrided =>
91 backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
92 ReadingStrategy::Tma => Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
93 Box::new("Data backprop doesn't yet work with current TMA tiling strategy")
94 ))),
95 }),
96 }
97}
98
99#[derive(new)]
100struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
101 client: &'a ComputeClient<R>,
102 out_grad: &'a MatmulInputHandleRef<'a, R>,
103 weights: &'a MatmulInputHandleRef<'a, R>,
104 in_grad: &'a TensorHandleRef<'a, R>,
105 args: ConvolutionArgs<N_SPATIAL>,
106 dtypes: MatmulElems,
107}
108
109impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
110 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
111 where
112 Alg::Args: ConcreteArgs,
113 {
114 let ConvolutionArgs {
115 stride,
116 padding,
117 dilation,
118 } = self.args;
119
120 let dimensionality = match N_SPATIAL {
121 1 => Dimensionality::Dim1,
122 2 => Dimensionality::Dim2,
123 3 => Dimensionality::Dim3,
124 other => unimplemented!("Unsupported dimensionality {other}"),
125 };
126
127 launch_with_algorithm::<R, Alg>(
128 self.client,
129 self.out_grad,
130 self.weights,
131 self.in_grad,
132 (&stride, &padding, &dilation),
133 dimensionality,
134 self.dtypes,
135 )
136 }
137}
138
139#[allow(clippy::too_many_arguments)]
140fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
141 client: &ComputeClient<R>,
142 out_grad: &MatmulInputHandleRef<'_, R>,
143 weights: &MatmulInputHandleRef<'_, R>,
144 in_grad: &TensorHandleRef<'_, R>,
145 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
146 dimensionality: Dimensionality,
147 dtypes: MatmulElems,
148) -> Result<(), ConvSetupError>
149where
150 Alg::Args: ConcreteArgs,
151{
152 let rank = in_grad.shape.len();
153 let dim_c = rank - 1;
154
155 let n = in_grad.shape[0];
156 let c = in_grad.shape[dim_c];
157
158 let out_c = out_grad.shape()[dim_c];
159
160 let in_shape = &in_grad.shape[1..dim_c];
161 let kernel_shape = &weights.shape()[1..dim_c];
162 let out_shape = &out_grad.shape()[1..dim_c];
163
164 let op = ConvolutionOperation::BackwardData;
165
166 let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.lhs_global, op)?;
167 let weights_data = Alg::into_tensor_handle(client, weights.data(), dtypes.rhs_global, op)?;
168
169 let mut out_grad = *out_grad;
170 let mut weights = *weights;
171
172 *out_grad.data_mut() = out_grad_data.as_ref();
173 *weights.data_mut() = weights_data.as_ref();
174
175 let problem = ConvolutionProblem {
176 m: n * in_shape.iter().product::<usize>(),
177 n: c,
178 k: out_c * kernel_shape.iter().product::<usize>(),
179
180 lhs_strides: out_grad.data().strides.to_vec(),
181 rhs_strides: weights.data().strides.to_vec(),
182 lhs_layout: MatrixLayout::RowMajor,
183 rhs_layout: MatrixLayout::RowMajor,
184 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
185 stride: stride.iter().map(|it| *it as u32).collect(),
186 padding: padding.iter().map(|it| *it as i32).collect(),
187 dilation: dilation.iter().map(|it| *it as u32).collect(),
188
189 batches: n,
190 in_shape: in_shape.to_vec(),
191 out_shape: out_shape.to_vec(),
192 channels: c,
193 out_channels: out_c,
194
195 padded_channels: out_c,
196 operation: op,
197
198 dimensionality,
199 global_dtypes: dtypes.as_global_elems(),
200 };
201
202 launch_kernel::<R, Alg>(client, &out_grad, &weights, in_grad, problem, dtypes)
203}
204
205#[allow(clippy::result_large_err, clippy::too_many_arguments)]
206pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
207 client: &ComputeClient<R>,
208 out_grad: &MatmulInputHandleRef<'_, R>,
209 weights: &MatmulInputHandleRef<'_, R>,
210 in_grad: &TensorHandleRef<'_, R>,
211 problem: ConvolutionProblem,
212 mut dtypes: MatmulElems,
213) -> Result<(), ConvSetupError>
214where
215 Alg::Args: ConcreteArgs,
216{
217 let plane_dim = client.properties().hardware.plane_size_max;
218 let line_sizes = AvailableLineSizes::from_type_sizes(
221 client,
222 out_grad.data().elem_size,
223 weights.data().elem_size,
224 in_grad.elem_size,
225 )
226 .filter_lhs_with_tensor(
227 out_grad.data().strides,
228 out_grad.data().shape,
229 MatrixLayout::RowMajor,
230 )
231 .filter_rhs_with_tensor(
232 weights.data().strides,
233 weights.data().shape,
234 MatrixLayout::RowMajor,
235 )
236 .filter_out_with_tensor(in_grad.strides, in_grad.shape);
237
238 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
239
240 let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
241 let problem = Alg::Args::adjust_problem(client, problem, &selection, &dtypes);
242
243 let device_props = client.properties();
244 let config = Alg::expand_config(device_props, &problem, &selection, &line_sizes, &dtypes)?;
245
246 let line_sizes = config.line_sizes();
247
248 launch_kernel_concrete::<R, Alg>(
249 client, out_grad, weights, in_grad, problem, line_sizes, selection, &dtypes,
250 )
251}