1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::{
3 AcceleratedTileKind, ReadingStrategy, algorithm::Algorithm,
4 components::global::args::RuntimeArgs,
5};
6use crate::{
7 ConvolutionArgs, Strategy, backward_weight::args::ConcreteArgs,
8 components::ConvolutionOperation, kernels::algorithm::simple::*,
9};
10use crate::{
11 components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
12};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::components::tile::{cmma::CmmaMatmul, mma::MmaMatmul};
15use cubek_matmul::{
16 definition::{AvailableVectorSizes, MatmulElems},
17 routines::BlueprintStrategy,
18};
19use cubek_std::{
20 InputBinding,
21 {MatrixLayout, tile::Strided},
22};
23use derive_new::new;
24
25macro_rules! with_tile_kind {
26 ($kind: expr, $T: ident, $launch: expr) => {
27 match $kind {
28 AcceleratedTileKind::Cmma => {
29 type $T = CmmaMatmul<Option<Strided>>;
30 ($launch)()
31 }
32 AcceleratedTileKind::Mma => {
33 type $T = MmaMatmul<Strided, Strided, Option<Strided>>;
34 ($launch)()
35 }
36 }
37 };
38}
39
40#[allow(clippy::result_large_err, clippy::too_many_arguments)]
49pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
50 strategy: &Strategy,
51 client: &ComputeClient<R>,
52 input: InputBinding<R>,
53 out_grad: InputBinding<R>,
54 weight_grad: TensorBinding<R>,
55 args: ConvolutionArgs<N_SPATIAL>,
56 dtypes: MatmulElems,
57) -> Result<(), ConvSetupError> {
58 let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
59
60 match strategy {
61 Strategy::Simple {
62 read_strategy,
63 tile_kind,
64 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
65 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
66 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
67 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
68 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
69 ReadingStrategy::AsyncStrided =>
70 backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
71 ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv<Accelerated>>(),
72 }),
73 }
74}
75
76#[derive(new)]
77struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
78 client: &'a ComputeClient<R>,
79 input: InputBinding<R>,
80 out_grad: InputBinding<R>,
81 weight_grad: TensorBinding<R>,
82 args: ConvolutionArgs<N_SPATIAL>,
83 dtypes: MatmulElems,
84}
85
86impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
87 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
88 where
89 Alg::Args: ConcreteArgs<Alg::Routine>,
90 {
91 let ConvolutionArgs {
92 stride,
93 padding,
94 dilation,
95 } = self.args;
96
97 let dimensionality = match N_SPATIAL {
98 1 => Dimensionality::Dim1,
99 2 => Dimensionality::Dim2,
100 3 => Dimensionality::Dim3,
101 other => unimplemented!("Unsupported dimensionality {other}"),
102 };
103
104 launch_with_algorithm::<R, Alg>(
105 self.client,
106 self.input,
107 self.out_grad,
108 self.weight_grad,
109 (&stride, &padding, &dilation),
110 dimensionality,
111 self.dtypes,
112 )
113 }
114}
115
116#[allow(clippy::too_many_arguments)]
117fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
118 client: &ComputeClient<R>,
119 input: InputBinding<R>,
120 out_grad: InputBinding<R>,
121 weight_grad: TensorBinding<R>,
122 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
123 dimensionality: Dimensionality,
124 dtypes: MatmulElems,
125) -> Result<(), ConvSetupError>
126where
127 Alg::Args: ConcreteArgs<Alg::Routine>,
128{
129 let rank = input.data().shape.len();
130 let dim_c = rank - 1;
131
132 let n = input.shape()[0];
133 let c = input.shape()[dim_c];
134
135 let out_c = out_grad.shape()[dim_c];
136
137 let in_shape = &input.shape()[1..dim_c];
138 let kernel_shape = &weight_grad.shape[1..dim_c];
139 let out_shape = &out_grad.shape()[1..dim_c];
140
141 let op = ConvolutionOperation::BackwardWeight;
142
143 let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
144 let out_grad_data =
145 Alg::correct_layout(client, out_grad.clone().into_data(), dtypes.rhs_global, op)?;
146
147 let mut input = input.clone();
148 let mut out_grad = out_grad.clone();
149
150 *input.data_mut() = input_data;
151 *out_grad.data_mut() = out_grad_data;
152
153 let address_type = input
154 .required_address_type()
155 .max(out_grad.required_address_type())
156 .max(weight_grad.required_address_type(dtypes.acc_global.size()));
157
158 let problem = ConvolutionProblem {
159 m: out_c,
160 n: c * kernel_shape.iter().product::<usize>(),
161 k: n * out_shape.iter().product::<usize>(),
162 lhs_strides: input.data().strides.clone(),
163 rhs_strides: out_grad.data().strides.clone(),
164 lhs_layout: MatrixLayout::ColMajor,
165 rhs_layout: MatrixLayout::RowMajor,
166 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
167 stride: stride.iter().map(|it| *it as u32).collect(),
168 padding: padding.iter().map(|it| *it as i32).collect(),
169 dilation: dilation.iter().map(|it| *it as u32).collect(),
170
171 batches: n,
172 in_shape: in_shape.into(),
173 out_shape: out_shape.into(),
174 channels: c,
175 out_channels: out_c,
176
177 padded_channels: c,
178 operation: op,
179
180 dimensionality,
181 global_dtypes: dtypes.as_global_elems(),
182 address_type,
183 };
184
185 launch_kernel::<R, Alg>(
186 client,
187 input,
188 out_grad,
189 weight_grad,
190 problem,
191 &BlueprintStrategy::Inferred(Default::default()),
192 dtypes,
193 )
194}
195
196#[allow(clippy::result_large_err, clippy::too_many_arguments)]
197pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
198 client: &ComputeClient<R>,
199 input: InputBinding<R>,
200 out_grad: InputBinding<R>,
201 weight_grad: TensorBinding<R>,
202 problem: ConvolutionProblem,
203 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
204 dtypes: MatmulElems,
205) -> Result<(), ConvSetupError>
206where
207 Alg::Args: ConcreteArgs<Alg::Routine>,
208{
209 let vector_sizes = AvailableVectorSizes::from_type_sizes(
212 client,
213 input.data_elem_size(),
214 out_grad.data_elem_size(),
215 dtypes.acc_global.size(),
216 )
217 .filter_lhs_with_tensor(
218 &out_grad.data().strides,
219 &out_grad.data().shape,
220 MatrixLayout::RowMajor,
221 )
222 .filter_rhs_with_tensor(
223 &input.data().strides,
224 &input.data().shape,
225 MatrixLayout::RowMajor,
226 )
227 .filter_out_with_tensor(&weight_grad.strides, &weight_grad.shape);
228
229 let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
230
231 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
232 client,
233 input,
234 out_grad,
235 weight_grad,
236 problem,
237 vector_sizes,
238 blueprint_strategy,
239 &dtypes,
240 )
241}