1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12 zspace::{shape, strides},
13};
14use cubek_matmul::{
15 components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16 definition::{Blueprint, MatmulElems, TilingBlueprint},
17 launch::*,
18 routines::Routine,
19};
20use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
21use enumset::EnumSet;
22
23use crate::components::{
24 ConvolutionParams, ConvolutionProblem,
25 global::{
26 args::{RuntimeArgs, RuntimeArgsLaunch},
27 layout::{
28 BiasLayout, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch,
29 OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
30 WeightLayoutLaunch,
31 },
32 },
33};
34
35pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
36 MatmulArgs<
37 Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
38 Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
39 Config = RuntimeArgs,
40 >
41{
42 fn adjust_problem<R: Runtime>(
43 client: &ComputeClient<R>,
44 problem: ConvolutionProblem,
45 blueprint: &A::Blueprint,
46 dtypes: &MatmulElems,
47 ) -> ConvolutionProblem;
48}
49
50impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
51 fn adjust_problem<R: Runtime>(
52 client: &ComputeClient<R>,
53 mut problem: ConvolutionProblem,
54 _blueprint: &A::Blueprint,
55 dtypes: &MatmulElems,
56 ) -> ConvolutionProblem {
57 let load_width = client.properties().hardware.load_width;
58 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
59 let padded_channels = problem.channels.next_multiple_of(channel_align);
60 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
61
62 problem.k = shape_k;
63 problem.padded_channels = padded_channels;
64
65 problem
66 }
67}
68
69impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
70 for TensorMapArgs<RuntimeArgs>
71{
72 fn adjust_problem<R: Runtime>(
73 _client: &ComputeClient<R>,
74 mut problem: ConvolutionProblem,
75 blueprint: &TilingBlueprint,
76 _dtypes: &MatmulElems,
77 ) -> ConvolutionProblem {
78 let channel_align = match blueprint.swizzle_modes.lhs {
79 SwizzleMode::None => blueprint.tiling_scheme.tile_size.k() as usize,
80 _ => blueprint.tiling_scheme.elements_per_stage_along_k() as usize,
81 };
82 let padded_channels = problem.channels.next_multiple_of(channel_align);
83 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
84
85 problem.k = shape_k;
86 problem.padded_channels = padded_channels;
87
88 problem
89 }
90}
91
92pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
95 #[allow(clippy::too_many_arguments)]
96 fn create<R: Runtime>(
97 lhs: InputBinding<R>,
98 rhs: InputBinding<R>,
99 bias: Option<InputBinding<R>>,
100 blueprint: &A::Blueprint,
101 problem: &ConvolutionProblem,
102 dtypes: &MatmulElems,
103 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
104}
105
106pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
109 fn create<R: Runtime>(
110 out: TensorBinding<R>,
111 blueprint: &A::Blueprint,
112 problem: &ConvolutionProblem,
113 dtypes: &MatmulElems,
114 ) -> Self::RuntimeArg<R>;
115}
116
117impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
118 ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
119{
120 fn create<R: Runtime>(
121 lhs: InputBinding<R>,
122 rhs: InputBinding<R>,
123 bias: Option<InputBinding<R>>,
124 blueprint: &A::Blueprint,
125 problem: &ConvolutionProblem,
126 _dtypes: &MatmulElems,
127 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
128 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
129 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
130
131 let padded_channels = problem.padded_channels as u32;
132 let conv_params = ConvolutionParams::from_problem(problem);
133
134 let layout_lhs = Im2colLayoutLaunch::from_args(
135 problem,
136 conv_params,
137 blueprint.lhs_global_layout_config(),
138 );
139 let layout_rhs =
140 WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
141
142 let layout_lhs = {
143 let mut checks = EnumSet::empty();
144 if problem.should_check_spatial_bounds() {
145 checks.insert(NhwcCheck::Spatial);
146 }
147 if problem.should_check_channel() {
148 checks.insert(NhwcCheck::Channel);
149 }
150 let global = NhwcLayoutLaunch::checked(checks);
151 ChainLaunch::new(global, layout_lhs)
152 };
153 let layout_rhs = {
154 let mut checks = EnumSet::empty();
155 if problem.should_check_channel() {
156 checks.insert(NhwcCheck::Channel);
157 }
158 let global = NhwcLayoutLaunch::checked(checks);
159 ChainLaunch::new(global, layout_rhs)
160 };
161
162 let inputs = TensorInputsLaunch::new(
163 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
164 ViewArg::new_tensor::<LhsLayout>(lhs.into_data().into_tensor_arg(), layout_lhs),
165 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
166 ViewArg::new_tensor::<RhsLayout>(rhs.into_data().into_tensor_arg(), layout_rhs),
167 bias.as_ref()
168 .map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
169 .into(),
170 bias.map(|bias| {
171 ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ())
172 })
173 .into(),
174 );
175
176 let runtime_args = RuntimeArgsLaunch::new(
177 problem.k as u32,
178 problem.channels as u32,
179 padded_channels,
180 conv_params.operation,
181 );
182
183 (inputs, runtime_args)
184 }
185}
186
187impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
188 fn create<R: Runtime>(
189 out: TensorBinding<R>,
190 blueprint: &A::Blueprint,
191 problem: &ConvolutionProblem,
192 _dtypes: &MatmulElems,
193 ) -> Self::RuntimeArg<R> {
194 type Layout = Chain<NhwcLayout, OutLayout>;
195
196 let global = NhwcLayoutLaunch::unchecked();
197 let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
198 let layout = ChainLaunch::new(global, layout);
199 let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
200 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
201 TensorOutputLaunch::new(view, batch)
202 }
203}
204
205impl<
206 Lhs: CubePrimitive,
207 Rhs: CubePrimitive,
208 EO: CubePrimitive,
209 A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
210> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
211{
212 fn create<R: Runtime>(
213 lhs: InputBinding<R>,
214 rhs: InputBinding<R>,
215 bias: Option<InputBinding<R>>,
216 blueprint: &TilingBlueprint,
217 problem: &ConvolutionProblem,
218 dtypes: &MatmulElems,
219 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
220 let tiling_scheme = blueprint.tiling_scheme;
221 let stage_m = tiling_scheme.elements_per_stage_along_m();
222 let stage_n = tiling_scheme.elements_per_stage_along_n();
223
224 let tile_size_k = match blueprint.swizzle_modes.lhs {
225 SwizzleMode::None => tiling_scheme.tile_size.k,
226 _ => tiling_scheme.elements_per_stage_along_k(),
227 };
228
229 let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
230 stage_size_rhs.insert(0, stage_n as usize);
231 stage_size_rhs.push(tile_size_k as usize);
232
233 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked().storage_type() {
236 tf32::as_type_native_unchecked().storage_type()
237 } else {
238 dtypes.lhs_stage
239 };
240
241 let mut elem_stride = strides![1; 2 + problem.stride.len()];
242
243 for (i, stride) in problem.stride.iter().enumerate() {
244 elem_stride[i + 1] = *stride as usize;
245 }
246
247 let lhs = TensorMapArg::new(
248 Im2colArgs {
249 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
250 pixel_box_upper_corner: calculate_upper_corner(
251 &problem.padding,
252 &problem.kernel_size,
253 &problem.dilation,
254 ),
255 channels_per_pixel: tile_size_k,
256 pixels_per_column: stage_m,
257 },
258 lhs.clone().into_data().into_tensor_arg(),
259 lhs_elem,
260 )
261 .with_elem_stride(elem_stride)
262 .with_swizzle(blueprint.swizzle_modes.lhs.into());
263
264 let rhs = TensorMapArg::new(
265 TiledArgs {
266 tile_size: stage_size_rhs,
267 },
268 rhs.clone().into_data().into_tensor_arg(),
269 dtypes.rhs_global,
270 )
271 .with_swizzle(blueprint.swizzle_modes.rhs.into());
272
273 let padded_channels = problem.padded_channels as u32;
274 let shape_k = problem.k as u32;
275
276 let stages_lhs = A::num_stages().lhs;
280 let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
281 let check_kernel = !shape_k.is_multiple_of(stages_size_k);
282 let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
283 let rhs_layout = WeightLayoutLaunch::from_args(
284 problem,
285 GlobalLayoutConfig {
286 check_row_bounds: false,
287 check_col_bounds: false,
288 matrix_layout: MatrixLayout::ColMajor,
289 },
290 );
291
292 let bias = bias
293 .map(|bias| ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ()));
294
295 let inputs = TensorMapInputsLaunch::new(
296 ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
297 ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
298 bias.into(),
299 ComptimeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
300 NoopLayoutLaunch::new(),
301 )),
302 );
303
304 let runtime_args = RuntimeArgsLaunch::new(
305 shape_k,
306 problem.channels as u32,
307 padded_channels,
308 problem.operation,
309 );
310
311 (inputs, runtime_args)
312 }
313}
314
315fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
316 padding.iter().map(|padding| -*padding).collect()
317}
318
319fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
320 padding
321 .iter()
322 .zip(kernel_size)
323 .zip(dilation)
324 .map(|((padding, kernel_size), dilation)| {
325 *padding - (*kernel_size - 1) as i32 * *dilation as i32
326 })
327 .collect()
328}