cubecl_convolution/components/global/
args.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4    CubeOptionArgs, FastDivmodArgs,
5    tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12};
13
14use crate::{
15    components::{
16        ConvGemmConfig, ConvolutionProblem,
17        global::{
18            layout::{
19                BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
20                NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
21            },
22            read::layout::{
23                TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
24            },
25        },
26    },
27    kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
28};
29use cubecl_matmul::{
30    MatmulInputHandleRef,
31    components::{
32        MatmulElems, MatmulLineSizes, MatmulSelection,
33        global::{
34            GlobalConfig,
35            args::{
36                TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
37                TensorOutput, TensorOutputLaunch,
38            },
39            memory::{NoopLayout, NoopLayoutLaunch},
40        },
41        stage::StageConfig as _,
42    },
43};
44
45/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
46/// output (not fused).
47pub trait ConcreteInputsFactory: LaunchArg {
48    #[allow(clippy::too_many_arguments)]
49    fn create<'a, R: Runtime>(
50        client: &ComputeClient<R>,
51        lhs: &'a MatmulInputHandleRef<'a, R>,
52        rhs: &'a MatmulInputHandleRef<'a, R>,
53        bias: Option<&'a TensorHandleRef<'a, R>>,
54        selection: &MatmulSelection,
55        problem: &ConvolutionProblem,
56        line_sizes: &MatmulLineSizes,
57        config: impl ConvGemmConfig,
58        dtypes: &MatmulElems,
59    ) -> Self::RuntimeArg<'a, R>;
60}
61
62/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
63/// output (not fused).
64pub trait ConcreteOutputFactory: LaunchArg {
65    fn create<'a, R: Runtime>(
66        client: &ComputeClient<R>,
67        out: &'a TensorHandleRef<'a, R>,
68        selection: &MatmulSelection,
69        problem: &ConvolutionProblem,
70        line_sizes: &MatmulLineSizes,
71        config: impl ConvGemmConfig,
72        dtypes: &MatmulElems,
73    ) -> Self::RuntimeArg<'a, R>;
74}
75
76impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
77    fn create<'a, R: Runtime>(
78        client: &ComputeClient<R>,
79        lhs: &'a MatmulInputHandleRef<'a, R>,
80        rhs: &'a MatmulInputHandleRef<'a, R>,
81        bias: Option<&'a TensorHandleRef<'a, R>>,
82        _selection: &MatmulSelection,
83        problem: &ConvolutionProblem,
84        line_sizes: &MatmulLineSizes,
85        config: impl ConvGemmConfig,
86        _dtypes: &MatmulElems,
87    ) -> Self::RuntimeArg<'a, R> {
88        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
89        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
90
91        let layout_nhwc = |handle, line_size, check| {
92            NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
93        };
94        let layout_lhs = Im2colLayoutLaunch::from_args(
95            client,
96            problem,
97            config.convolution_params(),
98            config.lhs_global_memory_config(),
99        );
100        let layout_rhs = WeightLayoutLaunch::from_args(
101            client,
102            problem,
103            config.convolution_params(),
104            config.rhs_global_memory_config(),
105        );
106        let layout_bias =
107            BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
108
109        let layout_lhs = {
110            let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
111            ChainLaunch::new(global, layout_lhs)
112        };
113        let layout_rhs = {
114            let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
115            ChainLaunch::new(global, layout_rhs)
116        };
117
118        TensorInputsLaunch::new(
119            ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
120            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
121            ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
122            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
123            bias.map(|bias| {
124                ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
125            })
126            .into(),
127            bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
128                .into(),
129        )
130    }
131}
132
133impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
134    fn create<'a, R: Runtime>(
135        client: &ComputeClient<R>,
136        out: &'a TensorHandleRef<'a, R>,
137        _selection: &MatmulSelection,
138        problem: &ConvolutionProblem,
139        line_sizes: &MatmulLineSizes,
140        config: impl ConvGemmConfig,
141        _dtypes: &MatmulElems,
142    ) -> Self::RuntimeArg<'a, R> {
143        type Layout = Chain<NhwcLayout, OutLayout>;
144
145        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
146        let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
147        let layout = ChainLaunch::new(global, layout);
148        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
149        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
150        TensorOutputLaunch::new(view, batch)
151    }
152}
153
154impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
155    for TensorMapInputs<Lhs, Rhs, EO>
156{
157    fn create<'a, R: Runtime>(
158        client: &ComputeClient<R>,
159        lhs: &'a MatmulInputHandleRef<'a, R>,
160        rhs: &'a MatmulInputHandleRef<'a, R>,
161        bias: Option<&'a TensorHandleRef<'a, R>>,
162        selection: &MatmulSelection,
163        problem: &ConvolutionProblem,
164        line_sizes: &MatmulLineSizes,
165        config: impl ConvGemmConfig,
166        dtypes: &MatmulElems,
167    ) -> Self::RuntimeArg<'a, R> {
168        let tiling_scheme = selection.tiling_scheme;
169        let stage_m = tiling_scheme.elements_per_stage_along_m();
170        let stage_n = tiling_scheme.elements_per_stage_along_n();
171        let tile_size_k = tiling_scheme.tile_size.k;
172        let stage_size_rhs = vec![stage_n, 1, tile_size_k];
173
174        let lhs_elem_size = size_of::<Lhs>();
175        let rhs_elem_size = size_of::<Rhs>();
176
177        fn prefetch(bytes: usize) -> TensorMapPrefetch {
178            match bytes {
179                ..64 => TensorMapPrefetch::None,
180                64..128 => TensorMapPrefetch::B64,
181                128..256 => TensorMapPrefetch::B128,
182                256.. => TensorMapPrefetch::B256,
183            }
184        }
185
186        let prefetch_lhs = prefetch(tile_size_k as usize * lhs_elem_size);
187        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
188
189        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
190        // It shouldn't matter, but it's better to be safe.
191        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
192            tf32::as_type_native_unchecked()
193        } else {
194            dtypes.lhs_stage
195        };
196
197        let mut elem_stride = vec![1; 2 + problem.stride.len()];
198
199        for (i, stride) in problem.stride.iter().enumerate() {
200            elem_stride[i + 1] = *stride as usize;
201        }
202
203        let lhs = TensorMapArg::new(
204            TensorMapFormat::Im2col {
205                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
206                pixel_box_upper_corner: calculate_upper_corner(
207                    &problem.padding,
208                    &problem.kernel_size,
209                    &problem.dilation,
210                ),
211                channels_per_pixel: tile_size_k,
212                pixels_per_column: stage_m,
213            },
214            lhs.data().as_tensor_arg(line_sizes.lhs),
215            lhs_elem,
216        )
217        .with_elem_stride(elem_stride)
218        .with_prefetch(prefetch_lhs);
219
220        let rhs = TensorMapArg::new(
221            TensorMapFormat::Tiled {
222                tile_size: stage_size_rhs,
223            },
224            rhs.data().as_tensor_arg(1),
225            dtypes.rhs_global,
226        )
227        .with_prefetch(prefetch_rhs);
228
229        let padded_channels = (problem.channels as u32)
230            .next_multiple_of(config.matmul_config().stage_config().elements_in_tile_k());
231
232        // Dummy layout since we don't support im2col loading rn
233        let lhs_layout = TmaDummyLayoutLaunch::new();
234        let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
235
236        let bias = bias.map(|bias| {
237            let layout =
238                BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
239            ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
240        });
241
242        TensorMapInputsLaunch::new(
243            ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
244            ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
245            bias.into(),
246            CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
247                NoopLayoutLaunch::new(),
248            )),
249        )
250    }
251}