Skip to main content

cubek_convolution/kernels/forward/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12    zspace::{shape, strides},
13};
14use cubek_matmul::{
15    components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16    definition::{Blueprint, MatmulElems, TilingBlueprint},
17    launch::*,
18    routines::Routine,
19};
20use cubek_std::launch::tma::remap_storage_for_tma;
21use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
22use enumset::EnumSet;
23
24use crate::components::{
25    ConvolutionParams, ConvolutionProblem,
26    global::{
27        args::{RuntimeArgs, RuntimeArgsLaunch},
28        layout::{
29            BiasLayout, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch,
30            OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
31            WeightLayoutLaunch,
32        },
33    },
34};
35
36pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
37    MatmulArgs<
38        Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
39        Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
40        Config = RuntimeArgs,
41    >
42{
43    fn adjust_problem<R: Runtime>(
44        client: &ComputeClient<R>,
45        problem: ConvolutionProblem,
46        blueprint: &A::Blueprint,
47        dtypes: &MatmulElems,
48    ) -> ConvolutionProblem;
49}
50
51impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
52    fn adjust_problem<R: Runtime>(
53        client: &ComputeClient<R>,
54        mut problem: ConvolutionProblem,
55        _blueprint: &A::Blueprint,
56        dtypes: &MatmulElems,
57    ) -> ConvolutionProblem {
58        let load_width = client.properties().hardware.load_width;
59        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
60        let padded_channels = problem.channels.next_multiple_of(channel_align);
61        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
62
63        problem.k = shape_k;
64        problem.padded_channels = padded_channels;
65
66        problem
67    }
68}
69
70impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
71    for TensorMapArgs<RuntimeArgs>
72{
73    fn adjust_problem<R: Runtime>(
74        _client: &ComputeClient<R>,
75        mut problem: ConvolutionProblem,
76        blueprint: &TilingBlueprint,
77        _dtypes: &MatmulElems,
78    ) -> ConvolutionProblem {
79        let channel_align = match blueprint.swizzle_modes.lhs {
80            SwizzleMode::None => blueprint.tiling_scheme.tile_size.k() as usize,
81            _ => blueprint.tiling_scheme.elements_per_stage_along_k() as usize,
82        };
83        let padded_channels = problem.channels.next_multiple_of(channel_align);
84        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
85
86        problem.k = shape_k;
87        problem.padded_channels = padded_channels;
88
89        problem
90    }
91}
92
93/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
94/// output (not fused).
95pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
96    #[allow(clippy::too_many_arguments)]
97    fn create<R: Runtime>(
98        lhs: InputBinding<R>,
99        rhs: InputBinding<R>,
100        bias: Option<InputBinding<R>>,
101        blueprint: &A::Blueprint,
102        problem: &ConvolutionProblem,
103        dtypes: &MatmulElems,
104    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
105}
106
107/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
108/// output (not fused).
109pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
110    fn create<R: Runtime>(
111        out: TensorBinding<R>,
112        blueprint: &A::Blueprint,
113        problem: &ConvolutionProblem,
114        dtypes: &MatmulElems,
115    ) -> Self::RuntimeArg<R>;
116}
117
118impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
119    ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
120{
121    fn create<R: Runtime>(
122        lhs: InputBinding<R>,
123        rhs: InputBinding<R>,
124        bias: Option<InputBinding<R>>,
125        blueprint: &A::Blueprint,
126        problem: &ConvolutionProblem,
127        _dtypes: &MatmulElems,
128    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
129        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
130        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
131
132        let padded_channels = problem.padded_channels as u32;
133        let conv_params = ConvolutionParams::from_problem(problem);
134
135        let layout_lhs = Im2colLayoutLaunch::from_args(
136            problem,
137            conv_params,
138            blueprint.lhs_global_layout_config(),
139        );
140        let layout_rhs =
141            WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
142
143        let layout_lhs = {
144            let mut checks = EnumSet::empty();
145            if problem.should_check_spatial_bounds() {
146                checks.insert(NhwcCheck::Spatial);
147            }
148            if problem.should_check_channel() {
149                checks.insert(NhwcCheck::Channel);
150            }
151            let global = NhwcLayoutLaunch::checked(checks);
152            ChainLaunch::new(global, layout_lhs)
153        };
154        let layout_rhs = {
155            let mut checks = EnumSet::empty();
156            if problem.should_check_channel() {
157                checks.insert(NhwcCheck::Channel);
158            }
159            let global = NhwcLayoutLaunch::checked(checks);
160            ChainLaunch::new(global, layout_rhs)
161        };
162
163        let inputs = TensorInputsLaunch::new(
164            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
165            ViewArg::new_tensor::<LhsLayout>(lhs.into_data().into_tensor_arg(), layout_lhs),
166            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
167            ViewArg::new_tensor::<RhsLayout>(rhs.into_data().into_tensor_arg(), layout_rhs),
168            bias.as_ref()
169                .map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
170                .into(),
171            bias.map(|bias| {
172                ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ())
173            })
174            .into(),
175        );
176
177        let runtime_args = RuntimeArgsLaunch::new(
178            problem.k as u32,
179            problem.channels as u32,
180            padded_channels,
181            conv_params.operation,
182        );
183
184        (inputs, runtime_args)
185    }
186}
187
188impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
189    fn create<R: Runtime>(
190        out: TensorBinding<R>,
191        blueprint: &A::Blueprint,
192        problem: &ConvolutionProblem,
193        _dtypes: &MatmulElems,
194    ) -> Self::RuntimeArg<R> {
195        type Layout = Chain<NhwcLayout, OutLayout>;
196
197        let global = NhwcLayoutLaunch::unchecked();
198        let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
199        let layout = ChainLaunch::new(global, layout);
200        let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
201        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
202        TensorOutputLaunch::new(view, batch)
203    }
204}
205
206impl<
207    Lhs: CubePrimitive,
208    Rhs: CubePrimitive,
209    EO: CubePrimitive,
210    A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
211> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
212{
213    fn create<R: Runtime>(
214        lhs: InputBinding<R>,
215        rhs: InputBinding<R>,
216        bias: Option<InputBinding<R>>,
217        blueprint: &TilingBlueprint,
218        problem: &ConvolutionProblem,
219        dtypes: &MatmulElems,
220    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
221        let tiling_scheme = blueprint.tiling_scheme;
222        let stage_m = tiling_scheme.elements_per_stage_along_m();
223        let stage_n = tiling_scheme.elements_per_stage_along_n();
224
225        let tile_size_k = match blueprint.swizzle_modes.lhs {
226            SwizzleMode::None => tiling_scheme.tile_size.k,
227            _ => tiling_scheme.elements_per_stage_along_k(),
228        };
229
230        let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
231        stage_size_rhs.insert(0, stage_n as usize);
232        stage_size_rhs.push(tile_size_k as usize);
233
234        let lhs_elem = remap_storage_for_tma(dtypes.lhs_stage);
235
236        let mut elem_stride = strides![1; 2 + problem.stride.len()];
237
238        for (i, stride) in problem.stride.iter().enumerate() {
239            elem_stride[i + 1] = *stride as usize;
240        }
241
242        let lhs = TensorMapArg::new(
243            Im2colArgs {
244                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
245                pixel_box_upper_corner: calculate_upper_corner(
246                    &problem.padding,
247                    &problem.kernel_size,
248                    &problem.dilation,
249                ),
250                channels_per_pixel: tile_size_k,
251                pixels_per_column: stage_m,
252            },
253            lhs.clone().into_data().into_tensor_arg(),
254            lhs_elem,
255        )
256        .with_elem_stride(elem_stride)
257        .with_swizzle(blueprint.swizzle_modes.lhs.into());
258
259        let rhs = TensorMapArg::new(
260            TiledArgs {
261                tile_size: stage_size_rhs,
262            },
263            rhs.clone().into_data().into_tensor_arg(),
264            dtypes.rhs_global,
265        )
266        .with_swizzle(blueprint.swizzle_modes.rhs.into());
267
268        let padded_channels = problem.padded_channels as u32;
269        let shape_k = problem.k as u32;
270
271        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
272        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
273        // any matrix dim is out of bounds.
274        let stages_lhs = A::num_stages().lhs;
275        let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
276        let check_kernel = !shape_k.is_multiple_of(stages_size_k);
277        let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
278        let rhs_layout = WeightLayoutLaunch::from_args(
279            problem,
280            GlobalLayoutConfig {
281                check_row_bounds: false,
282                check_col_bounds: false,
283                matrix_layout: MatrixLayout::ColMajor,
284            },
285        );
286
287        let bias = bias
288            .map(|bias| ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ()));
289
290        let inputs = TensorMapInputsLaunch::new(
291            ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
292            ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
293            bias.into(),
294            ComptimeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
295                NoopLayoutLaunch::new(),
296            )),
297        );
298
299        let runtime_args = RuntimeArgsLaunch::new(
300            shape_k,
301            problem.channels as u32,
302            padded_channels,
303            problem.operation,
304        );
305
306        (inputs, runtime_args)
307    }
308}
309
310fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
311    padding.iter().map(|padding| -*padding).collect()
312}
313
314fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
315    padding
316        .iter()
317        .zip(kernel_size)
318        .zip(dilation)
319        .map(|((padding, kernel_size), dilation)| {
320            *padding - (*kernel_size - 1) as i32 * *dilation as i32
321        })
322        .collect()
323}