cubek_convolution/kernels/backward_data/
args.rs1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12 zspace::{shape, strides},
13};
14use cubek_matmul::{
15 components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16 definition::{Blueprint, MatmulElems, TilingBlueprint},
17 launch::*,
18 routines::Routine,
19};
20use cubek_std::launch::tma::remap_storage_for_tma;
21use cubek_std::{InputBinding, MatrixLayout};
22use enumset::EnumSet;
23
24use crate::components::{
25 ConvolutionParams, ConvolutionProblem,
26 global::{
27 args::{RuntimeArgs, RuntimeArgsLaunch},
28 layout::{
29 Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
30 OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
31 WeightLayoutLaunch,
32 },
33 },
34};
35
36pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
37 MatmulArgs<
38 Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
39 Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
40 Config = RuntimeArgs,
41 >
42{
43 fn adjust_problem<R: Runtime>(
44 client: &ComputeClient<R>,
45 problem: ConvolutionProblem,
46 selection: &A::Blueprint,
47 dtypes: &MatmulElems,
48 ) -> ConvolutionProblem;
49}
50
51impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
52 fn adjust_problem<R: Runtime>(
53 client: &ComputeClient<R>,
54 mut problem: ConvolutionProblem,
55 _blueprint: &A::Blueprint,
56 dtypes: &MatmulElems,
57 ) -> ConvolutionProblem {
58 let load_width = client.properties().hardware.load_width;
59 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
60 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
61 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
62
63 problem.k = shape_k;
64 problem.padded_channels = padded_channels;
65
66 problem
67 }
68}
69
70impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
71 for TensorMapArgs<RuntimeArgs>
72{
73 fn adjust_problem<R: Runtime>(
74 _client: &ComputeClient<R>,
75 mut problem: ConvolutionProblem,
76 blueprint: &TilingBlueprint,
77 _dtypes: &MatmulElems,
78 ) -> ConvolutionProblem {
79 let channel_align = blueprint.tiling_scheme.tile_size.k() as usize;
80 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
81 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
82
83 problem.k = shape_k;
84 problem.padded_channels = padded_channels;
85
86 problem
87 }
88}
89
90pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
93 #[allow(clippy::too_many_arguments)]
94 fn create<R: Runtime>(
95 out_grad: InputBinding<R>,
96 weights: InputBinding<R>,
97 blueprint: &A::Blueprint,
98 problem: &ConvolutionProblem,
99 dtypes: &MatmulElems,
100 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
101}
102
103pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
106 fn create<R: Runtime>(
107 out: TensorBinding<R>,
108 blueprint: &A::Blueprint,
109 problem: &ConvolutionProblem,
110 ) -> Self::RuntimeArg<R>;
111}
112
113impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
114 ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
115{
116 fn create<R: Runtime>(
117 out_grad: InputBinding<R>,
118 weights: InputBinding<R>,
119 blueprint: &A::Blueprint,
120 problem: &ConvolutionProblem,
121 _dtypes: &MatmulElems,
122 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
123 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
124 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
125
126 let padded_channels = problem.padded_channels as u32;
127 let params = ConvolutionParams::from_problem(problem);
128
129 let layout_lhs =
130 Im2colLayoutLaunch::from_args(problem, params, blueprint.lhs_global_layout_config());
131 let layout_rhs =
132 WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
133
134 let layout_lhs = {
135 let mut checks = EnumSet::empty();
136 if problem.should_check_spatial_bounds() {
137 checks.insert(NhwcCheck::Spatial);
138 }
139 if problem.should_check_channel() {
140 checks.insert(NhwcCheck::Channel);
141 }
142 let global = NhwcLayoutLaunch::checked(checks);
143 ChainLaunch::new(global, layout_lhs)
144 };
145 let layout_rhs = {
146 let mut checks = EnumSet::empty();
147 if problem.should_check_channel() {
148 checks.insert(NhwcCheck::Batch);
149 }
150 let global = NhwcLayoutLaunch::checked(checks);
151 ChainLaunch::new(global, layout_rhs)
152 };
153
154 let inputs = TensorInputsLaunch::new(
155 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
156 ViewArg::new_tensor::<LhsLayout>(out_grad.into_data().into_tensor_arg(), layout_lhs),
157 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
158 ViewArg::new_tensor::<RhsLayout>(weights.into_data().into_tensor_arg(), layout_rhs),
159 ComptimeOptionArgs::None,
160 ComptimeOptionArgs::None,
161 );
162
163 let runtime_args = RuntimeArgsLaunch::new(
164 problem.k as u32,
165 problem.out_channels as u32,
166 padded_channels,
167 problem.operation,
168 );
169
170 (inputs, runtime_args)
171 }
172}
173
174impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
175 fn create<R: Runtime>(
176 out: TensorBinding<R>,
177 blueprint: &A::Blueprint,
178 problem: &ConvolutionProblem,
179 ) -> Self::RuntimeArg<R> {
180 type Layout = Chain<NhwcLayout, OutLayout>;
181
182 let global = NhwcLayoutLaunch::unchecked();
183 let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
184 let layout = ChainLaunch::new(global, layout);
185 let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
186 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
187 TensorOutputLaunch::new(view, batch)
188 }
189}
190
191impl<
192 Lhs: CubePrimitive,
193 Rhs: CubePrimitive,
194 EO: CubePrimitive,
195 A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
196> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
197{
198 fn create<R: Runtime>(
199 out_grad: InputBinding<R>,
200 weights: InputBinding<R>,
201 blueprint: &TilingBlueprint,
202 problem: &ConvolutionProblem,
203 dtypes: &MatmulElems,
204 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
205 type LhsLayout = TmaIm2colLayout;
206 type RhsLayout = WeightLayout;
207
208 let tiling_scheme = blueprint.tiling_scheme;
209 let stage_m = tiling_scheme.elements_per_stage_along_m();
210 let stage_n = tiling_scheme.elements_per_stage_along_n();
211 let stage_k = tiling_scheme.elements_per_stage_along_k();
212 let tile_size_k = tiling_scheme.tile_size.k;
213
214 let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
215 stage_size_rhs.insert(0, stage_k as usize);
216 stage_size_rhs.push(stage_n as usize);
217
218 let lhs_elem = remap_storage_for_tma(dtypes.lhs_stage);
219
220 let mut elem_stride = strides![1; 2 + problem.stride.len()];
221
222 for (i, stride) in problem.stride.iter().enumerate() {
223 elem_stride[i + 1] = *stride as usize;
224 }
225
226 let lhs = TensorMapArg::new(
227 Im2colArgs {
228 pixel_box_lower_corner: calculate_lower_corner(problem),
229 pixel_box_upper_corner: calculate_upper_corner(problem),
230 channels_per_pixel: tile_size_k,
231 pixels_per_column: stage_m,
232 },
233 out_grad.into_data().into_tensor_arg(),
234 lhs_elem,
235 )
236 .with_elem_stride(elem_stride);
237
238 let rhs = TensorMapArg::new(
239 TiledArgs {
240 tile_size: stage_size_rhs,
241 },
242 weights.into_data().into_tensor_arg(),
243 dtypes.rhs_global,
244 );
245
246 let padded_channels = problem.padded_channels as u32;
247 let shape_k = problem.k as u32;
248
249 let stages_lhs = A::num_stages().lhs;
253 let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
254 let check_kernel = !shape_k.is_multiple_of(stages_size_k);
255 let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
256 let rhs_layout = WeightLayoutLaunch::from_args(
257 problem,
258 GlobalLayoutConfig {
259 check_row_bounds: false,
260 check_col_bounds: false,
261 matrix_layout: MatrixLayout::default(),
262 },
263 );
264
265 let inputs = TensorMapInputsLaunch::new(
266 ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
267 ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
268 ComptimeOptionArgs::None,
269 ComptimeOptionArgs::None,
270 );
271
272 let runtime_args = RuntimeArgsLaunch::new(
273 shape_k,
274 problem.out_channels as u32,
275 padded_channels,
276 problem.operation,
277 );
278
279 (inputs, runtime_args)
280 }
281}
282
283#[allow(clippy::needless_range_loop)]
284fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
285 let mut out = vec![0; problem.padding.len()];
286 for i in 0..problem.padding.len() {
287 out[i] =
288 problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
289 }
290 out
291}
292
293#[allow(clippy::needless_range_loop)]
294fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
295 let mut out = vec![0; problem.padding.len()];
296 for i in 0..problem.padding.len() {
297 out[i] = problem.padding[i]
298 - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
299 + problem.in_shape[i] as i32
300 - problem.out_shape[i] as i32;
301 }
302 out
303}