cubecl_convolution/components/global/
args.rs

1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5use cubecl_std::{
6    FastDivmodArgs,
7    tensor::{
8        View,
9        launch::ViewArg,
10        layout::{
11            Coords3d,
12            chain::{Chain, ChainLaunch},
13        },
14    },
15};
16
17use crate::{
18    components::{
19        ConvGemmConfig, ConvolutionProblem,
20        global::{
21            layout::{
22                BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
23                NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
24            },
25            read::layout::{
26                TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
27            },
28        },
29    },
30    kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
31};
32use cubecl_matmul::{
33    MatmulInputHandleRef,
34    components::{
35        MatmulIdent, MatmulLineSizes, MatmulSelection,
36        global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
37    },
38};
39
40/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
41/// output (not fused).
42pub trait ConcreteInputsFactory: LaunchArg {
43    #[allow(clippy::too_many_arguments)]
44    fn create<'a, R: Runtime>(
45        client: &ComputeClient<R::Server>,
46        lhs: &'a MatmulInputHandleRef<'a, R>,
47        rhs: &'a MatmulInputHandleRef<'a, R>,
48        bias: Option<&'a TensorHandleRef<'a, R>>,
49        selection: &MatmulSelection,
50        problem: &ConvolutionProblem,
51        line_sizes: &MatmulLineSizes,
52        config: impl ConvGemmConfig,
53    ) -> Self::RuntimeArg<'a, R>;
54}
55
56/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
57/// output (not fused).
58pub trait ConcreteOutputFactory: LaunchArg {
59    fn create<'a, R: Runtime>(
60        client: &ComputeClient<R::Server>,
61        out: &'a TensorHandleRef<'a, R>,
62        selection: &MatmulSelection,
63        problem: &ConvolutionProblem,
64        line_sizes: &MatmulLineSizes,
65        config: impl ConvGemmConfig,
66    ) -> Self::RuntimeArg<'a, R>;
67}
68
69impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
70    fn create<'a, R: Runtime>(
71        client: &ComputeClient<R::Server>,
72        lhs: &'a MatmulInputHandleRef<'a, R>,
73        rhs: &'a MatmulInputHandleRef<'a, R>,
74        bias: Option<&'a TensorHandleRef<'a, R>>,
75        _selection: &MatmulSelection,
76        problem: &ConvolutionProblem,
77        line_sizes: &MatmulLineSizes,
78        config: impl ConvGemmConfig,
79    ) -> Self::RuntimeArg<'a, R> {
80        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
81        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
82
83        let layout_nhwc = |handle, line_size, check| {
84            NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
85        };
86        let layout_lhs = Im2colLayoutLaunch::from_args(
87            client,
88            problem,
89            config.convolution_params(),
90            config.global_memory_config(MatmulIdent::Lhs),
91        );
92        let layout_rhs = WeightLayoutLaunch::from_args(
93            client,
94            problem,
95            config.convolution_params(),
96            config.global_memory_config(MatmulIdent::Rhs),
97        );
98        let layout_bias =
99            BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
100
101        let layout_lhs = {
102            let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
103            ChainLaunch::new(global, layout_lhs)
104        };
105        let layout_rhs = {
106            let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
107            ChainLaunch::new(global, layout_rhs)
108        };
109
110        TensorInputsLaunch::new(
111            ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
112            ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
113            bias.map(|bias| {
114                ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
115            })
116            .into(),
117        )
118    }
119}
120
121impl<EG: Numeric> ConcreteOutputFactory for View<Line<EG>, Coords3d, ReadWrite> {
122    fn create<'a, R: Runtime>(
123        client: &ComputeClient<R::Server>,
124        out: &'a TensorHandleRef<'a, R>,
125        _selection: &MatmulSelection,
126        problem: &ConvolutionProblem,
127        line_sizes: &MatmulLineSizes,
128        config: impl ConvGemmConfig,
129    ) -> Self::RuntimeArg<'a, R> {
130        type Layout = Chain<NhwcLayout, OutLayout>;
131
132        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
133        let layout = OutLayoutLaunch::from_args(
134            client,
135            problem,
136            config.global_memory_config(MatmulIdent::Out),
137        );
138        let layout = ChainLaunch::new(global, layout);
139        ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout)
140    }
141}
142
143impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
144    for TensorMapInputs<Lhs, Rhs, EO>
145{
146    fn create<'a, R: Runtime>(
147        client: &ComputeClient<R::Server>,
148        lhs: &'a MatmulInputHandleRef<'a, R>,
149        rhs: &'a MatmulInputHandleRef<'a, R>,
150        bias: Option<&'a TensorHandleRef<'a, R>>,
151        selection: &MatmulSelection,
152        problem: &ConvolutionProblem,
153        line_sizes: &MatmulLineSizes,
154        config: impl ConvGemmConfig,
155    ) -> Self::RuntimeArg<'a, R> {
156        let tiling_scheme = selection.tiling_scheme;
157        let stage_m = tiling_scheme.elements_in_stage_m();
158        let stage_n = tiling_scheme.elements_in_stage_n();
159        let tile_size_k = tiling_scheme.elements_in_tile_k();
160        let stage_size_rhs = vec![stage_n, 1, tile_size_k];
161
162        let lhs_elem_size = size_of::<Lhs>();
163        let rhs_elem_size = size_of::<Rhs>();
164
165        fn prefetch(bytes: usize) -> TensorMapPrefetch {
166            match bytes {
167                ..64 => TensorMapPrefetch::None,
168                64..128 => TensorMapPrefetch::B64,
169                128..256 => TensorMapPrefetch::B128,
170                256.. => TensorMapPrefetch::B256,
171            }
172        }
173
174        let prefetch_lhs = prefetch(tile_size_k as usize * lhs_elem_size);
175        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
176
177        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
178        // It shouldn't matter, but it's better to be safe.
179        let lhs_elem = if TypeId::of::<Lhs>() == TypeId::of::<f32>() {
180            tf32::as_type_native_unchecked()
181        } else {
182            Lhs::as_type_native_unchecked()
183        };
184
185        let mut elem_stride = vec![1; 2 + problem.stride.len()];
186
187        for (i, stride) in problem.stride.iter().enumerate() {
188            elem_stride[i + 1] = *stride as usize;
189        }
190
191        let lhs = TensorMapArg::new(
192            TensorMapFormat::Im2col {
193                pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
194                pixel_box_upper_corner: calculate_upper_corner(
195                    &problem.padding,
196                    &problem.kernel_size,
197                    &problem.dilation,
198                ),
199                channels_per_pixel: tile_size_k,
200                pixels_per_column: stage_m,
201            },
202            lhs.data().as_tensor_arg(line_sizes.lhs),
203            lhs_elem,
204        )
205        .with_elem_stride(elem_stride)
206        .with_prefetch(prefetch_lhs);
207
208        let rhs = TensorMapArg::new(
209            TensorMapFormat::Tiled {
210                tile_size: stage_size_rhs,
211            },
212            rhs.data().as_tensor_arg(1),
213            Rhs::as_type_native_unchecked(),
214        )
215        .with_prefetch(prefetch_rhs);
216
217        let padded_channels =
218            (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
219
220        // Dummy layout since we don't support im2col loading rn
221        let lhs_layout = TmaDummyLayoutLaunch::new();
222        let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
223
224        let bias = bias.map(|bias| {
225            let layout =
226                BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
227            ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
228        });
229
230        TensorMapInputsLaunch::new(
231            ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
232            ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
233            bias.into(),
234        )
235    }
236}