cubecl_convolution/components/global/
args.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4    CubeOptionArgs, FastDivmodArgs,
5    tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12};
13
14use crate::{
15    components::{
16        ConvGemmConfig, ConvolutionProblem,
17        global::{
18            layout::{
19                BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
20                NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
21            },
22            read::layout::{
23                TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
24            },
25        },
26    },
27    kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
28};
29use cubecl_matmul::{
30    MatmulInputHandleRef,
31    components::{
32        MatmulElems, MatmulIdent, MatmulLineSizes, MatmulSelection,
33        global::{
34            args::{
35                TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
36                TensorOutput, TensorOutputLaunch,
37            },
38            memory::{NoopLayout, NoopLayoutLaunch},
39        },
40    },
41};
42
43/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
44/// output (not fused).
45pub trait ConcreteInputsFactory: LaunchArg {
46    #[allow(clippy::too_many_arguments)]
47    fn create<'a, R: Runtime>(
48        client: &ComputeClient<R::Server>,
49        lhs: &'a MatmulInputHandleRef<'a, R>,
50        rhs: &'a MatmulInputHandleRef<'a, R>,
51        bias: Option<&'a TensorHandleRef<'a, R>>,
52        selection: &MatmulSelection,
53        problem: &ConvolutionProblem,
54        line_sizes: &MatmulLineSizes,
55        config: impl ConvGemmConfig,
56        dtypes: &MatmulElems,
57    ) -> Self::RuntimeArg<'a, R>;
58}
59
60/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
61/// output (not fused).
62pub trait ConcreteOutputFactory: LaunchArg {
63    fn create<'a, R: Runtime>(
64        client: &ComputeClient<R::Server>,
65        out: &'a TensorHandleRef<'a, R>,
66        selection: &MatmulSelection,
67        problem: &ConvolutionProblem,
68        line_sizes: &MatmulLineSizes,
69        config: impl ConvGemmConfig,
70        dtypes: &MatmulElems,
71    ) -> Self::RuntimeArg<'a, R>;
72}
73
74impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
75    fn create<'a, R: Runtime>(
76        client: &ComputeClient<R::Server>,
77        lhs: &'a MatmulInputHandleRef<'a, R>,
78        rhs: &'a MatmulInputHandleRef<'a, R>,
79        bias: Option<&'a TensorHandleRef<'a, R>>,
80        _selection: &MatmulSelection,
81        problem: &ConvolutionProblem,
82        line_sizes: &MatmulLineSizes,
83        config: impl ConvGemmConfig,
84        _dtypes: &MatmulElems,
85    ) -> Self::RuntimeArg<'a, R> {
86        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
87        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
88
89        let layout_nhwc = |handle, line_size, check| {
90            NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
91        };
92        let layout_lhs = Im2colLayoutLaunch::from_args(
93            client,
94            problem,
95            config.convolution_params(),
96            config.global_memory_config(MatmulIdent::Lhs),
97        );
98        let layout_rhs = WeightLayoutLaunch::from_args(
99            client,
100            problem,
101            config.convolution_params(),
102            config.global_memory_config(MatmulIdent::Rhs),
103        );
104        let layout_bias =
105            BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
106
107        let layout_lhs = {
108            let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
109            ChainLaunch::new(global, layout_lhs)
110        };
111        let layout_rhs = {
112            let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
113            ChainLaunch::new(global, layout_rhs)
114        };
115
116        TensorInputsLaunch::new(
117            ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
118            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
119            ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
120            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
121            bias.map(|bias| {
122                ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
123            })
124            .into(),
125            bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
126                .into(),
127        )
128    }
129}
130
131impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
132    fn create<'a, R: Runtime>(
133        client: &ComputeClient<R::Server>,
134        out: &'a TensorHandleRef<'a, R>,
135        _selection: &MatmulSelection,
136        problem: &ConvolutionProblem,
137        line_sizes: &MatmulLineSizes,
138        config: impl ConvGemmConfig,
139        _dtypes: &MatmulElems,
140    ) -> Self::RuntimeArg<'a, R> {
141        type Layout = Chain<NhwcLayout, OutLayout>;
142
143        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
144        let layout = OutLayoutLaunch::from_args(
145            client,
146            problem,
147            config.global_memory_config(MatmulIdent::Out),
148        );
149        let layout = ChainLaunch::new(global, layout);
150        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
151        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
152        TensorOutputLaunch::new(view, batch)
153    }
154}
155
156impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
157    for TensorMapInputs<Lhs, Rhs, EO>
158{
159    fn create<'a, R: Runtime>(
160        client: &ComputeClient<R::Server>,
161        lhs: &'a MatmulInputHandleRef<'a, R>,
162        rhs: &'a MatmulInputHandleRef<'a, R>,
163        bias: Option<&'a TensorHandleRef<'a, R>>,
164        selection: &MatmulSelection,
165        problem: &ConvolutionProblem,
166        line_sizes: &MatmulLineSizes,
167        config: impl ConvGemmConfig,
168        dtypes: &MatmulElems,
169    ) -> Self::RuntimeArg<'a, R> {
170        let tiling_scheme = selection.tiling_scheme;
171        let stage_m = tiling_scheme.elements_in_stage_m();
172        let stage_n = tiling_scheme.elements_in_stage_n();
173        let tile_size_k = tiling_scheme.elements_in_tile_k();
174        let stage_size_rhs = vec![stage_n, 1, tile_size_k];
175
176        let lhs_elem_size = size_of::<Lhs>();
177        let rhs_elem_size = size_of::<Rhs>();
178
179        fn prefetch(bytes: usize) -> TensorMapPrefetch {
180            match bytes {
181                ..64 => TensorMapPrefetch::None,
182                64..128 => TensorMapPrefetch::B64,
183                128..256 => TensorMapPrefetch::B128,
184                256.. => TensorMapPrefetch::B256,
185            }
186        }
187
188        let prefetch_lhs = prefetch(tile_size_k as usize * lhs_elem_size);
189        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
190
191        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
192        // It shouldn't matter, but it's better to be safe.
193        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
194            tf32::as_type_native_unchecked()
195        } else {
196            dtypes.lhs_stage
197        };
198
199        let mut elem_stride = vec![1; 2 + problem.stride.len()];
200
201        for (i, stride) in problem.stride.iter().enumerate() {
202            elem_stride[i + 1] = *stride as usize;
203        }
204
205        let lhs = TensorMapArg::new(
206            TensorMapFormat::Im2col {
207                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
208                pixel_box_upper_corner: calculate_upper_corner(
209                    &problem.padding,
210                    &problem.kernel_size,
211                    &problem.dilation,
212                ),
213                channels_per_pixel: tile_size_k,
214                pixels_per_column: stage_m,
215            },
216            lhs.data().as_tensor_arg(line_sizes.lhs),
217            lhs_elem,
218        )
219        .with_elem_stride(elem_stride)
220        .with_prefetch(prefetch_lhs);
221
222        let rhs = TensorMapArg::new(
223            TensorMapFormat::Tiled {
224                tile_size: stage_size_rhs,
225            },
226            rhs.data().as_tensor_arg(1),
227            dtypes.rhs_global,
228        )
229        .with_prefetch(prefetch_rhs);
230
231        let padded_channels =
232            (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
233
234        // Dummy layout since we don't support im2col loading rn
235        let lhs_layout = TmaDummyLayoutLaunch::new();
236        let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
237
238        let bias = bias.map(|bias| {
239            let layout =
240                BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
241            ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
242        });
243
244        TensorMapInputsLaunch::new(
245            ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
246            ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
247            bias.into(),
248            CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
249                NoopLayoutLaunch::new(),
250            )),
251        )
252    }
253}