cubek_convolution/kernels/backward_data/
args.rs1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12 zspace::{shape, strides},
13};
14use cubek_matmul::{
15 components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16 definition::{Blueprint, MatmulElems, TilingBlueprint},
17 launch::*,
18 routines::Routine,
19};
20use cubek_std::{InputBinding, MatrixLayout};
21use enumset::EnumSet;
22
23use crate::components::{
24 ConvolutionParams, ConvolutionProblem,
25 global::{
26 args::{RuntimeArgs, RuntimeArgsLaunch},
27 layout::{
28 Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
29 OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
30 WeightLayoutLaunch,
31 },
32 },
33};
34
35pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
36 MatmulArgs<
37 Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
38 Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
39 Config = RuntimeArgs,
40 >
41{
42 fn adjust_problem<R: Runtime>(
43 client: &ComputeClient<R>,
44 problem: ConvolutionProblem,
45 selection: &A::Blueprint,
46 dtypes: &MatmulElems,
47 ) -> ConvolutionProblem;
48}
49
50impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
51 fn adjust_problem<R: Runtime>(
52 client: &ComputeClient<R>,
53 mut problem: ConvolutionProblem,
54 _blueprint: &A::Blueprint,
55 dtypes: &MatmulElems,
56 ) -> ConvolutionProblem {
57 let load_width = client.properties().hardware.load_width;
58 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
59 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
60 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
61
62 problem.k = shape_k;
63 problem.padded_channels = padded_channels;
64
65 problem
66 }
67}
68
69impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
70 for TensorMapArgs<RuntimeArgs>
71{
72 fn adjust_problem<R: Runtime>(
73 _client: &ComputeClient<R>,
74 mut problem: ConvolutionProblem,
75 blueprint: &TilingBlueprint,
76 _dtypes: &MatmulElems,
77 ) -> ConvolutionProblem {
78 let channel_align = blueprint.tiling_scheme.tile_size.k() as usize;
79 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
80 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
81
82 problem.k = shape_k;
83 problem.padded_channels = padded_channels;
84
85 problem
86 }
87}
88
89pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
92 #[allow(clippy::too_many_arguments)]
93 fn create<R: Runtime>(
94 out_grad: InputBinding<R>,
95 weights: InputBinding<R>,
96 blueprint: &A::Blueprint,
97 problem: &ConvolutionProblem,
98 dtypes: &MatmulElems,
99 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
100}
101
102pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
105 fn create<R: Runtime>(
106 out: TensorBinding<R>,
107 blueprint: &A::Blueprint,
108 problem: &ConvolutionProblem,
109 ) -> Self::RuntimeArg<R>;
110}
111
112impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
113 ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
114{
115 fn create<R: Runtime>(
116 out_grad: InputBinding<R>,
117 weights: InputBinding<R>,
118 blueprint: &A::Blueprint,
119 problem: &ConvolutionProblem,
120 _dtypes: &MatmulElems,
121 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
122 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
123 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
124
125 let padded_channels = problem.padded_channels as u32;
126 let params = ConvolutionParams::from_problem(problem);
127
128 let layout_lhs =
129 Im2colLayoutLaunch::from_args(problem, params, blueprint.lhs_global_layout_config());
130 let layout_rhs =
131 WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
132
133 let layout_lhs = {
134 let mut checks = EnumSet::empty();
135 if problem.should_check_spatial_bounds() {
136 checks.insert(NhwcCheck::Spatial);
137 }
138 if problem.should_check_channel() {
139 checks.insert(NhwcCheck::Channel);
140 }
141 let global = NhwcLayoutLaunch::checked(checks);
142 ChainLaunch::new(global, layout_lhs)
143 };
144 let layout_rhs = {
145 let mut checks = EnumSet::empty();
146 if problem.should_check_channel() {
147 checks.insert(NhwcCheck::Batch);
148 }
149 let global = NhwcLayoutLaunch::checked(checks);
150 ChainLaunch::new(global, layout_rhs)
151 };
152
153 let inputs = TensorInputsLaunch::new(
154 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
155 ViewArg::new_tensor::<LhsLayout>(out_grad.into_data().into_tensor_arg(), layout_lhs),
156 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
157 ViewArg::new_tensor::<RhsLayout>(weights.into_data().into_tensor_arg(), layout_rhs),
158 ComptimeOptionArgs::None,
159 ComptimeOptionArgs::None,
160 );
161
162 let runtime_args = RuntimeArgsLaunch::new(
163 problem.k as u32,
164 problem.out_channels as u32,
165 padded_channels,
166 problem.operation,
167 );
168
169 (inputs, runtime_args)
170 }
171}
172
173impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
174 fn create<R: Runtime>(
175 out: TensorBinding<R>,
176 blueprint: &A::Blueprint,
177 problem: &ConvolutionProblem,
178 ) -> Self::RuntimeArg<R> {
179 type Layout = Chain<NhwcLayout, OutLayout>;
180
181 let global = NhwcLayoutLaunch::unchecked();
182 let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
183 let layout = ChainLaunch::new(global, layout);
184 let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
185 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
186 TensorOutputLaunch::new(view, batch)
187 }
188}
189
190impl<
191 Lhs: CubePrimitive,
192 Rhs: CubePrimitive,
193 EO: CubePrimitive,
194 A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
195> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
196{
197 fn create<R: Runtime>(
198 out_grad: InputBinding<R>,
199 weights: InputBinding<R>,
200 blueprint: &TilingBlueprint,
201 problem: &ConvolutionProblem,
202 dtypes: &MatmulElems,
203 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
204 type LhsLayout = TmaIm2colLayout;
205 type RhsLayout = WeightLayout;
206
207 let tiling_scheme = blueprint.tiling_scheme;
208 let stage_m = tiling_scheme.elements_per_stage_along_m();
209 let stage_n = tiling_scheme.elements_per_stage_along_n();
210 let stage_k = tiling_scheme.elements_per_stage_along_k();
211 let tile_size_k = tiling_scheme.tile_size.k;
212
213 let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
214 stage_size_rhs.insert(0, stage_k as usize);
215 stage_size_rhs.push(stage_n as usize);
216
217 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked().storage_type() {
220 tf32::as_type_native_unchecked().storage_type()
221 } else {
222 dtypes.lhs_stage
223 };
224
225 let mut elem_stride = strides![1; 2 + problem.stride.len()];
226
227 for (i, stride) in problem.stride.iter().enumerate() {
228 elem_stride[i + 1] = *stride as usize;
229 }
230
231 let lhs = TensorMapArg::new(
232 Im2colArgs {
233 pixel_box_lower_corner: calculate_lower_corner(problem),
234 pixel_box_upper_corner: calculate_upper_corner(problem),
235 channels_per_pixel: tile_size_k,
236 pixels_per_column: stage_m,
237 },
238 out_grad.into_data().into_tensor_arg(),
239 lhs_elem,
240 )
241 .with_elem_stride(elem_stride);
242
243 let rhs = TensorMapArg::new(
244 TiledArgs {
245 tile_size: stage_size_rhs,
246 },
247 weights.into_data().into_tensor_arg(),
248 dtypes.rhs_global,
249 );
250
251 let padded_channels = problem.padded_channels as u32;
252 let shape_k = problem.k as u32;
253
254 let stages_lhs = A::num_stages().lhs;
258 let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
259 let check_kernel = !shape_k.is_multiple_of(stages_size_k);
260 let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
261 let rhs_layout = WeightLayoutLaunch::from_args(
262 problem,
263 GlobalLayoutConfig {
264 check_row_bounds: false,
265 check_col_bounds: false,
266 matrix_layout: MatrixLayout::default(),
267 },
268 );
269
270 let inputs = TensorMapInputsLaunch::new(
271 ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
272 ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
273 ComptimeOptionArgs::None,
274 ComptimeOptionArgs::None,
275 );
276
277 let runtime_args = RuntimeArgsLaunch::new(
278 shape_k,
279 problem.out_channels as u32,
280 padded_channels,
281 problem.operation,
282 );
283
284 (inputs, runtime_args)
285 }
286}
287
288#[allow(clippy::needless_range_loop)]
289fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
290 let mut out = vec![0; problem.padding.len()];
291 for i in 0..problem.padding.len() {
292 out[i] =
293 problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
294 }
295 out
296}
297
298#[allow(clippy::needless_range_loop)]
299fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
300 let mut out = vec![0; problem.padding.len()];
301 for i in 0..problem.padding.len() {
302 out[i] = problem.padding[i]
303 - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
304 + problem.in_shape[i] as i32
305 - problem.out_shape[i] as i32;
306 }
307 out
308}