1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::{
6 CubeOptionArgs, FastDivmodArgs,
7 tensor::{
8 launch::ViewArg,
9 layout::{
10 VirtualLayoutLaunch,
11 chain::{Chain, ChainLaunch},
12 },
13 },
14 },
15};
16use cubek_matmul::{
17 components::{
18 global::{
19 GlobalConfig as _,
20 memory::{GlobalMemoryConfig, NoopLayout, NoopLayoutLaunch, ViewDirection},
21 },
22 stage::StageConfig as _,
23 },
24 definition::{MatmulElems, MatmulLineSizes, MatrixLayout, TilingBlueprint},
25 launch::{
26 MatmulArgs, MatmulInputHandleRef, TensorArgs, TensorInputs, TensorInputsLaunch,
27 TensorMapArgs, TensorMapInputs, TensorMapInputsLaunch, TensorOutput, TensorOutputLaunch,
28 },
29};
30use enumset::EnumSet;
31
32use crate::components::{
33 ConvGemmConfig, ConvolutionParams, ConvolutionProblem,
34 global::{
35 args::RuntimeArgsLaunch,
36 layout::{
37 Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
38 OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
39 WeightLayoutLaunch,
40 },
41 },
42};
43
44pub trait ConcreteArgs:
45 MatmulArgs<
46 Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>: ConcreteInputsFactory,
47 Output<NumericExpand<2>>: ConcreteOutputFactory,
48 >
49{
50 fn adjust_problem<R: Runtime>(
51 client: &ComputeClient<R>,
52 problem: ConvolutionProblem,
53 selection: &TilingBlueprint,
54 dtypes: &MatmulElems,
55 ) -> ConvolutionProblem;
56}
57
58impl ConcreteArgs for TensorArgs {
59 fn adjust_problem<R: Runtime>(
60 client: &ComputeClient<R>,
61 mut problem: ConvolutionProblem,
62 _selection: &TilingBlueprint,
63 dtypes: &MatmulElems,
64 ) -> ConvolutionProblem {
65 let load_width = client.properties().hardware.load_width;
66 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
67 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
68 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
69
70 problem.k = shape_k;
71 problem.padded_channels = padded_channels;
72
73 problem
74 }
75}
76
77impl ConcreteArgs for TensorMapArgs {
78 fn adjust_problem<R: Runtime>(
79 _client: &ComputeClient<R>,
80 mut problem: ConvolutionProblem,
81 selection: &TilingBlueprint,
82 _dtypes: &MatmulElems,
83 ) -> ConvolutionProblem {
84 let channel_align = selection.tiling_scheme.tile_size.k() as usize;
85 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
86 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
87
88 problem.k = shape_k;
89 problem.padded_channels = padded_channels;
90
91 problem
92 }
93}
94
95pub trait ConcreteInputsFactory: LaunchArg {
98 #[allow(clippy::too_many_arguments)]
99 fn create<'a, R: Runtime>(
100 client: &ComputeClient<R>,
101 out_grad: &'a MatmulInputHandleRef<'a, R>,
102 weights: &'a MatmulInputHandleRef<'a, R>,
103 selection: &TilingBlueprint,
104 problem: &ConvolutionProblem,
105 line_sizes: &MatmulLineSizes,
106 config: impl ConvGemmConfig,
107 dtypes: &MatmulElems,
108 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>);
109}
110
111pub trait ConcreteOutputFactory: LaunchArg {
114 fn create<'a, R: Runtime>(
115 client: &ComputeClient<R>,
116 out: &'a TensorHandleRef<'a, R>,
117 selection: &TilingBlueprint,
118 problem: &ConvolutionProblem,
119 line_sizes: &MatmulLineSizes,
120 config: impl ConvGemmConfig,
121 ) -> Self::RuntimeArg<'a, R>;
122}
123
124impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
125 fn create<'a, R: Runtime>(
126 client: &ComputeClient<R>,
127 out_grad: &'a MatmulInputHandleRef<'a, R>,
128 weights: &'a MatmulInputHandleRef<'a, R>,
129 _selection: &TilingBlueprint,
130 problem: &ConvolutionProblem,
131 line_sizes: &MatmulLineSizes,
132 config: impl ConvGemmConfig,
133 _dtypes: &MatmulElems,
134 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
135 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
136 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
137
138 let padded_channels = problem.padded_channels as u32;
139
140 let layout_nhwc =
141 |handle, line_size, checks| NhwcLayoutLaunch::from_handle(handle, line_size, checks);
142
143 let layout_lhs = Im2colLayoutLaunch::from_args(
144 client,
145 problem,
146 config.params(),
147 config.lhs_global_memory_config(),
148 );
149 let layout_rhs =
150 WeightLayoutLaunch::from_args(client, problem, config.rhs_global_memory_config());
151
152 let layout_lhs = {
153 let mut checks = EnumSet::empty();
154 if problem.should_check_spatial_bounds() {
155 checks.insert(NhwcCheck::Spatial);
156 }
157 if problem.should_check_channel() {
158 checks.insert(NhwcCheck::Channel);
159 }
160 let global = layout_nhwc(out_grad.data(), line_sizes.lhs, checks);
161 ChainLaunch::new(global, layout_lhs)
162 };
163 let layout_rhs = {
164 let mut checks = EnumSet::empty();
165 if problem.should_check_channel() {
166 checks.insert(NhwcCheck::Batch);
167 }
168 let global = layout_nhwc(weights.data(), line_sizes.rhs, checks);
169 ChainLaunch::new(global, layout_rhs)
170 };
171
172 let inputs = TensorInputsLaunch::new(
173 ViewArg::new::<LhsLayout>(out_grad.data().as_array_arg(line_sizes.lhs), layout_lhs),
174 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
175 ViewArg::new::<RhsLayout>(weights.data().as_array_arg(line_sizes.rhs), layout_rhs),
176 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
177 CubeOptionArgs::None,
178 CubeOptionArgs::None,
179 );
180
181 let runtime_args = RuntimeArgsLaunch::new(
182 ScalarArg::new(problem.k as u32),
183 ScalarArg::new(problem.out_channels as u32),
184 FastDivmodArgs::<u32>::new(client, padded_channels),
185 config.operation(),
186 );
187
188 (inputs, runtime_args)
189 }
190}
191
192impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
193 fn create<'a, R: Runtime>(
194 client: &ComputeClient<R>,
195 out: &'a TensorHandleRef<'a, R>,
196 _selection: &TilingBlueprint,
197 problem: &ConvolutionProblem,
198 line_sizes: &MatmulLineSizes,
199 config: impl ConvGemmConfig,
200 ) -> Self::RuntimeArg<'a, R> {
201 type Layout = Chain<NhwcLayout, OutLayout>;
202
203 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out, EnumSet::empty());
204 let layout = OutLayoutLaunch::from_args(client, problem, config.rhs_global_memory_config());
205 let layout = ChainLaunch::new(global, layout);
206 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
207 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
208 TensorOutputLaunch::new(view, batch)
209 }
210}
211
212impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
213 for TensorMapInputs<Lhs, Rhs, EO>
214{
215 fn create<'a, R: Runtime>(
216 client: &ComputeClient<R>,
217 out_grad: &'a MatmulInputHandleRef<'a, R>,
218 weights: &'a MatmulInputHandleRef<'a, R>,
219 selection: &TilingBlueprint,
220 problem: &ConvolutionProblem,
221 line_sizes: &MatmulLineSizes,
222 config: impl ConvGemmConfig,
223 dtypes: &MatmulElems,
224 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
225 type LhsLayout = TmaIm2colLayout;
226 type RhsLayout = WeightLayout;
227
228 let tiling_scheme = selection.tiling_scheme;
229 let stage_m = tiling_scheme.elements_per_stage_along_m();
230 let stage_n = tiling_scheme.elements_per_stage_along_n();
231 let stage_k = tiling_scheme.elements_per_stage_along_k();
232 let tile_size_k = tiling_scheme.tile_size.k;
233
234 let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims()];
235 stage_size_rhs.insert(0, stage_k);
236 stage_size_rhs.push(stage_n);
237
238 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
241 tf32::as_type_native_unchecked()
242 } else {
243 dtypes.lhs_stage
244 };
245
246 let mut elem_stride = vec![1; 2 + problem.stride.len()];
247
248 for (i, stride) in problem.stride.iter().enumerate() {
249 elem_stride[i + 1] = *stride as usize;
250 }
251
252 let lhs = TensorMapArg::new(
253 Im2colArgs {
254 pixel_box_lower_corner: calculate_lower_corner(problem),
255 pixel_box_upper_corner: calculate_upper_corner(problem),
256 channels_per_pixel: tile_size_k,
257 pixels_per_column: stage_m,
258 },
259 out_grad.data().as_tensor_arg(line_sizes.lhs),
260 lhs_elem,
261 )
262 .with_elem_stride(elem_stride);
263
264 let rhs = TensorMapArg::new(
265 TiledArgs {
266 tile_size: stage_size_rhs,
267 },
268 weights.data().as_tensor_arg(line_sizes.rhs),
269 dtypes.rhs_global,
270 );
271
272 let padded_channels = problem.padded_channels as u32;
273 let shape_k = problem.k as u32;
274
275 let shape_out = problem
276 .out_shape
277 .iter()
278 .map(|it| FastDivmodArgs::<u32>::new(client, *it as u32))
279 .collect();
280
281 let stages_lhs = config.stage_config().lhs_smem_config().num_stages;
285 let stages_size_k = selection.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
286 let lhs_layout = TmaIm2colLayoutLaunch::new(
287 shape_out,
288 FastDivmodArgs::<u32>::new(client, padded_channels),
289 ConvolutionParams::from_problem(problem),
290 !shape_k.is_multiple_of(stages_size_k),
291 );
292 let rhs_layout = WeightLayoutLaunch::from_args(
293 client,
294 problem,
295 GlobalMemoryConfig {
296 line_size: line_sizes.rhs,
297 check_row_bounds: false,
298 check_col_bounds: false,
299 matrix_layout: MatrixLayout::default(),
300 view_direction: ViewDirection::default(),
301 dtype: dtypes.rhs_global,
302 },
303 );
304
305 let inputs = TensorMapInputsLaunch::new(
306 ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
307 ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
308 CubeOptionArgs::None,
309 CubeOptionArgs::None,
310 );
311
312 let runtime_args = RuntimeArgsLaunch::new(
313 ScalarArg::new(shape_k),
314 ScalarArg::new(problem.out_channels as u32),
315 FastDivmodArgs::<u32>::new(client, padded_channels),
316 config.operation(),
317 );
318
319 (inputs, runtime_args)
320 }
321}
322
323#[allow(clippy::needless_range_loop)]
324fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
325 let mut out = vec![0; problem.padding.len()];
326 for i in 0..problem.padding.len() {
327 out[i] =
328 problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
329 }
330 out
331}
332
333#[allow(clippy::needless_range_loop)]
334fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
335 let mut out = vec![0; problem.padding.len()];
336 for i in 0..problem.padding.len() {
337 out[i] = problem.padding[i]
338 - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
339 + problem.in_shape[i] as i32
340 - problem.out_shape[i] as i32;
341 }
342 out
343}