1use crate::{AcceleratedTileKind, ReadingStrategy};
2use crate::{
3 ConvolutionArgs, Strategy,
4 backward_weight::args::ConcreteArgs,
5 components::{ConvGemmConfig as _, ConvolutionOperation},
6 kernels::forward::simple::*,
7};
8use crate::{
9 components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
10};
11use crate::{
12 components::{ConvolutionProblem, Dimensionality},
13 kernels::forward::algorithm::Algorithm,
14};
15use cubecl::{
16 Runtime,
17 client::ComputeClient,
18 prelude::*,
19 std::{CubeOption, tensor::TensorHandle},
20};
21use cubek_matmul::definition::{AvailableLineSizes, MatmulElems, MatrixLayout};
22use cubek_matmul::launch::{MatmulInputHandle, MatmulInputHandleRef};
23use cubek_matmul::{
24 components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
25 definition,
26};
27use derive_new::new;
28
29macro_rules! with_tile_kind {
30 ($kind: expr, $T: ident, $launch: expr) => {
31 match $kind {
32 AcceleratedTileKind::Cmma => {
33 type $T = CmmaMatmul<CubeOption<Strided>>;
34 ($launch)()
35 }
36 AcceleratedTileKind::Mma => {
37 type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
38 ($launch)()
39 }
40 }
41 };
42}
43
44#[allow(clippy::result_large_err, clippy::too_many_arguments)]
45pub fn launch<R: Runtime, const N_SPATIAL: usize>(
46 strategy: &Strategy,
47 client: &ComputeClient<R>,
48 input: MatmulInputHandle<R>,
49 out_grad: MatmulInputHandle<R>,
50 weight_grad: TensorHandle<R>,
51 args: ConvolutionArgs<N_SPATIAL>,
52 dtypes: MatmulElems,
53) -> Result<(), ConvSetupError> {
54 launch_ref(
55 strategy,
56 client,
57 &input.as_ref(),
58 &out_grad.as_ref(),
59 &weight_grad.as_ref(),
60 args,
61 dtypes,
62 )
63}
64
65#[allow(clippy::result_large_err, clippy::too_many_arguments)]
74pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
75 strategy: &Strategy,
76 client: &ComputeClient<R>,
77 input: &MatmulInputHandleRef<'_, R>,
78 out_grad: &MatmulInputHandleRef<'_, R>,
79 weight_grad: &TensorHandleRef<'_, R>,
80 args: ConvolutionArgs<N_SPATIAL>,
81 dtypes: MatmulElems,
82) -> Result<(), ConvSetupError> {
83 let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
84
85 match strategy {
86 Strategy::Simple {
87 read_strategy,
88 tile_kind,
89 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
90 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
91 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
92 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
93 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
94 ReadingStrategy::AsyncStrided =>
95 backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
96 ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv<Accelerated>>(),
97 }),
98 }
99}
100
101#[derive(new)]
102struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
103 client: &'a ComputeClient<R>,
104 input: &'a MatmulInputHandleRef<'a, R>,
105 out_grad: &'a MatmulInputHandleRef<'a, R>,
106 weight_grad: &'a TensorHandleRef<'a, R>,
107 args: ConvolutionArgs<N_SPATIAL>,
108 dtypes: MatmulElems,
109}
110
111impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
112 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
113 where
114 Alg::Args: ConcreteArgs,
115 {
116 let ConvolutionArgs {
117 stride,
118 padding,
119 dilation,
120 } = self.args;
121
122 let dimensionality = match N_SPATIAL {
123 1 => Dimensionality::Dim1,
124 2 => Dimensionality::Dim2,
125 3 => Dimensionality::Dim3,
126 other => unimplemented!("Unsupported dimensionality {other}"),
127 };
128
129 launch_with_algorithm::<R, Alg>(
130 self.client,
131 self.input,
132 self.out_grad,
133 self.weight_grad,
134 (&stride, &padding, &dilation),
135 dimensionality,
136 self.dtypes,
137 )
138 }
139}
140
141#[allow(clippy::too_many_arguments)]
142fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
143 client: &ComputeClient<R>,
144 input: &MatmulInputHandleRef<'_, R>,
145 out_grad: &MatmulInputHandleRef<'_, R>,
146 weight_grad: &TensorHandleRef<'_, R>,
147 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
148 dimensionality: Dimensionality,
149 dtypes: MatmulElems,
150) -> Result<(), ConvSetupError>
151where
152 Alg::Args: ConcreteArgs,
153{
154 let rank = input.data().shape.len();
155 let dim_c = rank - 1;
156
157 let n = input.shape()[0];
158 let c = input.shape()[dim_c];
159
160 let out_c = out_grad.shape()[dim_c];
161
162 let in_shape = &input.shape()[1..dim_c];
163 let kernel_shape = &weight_grad.shape[1..dim_c];
164 let out_shape = &out_grad.shape()[1..dim_c];
165
166 let op = ConvolutionOperation::BackwardWeight;
167
168 let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
169 let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.rhs_global, op)?;
170
171 let mut input = *input;
172 let mut out_grad = *out_grad;
173
174 *input.data_mut() = input_data.as_ref();
175 *out_grad.data_mut() = out_grad_data.as_ref();
176
177 let problem = ConvolutionProblem {
178 m: out_c,
179 n: c * kernel_shape.iter().product::<usize>(),
180 k: n * out_shape.iter().product::<usize>(),
181 lhs_strides: input.data().strides.to_vec(),
182 rhs_strides: out_grad.data().strides.to_vec(),
183 lhs_layout: definition::MatrixLayout::ColMajor,
184 rhs_layout: definition::MatrixLayout::RowMajor,
185 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
186 stride: stride.iter().map(|it| *it as u32).collect(),
187 padding: padding.iter().map(|it| *it as i32).collect(),
188 dilation: dilation.iter().map(|it| *it as u32).collect(),
189
190 batches: n,
191 in_shape: in_shape.to_vec(),
192 out_shape: out_shape.to_vec(),
193 channels: c,
194 out_channels: out_c,
195
196 padded_channels: c,
197 operation: op,
198
199 dimensionality,
200 global_dtypes: dtypes.as_global_elems(),
201 };
202
203 launch_kernel::<R, Alg>(client, &input, &out_grad, weight_grad, problem, dtypes)
204}
205
206#[allow(clippy::result_large_err, clippy::too_many_arguments)]
207pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
208 client: &ComputeClient<R>,
209 input: &MatmulInputHandleRef<'_, R>,
210 out_grad: &MatmulInputHandleRef<'_, R>,
211 weight_grad: &TensorHandleRef<'_, R>,
212 problem: ConvolutionProblem,
213 mut dtypes: MatmulElems,
214) -> Result<(), ConvSetupError>
215where
216 Alg::Args: ConcreteArgs,
217{
218 let plane_dim = client.properties().hardware.plane_size_max;
219 let line_sizes = AvailableLineSizes::from_type_sizes(
222 client,
223 input.data().elem_size,
224 out_grad.data().elem_size,
225 weight_grad.elem_size,
226 )
227 .filter_lhs_with_tensor(
228 out_grad.data().strides,
229 out_grad.data().shape,
230 MatrixLayout::RowMajor,
231 )
232 .filter_rhs_with_tensor(
233 input.data().strides,
234 input.data().shape,
235 MatrixLayout::RowMajor,
236 )
237 .filter_out_with_tensor(weight_grad.strides, weight_grad.shape);
238
239 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
240
241 let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
242 let problem = Alg::Args::adjust_problem(client, problem, &selection, &dtypes);
243
244 let config = Alg::expand_config(
245 client.properties(),
246 &problem,
247 &selection,
248 &line_sizes,
249 &dtypes,
250 )?;
251
252 let line_sizes = config.line_sizes();
253
254 launch_kernel_concrete::<R, Alg>(
255 client,
256 input,
257 out_grad,
258 weight_grad,
259 problem,
260 line_sizes,
261 selection,
262 &dtypes,
263 )
264}