1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::{
6 CubeOptionArgs, FastDivmodArgs,
7 tensor::{
8 launch::ViewArg,
9 layout::{
10 VirtualLayoutLaunch,
11 chain::{Chain, ChainLaunch},
12 },
13 },
14 },
15};
16use cubek_matmul::{
17 components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
18 definition::{Blueprint, MatmulElems, MatmulLineSizes, MatrixLayout, TilingBlueprint},
19 launch::{
20 MatmulArgs, MatmulInputHandleRef, TensorArgs, TensorInputs, TensorInputsLaunch,
21 TensorMapArgs, TensorMapInputs, TensorMapInputsLaunch, TensorOutput, TensorOutputLaunch,
22 },
23 routines::Routine,
24};
25use enumset::EnumSet;
26
27use crate::components::{
28 ConvolutionParams, ConvolutionProblem,
29 global::{
30 args::{RuntimeArgs, RuntimeArgsLaunch},
31 layout::{
32 Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
33 OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
34 WeightLayoutLaunch,
35 },
36 },
37};
38
39pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
40 MatmulArgs<
41 Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>: ConcreteInputsFactory<A>,
42 Output<NumericExpand<2>>: ConcreteOutputFactory<A>,
43 Config = RuntimeArgs,
44 >
45{
46 fn adjust_problem<R: Runtime>(
47 client: &ComputeClient<R>,
48 problem: ConvolutionProblem,
49 selection: &A::Blueprint,
50 dtypes: &MatmulElems,
51 ) -> ConvolutionProblem;
52}
53
54impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
55 fn adjust_problem<R: Runtime>(
56 client: &ComputeClient<R>,
57 mut problem: ConvolutionProblem,
58 _blueprint: &A::Blueprint,
59 dtypes: &MatmulElems,
60 ) -> ConvolutionProblem {
61 let load_width = client.properties().hardware.load_width;
62 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
63 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
64 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
65
66 problem.k = shape_k;
67 problem.padded_channels = padded_channels;
68
69 problem
70 }
71}
72
73impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
74 for TensorMapArgs<RuntimeArgs>
75{
76 fn adjust_problem<R: Runtime>(
77 _client: &ComputeClient<R>,
78 mut problem: ConvolutionProblem,
79 blueprint: &TilingBlueprint,
80 _dtypes: &MatmulElems,
81 ) -> ConvolutionProblem {
82 let channel_align = blueprint.tiling_scheme.tile_size.k() as usize;
83 let padded_channels = problem.out_channels.next_multiple_of(channel_align);
84 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
85
86 problem.k = shape_k;
87 problem.padded_channels = padded_channels;
88
89 problem
90 }
91}
92
93pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
96 #[allow(clippy::too_many_arguments)]
97 fn create<'a, R: Runtime>(
98 client: &ComputeClient<R>,
99 out_grad: &'a MatmulInputHandleRef<'a, R>,
100 weights: &'a MatmulInputHandleRef<'a, R>,
101 blueprint: &A::Blueprint,
102 problem: &ConvolutionProblem,
103 line_sizes: &MatmulLineSizes,
104 dtypes: &MatmulElems,
105 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>);
106}
107
108pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
111 fn create<'a, R: Runtime>(
112 client: &ComputeClient<R>,
113 out: &'a TensorHandleRef<'a, R>,
114 blueprint: &A::Blueprint,
115 problem: &ConvolutionProblem,
116 line_sizes: &MatmulLineSizes,
117 ) -> Self::RuntimeArg<'a, R>;
118}
119
120impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric, A: Routine<RuntimeArgs>> ConcreteInputsFactory<A>
121 for TensorInputs<Lhs, Rhs, EO>
122{
123 fn create<'a, R: Runtime>(
124 client: &ComputeClient<R>,
125 out_grad: &'a MatmulInputHandleRef<'a, R>,
126 weights: &'a MatmulInputHandleRef<'a, R>,
127 blueprint: &A::Blueprint,
128 problem: &ConvolutionProblem,
129 line_sizes: &MatmulLineSizes,
130 _dtypes: &MatmulElems,
131 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
132 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
133 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
134
135 let padded_channels = problem.padded_channels as u32;
136 let params = ConvolutionParams::from_problem(problem);
137
138 let layout_nhwc =
139 |handle, line_size, checks| NhwcLayoutLaunch::from_handle(handle, line_size, checks);
140
141 let layout_lhs = Im2colLayoutLaunch::from_args(
142 client,
143 problem,
144 params,
145 blueprint.lhs_global_layout_config(),
146 );
147 let layout_rhs =
148 WeightLayoutLaunch::from_args(client, problem, blueprint.rhs_global_layout_config());
149
150 let layout_lhs = {
151 let mut checks = EnumSet::empty();
152 if problem.should_check_spatial_bounds() {
153 checks.insert(NhwcCheck::Spatial);
154 }
155 if problem.should_check_channel() {
156 checks.insert(NhwcCheck::Channel);
157 }
158 let global = layout_nhwc(out_grad.data(), line_sizes.lhs, checks);
159 ChainLaunch::new(global, layout_lhs)
160 };
161 let layout_rhs = {
162 let mut checks = EnumSet::empty();
163 if problem.should_check_channel() {
164 checks.insert(NhwcCheck::Batch);
165 }
166 let global = layout_nhwc(weights.data(), line_sizes.rhs, checks);
167 ChainLaunch::new(global, layout_rhs)
168 };
169
170 let inputs = TensorInputsLaunch::new(
171 ViewArg::new::<LhsLayout>(out_grad.data().as_array_arg(line_sizes.lhs), layout_lhs),
172 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
173 ViewArg::new::<RhsLayout>(weights.data().as_array_arg(line_sizes.rhs), layout_rhs),
174 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
175 CubeOptionArgs::None,
176 CubeOptionArgs::None,
177 );
178
179 let runtime_args = RuntimeArgsLaunch::new(
180 ScalarArg::new(problem.k as u32),
181 ScalarArg::new(problem.out_channels as u32),
182 FastDivmodArgs::<u32>::new(client, padded_channels),
183 problem.operation,
184 );
185
186 (inputs, runtime_args)
187 }
188}
189
190impl<EG: Numeric, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
191 fn create<'a, R: Runtime>(
192 client: &ComputeClient<R>,
193 out: &'a TensorHandleRef<'a, R>,
194 blueprint: &A::Blueprint,
195 problem: &ConvolutionProblem,
196 line_sizes: &MatmulLineSizes,
197 ) -> Self::RuntimeArg<'a, R> {
198 type Layout = Chain<NhwcLayout, OutLayout>;
199
200 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out, EnumSet::empty());
201 let layout =
202 OutLayoutLaunch::from_args(client, problem, blueprint.out_global_layout_config());
203 let layout = ChainLaunch::new(global, layout);
204 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
205 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
206 TensorOutputLaunch::new(view, batch)
207 }
208}
209
210impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric, A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>>
211 ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
212{
213 fn create<'a, R: Runtime>(
214 client: &ComputeClient<R>,
215 out_grad: &'a MatmulInputHandleRef<'a, R>,
216 weights: &'a MatmulInputHandleRef<'a, R>,
217 blueprint: &TilingBlueprint,
218 problem: &ConvolutionProblem,
219 line_sizes: &MatmulLineSizes,
220 dtypes: &MatmulElems,
221 ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
222 type LhsLayout = TmaIm2colLayout;
223 type RhsLayout = WeightLayout;
224
225 let tiling_scheme = blueprint.tiling_scheme;
226 let stage_m = tiling_scheme.elements_per_stage_along_m();
227 let stage_n = tiling_scheme.elements_per_stage_along_n();
228 let stage_k = tiling_scheme.elements_per_stage_along_k();
229 let tile_size_k = tiling_scheme.tile_size.k;
230
231 let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims()];
232 stage_size_rhs.insert(0, stage_k);
233 stage_size_rhs.push(stage_n);
234
235 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
238 tf32::as_type_native_unchecked()
239 } else {
240 dtypes.lhs_stage
241 };
242
243 let mut elem_stride = vec![1; 2 + problem.stride.len()];
244
245 for (i, stride) in problem.stride.iter().enumerate() {
246 elem_stride[i + 1] = *stride as usize;
247 }
248
249 let lhs = TensorMapArg::new(
250 Im2colArgs {
251 pixel_box_lower_corner: calculate_lower_corner(problem),
252 pixel_box_upper_corner: calculate_upper_corner(problem),
253 channels_per_pixel: tile_size_k,
254 pixels_per_column: stage_m,
255 },
256 out_grad.data().as_tensor_arg(line_sizes.lhs),
257 lhs_elem,
258 )
259 .with_elem_stride(elem_stride);
260
261 let rhs = TensorMapArg::new(
262 TiledArgs {
263 tile_size: stage_size_rhs,
264 },
265 weights.data().as_tensor_arg(line_sizes.rhs),
266 dtypes.rhs_global,
267 );
268
269 let padded_channels = problem.padded_channels as u32;
270 let shape_k = problem.k as u32;
271
272 let stages_lhs = A::num_stages().lhs;
276 let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
277 let check_kernel = !shape_k.is_multiple_of(stages_size_k);
278 let lhs_layout = TmaIm2colLayoutLaunch::from_args(client, problem, check_kernel);
279 let rhs_layout = WeightLayoutLaunch::from_args(
280 client,
281 problem,
282 GlobalLayoutConfig {
283 check_row_bounds: false,
284 check_col_bounds: false,
285 matrix_layout: MatrixLayout::default(),
286 },
287 );
288
289 let inputs = TensorMapInputsLaunch::new(
290 ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
291 ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
292 CubeOptionArgs::None,
293 CubeOptionArgs::None,
294 );
295
296 let runtime_args = RuntimeArgsLaunch::new(
297 ScalarArg::new(shape_k),
298 ScalarArg::new(problem.out_channels as u32),
299 FastDivmodArgs::<u32>::new(client, padded_channels),
300 problem.operation,
301 );
302
303 (inputs, runtime_args)
304 }
305}
306
307#[allow(clippy::needless_range_loop)]
308fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
309 let mut out = vec![0; problem.padding.len()];
310 for i in 0..problem.padding.len() {
311 out[i] =
312 problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
313 }
314 out
315}
316
317#[allow(clippy::needless_range_loop)]
318fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
319 let mut out = vec![0; problem.padding.len()];
320 for i in 0..problem.padding.len() {
321 out[i] = problem.padding[i]
322 - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
323 + problem.in_shape[i] as i32
324 - problem.out_shape[i] as i32;
325 }
326 out
327}