1use cubecl::{
2 Runtime,
3 client::ComputeClient,
4 prelude::*,
5 std::tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12 zspace::{shape, strides},
13};
14use cubek_matmul::{
15 components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16 definition::{Blueprint, MatmulElems, TilingBlueprint},
17 launch::*,
18 routines::Routine,
19};
20use cubek_std::launch::tma::remap_storage_for_tma;
21use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
22use enumset::EnumSet;
23
24use crate::components::{
25 ConvolutionParams, ConvolutionProblem,
26 global::{
27 args::{RuntimeArgs, RuntimeArgsLaunch},
28 layout::{
29 BiasLayout, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch,
30 OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
31 WeightLayoutLaunch,
32 },
33 },
34};
35
36pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
37 MatmulArgs<
38 Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
39 Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
40 Config = RuntimeArgs,
41 >
42{
43 fn adjust_problem<R: Runtime>(
44 client: &ComputeClient<R>,
45 problem: ConvolutionProblem,
46 blueprint: &A::Blueprint,
47 dtypes: &MatmulElems,
48 ) -> ConvolutionProblem;
49}
50
51impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
52 fn adjust_problem<R: Runtime>(
53 client: &ComputeClient<R>,
54 mut problem: ConvolutionProblem,
55 _blueprint: &A::Blueprint,
56 dtypes: &MatmulElems,
57 ) -> ConvolutionProblem {
58 let load_width = client.properties().hardware.load_width;
59 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
60 let padded_channels = problem.channels.next_multiple_of(channel_align);
61 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
62
63 problem.k = shape_k;
64 problem.padded_channels = padded_channels;
65
66 problem
67 }
68}
69
70impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
71 for TensorMapArgs<RuntimeArgs>
72{
73 fn adjust_problem<R: Runtime>(
74 _client: &ComputeClient<R>,
75 mut problem: ConvolutionProblem,
76 blueprint: &TilingBlueprint,
77 _dtypes: &MatmulElems,
78 ) -> ConvolutionProblem {
79 let channel_align = match blueprint.swizzle_modes.lhs {
80 SwizzleMode::None => blueprint.tiling_scheme.tile_size.k() as usize,
81 _ => blueprint.tiling_scheme.elements_per_stage_along_k() as usize,
82 };
83 let padded_channels = problem.channels.next_multiple_of(channel_align);
84 let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
85
86 problem.k = shape_k;
87 problem.padded_channels = padded_channels;
88
89 problem
90 }
91}
92
93pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
96 #[allow(clippy::too_many_arguments)]
97 fn create<R: Runtime>(
98 lhs: InputBinding<R>,
99 rhs: InputBinding<R>,
100 bias: Option<InputBinding<R>>,
101 blueprint: &A::Blueprint,
102 problem: &ConvolutionProblem,
103 dtypes: &MatmulElems,
104 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
105}
106
107pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
110 fn create<R: Runtime>(
111 out: TensorBinding<R>,
112 blueprint: &A::Blueprint,
113 problem: &ConvolutionProblem,
114 dtypes: &MatmulElems,
115 ) -> Self::RuntimeArg<R>;
116}
117
118impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
119 ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
120{
121 fn create<R: Runtime>(
122 lhs: InputBinding<R>,
123 rhs: InputBinding<R>,
124 bias: Option<InputBinding<R>>,
125 blueprint: &A::Blueprint,
126 problem: &ConvolutionProblem,
127 _dtypes: &MatmulElems,
128 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
129 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
130 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
131
132 let padded_channels = problem.padded_channels as u32;
133 let conv_params = ConvolutionParams::from_problem(problem);
134
135 let layout_lhs = Im2colLayoutLaunch::from_args(
136 problem,
137 conv_params,
138 blueprint.lhs_global_layout_config(),
139 );
140 let layout_rhs =
141 WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
142
143 let layout_lhs = {
144 let mut checks = EnumSet::empty();
145 if problem.should_check_spatial_bounds() {
146 checks.insert(NhwcCheck::Spatial);
147 }
148 if problem.should_check_channel() {
149 checks.insert(NhwcCheck::Channel);
150 }
151 let global = NhwcLayoutLaunch::checked(checks);
152 ChainLaunch::new(global, layout_lhs)
153 };
154 let layout_rhs = {
155 let mut checks = EnumSet::empty();
156 if problem.should_check_channel() {
157 checks.insert(NhwcCheck::Channel);
158 }
159 let global = NhwcLayoutLaunch::checked(checks);
160 ChainLaunch::new(global, layout_rhs)
161 };
162
163 let inputs = TensorInputsLaunch::new(
164 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
165 ViewArg::new_tensor::<LhsLayout>(lhs.into_data().into_tensor_arg(), layout_lhs),
166 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
167 ViewArg::new_tensor::<RhsLayout>(rhs.into_data().into_tensor_arg(), layout_rhs),
168 bias.as_ref()
169 .map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
170 .into(),
171 bias.map(|bias| {
172 ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ())
173 })
174 .into(),
175 );
176
177 let runtime_args = RuntimeArgsLaunch::new(
178 problem.k as u32,
179 problem.channels as u32,
180 padded_channels,
181 conv_params.operation,
182 );
183
184 (inputs, runtime_args)
185 }
186}
187
188impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
189 fn create<R: Runtime>(
190 out: TensorBinding<R>,
191 blueprint: &A::Blueprint,
192 problem: &ConvolutionProblem,
193 _dtypes: &MatmulElems,
194 ) -> Self::RuntimeArg<R> {
195 type Layout = Chain<NhwcLayout, OutLayout>;
196
197 let global = NhwcLayoutLaunch::unchecked();
198 let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
199 let layout = ChainLaunch::new(global, layout);
200 let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
201 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
202 TensorOutputLaunch::new(view, batch)
203 }
204}
205
206impl<
207 Lhs: CubePrimitive,
208 Rhs: CubePrimitive,
209 EO: CubePrimitive,
210 A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
211> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
212{
213 fn create<R: Runtime>(
214 lhs: InputBinding<R>,
215 rhs: InputBinding<R>,
216 bias: Option<InputBinding<R>>,
217 blueprint: &TilingBlueprint,
218 problem: &ConvolutionProblem,
219 dtypes: &MatmulElems,
220 ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
221 let tiling_scheme = blueprint.tiling_scheme;
222 let stage_m = tiling_scheme.elements_per_stage_along_m();
223 let stage_n = tiling_scheme.elements_per_stage_along_n();
224
225 let tile_size_k = match blueprint.swizzle_modes.lhs {
226 SwizzleMode::None => tiling_scheme.tile_size.k,
227 _ => tiling_scheme.elements_per_stage_along_k(),
228 };
229
230 let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
231 stage_size_rhs.insert(0, stage_n as usize);
232 stage_size_rhs.push(tile_size_k as usize);
233
234 let lhs_elem = remap_storage_for_tma(dtypes.lhs_stage);
235
236 let mut elem_stride = strides![1; 2 + problem.stride.len()];
237
238 for (i, stride) in problem.stride.iter().enumerate() {
239 elem_stride[i + 1] = *stride as usize;
240 }
241
242 let lhs = TensorMapArg::new(
243 Im2colArgs {
244 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
245 pixel_box_upper_corner: calculate_upper_corner(
246 &problem.padding,
247 &problem.kernel_size,
248 &problem.dilation,
249 ),
250 channels_per_pixel: tile_size_k,
251 pixels_per_column: stage_m,
252 },
253 lhs.clone().into_data().into_tensor_arg(),
254 lhs_elem,
255 )
256 .with_elem_stride(elem_stride)
257 .with_swizzle(blueprint.swizzle_modes.lhs.into());
258
259 let rhs = TensorMapArg::new(
260 TiledArgs {
261 tile_size: stage_size_rhs,
262 },
263 rhs.clone().into_data().into_tensor_arg(),
264 dtypes.rhs_global,
265 )
266 .with_swizzle(blueprint.swizzle_modes.rhs.into());
267
268 let padded_channels = problem.padded_channels as u32;
269 let shape_k = problem.k as u32;
270
271 let stages_lhs = A::num_stages().lhs;
275 let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
276 let check_kernel = !shape_k.is_multiple_of(stages_size_k);
277 let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
278 let rhs_layout = WeightLayoutLaunch::from_args(
279 problem,
280 GlobalLayoutConfig {
281 check_row_bounds: false,
282 check_col_bounds: false,
283 matrix_layout: MatrixLayout::ColMajor,
284 },
285 );
286
287 let bias = bias
288 .map(|bias| ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ()));
289
290 let inputs = TensorMapInputsLaunch::new(
291 ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
292 ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
293 bias.into(),
294 ComptimeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
295 NoopLayoutLaunch::new(),
296 )),
297 );
298
299 let runtime_args = RuntimeArgsLaunch::new(
300 shape_k,
301 problem.channels as u32,
302 padded_channels,
303 problem.operation,
304 );
305
306 (inputs, runtime_args)
307 }
308}
309
310fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
311 padding.iter().map(|padding| -*padding).collect()
312}
313
314fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
315 padding
316 .iter()
317 .zip(kernel_size)
318 .zip(dilation)
319 .map(|((padding, kernel_size), dilation)| {
320 *padding - (*kernel_size - 1) as i32 * *dilation as i32
321 })
322 .collect()
323}