Skip to main content

cubek_convolution/kernels/forward/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12    zspace::{shape, strides},
13};
14use cubek_matmul::{
15    components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16    definition::{Blueprint, MatmulElems, TilingBlueprint},
17    launch::*,
18    routines::Routine,
19};
20use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
21use enumset::EnumSet;
22
23use crate::components::{
24    ConvolutionParams, ConvolutionProblem,
25    global::{
26        args::{RuntimeArgs, RuntimeArgsLaunch},
27        layout::{
28            BiasLayout, Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch,
29            OutLayout, OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
30            WeightLayoutLaunch,
31        },
32    },
33};
34
35pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
36    MatmulArgs<
37        Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
38        Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
39        Config = RuntimeArgs,
40    >
41{
42    fn adjust_problem<R: Runtime>(
43        client: &ComputeClient<R>,
44        problem: ConvolutionProblem,
45        blueprint: &A::Blueprint,
46        dtypes: &MatmulElems,
47    ) -> ConvolutionProblem;
48}
49
50impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
51    fn adjust_problem<R: Runtime>(
52        client: &ComputeClient<R>,
53        mut problem: ConvolutionProblem,
54        _blueprint: &A::Blueprint,
55        dtypes: &MatmulElems,
56    ) -> ConvolutionProblem {
57        let load_width = client.properties().hardware.load_width;
58        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
59        let padded_channels = problem.channels.next_multiple_of(channel_align);
60        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
61
62        problem.k = shape_k;
63        problem.padded_channels = padded_channels;
64
65        problem
66    }
67}
68
69impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
70    for TensorMapArgs<RuntimeArgs>
71{
72    fn adjust_problem<R: Runtime>(
73        _client: &ComputeClient<R>,
74        mut problem: ConvolutionProblem,
75        blueprint: &TilingBlueprint,
76        _dtypes: &MatmulElems,
77    ) -> ConvolutionProblem {
78        let channel_align = match blueprint.swizzle_modes.lhs {
79            SwizzleMode::None => blueprint.tiling_scheme.tile_size.k() as usize,
80            _ => blueprint.tiling_scheme.elements_per_stage_along_k() as usize,
81        };
82        let padded_channels = problem.channels.next_multiple_of(channel_align);
83        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
84
85        problem.k = shape_k;
86        problem.padded_channels = padded_channels;
87
88        problem
89    }
90}
91
92/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
93/// output (not fused).
94pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
95    #[allow(clippy::too_many_arguments)]
96    fn create<R: Runtime>(
97        lhs: InputBinding<R>,
98        rhs: InputBinding<R>,
99        bias: Option<InputBinding<R>>,
100        blueprint: &A::Blueprint,
101        problem: &ConvolutionProblem,
102        dtypes: &MatmulElems,
103    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
104}
105
106/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
107/// output (not fused).
108pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
109    fn create<R: Runtime>(
110        out: TensorBinding<R>,
111        blueprint: &A::Blueprint,
112        problem: &ConvolutionProblem,
113        dtypes: &MatmulElems,
114    ) -> Self::RuntimeArg<R>;
115}
116
117impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
118    ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
119{
120    fn create<R: Runtime>(
121        lhs: InputBinding<R>,
122        rhs: InputBinding<R>,
123        bias: Option<InputBinding<R>>,
124        blueprint: &A::Blueprint,
125        problem: &ConvolutionProblem,
126        _dtypes: &MatmulElems,
127    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
128        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
129        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
130
131        let padded_channels = problem.padded_channels as u32;
132        let conv_params = ConvolutionParams::from_problem(problem);
133
134        let layout_lhs = Im2colLayoutLaunch::from_args(
135            problem,
136            conv_params,
137            blueprint.lhs_global_layout_config(),
138        );
139        let layout_rhs =
140            WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
141
142        let layout_lhs = {
143            let mut checks = EnumSet::empty();
144            if problem.should_check_spatial_bounds() {
145                checks.insert(NhwcCheck::Spatial);
146            }
147            if problem.should_check_channel() {
148                checks.insert(NhwcCheck::Channel);
149            }
150            let global = NhwcLayoutLaunch::checked(checks);
151            ChainLaunch::new(global, layout_lhs)
152        };
153        let layout_rhs = {
154            let mut checks = EnumSet::empty();
155            if problem.should_check_channel() {
156                checks.insert(NhwcCheck::Channel);
157            }
158            let global = NhwcLayoutLaunch::checked(checks);
159            ChainLaunch::new(global, layout_rhs)
160        };
161
162        let inputs = TensorInputsLaunch::new(
163            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
164            ViewArg::new_tensor::<LhsLayout>(lhs.into_data().into_tensor_arg(), layout_lhs),
165            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
166            ViewArg::new_tensor::<RhsLayout>(rhs.into_data().into_tensor_arg(), layout_rhs),
167            bias.as_ref()
168                .map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
169                .into(),
170            bias.map(|bias| {
171                ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ())
172            })
173            .into(),
174        );
175
176        let runtime_args = RuntimeArgsLaunch::new(
177            problem.k as u32,
178            problem.channels as u32,
179            padded_channels,
180            conv_params.operation,
181        );
182
183        (inputs, runtime_args)
184    }
185}
186
187impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
188    fn create<R: Runtime>(
189        out: TensorBinding<R>,
190        blueprint: &A::Blueprint,
191        problem: &ConvolutionProblem,
192        _dtypes: &MatmulElems,
193    ) -> Self::RuntimeArg<R> {
194        type Layout = Chain<NhwcLayout, OutLayout>;
195
196        let global = NhwcLayoutLaunch::unchecked();
197        let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
198        let layout = ChainLaunch::new(global, layout);
199        let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
200        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
201        TensorOutputLaunch::new(view, batch)
202    }
203}
204
205impl<
206    Lhs: CubePrimitive,
207    Rhs: CubePrimitive,
208    EO: CubePrimitive,
209    A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
210> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
211{
212    fn create<R: Runtime>(
213        lhs: InputBinding<R>,
214        rhs: InputBinding<R>,
215        bias: Option<InputBinding<R>>,
216        blueprint: &TilingBlueprint,
217        problem: &ConvolutionProblem,
218        dtypes: &MatmulElems,
219    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
220        let tiling_scheme = blueprint.tiling_scheme;
221        let stage_m = tiling_scheme.elements_per_stage_along_m();
222        let stage_n = tiling_scheme.elements_per_stage_along_n();
223
224        let tile_size_k = match blueprint.swizzle_modes.lhs {
225            SwizzleMode::None => tiling_scheme.tile_size.k,
226            _ => tiling_scheme.elements_per_stage_along_k(),
227        };
228
229        let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
230        stage_size_rhs.insert(0, stage_n as usize);
231        stage_size_rhs.push(tile_size_k as usize);
232
233        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
234        // It shouldn't matter, but it's better to be safe.
235        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked().storage_type() {
236            tf32::as_type_native_unchecked().storage_type()
237        } else {
238            dtypes.lhs_stage
239        };
240
241        let mut elem_stride = strides![1; 2 + problem.stride.len()];
242
243        for (i, stride) in problem.stride.iter().enumerate() {
244            elem_stride[i + 1] = *stride as usize;
245        }
246
247        let lhs = TensorMapArg::new(
248            Im2colArgs {
249                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
250                pixel_box_upper_corner: calculate_upper_corner(
251                    &problem.padding,
252                    &problem.kernel_size,
253                    &problem.dilation,
254                ),
255                channels_per_pixel: tile_size_k,
256                pixels_per_column: stage_m,
257            },
258            lhs.clone().into_data().into_tensor_arg(),
259            lhs_elem,
260        )
261        .with_elem_stride(elem_stride)
262        .with_swizzle(blueprint.swizzle_modes.lhs.into());
263
264        let rhs = TensorMapArg::new(
265            TiledArgs {
266                tile_size: stage_size_rhs,
267            },
268            rhs.clone().into_data().into_tensor_arg(),
269            dtypes.rhs_global,
270        )
271        .with_swizzle(blueprint.swizzle_modes.rhs.into());
272
273        let padded_channels = problem.padded_channels as u32;
274        let shape_k = problem.k as u32;
275
276        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
277        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
278        // any matrix dim is out of bounds.
279        let stages_lhs = A::num_stages().lhs;
280        let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
281        let check_kernel = !shape_k.is_multiple_of(stages_size_k);
282        let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
283        let rhs_layout = WeightLayoutLaunch::from_args(
284            problem,
285            GlobalLayoutConfig {
286                check_row_bounds: false,
287                check_col_bounds: false,
288                matrix_layout: MatrixLayout::ColMajor,
289            },
290        );
291
292        let bias = bias
293            .map(|bias| ViewArg::new_tensor::<BiasLayout>(bias.into_data().into_tensor_arg(), ()));
294
295        let inputs = TensorMapInputsLaunch::new(
296            ViewArg::new_tensor_map_im2col::<TmaIm2colLayout, _, _>(lhs, lhs_layout),
297            ViewArg::new_tensor_map_tiled::<WeightLayout>(rhs, rhs_layout),
298            bias.into(),
299            ComptimeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
300                NoopLayoutLaunch::new(),
301            )),
302        );
303
304        let runtime_args = RuntimeArgsLaunch::new(
305            shape_k,
306            problem.channels as u32,
307            padded_channels,
308            problem.operation,
309        );
310
311        (inputs, runtime_args)
312    }
313}
314
315fn calculate_lower_corner(padding: &[i32]) -> Vec<i32> {
316    padding.iter().map(|padding| -*padding).collect()
317}
318
319fn calculate_upper_corner(padding: &[i32], kernel_size: &[u32], dilation: &[u32]) -> Vec<i32> {
320    padding
321        .iter()
322        .zip(kernel_size)
323        .zip(dilation)
324        .map(|((padding, kernel_size), dilation)| {
325            *padding - (*kernel_size - 1) as i32 * *dilation as i32
326        })
327        .collect()
328}