1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::{
3 AcceleratedTileKind, ReadingStrategy, algorithm::Algorithm,
4 components::global::args::RuntimeArgs,
5};
6use crate::{
7 ConvolutionArgs, Strategy, backward_weight::args::ConcreteArgs,
8 components::ConvolutionOperation, kernels::algorithm::simple::*,
9};
10use crate::{
11 components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
12};
13use cubecl::{
14 Runtime,
15 client::ComputeClient,
16 prelude::*,
17 std::{CubeOption, tensor::TensorHandle},
18};
19use cubek_matmul::launch::{MatmulInputHandle, MatmulInputHandleRef};
20use cubek_matmul::{
21 components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
22 definition,
23};
24use cubek_matmul::{
25 definition::{AvailableLineSizes, MatmulElems, MatrixLayout},
26 routines::BlueprintStrategy,
27};
28use derive_new::new;
29
30macro_rules! with_tile_kind {
31 ($kind: expr, $T: ident, $launch: expr) => {
32 match $kind {
33 AcceleratedTileKind::Cmma => {
34 type $T = CmmaMatmul<CubeOption<Strided>>;
35 ($launch)()
36 }
37 AcceleratedTileKind::Mma => {
38 type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
39 ($launch)()
40 }
41 }
42 };
43}
44
45#[allow(clippy::result_large_err, clippy::too_many_arguments)]
46pub fn launch<R: Runtime, const N_SPATIAL: usize>(
47 strategy: &Strategy,
48 client: &ComputeClient<R>,
49 input: MatmulInputHandle<R>,
50 out_grad: MatmulInputHandle<R>,
51 weight_grad: TensorHandle<R>,
52 args: ConvolutionArgs<N_SPATIAL>,
53 dtypes: MatmulElems,
54) -> Result<(), ConvSetupError> {
55 launch_ref(
56 strategy,
57 client,
58 &input.as_ref(),
59 &out_grad.as_ref(),
60 &weight_grad.as_ref(),
61 args,
62 dtypes,
63 )
64}
65
66#[allow(clippy::result_large_err, clippy::too_many_arguments)]
75pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
76 strategy: &Strategy,
77 client: &ComputeClient<R>,
78 input: &MatmulInputHandleRef<'_, R>,
79 out_grad: &MatmulInputHandleRef<'_, R>,
80 weight_grad: &TensorHandleRef<'_, R>,
81 args: ConvolutionArgs<N_SPATIAL>,
82 dtypes: MatmulElems,
83) -> Result<(), ConvSetupError> {
84 let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
85
86 match strategy {
87 Strategy::Simple {
88 read_strategy,
89 tile_kind,
90 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
91 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
92 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
93 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
94 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
95 ReadingStrategy::AsyncStrided =>
96 backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
97 ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv<Accelerated>>(),
98 }),
99 }
100}
101
102#[derive(new)]
103struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
104 client: &'a ComputeClient<R>,
105 input: &'a MatmulInputHandleRef<'a, R>,
106 out_grad: &'a MatmulInputHandleRef<'a, R>,
107 weight_grad: &'a TensorHandleRef<'a, R>,
108 args: ConvolutionArgs<N_SPATIAL>,
109 dtypes: MatmulElems,
110}
111
112impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
113 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
114 where
115 Alg::Args: ConcreteArgs<Alg::Routine>,
116 {
117 let ConvolutionArgs {
118 stride,
119 padding,
120 dilation,
121 } = self.args;
122
123 let dimensionality = match N_SPATIAL {
124 1 => Dimensionality::Dim1,
125 2 => Dimensionality::Dim2,
126 3 => Dimensionality::Dim3,
127 other => unimplemented!("Unsupported dimensionality {other}"),
128 };
129
130 launch_with_algorithm::<R, Alg>(
131 self.client,
132 self.input,
133 self.out_grad,
134 self.weight_grad,
135 (&stride, &padding, &dilation),
136 dimensionality,
137 self.dtypes,
138 )
139 }
140}
141
142#[allow(clippy::too_many_arguments)]
143fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
144 client: &ComputeClient<R>,
145 input: &MatmulInputHandleRef<'_, R>,
146 out_grad: &MatmulInputHandleRef<'_, R>,
147 weight_grad: &TensorHandleRef<'_, R>,
148 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
149 dimensionality: Dimensionality,
150 dtypes: MatmulElems,
151) -> Result<(), ConvSetupError>
152where
153 Alg::Args: ConcreteArgs<Alg::Routine>,
154{
155 let rank = input.data().shape.len();
156 let dim_c = rank - 1;
157
158 let n = input.shape()[0];
159 let c = input.shape()[dim_c];
160
161 let out_c = out_grad.shape()[dim_c];
162
163 let in_shape = &input.shape()[1..dim_c];
164 let kernel_shape = &weight_grad.shape[1..dim_c];
165 let out_shape = &out_grad.shape()[1..dim_c];
166
167 let op = ConvolutionOperation::BackwardWeight;
168
169 let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
170 let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.rhs_global, op)?;
171
172 let mut input = *input;
173 let mut out_grad = *out_grad;
174
175 *input.data_mut() = input_data.as_ref();
176 *out_grad.data_mut() = out_grad_data.as_ref();
177
178 let problem = ConvolutionProblem {
179 m: out_c,
180 n: c * kernel_shape.iter().product::<usize>(),
181 k: n * out_shape.iter().product::<usize>(),
182 lhs_strides: input.data().strides.to_vec(),
183 rhs_strides: out_grad.data().strides.to_vec(),
184 lhs_layout: definition::MatrixLayout::ColMajor,
185 rhs_layout: definition::MatrixLayout::RowMajor,
186 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
187 stride: stride.iter().map(|it| *it as u32).collect(),
188 padding: padding.iter().map(|it| *it as i32).collect(),
189 dilation: dilation.iter().map(|it| *it as u32).collect(),
190
191 batches: n,
192 in_shape: in_shape.to_vec(),
193 out_shape: out_shape.to_vec(),
194 channels: c,
195 out_channels: out_c,
196
197 padded_channels: c,
198 operation: op,
199
200 dimensionality,
201 global_dtypes: dtypes.as_global_elems(),
202 };
203
204 launch_kernel::<R, Alg>(
205 client,
206 &input,
207 &out_grad,
208 weight_grad,
209 problem,
210 &BlueprintStrategy::Inferred(Default::default()),
211 dtypes,
212 )
213}
214
215#[allow(clippy::result_large_err, clippy::too_many_arguments)]
216pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
217 client: &ComputeClient<R>,
218 input: &MatmulInputHandleRef<'_, R>,
219 out_grad: &MatmulInputHandleRef<'_, R>,
220 weight_grad: &TensorHandleRef<'_, R>,
221 problem: ConvolutionProblem,
222 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
223 dtypes: MatmulElems,
224) -> Result<(), ConvSetupError>
225where
226 Alg::Args: ConcreteArgs<Alg::Routine>,
227{
228 let line_sizes = AvailableLineSizes::from_type_sizes(
231 client,
232 input.data().elem_size,
233 out_grad.data().elem_size,
234 weight_grad.elem_size,
235 )
236 .filter_lhs_with_tensor(
237 out_grad.data().strides,
238 out_grad.data().shape,
239 MatrixLayout::RowMajor,
240 )
241 .filter_rhs_with_tensor(
242 input.data().strides,
243 input.data().shape,
244 MatrixLayout::RowMajor,
245 )
246 .filter_out_with_tensor(weight_grad.strides, weight_grad.shape);
247
248 let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
249
250 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
251 client,
252 input,
253 out_grad,
254 weight_grad,
255 problem,
256 line_sizes,
257 blueprint_strategy,
258 &dtypes,
259 )
260}