cubek_convolution/components/global/
args.rs

1use cubecl::prelude::*;
2use cubecl::std::{
3    CubeOptionArgs, FastDivmod, FastDivmodArgs,
4    tensor::{
5        launch::ViewArg,
6        layout::{
7            VirtualLayoutLaunch,
8            chain::{Chain, ChainLaunch},
9        },
10    },
11};
12
13use crate::{
14    components::{
15        ConvGemmConfig, ConvolutionProblem,
16        global::{
17            layout::{
18                BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
19                NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
20            },
21            read::layout::{
22                TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
23            },
24        },
25    },
26    kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
27};
28use cubek_matmul::{
29    MatmulInputHandleRef,
30    components::{
31        MatmulElems, MatmulLineSizes, MatmulSelection,
32        global::{
33            GlobalConfig,
34            args::{
35                TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
36                TensorOutput, TensorOutputLaunch,
37            },
38            memory::{NoopLayout, NoopLayoutLaunch},
39        },
40        stage::StageConfig as _,
41    },
42};
43
44#[derive(CubeType, CubeLaunch, Clone)]
45pub struct RuntimeArgs {
46    pub shape_m: u32,
47    pub shape_n: u32,
48    pub shape_k: u32,
49    pub channels: u32,
50    pub padded_channels: FastDivmod,
51    pub shape_out: Sequence<FastDivmod>,
52}
53
54/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
55/// output (not fused).
56pub trait ConcreteInputsFactory: LaunchArg {
57    #[allow(clippy::too_many_arguments)]
58    fn create<'a, R: Runtime>(
59        client: &ComputeClient<R>,
60        lhs: &'a MatmulInputHandleRef<'a, R>,
61        rhs: &'a MatmulInputHandleRef<'a, R>,
62        bias: Option<&'a TensorHandleRef<'a, R>>,
63        selection: &MatmulSelection,
64        problem: &ConvolutionProblem,
65        line_sizes: &MatmulLineSizes,
66        config: impl ConvGemmConfig,
67        dtypes: &MatmulElems,
68    ) -> Self::RuntimeArg<'a, R>;
69}
70
71/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
72/// output (not fused).
73pub trait ConcreteOutputFactory: LaunchArg {
74    fn create<'a, R: Runtime>(
75        client: &ComputeClient<R>,
76        out: &'a TensorHandleRef<'a, R>,
77        selection: &MatmulSelection,
78        problem: &ConvolutionProblem,
79        line_sizes: &MatmulLineSizes,
80        config: impl ConvGemmConfig,
81        dtypes: &MatmulElems,
82    ) -> Self::RuntimeArg<'a, R>;
83}
84
85impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
86    fn create<'a, R: Runtime>(
87        client: &ComputeClient<R>,
88        lhs: &'a MatmulInputHandleRef<'a, R>,
89        rhs: &'a MatmulInputHandleRef<'a, R>,
90        bias: Option<&'a TensorHandleRef<'a, R>>,
91        _selection: &MatmulSelection,
92        problem: &ConvolutionProblem,
93        line_sizes: &MatmulLineSizes,
94        config: impl ConvGemmConfig,
95        dtypes: &MatmulElems,
96    ) -> Self::RuntimeArg<'a, R> {
97        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
98        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
99
100        let load_width = client.properties().hardware.load_width;
101        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
102
103        let layout_nhwc = |handle, line_size, check_spatial| {
104            NhwcLayoutLaunch::from_handle(
105                handle,
106                line_size as u32,
107                check_spatial,
108                !problem.channels.is_multiple_of(channel_align),
109            )
110        };
111        let layout_lhs = Im2colLayoutLaunch::from_args(
112            client,
113            problem,
114            config.convolution_params(),
115            config.lhs_global_memory_config(),
116            dtypes,
117        );
118        let layout_rhs = WeightLayoutLaunch::from_args(
119            client,
120            problem,
121            config.convolution_params(),
122            config.rhs_global_memory_config(),
123            dtypes,
124        );
125        let layout_bias =
126            BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
127
128        let layout_lhs = {
129            let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
130            ChainLaunch::new(global, layout_lhs)
131        };
132        let layout_rhs = {
133            let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
134            ChainLaunch::new(global, layout_rhs)
135        };
136
137        TensorInputsLaunch::new(
138            ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
139            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
140            ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
141            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
142            bias.map(|bias| {
143                ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
144            })
145            .into(),
146            bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
147                .into(),
148        )
149    }
150}
151
152impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
153    fn create<'a, R: Runtime>(
154        client: &ComputeClient<R>,
155        out: &'a TensorHandleRef<'a, R>,
156        _selection: &MatmulSelection,
157        problem: &ConvolutionProblem,
158        line_sizes: &MatmulLineSizes,
159        config: impl ConvGemmConfig,
160        _dtypes: &MatmulElems,
161    ) -> Self::RuntimeArg<'a, R> {
162        type Layout = Chain<NhwcLayout, OutLayout>;
163
164        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false, false);
165        let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
166        let layout = ChainLaunch::new(global, layout);
167        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
168        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
169        TensorOutputLaunch::new(view, batch)
170    }
171}
172
173impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
174    for TensorMapInputs<Lhs, Rhs, EO>
175{
176    fn create<'a, R: Runtime>(
177        client: &ComputeClient<R>,
178        lhs: &'a MatmulInputHandleRef<'a, R>,
179        rhs: &'a MatmulInputHandleRef<'a, R>,
180        bias: Option<&'a TensorHandleRef<'a, R>>,
181        selection: &MatmulSelection,
182        problem: &ConvolutionProblem,
183        line_sizes: &MatmulLineSizes,
184        config: impl ConvGemmConfig,
185        dtypes: &MatmulElems,
186    ) -> Self::RuntimeArg<'a, R> {
187        let tiling_scheme = selection.tiling_scheme;
188        let stage_m = tiling_scheme.elements_per_stage_along_m();
189        let stage_n = tiling_scheme.elements_per_stage_along_n();
190        let tile_size_k = tiling_scheme.tile_size.k;
191
192        let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims() as usize];
193        stage_size_rhs.insert(0, stage_n);
194        stage_size_rhs.push(tile_size_k);
195
196        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
197        // It shouldn't matter, but it's better to be safe.
198        let lhs_elem = if *dtypes.lhs_stage == f32::as_type_native_unchecked() {
199            tf32::as_type_native_unchecked()
200        } else {
201            *dtypes.lhs_stage
202        };
203
204        let mut elem_stride = vec![1; 2 + problem.stride.len()];
205
206        for (i, stride) in problem.stride.iter().enumerate() {
207            elem_stride[i + 1] = *stride as usize;
208        }
209
210        let lhs = TensorMapArg::new(
211            TensorMapFormat::Im2col {
212                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
213                pixel_box_upper_corner: calculate_upper_corner(
214                    &problem.padding,
215                    &problem.kernel_size,
216                    &problem.dilation,
217                ),
218                channels_per_pixel: tile_size_k,
219                pixels_per_column: stage_m,
220            },
221            lhs.data().as_tensor_arg(line_sizes.lhs),
222            lhs_elem,
223        )
224        .with_elem_stride(elem_stride);
225
226        let rhs = TensorMapArg::new(
227            TensorMapFormat::Tiled {
228                tile_size: stage_size_rhs,
229            },
230            rhs.data().as_tensor_arg(1),
231            *dtypes.rhs_global,
232        );
233
234        let padded_channels = (problem.channels as u32)
235            .next_multiple_of(config.matmul_config().stage_config().elements_in_tile_k());
236
237        // Dummy layout since we don't support im2col loading rn
238        let lhs_layout = TmaDummyLayoutLaunch::new();
239        let rhs_layout = TmaWeightLayoutLaunch::new(
240            FastDivmodArgs::new(client, padded_channels),
241            problem.kernel_size.clone(),
242        );
243
244        let bias = bias.map(|bias| {
245            let layout =
246                BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
247            ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
248        });
249
250        TensorMapInputsLaunch::new(
251            ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
252            ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
253            bias.into(),
254            CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
255                NoopLayoutLaunch::new(),
256            )),
257        )
258    }
259}