cubek_convolution/kernels/forward/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::{
6        CubeOptionArgs, FastDivmodArgs,
7        tensor::{
8            launch::ViewArg,
9            layout::{
10                VirtualLayoutLaunch,
11                chain::{Chain, ChainLaunch},
12            },
13        },
14    },
15};
16use cubek_matmul::{
17    components::{
18        global::{
19            GlobalConfig as _,
20            memory::{GlobalMemoryConfig, NoopLayout, NoopLayoutLaunch, ViewDirection},
21        },
22        stage::StageConfig as _,
23    },
24    definition::{MatmulElems, MatmulLineSizes, MatrixLayout, TilingBlueprint},
25    launch::{
26        MatmulArgs, MatmulInputHandleRef, TensorArgs, TensorInputs, TensorInputsLaunch,
27        TensorMapArgs, TensorMapInputs, TensorMapInputsLaunch, TensorOutput, TensorOutputLaunch,
28    },
29};
30use enumset::EnumSet;
31
32use crate::components::{
33    ConvGemmConfig, ConvolutionParams, ConvolutionProblem,
34    global::{
35        args::RuntimeArgsLaunch,
36        layout::{
37            BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout,
38            NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch,
39            WeightLayout, WeightLayoutLaunch,
40        },
41    },
42};
43
44pub trait ConcreteArgs:
45    MatmulArgs<
46        Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>: ConcreteInputsFactory,
47        Output<NumericExpand<2>>: ConcreteOutputFactory,
48    >
49{
50    fn adjust_problem<R: Runtime>(
51        client: &ComputeClient<R>,
52        problem: ConvolutionProblem,
53        selection: &TilingBlueprint,
54        dtypes: &MatmulElems,
55    ) -> ConvolutionProblem;
56}
57
58impl ConcreteArgs for TensorArgs {
59    fn adjust_problem<R: Runtime>(
60        client: &ComputeClient<R>,
61        mut problem: ConvolutionProblem,
62        _selection: &TilingBlueprint,
63        dtypes: &MatmulElems,
64    ) -> ConvolutionProblem {
65        let load_width = client.properties().hardware.load_width;
66        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
67        let padded_channels = problem.channels.next_multiple_of(channel_align);
68        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
69
70        problem.k = shape_k;
71        problem.padded_channels = padded_channels;
72
73        problem
74    }
75}
76
77impl ConcreteArgs for TensorMapArgs {
78    fn adjust_problem<R: Runtime>(
79        _client: &ComputeClient<R>,
80        mut problem: ConvolutionProblem,
81        selection: &TilingBlueprint,
82        _dtypes: &MatmulElems,
83    ) -> ConvolutionProblem {
84        let channel_align = selection.tiling_scheme.tile_size.k() as usize;
85        let padded_channels = problem.channels.next_multiple_of(channel_align);
86        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
87
88        problem.k = shape_k;
89        problem.padded_channels = padded_channels;
90
91        problem
92    }
93}
94
95/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
96/// output (not fused).
97pub trait ConcreteInputsFactory: LaunchArg {
98    #[allow(clippy::too_many_arguments)]
99    fn create<'a, R: Runtime>(
100        client: &ComputeClient<R>,
101        lhs: &'a MatmulInputHandleRef<'a, R>,
102        rhs: &'a MatmulInputHandleRef<'a, R>,
103        bias: Option<&'a MatmulInputHandleRef<'a, R>>,
104        selection: &TilingBlueprint,
105        problem: &ConvolutionProblem,
106        line_sizes: &MatmulLineSizes,
107        config: impl ConvGemmConfig,
108        dtypes: &MatmulElems,
109    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>);
110}
111
112/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
113/// output (not fused).
114pub trait ConcreteOutputFactory: LaunchArg {
115    fn create<'a, R: Runtime>(
116        client: &ComputeClient<R>,
117        out: &'a TensorHandleRef<'a, R>,
118        selection: &TilingBlueprint,
119        problem: &ConvolutionProblem,
120        line_sizes: &MatmulLineSizes,
121        config: impl ConvGemmConfig,
122        dtypes: &MatmulElems,
123    ) -> Self::RuntimeArg<'a, R>;
124}
125
126impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
127    fn create<'a, R: Runtime>(
128        client: &ComputeClient<R>,
129        lhs: &'a MatmulInputHandleRef<'a, R>,
130        rhs: &'a MatmulInputHandleRef<'a, R>,
131        bias: Option<&'a MatmulInputHandleRef<'a, R>>,
132        _selection: &TilingBlueprint,
133        problem: &ConvolutionProblem,
134        line_sizes: &MatmulLineSizes,
135        config: impl ConvGemmConfig,
136        _dtypes: &MatmulElems,
137    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
138        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
139        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
140
141        let padded_channels = problem.padded_channels as u32;
142
143        let layout_nhwc =
144            |handle, line_size, checks| NhwcLayoutLaunch::from_handle(handle, line_size, checks);
145        let layout_lhs = Im2colLayoutLaunch::from_args(
146            client,
147            problem,
148            config.params(),
149            config.lhs_global_memory_config(),
150        );
151        let layout_rhs =
152            WeightLayoutLaunch::from_args(client, problem, config.rhs_global_memory_config());
153        let layout_bias =
154            BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
155
156        let layout_lhs = {
157            let mut checks = EnumSet::empty();
158            if problem.should_check_spatial_bounds() {
159                checks.insert(NhwcCheck::Spatial);
160            }
161            if problem.should_check_channel() {
162                checks.insert(NhwcCheck::Channel);
163            }
164            let global = layout_nhwc(lhs.data(), line_sizes.lhs, checks);
165            ChainLaunch::new(global, layout_lhs)
166        };
167        let layout_rhs = {
168            let mut checks = EnumSet::empty();
169            if problem.should_check_channel() {
170                checks.insert(NhwcCheck::Channel);
171            }
172            let global = layout_nhwc(rhs.data(), line_sizes.rhs, checks);
173            ChainLaunch::new(global, layout_rhs)
174        };
175
176        let inputs = TensorInputsLaunch::new(
177            ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
178            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
179            ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
180            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
181            bias.map(|bias| {
182                ViewArg::new::<BiasLayout>(bias.data().as_array_arg(line_sizes.out), layout_bias)
183            })
184            .into(),
185            bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
186                .into(),
187        );
188
189        let runtime_args = RuntimeArgsLaunch::new(
190            ScalarArg::new(problem.k as u32),
191            ScalarArg::new(problem.channels as u32),
192            FastDivmodArgs::<u32>::new(client, padded_channels),
193            config.operation(),
194        );
195
196        (inputs, runtime_args)
197    }
198}
199
200impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
201    fn create<'a, R: Runtime>(
202        client: &ComputeClient<R>,
203        out: &'a TensorHandleRef<'a, R>,
204        _selection: &TilingBlueprint,
205        problem: &ConvolutionProblem,
206        line_sizes: &MatmulLineSizes,
207        config: impl ConvGemmConfig,
208        _dtypes: &MatmulElems,
209    ) -> Self::RuntimeArg<'a, R> {
210        type Layout = Chain<NhwcLayout, OutLayout>;
211
212        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out, EnumSet::empty());
213        let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
214        let layout = ChainLaunch::new(global, layout);
215        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
216        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
217        TensorOutputLaunch::new(view, batch)
218    }
219}
220
221impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
222    for TensorMapInputs<Lhs, Rhs, EO>
223{
224    fn create<'a, R: Runtime>(
225        client: &ComputeClient<R>,
226        lhs: &'a MatmulInputHandleRef<'a, R>,
227        rhs: &'a MatmulInputHandleRef<'a, R>,
228        bias: Option<&'a MatmulInputHandleRef<'a, R>>,
229        selection: &TilingBlueprint,
230        problem: &ConvolutionProblem,
231        line_sizes: &MatmulLineSizes,
232        config: impl ConvGemmConfig,
233        dtypes: &MatmulElems,
234    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
235        let tiling_scheme = selection.tiling_scheme;
236        let stage_m = tiling_scheme.elements_per_stage_along_m();
237        let stage_n = tiling_scheme.elements_per_stage_along_n();
238        let tile_size_k = tiling_scheme.tile_size.k;
239
240        let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims()];
241        stage_size_rhs.insert(0, stage_n);
242        stage_size_rhs.push(tile_size_k);
243
244        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
245        // It shouldn't matter, but it's better to be safe.
246        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
247            tf32::as_type_native_unchecked()
248        } else {
249            dtypes.lhs_stage
250        };
251
252        let mut elem_stride = vec![1; 2 + problem.stride.len()];
253
254        for (i, stride) in problem.stride.iter().enumerate() {
255            elem_stride[i + 1] = *stride as usize;
256        }
257
258        let lhs = TensorMapArg::new(
259            Im2colArgs {
260                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
261                pixel_box_upper_corner: calculate_upper_corner(
262                    &problem.padding,
263                    &problem.kernel_size,
264                    &problem.dilation,
265                ),
266                channels_per_pixel: tile_size_k,
267                pixels_per_column: stage_m,
268            },
269            lhs.data().as_tensor_arg(line_sizes.lhs),
270            lhs_elem,
271        )
272        .with_elem_stride(elem_stride);
273
274        let rhs = TensorMapArg::new(
275            TiledArgs {
276                tile_size: stage_size_rhs,
277            },
278            rhs.data().as_tensor_arg(1),
279            dtypes.rhs_global,
280        );
281
282        let padded_channels = problem.padded_channels as u32;
283        let shape_k = problem.k as u32;
284
285        let shape_out = problem
286            .out_shape
287            .iter()
288            .map(|it| FastDivmodArgs::<u32>::new(client, *it as u32))
289            .collect();
290
291        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
292        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
293        // any matrix dim is out of bounds.
294        let stages_lhs = config.stage_config().lhs_smem_config().num_stages;
295        let stages_size_k = selection.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
296        let lhs_layout = TmaIm2colLayoutLaunch::new(
297            shape_out,
298            FastDivmodArgs::<u32>::new(client, padded_channels),
299            ConvolutionParams::from_problem(problem),
300            !shape_k.is_multiple_of(stages_size_k),
301        );
302        let rhs_layout = WeightLayoutLaunch::from_args(
303            client,
304            problem,
305            GlobalMemoryConfig {
306                line_size: line_sizes.rhs,
307                check_row_bounds: false,
308                check_col_bounds: false,
309                matrix_layout: MatrixLayout::default(),
310                view_direction: ViewDirection::default(),
311                dtype: dtypes.rhs_global,
312            },
313        );
314
315        let bias = bias.map(|bias| {
316            let layout =
317                BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
318            ViewArg::new::<BiasLayout>(bias.data().as_array_arg(line_sizes.out), layout)
319        });
320
321        let inputs = TensorMapInputsLaunch::new(
322            ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
323            ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
324            bias.into(),
325            CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
326                NoopLayoutLaunch::new(),
327            )),
328        );
329
330        let runtime_args = RuntimeArgsLaunch::new(
331            ScalarArg::new(shape_k),
332            ScalarArg::new(problem.channels as u32),
333            FastDivmodArgs::<u32>::new(client, padded_channels),
334            config.operation(),
335        );
336
337        (inputs, runtime_args)
338    }
339}
340
341fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
342    padding.iter().map(|padding| -*padding).collect()
343}
344
345fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
346    padding
347        .iter()
348        .zip(kernel_size)
349        .zip(dilation)
350        .map(|((padding, kernel_size), dilation)| {
351            *padding - (*kernel_size - 1) as i32 * *dilation as i32
352        })
353        .collect()
354}