cubecl_convolution/components/global/
args.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::{
4    CubeOptionArgs, FastDivmod, FastDivmodArgs,
5    tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12};
13
14use crate::{
15    components::{
16        ConvGemmConfig, ConvolutionProblem,
17        global::{
18            layout::{
19                BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
20                NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
21            },
22            read::layout::{
23                TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
24            },
25        },
26    },
27    kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
28};
29use cubecl_matmul::{
30    MatmulInputHandleRef,
31    components::{
32        MatmulElems, MatmulLineSizes, MatmulSelection,
33        global::{
34            GlobalConfig,
35            args::{
36                TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
37                TensorOutput, TensorOutputLaunch,
38            },
39            memory::{NoopLayout, NoopLayoutLaunch},
40        },
41        stage::StageConfig as _,
42    },
43};
44
45#[derive(CubeType, CubeLaunch, Clone)]
46pub struct RuntimeArgs {
47    pub shape_m: u32,
48    pub shape_n: u32,
49    pub shape_k: u32,
50    pub channels: u32,
51    pub padded_channels: FastDivmod,
52    pub shape_out: Sequence<FastDivmod>,
53}
54
55/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
56/// output (not fused).
57pub trait ConcreteInputsFactory: LaunchArg {
58    #[allow(clippy::too_many_arguments)]
59    fn create<'a, R: Runtime>(
60        client: &ComputeClient<R>,
61        lhs: &'a MatmulInputHandleRef<'a, R>,
62        rhs: &'a MatmulInputHandleRef<'a, R>,
63        bias: Option<&'a TensorHandleRef<'a, R>>,
64        selection: &MatmulSelection,
65        problem: &ConvolutionProblem,
66        line_sizes: &MatmulLineSizes,
67        config: impl ConvGemmConfig,
68        dtypes: &MatmulElems,
69    ) -> Self::RuntimeArg<'a, R>;
70}
71
72/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
73/// output (not fused).
74pub trait ConcreteOutputFactory: LaunchArg {
75    fn create<'a, R: Runtime>(
76        client: &ComputeClient<R>,
77        out: &'a TensorHandleRef<'a, R>,
78        selection: &MatmulSelection,
79        problem: &ConvolutionProblem,
80        line_sizes: &MatmulLineSizes,
81        config: impl ConvGemmConfig,
82        dtypes: &MatmulElems,
83    ) -> Self::RuntimeArg<'a, R>;
84}
85
86impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
87    fn create<'a, R: Runtime>(
88        client: &ComputeClient<R>,
89        lhs: &'a MatmulInputHandleRef<'a, R>,
90        rhs: &'a MatmulInputHandleRef<'a, R>,
91        bias: Option<&'a TensorHandleRef<'a, R>>,
92        _selection: &MatmulSelection,
93        problem: &ConvolutionProblem,
94        line_sizes: &MatmulLineSizes,
95        config: impl ConvGemmConfig,
96        dtypes: &MatmulElems,
97    ) -> Self::RuntimeArg<'a, R> {
98        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
99        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
100
101        let load_width = client.properties().hardware.load_width;
102        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
103
104        let layout_nhwc = |handle, line_size, check_spatial| {
105            NhwcLayoutLaunch::from_handle(
106                handle,
107                line_size as u32,
108                check_spatial,
109                !problem.channels.is_multiple_of(channel_align),
110            )
111        };
112        let layout_lhs = Im2colLayoutLaunch::from_args(
113            client,
114            problem,
115            config.convolution_params(),
116            config.lhs_global_memory_config(),
117            dtypes,
118        );
119        let layout_rhs = WeightLayoutLaunch::from_args(
120            client,
121            problem,
122            config.convolution_params(),
123            config.rhs_global_memory_config(),
124            dtypes,
125        );
126        let layout_bias =
127            BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
128
129        let layout_lhs = {
130            let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
131            ChainLaunch::new(global, layout_lhs)
132        };
133        let layout_rhs = {
134            let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
135            ChainLaunch::new(global, layout_rhs)
136        };
137
138        TensorInputsLaunch::new(
139            ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
140            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
141            ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
142            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
143            bias.map(|bias| {
144                ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
145            })
146            .into(),
147            bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
148                .into(),
149        )
150    }
151}
152
153impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
154    fn create<'a, R: Runtime>(
155        client: &ComputeClient<R>,
156        out: &'a TensorHandleRef<'a, R>,
157        _selection: &MatmulSelection,
158        problem: &ConvolutionProblem,
159        line_sizes: &MatmulLineSizes,
160        config: impl ConvGemmConfig,
161        _dtypes: &MatmulElems,
162    ) -> Self::RuntimeArg<'a, R> {
163        type Layout = Chain<NhwcLayout, OutLayout>;
164
165        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false, false);
166        let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
167        let layout = ChainLaunch::new(global, layout);
168        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
169        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
170        TensorOutputLaunch::new(view, batch)
171    }
172}
173
174impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
175    for TensorMapInputs<Lhs, Rhs, EO>
176{
177    fn create<'a, R: Runtime>(
178        client: &ComputeClient<R>,
179        lhs: &'a MatmulInputHandleRef<'a, R>,
180        rhs: &'a MatmulInputHandleRef<'a, R>,
181        bias: Option<&'a TensorHandleRef<'a, R>>,
182        selection: &MatmulSelection,
183        problem: &ConvolutionProblem,
184        line_sizes: &MatmulLineSizes,
185        config: impl ConvGemmConfig,
186        dtypes: &MatmulElems,
187    ) -> Self::RuntimeArg<'a, R> {
188        let tiling_scheme = selection.tiling_scheme;
189        let stage_m = tiling_scheme.elements_per_stage_along_m();
190        let stage_n = tiling_scheme.elements_per_stage_along_n();
191        let tile_size_k = tiling_scheme.tile_size.k;
192
193        let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims() as usize];
194        stage_size_rhs.insert(0, stage_n);
195        stage_size_rhs.push(tile_size_k);
196
197        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
198        // It shouldn't matter, but it's better to be safe.
199        let lhs_elem = if *dtypes.lhs_stage == f32::as_type_native_unchecked() {
200            tf32::as_type_native_unchecked()
201        } else {
202            *dtypes.lhs_stage
203        };
204
205        let mut elem_stride = vec![1; 2 + problem.stride.len()];
206
207        for (i, stride) in problem.stride.iter().enumerate() {
208            elem_stride[i + 1] = *stride as usize;
209        }
210
211        let lhs = TensorMapArg::new(
212            TensorMapFormat::Im2col {
213                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
214                pixel_box_upper_corner: calculate_upper_corner(
215                    &problem.padding,
216                    &problem.kernel_size,
217                    &problem.dilation,
218                ),
219                channels_per_pixel: tile_size_k,
220                pixels_per_column: stage_m,
221            },
222            lhs.data().as_tensor_arg(line_sizes.lhs),
223            lhs_elem,
224        )
225        .with_elem_stride(elem_stride);
226
227        let rhs = TensorMapArg::new(
228            TensorMapFormat::Tiled {
229                tile_size: stage_size_rhs,
230            },
231            rhs.data().as_tensor_arg(1),
232            *dtypes.rhs_global,
233        );
234
235        let padded_channels = (problem.channels as u32)
236            .next_multiple_of(config.matmul_config().stage_config().elements_in_tile_k());
237
238        // Dummy layout since we don't support im2col loading rn
239        let lhs_layout = TmaDummyLayoutLaunch::new();
240        let rhs_layout = TmaWeightLayoutLaunch::new(
241            FastDivmodArgs::new(client, padded_channels),
242            problem.kernel_size.clone(),
243        );
244
245        let bias = bias.map(|bias| {
246            let layout =
247                BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
248            ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
249        });
250
251        TensorMapInputsLaunch::new(
252            ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
253            ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
254            bias.into(),
255            CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
256                NoopLayoutLaunch::new(),
257            )),
258        )
259    }
260}