1use crate::{
2 AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3 backward_data::args::ConcreteArgs,
4 components::{ConvolutionOperation, global::args::RuntimeArgs},
5 kernels::algorithm::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9 components::{ConvolutionProblem, Dimensionality},
10 kernels::algorithm::Algorithm,
11};
12use cubecl::{
13 Runtime,
14 client::ComputeClient,
15 prelude::*,
16 std::{CubeOption, tensor::TensorHandle},
17};
18use cubek_matmul::{
19 components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
20 definition::{AvailableLineSizes, MatmulElems, MatmulSetupError, MatrixLayout},
21 launch::{MatmulInputHandle, MatmulInputHandleRef},
22 routines::BlueprintStrategy,
23};
24use derive_new::new;
25
26macro_rules! with_tile_kind {
27 ($kind: expr, $T: ident, $launch: expr) => {
28 match $kind {
29 AcceleratedTileKind::Cmma => {
30 type $T = CmmaMatmul<CubeOption<Strided>>;
31 ($launch)()
32 }
33 AcceleratedTileKind::Mma => {
34 type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
35 ($launch)()
36 }
37 }
38 };
39}
40
41#[allow(clippy::result_large_err, clippy::too_many_arguments)]
42pub fn launch<R: Runtime, const N_SPATIAL: usize>(
43 strategy: &Strategy,
44 client: &ComputeClient<R>,
45 out_grad: MatmulInputHandle<R>,
46 weights: MatmulInputHandle<R>,
47 in_grad: TensorHandle<R>,
48 args: ConvolutionArgs<N_SPATIAL>,
49 dtypes: MatmulElems,
50) -> Result<(), ConvSetupError> {
51 launch_ref(
52 strategy,
53 client,
54 &out_grad.as_ref(),
55 &weights.as_ref(),
56 &in_grad.as_ref(),
57 args,
58 dtypes,
59 )
60}
61
62#[allow(clippy::result_large_err, clippy::too_many_arguments)]
71pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
72 strategy: &Strategy,
73 client: &ComputeClient<R>,
74 out_grad: &MatmulInputHandleRef<'_, R>,
75 weights: &MatmulInputHandleRef<'_, R>,
76 in_grad: &TensorHandleRef<'_, R>,
77 args: ConvolutionArgs<N_SPATIAL>,
78 dtypes: MatmulElems,
79) -> Result<(), ConvSetupError> {
80 let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
81
82 match strategy {
83 Strategy::Simple {
84 read_strategy,
85 tile_kind,
86 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
87 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
88 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
89 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
90 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
91 ReadingStrategy::AsyncStrided =>
92 backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
93 ReadingStrategy::Tma => Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
94 Box::new("Data backprop doesn't yet work with current TMA tiling strategy")
95 ))),
96 }),
97 }
98}
99
100#[derive(new)]
101struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
102 client: &'a ComputeClient<R>,
103 out_grad: &'a MatmulInputHandleRef<'a, R>,
104 weights: &'a MatmulInputHandleRef<'a, R>,
105 in_grad: &'a TensorHandleRef<'a, R>,
106 args: ConvolutionArgs<N_SPATIAL>,
107 dtypes: MatmulElems,
108}
109
110impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
111 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
112 where
113 Alg::Args: ConcreteArgs<Alg::Routine>,
114 {
115 let ConvolutionArgs {
116 stride,
117 padding,
118 dilation,
119 } = self.args;
120
121 let dimensionality = match N_SPATIAL {
122 1 => Dimensionality::Dim1,
123 2 => Dimensionality::Dim2,
124 3 => Dimensionality::Dim3,
125 other => unimplemented!("Unsupported dimensionality {other}"),
126 };
127
128 launch_with_algorithm::<R, Alg>(
129 self.client,
130 self.out_grad,
131 self.weights,
132 self.in_grad,
133 (&stride, &padding, &dilation),
134 dimensionality,
135 &BlueprintStrategy::Inferred(Default::default()),
136 self.dtypes,
137 )
138 }
139}
140
141#[allow(clippy::too_many_arguments)]
142fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
143 client: &ComputeClient<R>,
144 out_grad: &MatmulInputHandleRef<'_, R>,
145 weights: &MatmulInputHandleRef<'_, R>,
146 in_grad: &TensorHandleRef<'_, R>,
147 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
148 dimensionality: Dimensionality,
149 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
150 dtypes: MatmulElems,
151) -> Result<(), ConvSetupError>
152where
153 Alg::Args: ConcreteArgs<Alg::Routine>,
154{
155 let rank = in_grad.shape.len();
156 let dim_c = rank - 1;
157
158 let n = in_grad.shape[0];
159 let c = in_grad.shape[dim_c];
160
161 let out_c = out_grad.shape()[dim_c];
162
163 let in_shape = &in_grad.shape[1..dim_c];
164 let kernel_shape = &weights.shape()[1..dim_c];
165 let out_shape = &out_grad.shape()[1..dim_c];
166
167 let op = ConvolutionOperation::BackwardData;
168
169 let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.lhs_global, op)?;
170 let weights_data = Alg::into_tensor_handle(client, weights.data(), dtypes.rhs_global, op)?;
171
172 let mut out_grad = *out_grad;
173 let mut weights = *weights;
174
175 *out_grad.data_mut() = out_grad_data.as_ref();
176 *weights.data_mut() = weights_data.as_ref();
177
178 let problem = ConvolutionProblem {
179 m: n * in_shape.iter().product::<usize>(),
180 n: c,
181 k: out_c * kernel_shape.iter().product::<usize>(),
182
183 lhs_strides: out_grad.data().strides.to_vec(),
184 rhs_strides: weights.data().strides.to_vec(),
185 lhs_layout: MatrixLayout::RowMajor,
186 rhs_layout: MatrixLayout::RowMajor,
187 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
188 stride: stride.iter().map(|it| *it as u32).collect(),
189 padding: padding.iter().map(|it| *it as i32).collect(),
190 dilation: dilation.iter().map(|it| *it as u32).collect(),
191
192 batches: n,
193 in_shape: in_shape.to_vec(),
194 out_shape: out_shape.to_vec(),
195 channels: c,
196 out_channels: out_c,
197
198 padded_channels: out_c,
199 operation: op,
200
201 dimensionality,
202 global_dtypes: dtypes.as_global_elems(),
203 };
204
205 launch_kernel::<R, Alg>(
206 client,
207 &out_grad,
208 &weights,
209 in_grad,
210 problem,
211 blueprint_strategy,
212 dtypes,
213 )
214}
215
216#[allow(clippy::result_large_err, clippy::too_many_arguments)]
217pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
218 client: &ComputeClient<R>,
219 out_grad: &MatmulInputHandleRef<'_, R>,
220 weights: &MatmulInputHandleRef<'_, R>,
221 in_grad: &TensorHandleRef<'_, R>,
222 problem: ConvolutionProblem,
223 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
224 dtypes: MatmulElems,
225) -> Result<(), ConvSetupError>
226where
227 Alg::Args: ConcreteArgs<Alg::Routine>,
228{
229 let line_sizes = AvailableLineSizes::from_type_sizes(
232 client,
233 out_grad.data().elem_size,
234 weights.data().elem_size,
235 in_grad.elem_size,
236 )
237 .filter_lhs_with_tensor(
238 out_grad.data().strides,
239 out_grad.data().shape,
240 MatrixLayout::RowMajor,
241 )
242 .filter_rhs_with_tensor(
243 weights.data().strides,
244 weights.data().shape,
245 MatrixLayout::RowMajor,
246 )
247 .filter_out_with_tensor(in_grad.strides, in_grad.shape);
248
249 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
250
251 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
252 client,
253 out_grad,
254 weights,
255 in_grad,
256 problem,
257 line_sizes,
258 blueprint_strategy,
259 &dtypes,
260 )
261}