1use crate::{
2 AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3 backward_data::args::ConcreteArgs,
4 components::{ConvolutionOperation, global::args::RuntimeArgs},
5 kernels::algorithm::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9 components::{ConvolutionProblem, Dimensionality},
10 kernels::algorithm::Algorithm,
11};
12use cubecl::{Runtime, client::ComputeClient, prelude::*};
13use cubek_matmul::{
14 components::tile::{cmma::CmmaMatmul, mma::MmaMatmul},
15 definition::{AvailableVectorSizes, MatmulElems, MatmulSetupError},
16 routines::BlueprintStrategy,
17};
18use cubek_std::{InputBinding, MatrixLayout, tile::Strided};
19use derive_new::new;
20
21macro_rules! with_tile_kind {
22 ($kind: expr, $T: ident, $launch: expr) => {
23 match $kind {
24 AcceleratedTileKind::Cmma => {
25 type $T = CmmaMatmul<Option<Strided>>;
26 ($launch)()
27 }
28 AcceleratedTileKind::Mma => {
29 type $T = MmaMatmul<Strided, Strided, Option<Strided>>;
30 ($launch)()
31 }
32 }
33 };
34}
35
36#[allow(clippy::result_large_err, clippy::too_many_arguments)]
45pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
46 strategy: &Strategy,
47 client: &ComputeClient<R>,
48 out_grad: InputBinding<R>,
49 weights: InputBinding<R>,
50 in_grad: TensorBinding<R>,
51 args: ConvolutionArgs<N_SPATIAL>,
52 dtypes: MatmulElems,
53) -> Result<(), ConvSetupError> {
54 let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
55
56 match strategy {
57 Strategy::Simple {
58 read_strategy,
59 tile_kind,
60 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
61 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
62 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
63 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
64 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
65 ReadingStrategy::AsyncStrided =>
66 backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
67 ReadingStrategy::Tma => Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
68 Box::new("Data backprop doesn't yet work with current TMA tiling strategy")
69 ))),
70 }),
71 }
72}
73
74#[derive(new)]
75struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
76 client: &'a ComputeClient<R>,
77 out_grad: InputBinding<R>,
78 weights: InputBinding<R>,
79 in_grad: TensorBinding<R>,
80 args: ConvolutionArgs<N_SPATIAL>,
81 dtypes: MatmulElems,
82}
83
84impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
85 fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
86 where
87 Alg::Args: ConcreteArgs<Alg::Routine>,
88 {
89 let ConvolutionArgs {
90 stride,
91 padding,
92 dilation,
93 } = self.args;
94
95 let dimensionality = match N_SPATIAL {
96 1 => Dimensionality::Dim1,
97 2 => Dimensionality::Dim2,
98 3 => Dimensionality::Dim3,
99 other => unimplemented!("Unsupported dimensionality {other}"),
100 };
101
102 launch_with_algorithm::<R, Alg>(
103 self.client,
104 self.out_grad,
105 self.weights,
106 self.in_grad,
107 (&stride, &padding, &dilation),
108 dimensionality,
109 &BlueprintStrategy::Inferred(Default::default()),
110 self.dtypes,
111 )
112 }
113}
114
115#[allow(clippy::too_many_arguments)]
116fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
117 client: &ComputeClient<R>,
118 out_grad: InputBinding<R>,
119 weights: InputBinding<R>,
120 in_grad: TensorBinding<R>,
121 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
122 dimensionality: Dimensionality,
123 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
124 dtypes: MatmulElems,
125) -> Result<(), ConvSetupError>
126where
127 Alg::Args: ConcreteArgs<Alg::Routine>,
128{
129 let rank = in_grad.shape.len();
130 let dim_c = rank - 1;
131
132 let n = in_grad.shape[0];
133 let c = in_grad.shape[dim_c];
134
135 let out_c = out_grad.shape()[dim_c];
136
137 let in_shape = &in_grad.shape[1..dim_c];
138 let kernel_shape = &weights.shape()[1..dim_c];
139 let out_shape = &out_grad.shape()[1..dim_c];
140
141 let op = ConvolutionOperation::BackwardData;
142
143 let out_grad_tmp = out_grad.clone();
144 let weights_tmp = weights.clone();
145
146 let out_grad_data =
147 Alg::correct_layout(client, out_grad_tmp.into_data(), dtypes.lhs_global, op)?;
148 let weights_data = Alg::correct_layout(client, weights_tmp.into_data(), dtypes.rhs_global, op)?;
149
150 let mut out_grad = out_grad.clone();
151 let mut weights = weights.clone();
152
153 *out_grad.data_mut() = out_grad_data;
154 *weights.data_mut() = weights_data;
155
156 let address_type = out_grad
157 .required_address_type()
158 .max(weights.required_address_type())
159 .max(in_grad.required_address_type(dtypes.acc_global.size()));
160
161 let problem = ConvolutionProblem {
162 m: n * in_shape.iter().product::<usize>(),
163 n: c,
164 k: out_c * kernel_shape.iter().product::<usize>(),
165
166 lhs_strides: out_grad.data().strides.clone(),
167 rhs_strides: weights.data().strides.clone(),
168 lhs_layout: MatrixLayout::RowMajor,
169 rhs_layout: MatrixLayout::RowMajor,
170 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
171 stride: stride.iter().map(|it| *it as u32).collect(),
172 padding: padding.iter().map(|it| *it as i32).collect(),
173 dilation: dilation.iter().map(|it| *it as u32).collect(),
174
175 batches: n,
176 in_shape: in_shape.into(),
177 out_shape: out_shape.into(),
178 channels: c,
179 out_channels: out_c,
180
181 padded_channels: out_c,
182 operation: op,
183
184 dimensionality,
185 global_dtypes: dtypes.as_global_elems(),
186 address_type,
187 };
188
189 launch_kernel::<R, Alg>(
190 client,
191 out_grad,
192 weights,
193 in_grad,
194 problem,
195 blueprint_strategy,
196 dtypes,
197 )
198}
199
200#[allow(clippy::result_large_err, clippy::too_many_arguments)]
201pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
202 client: &ComputeClient<R>,
203 out_grad: InputBinding<R>,
204 weights: InputBinding<R>,
205 in_grad: TensorBinding<R>,
206 problem: ConvolutionProblem,
207 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
208 dtypes: MatmulElems,
209) -> Result<(), ConvSetupError>
210where
211 Alg::Args: ConcreteArgs<Alg::Routine>,
212{
213 let vector_sizes = AvailableVectorSizes::from_type_sizes(
216 client,
217 out_grad.data_elem_size(),
218 weights.data_elem_size(),
219 dtypes.acc_global.size(),
220 )
221 .filter_lhs_with_tensor(
222 &out_grad.data().strides,
223 &out_grad.data().shape,
224 MatrixLayout::RowMajor,
225 )
226 .filter_rhs_with_tensor(
227 &weights.data().strides,
228 &weights.data().shape,
229 MatrixLayout::RowMajor,
230 )
231 .filter_out_with_tensor(&in_grad.strides, &in_grad.shape);
232
233 let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
234
235 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
236 client,
237 out_grad,
238 weights,
239 in_grad,
240 problem,
241 vector_sizes,
242 blueprint_strategy,
243 &dtypes,
244 )
245}