cubecl_matmul/components/tile/mma/
matmul.rs

1use std::marker::PhantomData;
2
3use crate::components::MatrixLayout;
4use crate::components::tile::io::{Filled, Strided, TileKind};
5use crate::components::tile::mma::config::MmaMatmulConfig;
6use crate::components::tile::{
7    TileMatmul,
8    mma::{reader::MmaStageReader, writer::MmaStageWriter},
9};
10use crate::components::tile::{mma::reader::MmaFragmentReader, tile_data::StridedTile};
11use cubecl_core::prelude::*;
12use cubecl_core::{self as cubecl, cmma::MmaDefinition, ir::MatrixIdent};
13
14/// Uses one plane to perform a small matmul using accelerated instructions, with manual register
15/// management.
16/// Currently requires matrix layout to match the platform's preferred layout.
17pub struct MmaMatmul<Lhs: TileKind = Strided, Rhs: TileKind = Strided, Acc: TileKind = Filled> {
18    _ty: PhantomData<(Lhs, Rhs, Acc)>,
19}
20
21#[derive(CubeType)]
22pub struct MmaFragment<E: Numeric> {
23    fragment: Array<Line<E>>,
24    #[cube(comptime)]
25    layout: MatrixLayout,
26}
27
28#[cube]
29impl<L: Numeric, R: Numeric, A: Numeric, LhsTile: TileKind, RhsTile: TileKind, AccTile: TileKind>
30    TileMatmul<L, R, A> for MmaMatmul<LhsTile, RhsTile, AccTile>
31where
32    MmaStageReader<LhsTile>: MmaFragmentReader<TileKind = LhsTile>,
33    MmaStageReader<RhsTile>: MmaFragmentReader<TileKind = RhsTile>,
34    MmaStageReader<AccTile>: MmaFragmentReader<TileKind = AccTile>,
35{
36    type Config = MmaMatmulConfig;
37
38    type LhsFragment = MmaFragment<L>;
39    type RhsFragment = MmaFragment<R>;
40    type AccFragment = MmaFragment<A>;
41
42    type LhsTile = LhsTile;
43    type RhsTile = RhsTile;
44    type AccTile = AccTile;
45    type OutTile = Strided;
46
47    fn execute(
48        lhs: &Self::LhsFragment,
49        rhs: &Self::RhsFragment,
50        out: &mut Self::AccFragment,
51        #[comptime] config: Self::Config,
52    ) {
53        let def = mma_definition(config);
54        let out_arr = def.execute(&lhs.fragment, &rhs.fragment, &out.fragment);
55        let num_lines = def.lines_per_lane(MatrixIdent::Accumulator);
56
57        #[unroll]
58        for i in 0..num_lines {
59            out.fragment[i] = out_arr[i];
60        }
61    }
62
63    fn allocate_lhs(
64        #[comptime] layout: MatrixLayout,
65        #[comptime] config: Self::Config,
66    ) -> Self::LhsFragment {
67        let def = mma_definition::<L, R, A>(config);
68        let line_size = def.line_size(MatrixIdent::A);
69        let line_count = def.lines_per_lane(MatrixIdent::A);
70
71        MmaFragment::<L> {
72            fragment: Array::vectorized(line_count, line_size),
73            layout,
74        }
75    }
76
77    fn allocate_rhs(
78        #[comptime] layout: MatrixLayout,
79        #[comptime] config: Self::Config,
80    ) -> Self::RhsFragment {
81        let def = mma_definition::<L, R, A>(config);
82        let line_size = def.line_size(MatrixIdent::B);
83        let line_count = def.lines_per_lane(MatrixIdent::B);
84
85        MmaFragment::<R> {
86            fragment: Array::vectorized(line_count, line_size),
87            layout,
88        }
89    }
90
91    fn allocate_acc(
92        #[comptime] layout: MatrixLayout,
93        #[comptime] config: Self::Config,
94    ) -> Self::AccFragment {
95        let def = mma_definition::<L, R, A>(config);
96        let line_size = def.line_size(MatrixIdent::Accumulator);
97        let line_count = def.lines_per_lane(MatrixIdent::Accumulator);
98
99        MmaFragment::<A> {
100            fragment: Array::vectorized(line_count, line_size),
101            layout,
102        }
103    }
104
105    fn load_lhs<E: Numeric>(
106        tile: &LhsTile::Tile<E>,
107        lhs: &mut Self::LhsFragment,
108        #[comptime] config: Self::Config,
109    ) {
110        MmaStageReader::<Self::LhsTile>::load_fragment(
111            tile,
112            &mut lhs.fragment,
113            mma_definition::<L, R, A>(config),
114            MatrixIdent::A,
115            lhs.layout,
116            config,
117        );
118    }
119
120    fn load_rhs<E: Numeric>(
121        tile: &RhsTile::Tile<E>,
122        rhs: &mut Self::RhsFragment,
123        #[comptime] config: Self::Config,
124    ) {
125        MmaStageReader::<Self::RhsTile>::load_fragment(
126            tile,
127            &mut rhs.fragment,
128            mma_definition::<L, R, A>(config),
129            MatrixIdent::B,
130            rhs.layout,
131            config,
132        );
133    }
134
135    fn load_acc<E: Numeric>(
136        tile: &AccTile::Tile<E>,
137        acc: &mut Self::AccFragment,
138        #[comptime] config: Self::Config,
139    ) {
140        MmaStageReader::<Self::AccTile>::load_fragment(
141            tile,
142            &mut acc.fragment,
143            mma_definition::<L, R, A>(config),
144            MatrixIdent::Accumulator,
145            acc.layout,
146            config,
147        );
148    }
149
150    fn write_results<E: Numeric>(
151        tile: &mut StridedTile<E, ReadWrite>,
152        out: &Self::AccFragment,
153        #[comptime] config: Self::Config,
154    ) {
155        MmaStageWriter::store_fragment(
156            tile,
157            &out.fragment,
158            mma_definition::<L, R, A>(config),
159            MatrixIdent::Accumulator,
160            tile.layout,
161        );
162    }
163}
164
165#[cube]
166pub(super) fn mma_definition<L: Numeric, R: Numeric, A: Numeric>(
167    #[comptime] config: MmaMatmulConfig,
168) -> MmaDefinition<L, R, A> {
169    let size = config.shared.tile_size;
170    MmaDefinition::new(size.m(), size.n(), size.k())
171}