1use crate::{
2 AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3 backward_data::args::ConcreteArgs,
4 components::{ConvolutionOperation, global::args::RuntimeArgs},
5 kernels::algorithm::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9 components::{ConvolutionProblem, Dimensionality},
10 kernels::algorithm::Algorithm,
11};
12use cubecl::{Runtime, client::ComputeClient, prelude::*};
13use cubek_matmul::{
14 components::tile_matmul::DispatchTileMatmul,
15 definition::{AvailableVectorSizes, MatmulElems, MatmulSetupError},
16 routines::{BlueprintStrategy, Routine, TilingArgs},
17};
18use cubek_std::{InputBinding, MatrixLayout};
19use derive_new::new;
20
21fn tile_kind_to_dispatch(kind: &AcceleratedTileKind) -> DispatchTileMatmul {
22 match kind {
23 AcceleratedTileKind::Cmma => DispatchTileMatmul::Cmma,
24 AcceleratedTileKind::Mma => DispatchTileMatmul::Mma,
25 }
26}
27
28#[allow(clippy::result_large_err, clippy::too_many_arguments)]
37pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
38 strategy: &Strategy,
39 client: &ComputeClient<R>,
40 out_grad: InputBinding<R>,
41 weights: InputBinding<R>,
42 in_grad: TensorBinding<R>,
43 args: ConvolutionArgs<N_SPATIAL>,
44 dtypes: MatmulElems,
45) -> Result<(), ConvSetupError> {
46 let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
47
48 match strategy {
49 Strategy::Simple {
50 read_strategy,
51 tile_kind,
52 } => {
53 let kind = tile_kind_to_dispatch(tile_kind);
54 match read_strategy {
55 ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv>(kind),
56 ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv>(kind),
57 ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv>(kind),
58 ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv>(kind),
59 ReadingStrategy::AsyncStrided => backprop.launch::<SimpleAsyncStridedConv>(kind),
60 ReadingStrategy::Tma => {
61 Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
62 Box::new("Data backprop doesn't yet work with current TMA tiling strategy"),
63 )))
64 }
65 }
66 }
67 }
68}
69
70#[derive(new)]
71struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
72 client: &'a ComputeClient<R>,
73 out_grad: InputBinding<R>,
74 weights: InputBinding<R>,
75 in_grad: TensorBinding<R>,
76 args: ConvolutionArgs<N_SPATIAL>,
77 dtypes: MatmulElems,
78}
79
80impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
81 fn launch<Alg: Algorithm>(self, tile_matmul: DispatchTileMatmul) -> Result<(), ConvSetupError>
82 where
83 Alg::Args: ConcreteArgs<Alg::Routine>,
84 <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
85 {
86 let ConvolutionArgs {
87 stride,
88 padding,
89 dilation,
90 } = self.args;
91
92 let dimensionality = match N_SPATIAL {
93 1 => Dimensionality::Dim1,
94 2 => Dimensionality::Dim2,
95 3 => Dimensionality::Dim3,
96 other => unimplemented!("Unsupported dimensionality {other}"),
97 };
98
99 let mut args = <Alg::Routine as Routine<RuntimeArgs>>::Strategy::default();
100 args.set_tile_matmul(tile_matmul);
101
102 launch_with_algorithm::<R, Alg>(
103 self.client,
104 self.out_grad,
105 self.weights,
106 self.in_grad,
107 (&stride, &padding, &dilation),
108 dimensionality,
109 &BlueprintStrategy::Inferred(args),
110 self.dtypes,
111 )
112 }
113}
114
115#[allow(clippy::too_many_arguments)]
116fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
117 client: &ComputeClient<R>,
118 out_grad: InputBinding<R>,
119 weights: InputBinding<R>,
120 in_grad: TensorBinding<R>,
121 (stride, padding, dilation): (&[usize], &[usize], &[usize]),
122 dimensionality: Dimensionality,
123 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
124 dtypes: MatmulElems,
125) -> Result<(), ConvSetupError>
126where
127 Alg::Args: ConcreteArgs<Alg::Routine>,
128{
129 let rank = in_grad.shape.len();
130 let dim_c = rank - 1;
131
132 let n = in_grad.shape[0];
133 let c = in_grad.shape[dim_c];
134
135 let out_c = out_grad.shape()[dim_c];
136
137 let in_shape = &in_grad.shape[1..dim_c];
138 let kernel_shape = &weights.shape()[1..dim_c];
139 let out_shape = &out_grad.shape()[1..dim_c];
140
141 let op = ConvolutionOperation::BackwardData;
142
143 let out_grad_tmp = out_grad.clone();
144 let weights_tmp = weights.clone();
145
146 let out_grad_data =
147 Alg::correct_layout(client, out_grad_tmp.into_data(), dtypes.lhs_global, op)?;
148 let weights_data = Alg::correct_layout(client, weights_tmp.into_data(), dtypes.rhs_global, op)?;
149
150 let mut out_grad = out_grad.clone();
151 let mut weights = weights.clone();
152
153 *out_grad.data_mut() = out_grad_data;
154 *weights.data_mut() = weights_data;
155
156 let address_type = out_grad
157 .required_address_type()
158 .max(weights.required_address_type())
159 .max(in_grad.required_address_type(dtypes.acc_global.size()));
160
161 let problem = ConvolutionProblem {
162 m: n * in_shape.iter().product::<usize>(),
163 n: c,
164 k: out_c * kernel_shape.iter().product::<usize>(),
165
166 lhs_strides: out_grad.data().strides.clone(),
167 rhs_strides: weights.data().strides.clone(),
168 lhs_layout: MatrixLayout::RowMajor,
169 rhs_layout: MatrixLayout::RowMajor,
170 kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
171 stride: stride.iter().map(|it| *it as u32).collect(),
172 padding: padding.iter().map(|it| *it as i32).collect(),
173 dilation: dilation.iter().map(|it| *it as u32).collect(),
174
175 batches: n,
176 in_shape: in_shape.into(),
177 out_shape: out_shape.into(),
178 channels: c,
179 out_channels: out_c,
180
181 padded_channels: out_c,
182 operation: op,
183
184 dimensionality,
185 global_dtypes: dtypes.as_global_elems(),
186 address_type,
187 };
188
189 launch_kernel::<R, Alg>(
190 client,
191 out_grad,
192 weights,
193 in_grad,
194 problem,
195 blueprint_strategy,
196 dtypes,
197 )
198}
199
200#[allow(clippy::result_large_err, clippy::too_many_arguments)]
201pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
202 client: &ComputeClient<R>,
203 out_grad: InputBinding<R>,
204 weights: InputBinding<R>,
205 in_grad: TensorBinding<R>,
206 problem: ConvolutionProblem,
207 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
208 dtypes: MatmulElems,
209) -> Result<(), ConvSetupError>
210where
211 Alg::Args: ConcreteArgs<Alg::Routine>,
212{
213 let vector_sizes = AvailableVectorSizes::from_type_sizes(
216 client,
217 out_grad.data_elem_size(),
218 weights.data_elem_size(),
219 dtypes.acc_global.size(),
220 )
221 .filter_lhs_with_tensor(
222 &out_grad.data().strides,
223 &out_grad.data().shape,
224 MatrixLayout::RowMajor,
225 )
226 .filter_rhs_with_tensor(
227 &weights.data().strides,
228 &weights.data().shape,
229 MatrixLayout::RowMajor,
230 )
231 .filter_out_with_tensor(&in_grad.strides, &in_grad.shape);
232
233 let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
234
235 launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
236 client,
237 out_grad,
238 weights,
239 in_grad,
240 problem,
241 vector_sizes,
242 blueprint_strategy,
243 &dtypes,
244 )
245}