1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::{
6 CubeOptionArgs, FastDivmodArgs,
7 tensor::{
8 launch::ViewArg,
9 layout::{
10 VirtualLayoutLaunch,
11 chain::{Chain, ChainLaunch},
12 },
13 },
14 },
15};
16use cubek_matmul::{
17 components::{
18 global::{
19 GlobalConfig as _,
20 memory::{GlobalMemoryConfig, NoopLayout, NoopLayoutLaunch, ViewDirection},
21 },
22 stage::StageConfig as _,
23 },
24 definition::{MatmulElems, MatmulLineSizes, MatrixLayout, TilingBlueprint},
25 launch::{
26 MatmulArgs, MatmulInputHandleRef, TensorArgs, TensorInputs, TensorInputsLaunch,
27 TensorMapArgs, TensorMapInputs, TensorMapInputsLaunch, TensorOutput, TensorOutputLaunch,
28 },
29};
30use enumset::EnumSet;
31
32use crate::components::{
33 ConvGemmConfig, ConvolutionParams, ConvolutionProblem,
34 global::{
35 args::RuntimeArgsLaunch,
36 layout::{
37 BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout,
38 NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch,
39 WeightLayout, WeightLayoutLaunch,
40 },
41 },
42};
43
44pub trait ConcreteArgs:
45 MatmulArgs<
46 Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>: ConcreteInputsFactory,
47 Output<NumericExpand<2>>: ConcreteOutputFactory,
48 >
49{
50 fn adjust_problem<R: Runtime>(
51 client: &ComputeClient<R>,
52 problem: ConvolutionProblem,
53 selection: &TilingBlueprint,
54 dtypes: &MatmulElems,
55 ) -> ConvolutionProblem;
56}
57
58impl ConcreteArgs for TensorArgs {
59 fn adjust_problem<R: Runtime>(
60 client: &ComputeClient<R>,
61 mut problem: ConvolutionProblem,
62 _selection: &TilingBlueprint,
63 dtypes: &MatmulElems,
64 ) -> ConvolutionProblem {
65 let load_width = client.properties().hardware.load_width;
66 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
67 let padded_channels = problem.channels.next_multiple_of(channel_align);
68 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
69
70 problem.k = shape_k;
71 problem.padded_channels = padded_channels;
72
73 problem
74 }
75}
76
77impl ConcreteArgs for TensorMapArgs {
78 fn adjust_problem<R: Runtime>(
79 _client: &ComputeClient<R>,
80 mut problem: ConvolutionProblem,
81 selection: &TilingBlueprint,
82 _dtypes: &MatmulElems,
83 ) -> ConvolutionProblem {
84 let channel_align = selection.tiling_scheme.tile_size.k() as usize;
85 let padded_channels = problem.channels.next_multiple_of(channel_align);
86 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
87
88 problem.k = shape_k;
89 problem.padded_channels = padded_channels;
90
91 problem
92 }
93}
94
95pub trait ConcreteInputsFactory: LaunchArg {
98 #[allow(clippy::too_many_arguments)]
99 fn create<'a, R: Runtime>(
100 client: &ComputeClient<R>,
101 lhs: &'a MatmulInputHandleRef<'a, R>,
102 rhs: &'a MatmulInputHandleRef<'a, R>,
103 bias: Option<&'a MatmulInputHandleRef<'a, R>>,
104 selection: &TilingBlueprint,
105 problem: &ConvolutionProblem,
106 line_sizes: &MatmulLineSizes,
107 config: impl ConvGemmConfig,
108 dtypes: &MatmulElems,
109 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>);
110}
111
112pub trait ConcreteOutputFactory: LaunchArg {
115 fn create<'a, R: Runtime>(
116 client: &ComputeClient<R>,
117 out: &'a TensorHandleRef<'a, R>,
118 selection: &TilingBlueprint,
119 problem: &ConvolutionProblem,
120 line_sizes: &MatmulLineSizes,
121 config: impl ConvGemmConfig,
122 dtypes: &MatmulElems,
123 ) -> Self::RuntimeArg<'a, R>;
124}
125
126impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
127 fn create<'a, R: Runtime>(
128 client: &ComputeClient<R>,
129 lhs: &'a MatmulInputHandleRef<'a, R>,
130 rhs: &'a MatmulInputHandleRef<'a, R>,
131 bias: Option<&'a MatmulInputHandleRef<'a, R>>,
132 _selection: &TilingBlueprint,
133 problem: &ConvolutionProblem,
134 line_sizes: &MatmulLineSizes,
135 config: impl ConvGemmConfig,
136 _dtypes: &MatmulElems,
137 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
138 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
139 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
140
141 let padded_channels = problem.padded_channels as u32;
142
143 let layout_nhwc =
144 |handle, line_size, checks| NhwcLayoutLaunch::from_handle(handle, line_size, checks);
145 let layout_lhs = Im2colLayoutLaunch::from_args(
146 client,
147 problem,
148 config.params(),
149 config.lhs_global_memory_config(),
150 );
151 let layout_rhs =
152 WeightLayoutLaunch::from_args(client, problem, config.rhs_global_memory_config());
153 let layout_bias =
154 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
155
156 let layout_lhs = {
157 let mut checks = EnumSet::empty();
158 if problem.should_check_spatial_bounds() {
159 checks.insert(NhwcCheck::Spatial);
160 }
161 if problem.should_check_channel() {
162 checks.insert(NhwcCheck::Channel);
163 }
164 let global = layout_nhwc(lhs.data(), line_sizes.lhs, checks);
165 ChainLaunch::new(global, layout_lhs)
166 };
167 let layout_rhs = {
168 let mut checks = EnumSet::empty();
169 if problem.should_check_channel() {
170 checks.insert(NhwcCheck::Channel);
171 }
172 let global = layout_nhwc(rhs.data(), line_sizes.rhs, checks);
173 ChainLaunch::new(global, layout_rhs)
174 };
175
176 let inputs = TensorInputsLaunch::new(
177 ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
178 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
179 ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
180 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
181 bias.map(|bias| {
182 ViewArg::new::<BiasLayout>(bias.data().as_array_arg(line_sizes.out), layout_bias)
183 })
184 .into(),
185 bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
186 .into(),
187 );
188
189 let runtime_args = RuntimeArgsLaunch::new(
190 ScalarArg::new(problem.k as u32),
191 ScalarArg::new(problem.channels as u32),
192 FastDivmodArgs::<u32>::new(client, padded_channels),
193 config.operation(),
194 );
195
196 (inputs, runtime_args)
197 }
198}
199
200impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
201 fn create<'a, R: Runtime>(
202 client: &ComputeClient<R>,
203 out: &'a TensorHandleRef<'a, R>,
204 _selection: &TilingBlueprint,
205 problem: &ConvolutionProblem,
206 line_sizes: &MatmulLineSizes,
207 config: impl ConvGemmConfig,
208 _dtypes: &MatmulElems,
209 ) -> Self::RuntimeArg<'a, R> {
210 type Layout = Chain<NhwcLayout, OutLayout>;
211
212 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out, EnumSet::empty());
213 let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
214 let layout = ChainLaunch::new(global, layout);
215 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
216 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
217 TensorOutputLaunch::new(view, batch)
218 }
219}
220
221impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
222 for TensorMapInputs<Lhs, Rhs, EO>
223{
224 fn create<'a, R: Runtime>(
225 client: &ComputeClient<R>,
226 lhs: &'a MatmulInputHandleRef<'a, R>,
227 rhs: &'a MatmulInputHandleRef<'a, R>,
228 bias: Option<&'a MatmulInputHandleRef<'a, R>>,
229 selection: &TilingBlueprint,
230 problem: &ConvolutionProblem,
231 line_sizes: &MatmulLineSizes,
232 config: impl ConvGemmConfig,
233 dtypes: &MatmulElems,
234 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
235 let tiling_scheme = selection.tiling_scheme;
236 let stage_m = tiling_scheme.elements_per_stage_along_m();
237 let stage_n = tiling_scheme.elements_per_stage_along_n();
238 let tile_size_k = tiling_scheme.tile_size.k;
239
240 let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims()];
241 stage_size_rhs.insert(0, stage_n);
242 stage_size_rhs.push(tile_size_k);
243
244 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
247 tf32::as_type_native_unchecked()
248 } else {
249 dtypes.lhs_stage
250 };
251
252 let mut elem_stride = vec![1; 2 + problem.stride.len()];
253
254 for (i, stride) in problem.stride.iter().enumerate() {
255 elem_stride[i + 1] = *stride as usize;
256 }
257
258 let lhs = TensorMapArg::new(
259 Im2colArgs {
260 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
261 pixel_box_upper_corner: calculate_upper_corner(
262 &problem.padding,
263 &problem.kernel_size,
264 &problem.dilation,
265 ),
266 channels_per_pixel: tile_size_k,
267 pixels_per_column: stage_m,
268 },
269 lhs.data().as_tensor_arg(line_sizes.lhs),
270 lhs_elem,
271 )
272 .with_elem_stride(elem_stride);
273
274 let rhs = TensorMapArg::new(
275 TiledArgs {
276 tile_size: stage_size_rhs,
277 },
278 rhs.data().as_tensor_arg(1),
279 dtypes.rhs_global,
280 );
281
282 let padded_channels = problem.padded_channels as u32;
283 let shape_k = problem.k as u32;
284
285 let shape_out = problem
286 .out_shape
287 .iter()
288 .map(|it| FastDivmodArgs::<u32>::new(client, *it as u32))
289 .collect();
290
291 let stages_lhs = config.stage_config().lhs_smem_config().num_stages;
295 let stages_size_k = selection.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
296 let lhs_layout = TmaIm2colLayoutLaunch::new(
297 shape_out,
298 FastDivmodArgs::<u32>::new(client, padded_channels),
299 ConvolutionParams::from_problem(problem),
300 !shape_k.is_multiple_of(stages_size_k),
301 );
302 let rhs_layout = WeightLayoutLaunch::from_args(
303 client,
304 problem,
305 GlobalMemoryConfig {
306 line_size: line_sizes.rhs,
307 check_row_bounds: false,
308 check_col_bounds: false,
309 matrix_layout: MatrixLayout::default(),
310 view_direction: ViewDirection::default(),
311 dtype: dtypes.rhs_global,
312 },
313 );
314
315 let bias = bias.map(|bias| {
316 let layout =
317 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
318 ViewArg::new::<BiasLayout>(bias.data().as_array_arg(line_sizes.out), layout)
319 });
320
321 let inputs = TensorMapInputsLaunch::new(
322 ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
323 ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
324 bias.into(),
325 CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
326 NoopLayoutLaunch::new(),
327 )),
328 );
329
330 let runtime_args = RuntimeArgsLaunch::new(
331 ScalarArg::new(shape_k),
332 ScalarArg::new(problem.channels as u32),
333 FastDivmodArgs::<u32>::new(client, padded_channels),
334 config.operation(),
335 );
336
337 (inputs, runtime_args)
338 }
339}
340
341fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
342 padding.iter().map(|padding| -*padding).collect()
343}
344
345fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
346 padding
347 .iter()
348 .zip(kernel_size)
349 .zip(dilation)
350 .map(|((padding, kernel_size), dilation)| {
351 *padding - (*kernel_size - 1) as i32 * *dilation as i32
352 })
353 .collect()
354}