cubecl_matmul/components/global/
args.rs

1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded};
5use cubecl_std::{
6    CubeOption, CubeOptionArgs,
7    tensor::{View, launch::ViewArg, layout::Coords3d},
8};
9
10use crate::{
11    MatmulInputHandleRef,
12    components::{
13        self, MatmulIdent, MatmulLineSizes, MatmulProblem, MatmulSelection,
14        batch::BatchConfig,
15        global::{
16            GlobalConfig,
17            memory::{
18                BatchedGlobalLayout, BatchedGlobalLayoutLaunch, BatchedGlobalScaleLayout,
19                SimpleTmaGlobalLayout, SimpleTmaGlobalLayoutLaunch,
20            },
21        },
22    },
23};
24
25/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
26/// output (not fused).
27pub trait ConcreteInputsFactory: LaunchArg {
28    fn create<'a, R: Runtime>(
29        client: &ComputeClient<R::Server>,
30        lhs: &'a MatmulInputHandleRef<'a, R>,
31        rhs: &'a MatmulInputHandleRef<'a, R>,
32        selection: &MatmulSelection,
33        problem: &MatmulProblem,
34        line_sizes: &MatmulLineSizes,
35        config: impl BatchConfig,
36    ) -> Self::RuntimeArg<'a, R>;
37}
38
39/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
40/// output (not fused).
41pub trait ConcreteOutputFactory: LaunchArg {
42    fn create<'a, R: Runtime>(
43        client: &ComputeClient<R::Server>,
44        out: &'a TensorHandleRef<'a, R>,
45        selection: &MatmulSelection,
46        problem: &MatmulProblem,
47        line_sizes: &MatmulLineSizes,
48        config: impl BatchConfig,
49    ) -> Self::RuntimeArg<'a, R>;
50}
51
52#[cube]
53/// Arguments for the matrix multiplication algorithm.
54pub trait MatmulArgs: Send + Sync + 'static + Clone {
55    /// Type used for the input.
56    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
57    /// Type used for the output.
58    type Output<EO: Numeric>: LaunchArg + CubeType;
59    /// Inner state that is used to create [tensor inputs](TensorInput) and
60    /// [tensor outputs](TensorOutput) .
61    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
62
63    /// Init the state.
64    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
65        input: &Self::Input<Lhs, Rhs, EO>,
66        output: &mut Self::Output<EO>,
67        #[comptime] config: G,
68    ) -> Self::State<Lhs, Rhs, EO>;
69
70    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
71        _state: &Self::State<Lhs, Rhs, EO>,
72    ) -> View<Line<Lhs>, Coords3d> {
73        unexpanded!()
74    }
75    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
76        _state: &Self::State<Lhs, Rhs, EO>,
77    ) -> View<Line<Rhs>, Coords3d> {
78        unexpanded!()
79    }
80    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
81        _state: &Self::State<Lhs, Rhs, EO>,
82    ) -> CubeOption<View<Line<EO>, Coords3d>> {
83        unexpanded!()
84    }
85    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
86        _state: &mut Self::State<Lhs, Rhs, EO>,
87    ) -> View<Line<EO>, Coords3d, ReadWrite> {
88        unexpanded!()
89    }
90}
91
92#[derive(Clone, Copy)]
93/// Identification of the [tensor input](TensorInput).
94pub enum TensorInputIdent {
95    Lhs,
96    Rhs,
97}
98
99#[derive(Clone)]
100/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
101///
102/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
103pub struct TensorArgs;
104
105#[derive(CubeLaunch, CubeType)]
106/// Input representation for [TensorArgs] implementing [MatmulArgs].
107pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
108    /// The lhs tensor.
109    pub lhs: View<Line<Lhs>, Coords3d>,
110    /// The rhs tensor.
111    pub rhs: View<Line<Rhs>, Coords3d>,
112    /// The tensor for loading the accumulator, if present
113    pub acc: CubeOption<View<Line<Acc>, Coords3d>>,
114}
115
116pub type TensorOutput<EO> = View<Line<EO>, Coords3d, ReadWrite>;
117
118impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
119    for TensorInputs<Lhs, Rhs, Acc>
120{
121    fn create<'a, R: Runtime>(
122        client: &ComputeClient<R::Server>,
123        lhs: &'a MatmulInputHandleRef<'a, R>,
124        rhs: &'a MatmulInputHandleRef<'a, R>,
125        _selection: &MatmulSelection,
126        problem: &MatmulProblem,
127        line_sizes: &MatmulLineSizes,
128        config: impl BatchConfig,
129    ) -> Self::RuntimeArg<'a, R> {
130        let config = config.global_config();
131        let view = |handle: &'a MatmulInputHandleRef<'a, R>, ident, line_size| match handle {
132            MatmulInputHandleRef::Normal(handle) => {
133                let layout = BatchedGlobalLayoutLaunch::from_handle(
134                    client,
135                    handle,
136                    problem,
137                    line_size,
138                    config.global_memory_config(ident).into(),
139                );
140                ViewArg::new::<BatchedGlobalLayout>(handle.as_array_arg(line_size), layout)
141            }
142            MatmulInputHandleRef::Quantized {
143                data,
144                scale,
145                shape,
146                scheme,
147            } => {
148                let (data_layout, scales_layout) = BatchedGlobalLayoutLaunch::from_quantized_handle(
149                    client,
150                    data,
151                    scale,
152                    shape,
153                    problem,
154                    **scheme,
155                    line_size,
156                    config.global_memory_config(ident).into(),
157                );
158                let data_view =
159                    ViewArg::new::<BatchedGlobalLayout>(data.as_array_arg(line_size), data_layout);
160                let scales_view =
161                    ViewArg::new::<BatchedGlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
162                ViewArg::new_quantized(data_view, scales_view, **scheme)
163            }
164        };
165
166        TensorInputsLaunch::new(
167            view(lhs, MatmulIdent::Lhs, line_sizes.lhs),
168            view(rhs, MatmulIdent::Rhs, line_sizes.rhs),
169            CubeOptionArgs::None,
170        )
171    }
172}
173
174impl<EG: Numeric> ConcreteOutputFactory for View<Line<EG>, Coords3d, ReadWrite> {
175    fn create<'a, R: Runtime>(
176        client: &ComputeClient<R::Server>,
177        out: &'a TensorHandleRef<'a, R>,
178        _selection: &MatmulSelection,
179        problem: &MatmulProblem,
180        line_sizes: &MatmulLineSizes,
181        config: impl BatchConfig,
182    ) -> Self::RuntimeArg<'a, R> {
183        let config = config.global_config();
184        let layout = BatchedGlobalLayoutLaunch::from_handle(
185            client,
186            out,
187            problem,
188            line_sizes.out,
189            config.global_memory_config(MatmulIdent::Out).into(),
190        );
191        ViewArg::new::<BatchedGlobalLayout>(out.as_array_arg(line_sizes.out), layout)
192    }
193}
194
195#[cube]
196impl MatmulArgs for TensorArgs {
197    type Output<EO: Numeric> = TensorOutput<EO>;
198    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
199    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = (
200        View<Line<Lhs>, Coords3d>,
201        View<Line<Rhs>, Coords3d>,
202        CubeOption<View<Line<EO>, Coords3d>>,
203        View<Line<EO>, Coords3d, ReadWrite>,
204    );
205
206    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
207        input: &Self::Input<Lhs, Rhs, EO>,
208        output: &mut Self::Output<EO>,
209        #[comptime] _config: G,
210    ) -> Self::State<Lhs, Rhs, EO> {
211        (input.lhs, input.rhs, input.acc, *output)
212    }
213
214    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
215        state: &Self::State<Lhs, Rhs, EO>,
216    ) -> View<Line<Lhs>, Coords3d> {
217        state.0
218    }
219
220    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
221        state: &Self::State<Lhs, Rhs, EO>,
222    ) -> View<Line<Rhs>, Coords3d> {
223        state.1
224    }
225
226    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
227        state: &Self::State<Lhs, Rhs, EO>,
228    ) -> CubeOption<View<Line<EO>, Coords3d>> {
229        state.2
230    }
231
232    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
233        state: &mut Self::State<Lhs, Rhs, EO>,
234    ) -> View<Line<EO>, Coords3d, ReadWrite> {
235        state.3
236    }
237}
238
239#[derive(Clone)]
240/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
241///
242/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
243pub struct TensorMapArgs;
244
245#[derive(CubeLaunch, CubeType)]
246/// Input representation for [TensorArgs] implementing [MatmulArgs].
247pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
248    /// The lhs tensor.
249    pub lhs: View<Line<Lhs>, Coords3d>,
250    /// The rhs tensor.
251    pub rhs: View<Line<Rhs>, Coords3d>,
252    /// The accumulator
253    pub acc: CubeOption<View<Line<EO>, Coords3d>>,
254}
255
256impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
257    for TensorMapInputs<Lhs, Rhs, EO>
258{
259    fn create<'a, R: Runtime>(
260        _client: &ComputeClient<R::Server>,
261        lhs_handle: &'a MatmulInputHandleRef<'a, R>,
262        rhs_handle: &'a MatmulInputHandleRef<'a, R>,
263        selection: &MatmulSelection,
264        problem: &MatmulProblem,
265        line_sizes: &MatmulLineSizes,
266        _config: impl BatchConfig,
267    ) -> Self::RuntimeArg<'a, R> {
268        let lhs = lhs_handle.data();
269        let rhs = rhs_handle.data();
270
271        let tiling_scheme = selection.tiling_scheme;
272        let stage_m = tiling_scheme.elements_in_stage_m();
273        let stage_n = tiling_scheme.elements_in_stage_n();
274        let stage_k = tiling_scheme.elements_in_stage_k();
275        let stage_size_lhs = match problem.lhs_layout {
276            components::MatrixLayout::RowMajor => {
277                vec![1, stage_m, tiling_scheme.elements_in_tile_k()]
278            }
279            components::MatrixLayout::ColMajor => {
280                vec![1, stage_k, tiling_scheme.elements_in_tile_m()]
281            }
282        };
283        let stage_size_rhs = match problem.rhs_layout {
284            components::MatrixLayout::RowMajor => {
285                vec![1, stage_k, tiling_scheme.elements_in_tile_n()]
286            }
287            components::MatrixLayout::ColMajor => {
288                vec![1, stage_n, tiling_scheme.elements_in_tile_k()]
289            }
290        };
291
292        let lhs_elem_size = size_of::<Lhs>();
293        let rhs_elem_size = size_of::<Rhs>();
294
295        let lhs_rank = lhs.shape.len();
296        let mut lhs_shape = vec![
297            problem.lhs_batches[0],
298            lhs.shape[lhs_rank - 2],
299            lhs.shape[lhs_rank - 1],
300        ];
301        let mut lhs_strides = if lhs_rank > 2 {
302            lhs.strides[lhs_rank - 3..].to_vec()
303        } else {
304            vec![1, lhs.strides[lhs_rank - 2], lhs.strides[lhs_rank - 1]]
305        };
306
307        let rhs_rank = rhs.shape.len();
308        let mut rhs_shape = vec![
309            problem.rhs_batches[0],
310            rhs.shape[rhs_rank - 2],
311            rhs.shape[rhs_rank - 1],
312        ];
313        let mut rhs_strides = if rhs_rank > 2 {
314            rhs.strides[rhs_rank - 3..].to_vec()
315        } else {
316            vec![1, rhs.strides[rhs_rank - 2], rhs.strides[rhs_rank - 1]]
317        };
318
319        let mut lhs_transposed = false;
320        let mut rhs_transposed = false;
321
322        // TMA assumes the last stride is contiguous and won't even take it, so we need to map it
323        // with transposed shape and stride. Tensor metadata still has the normal layout.
324        if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
325            lhs_shape.swap(lhs_rank - 1, lhs_rank - 2);
326            lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
327            lhs_transposed = true;
328        }
329        if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
330            rhs_shape.swap(rhs_rank - 1, rhs_rank - 2);
331            rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
332            rhs_transposed = true;
333        }
334
335        fn prefetch(bytes: usize) -> TensorMapPrefetch {
336            match bytes {
337                ..64 => TensorMapPrefetch::None,
338                64..128 => TensorMapPrefetch::B64,
339                128..256 => TensorMapPrefetch::B128,
340                256.. => TensorMapPrefetch::B256,
341            }
342        }
343
344        let prefetch_lhs = prefetch(stage_size_lhs[2] as usize * lhs_elem_size);
345        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
346
347        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
348        // It shouldn't matter, but it's better to be safe.
349        let lhs_elem = if TypeId::of::<Lhs>() == TypeId::of::<f32>() {
350            tf32::as_type_native_unchecked()
351        } else {
352            Lhs::as_type_native_unchecked()
353        };
354        let rhs_elem = if TypeId::of::<Rhs>() == TypeId::of::<f32>() {
355            tf32::as_type_native_unchecked()
356        } else {
357            Rhs::as_type_native_unchecked()
358        };
359
360        let meta_lhs = TensorMapMeta {
361            format: TensorMapFormat::Tiled {
362                tile_size: stage_size_lhs,
363            },
364            rank: 3,
365            shape: lhs_shape.clone(),
366            strides: lhs_strides,
367            elem_stride: vec![1, 1, 1],
368            interleave: TensorMapInterleave::None,
369            swizzle: TensorMapSwizzle::None,
370            prefetch: prefetch_lhs,
371            oob_fill: OobFill::Zero,
372            storage_ty: lhs_elem,
373        };
374
375        let meta_rhs = TensorMapMeta {
376            format: TensorMapFormat::Tiled {
377                tile_size: stage_size_rhs,
378            },
379            rank: 3,
380            shape: rhs_shape.clone(),
381            strides: rhs_strides,
382            elem_stride: vec![1, 1, 1],
383            interleave: TensorMapInterleave::None,
384            swizzle: TensorMapSwizzle::None,
385            prefetch: prefetch_rhs,
386            oob_fill: OobFill::Zero,
387            storage_ty: rhs_elem,
388        };
389
390        let lhs = TensorMapArg {
391            tensor: lhs.as_tensor_arg(line_sizes.lhs),
392            metadata: meta_lhs,
393        };
394        let rhs = TensorMapArg {
395            tensor: rhs.as_tensor_arg(line_sizes.rhs),
396            metadata: meta_rhs,
397        };
398
399        let view = |buffer, shape: &[usize], transposed| {
400            let batches = ScalarArg::new(shape[0] as u32);
401            let (rows, cols) = match transposed {
402                true => (
403                    ScalarArg::new(shape[2] as u32),
404                    ScalarArg::new(shape[1] as u32),
405                ),
406                false => (
407                    ScalarArg::new(shape[1] as u32),
408                    ScalarArg::new(shape[2] as u32),
409                ),
410            };
411            let shape = (batches, rows, cols);
412            let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
413            ViewArg::new_tensor_map::<SimpleTmaGlobalLayout>(buffer, layout)
414        };
415
416        TensorMapInputsLaunch::new(
417            view(lhs, &lhs_shape, lhs_transposed),
418            view(rhs, &rhs_shape, rhs_transposed),
419            CubeOptionArgs::None,
420        )
421    }
422}
423
424#[cube]
425impl MatmulArgs for TensorMapArgs {
426    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
427    type Output<EO: Numeric> = TensorOutput<EO>;
428    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = (
429        View<Line<Lhs>, Coords3d>,
430        View<Line<Rhs>, Coords3d>,
431        CubeOption<View<Line<EO>, Coords3d>>,
432        View<Line<EO>, Coords3d, ReadWrite>,
433    );
434
435    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
436        input: &Self::Input<Lhs, Rhs, EO>,
437        output: &mut Self::Output<EO>,
438        #[comptime] _config: G,
439    ) -> Self::State<Lhs, Rhs, EO> {
440        (input.lhs, input.rhs, input.acc, *output)
441    }
442
443    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
444        state: &Self::State<Lhs, Rhs, EO>,
445    ) -> View<Line<Lhs>, Coords3d> {
446        state.0
447    }
448
449    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
450        state: &Self::State<Lhs, Rhs, EO>,
451    ) -> View<Line<Rhs>, Coords3d> {
452        state.1
453    }
454
455    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
456        state: &Self::State<Lhs, Rhs, EO>,
457    ) -> CubeOption<View<Line<EO>, Coords3d>> {
458        state.2
459    }
460
461    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
462        state: &mut Self::State<Lhs, Rhs, EO>,
463    ) -> View<Line<EO>, Coords3d, ReadWrite> {
464        state.3
465    }
466}