cubecl_matmul/components/tile/mma/
matmul.rs

1use std::marker::PhantomData;
2
3use crate::components::tile::io::{Strided, TileKind};
4use crate::components::tile::{
5    TileConfig, TileMatmul,
6    mma::{reader::MmaStageReader, writer::MmaStageWriter},
7};
8use crate::components::tile::{mma::reader::MmaFragmentReader, tile_data::StridedTile};
9use crate::components::{StageIdent, tile::mma::config::MmaMatmulConfig};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, cmma::MmaDefinition, ir::MatrixIdent};
12
13/// Uses one plane to perform a small matmul using accelerated instructions, with manual register
14/// management.
15/// Currently requires matrix layout to match the platform's preferred layout.
16pub struct MmaMatmul<Acc: TileKind> {
17    _ty: PhantomData<Acc>,
18}
19
20#[cube]
21impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
22    for MmaMatmul<AccTile>
23where
24    MmaStageReader<AccTile>: MmaFragmentReader<TileKind = AccTile>,
25{
26    type Config = MmaMatmulConfig;
27    type LhsFragment = Sequence<Line<L>>;
28    type RhsFragment = Sequence<Line<R>>;
29    type AccFragment = Sequence<Line<A>>;
30
31    type LhsTile = Strided;
32    type RhsTile = Strided;
33    type AccTile = AccTile;
34    type OutTile = Strided;
35
36    fn execute(
37        lhs: &Self::LhsFragment,
38        rhs: &Self::RhsFragment,
39        out: &mut Self::AccFragment,
40        #[comptime] config: Self::Config,
41    ) {
42        let def = mma_definition(config);
43        let out_arr = def.execute(lhs, rhs, out);
44        let num_lines = def.lines_per_lane(MatrixIdent::Accumulator);
45
46        #[unroll]
47        for i in 0..num_lines {
48            *out.index_mut(i) = out_arr[i];
49        }
50    }
51
52    fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment {
53        let def = mma_definition::<L, R, A>(config);
54        let line_size = def.line_size(MatrixIdent::A);
55        let mut frag = Sequence::new();
56        #[unroll]
57        for _ in 0..def.lines_per_lane(MatrixIdent::A) {
58            // Needs to be mut because sequence is dodgy
59            #[allow(unused_mut)]
60            let mut reg = Line::empty(line_size);
61            frag.push(reg);
62        }
63        frag
64    }
65
66    fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment {
67        let def = mma_definition::<L, R, A>(config);
68        let line_size = def.line_size(MatrixIdent::B);
69        let mut frag = Sequence::new();
70        #[unroll]
71        for _ in 0..def.lines_per_lane(MatrixIdent::B) {
72            // Needs to be mut because sequence is dodgy
73            #[allow(unused_mut)]
74            let mut reg = Line::empty(line_size);
75            frag.push(reg);
76        }
77        frag
78    }
79
80    fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {
81        let def = mma_definition::<L, R, A>(config);
82        let line_size = def.line_size(MatrixIdent::Accumulator);
83        let mut frag = Sequence::new();
84        #[unroll]
85        for _ in 0..def.lines_per_lane(MatrixIdent::Accumulator) {
86            // Needs to be mut because sequence is dodgy
87            #[allow(unused_mut)]
88            let mut reg = Line::empty(line_size);
89            frag.push(reg);
90        }
91        frag
92    }
93
94    fn load_lhs<E: Numeric>(
95        tile: &StridedTile<E>,
96        lhs: &mut Self::LhsFragment,
97        #[comptime] config: Self::Config,
98    ) {
99        MmaStageReader::<Self::LhsTile>::load_fragment(
100            tile,
101            lhs,
102            mma_definition::<L, R, A>(config),
103            MatrixIdent::A,
104            config.matrix_layout(StageIdent::Lhs),
105        );
106    }
107
108    fn load_rhs<E: Numeric>(
109        tile: &StridedTile<E>,
110        rhs: &mut Self::RhsFragment,
111        #[comptime] config: Self::Config,
112    ) {
113        MmaStageReader::<Self::LhsTile>::load_fragment(
114            tile,
115            rhs,
116            mma_definition::<L, R, A>(config),
117            MatrixIdent::B,
118            config.matrix_layout(StageIdent::Rhs),
119        );
120    }
121
122    fn load_acc<E: Numeric>(
123        tile: &AccTile::Tile<E>,
124        acc: &mut Self::AccFragment,
125        #[comptime] config: Self::Config,
126    ) {
127        MmaStageReader::<Self::AccTile>::load_fragment(
128            tile,
129            acc,
130            mma_definition::<L, R, A>(config),
131            MatrixIdent::Accumulator,
132            config.matrix_layout(StageIdent::Acc),
133        );
134    }
135
136    fn write_results<E: Numeric>(
137        tile: &mut StridedTile<E, ReadWrite>,
138        out: &Self::AccFragment,
139        #[comptime] config: Self::Config,
140    ) {
141        MmaStageWriter::store_fragment(
142            tile,
143            out,
144            mma_definition::<L, R, A>(config),
145            MatrixIdent::Accumulator,
146            config.matrix_layout(StageIdent::Out),
147        );
148    }
149}
150
151#[cube]
152pub(super) fn mma_definition<L: Numeric, R: Numeric, A: Numeric>(
153    #[comptime] config: MmaMatmulConfig,
154) -> MmaDefinition<L, R, A> {
155    let size = config.tile_size();
156    MmaDefinition::new(size.m(), size.n(), size.k())
157}