1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::{
3 AcceleratedTileKind, ReadingStrategy, algorithm::Algorithm,
4 components::global::args::RuntimeArgs,
5};
6use crate::{
7 ConvolutionArgs, Strategy, backward_weight::args::ConcreteArgs,
8 components::ConvolutionOperation, kernels::algorithm::simple::*,
9};
10use crate::{
11 components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
12};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::components::tile_matmul::DispatchTileMatmul;
15use cubek_matmul::{
16 definition::{AvailableVectorSizes, MatmulElems},
17 routines::{BlueprintStrategy, Routine, TilingArgs},
18};
19use cubek_std::{InputBinding, MatrixLayout};
20use derive_new::new;
21
22fn tile_kind_to_dispatch(kind: &AcceleratedTileKind) -> DispatchTileMatmul {
23 match kind {
24 AcceleratedTileKind::Cmma => DispatchTileMatmul::Cmma,
25 AcceleratedTileKind::Mma => DispatchTileMatmul::Mma,
26 }
27}
28
29#[allow(clippy::result_large_err, clippy::too_many_arguments)]
38pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
39 strategy: &Strategy,
40 client: &ComputeClient<R>,
41 input: InputBinding<R>,
42 out_grad: InputBinding<R>,
43 weight_grad: TensorBinding<R>,
44 args: ConvolutionArgs<N_SPATIAL>,
45 dtypes: MatmulElems,
46) -> Result<(), ConvSetupError> {
47 let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
48
49 match strategy {
50 Strategy::Simple {
51 read_strategy,
52 tile_kind,
53 } => {
54 let kind = tile_kind_to_dispatch(tile_kind);
55 match read_strategy {
56 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv>(kind),
57 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv>(kind),
58 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv>(kind),
59 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv>(kind),
60 ReadingStrategy::AsyncStrided => backprop.launch::<SimpleAsyncStridedConv>(kind),
61 ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv>(kind),
62 }
63 }
64 }
65}
66
67#[derive(new)]
68struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
69 client: &'a ComputeClient<R>,
70 input: InputBinding<R>,
71 out_grad: InputBinding<R>,
72 weight_grad: TensorBinding<R>,
73 args: ConvolutionArgs<N_SPATIAL>,
74 dtypes: MatmulElems,
75}
76
77impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
78 fn launch<Alg: Algorithm>(self, tile_matmul: DispatchTileMatmul) -> Result<(), ConvSetupError>
79 where
80 Alg::Args: ConcreteArgs<Alg::Routine>,
81 <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
82 {
83 let ConvolutionArgs {
84 stride,
85 padding,
86 dilation,
87 } = self.args;
88
89 let dimensionality = match N_SPATIAL {
90 1 => Dimensionality::Dim1,
91 2 => Dimensionality::Dim2,
92 3 => Dimensionality::Dim3,
93 other => unimplemented!("Unsupported dimensionality {other}"),
94 };
95
96 launch_with_algorithm::<R, Alg>(
97 self.client,
98 self.input,
99 self.out_grad,
100 self.weight_grad,
101 (&stride, &padding, &dilation),
102 dimensionality,
103 tile_matmul,
104 self.dtypes,
105 )
106 }
107}
108
109#[allow(clippy::too_many_arguments)]
110fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
111 client: &ComputeClient<R>,
112 input: InputBinding<R>,
113 out_grad: InputBinding<R>,
114 weight_grad: TensorBinding<R>,
115 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
116 dimensionality: Dimensionality,
117 tile_matmul: DispatchTileMatmul,
118 dtypes: MatmulElems,
119) -> Result<(), ConvSetupError>
120where
121 Alg::Args: ConcreteArgs<Alg::Routine>,
122 <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
123{
124 let rank = input.data().shape.len();
125 let dim_c = rank - 1;
126
127 let n = input.shape()[0];
128 let c = input.shape()[dim_c];
129
130 let out_c = out_grad.shape()[dim_c];
131
132 let in_shape = &input.shape()[1..dim_c];
133 let kernel_shape = &weight_grad.shape[1..dim_c];
134 let out_shape = &out_grad.shape()[1..dim_c];
135
136 let op = ConvolutionOperation::BackwardWeight;
137
138 let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
139 let out_grad_data =
140 Alg::correct_layout(client, out_grad.clone().into_data(), dtypes.rhs_global, op)?;
141
142 let mut input = input.clone();
143 let mut out_grad = out_grad.clone();
144
145 *input.data_mut() = input_data;
146 *out_grad.data_mut() = out_grad_data;
147
148 let address_type = input
149 .required_address_type()
150 .max(out_grad.required_address_type())
151 .max(weight_grad.required_address_type(dtypes.acc_global.size()));
152
153 let problem = ConvolutionProblem {
154 m: out_c,
155 n: c * kernel_shape.iter().product::<usize>(),
156 k: n * out_shape.iter().product::<usize>(),
157 lhs_strides: input.data().strides.clone(),
158 rhs_strides: out_grad.data().strides.clone(),
159 lhs_layout: MatrixLayout::ColMajor,
160 rhs_layout: MatrixLayout::RowMajor,
161 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
162 stride: stride.iter().map(|it| *it as u32).collect(),
163 padding: padding.iter().map(|it| *it as i32).collect(),
164 dilation: dilation.iter().map(|it| *it as u32).collect(),
165
166 batches: n,
167 in_shape: in_shape.into(),
168 out_shape: out_shape.into(),
169 channels: c,
170 out_channels: out_c,
171
172 padded_channels: c,
173 operation: op,
174
175 dimensionality,
176 global_dtypes: dtypes.as_global_elems(),
177 address_type,
178 };
179
180 let mut args = <Alg::Routine as Routine<RuntimeArgs>>::Strategy::default();
181 args.set_tile_matmul(tile_matmul);
182
183 launch_kernel::<R, Alg>(
184 client,
185 input,
186 out_grad,
187 weight_grad,
188 problem,
189 &BlueprintStrategy::Inferred(args),
190 dtypes,
191 )
192}
193
194#[allow(clippy::result_large_err, clippy::too_many_arguments)]
195pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
196 client: &ComputeClient<R>,
197 input: InputBinding<R>,
198 out_grad: InputBinding<R>,
199 weight_grad: TensorBinding<R>,
200 problem: ConvolutionProblem,
201 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
202 dtypes: MatmulElems,
203) -> Result<(), ConvSetupError>
204where
205 Alg::Args: ConcreteArgs<Alg::Routine>,
206{
207 let vector_sizes = AvailableVectorSizes::from_type_sizes(
210 client,
211 input.data_elem_size(),
212 out_grad.data_elem_size(),
213 dtypes.acc_global.size(),
214 )
215 .filter_lhs_with_tensor(
216 &out_grad.data().strides,
217 &out_grad.data().shape,
218 MatrixLayout::RowMajor,
219 )
220 .filter_rhs_with_tensor(
221 &input.data().strides,
222 &input.data().shape,
223 MatrixLayout::RowMajor,
224 )
225 .filter_out_with_tensor(&weight_grad.strides, &weight_grad.shape);
226
227 let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
228
229 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
230 client,
231 input,
232 out_grad,
233 weight_grad,
234 problem,
235 vector_sizes,
236 blueprint_strategy,
237 &dtypes,
238 )
239}