cubecl_matmul/components/tile/cmma/
matmul.rs

1use std::marker::PhantomData;
2
3use crate::components::tile::{SharedTileConfig, TileMatmul, cmma::reader::CmmaFragmentReader};
4use crate::components::tile::{
5    cmma::reader::CmmaStageReader,
6    io::{Strided, TileKind},
7};
8use crate::components::tile::{cmma::writer::CmmaStageWriter, tile_data::StridedTile};
9use crate::components::{MatrixLayout, as_cmma_layout};
10use cubecl_core as cubecl;
11use cubecl_core::{cmma, prelude::*};
12use cubecl_std::CubeOption;
13
14/// Uses one plane to perform a small matmul using accelerated instructions.
15pub struct CmmaMatmul<Acc: TileKind> {
16    _ty: PhantomData<Acc>,
17}
18
19#[derive(CubeType)]
20pub struct Fragment<E: Numeric> {
21    fragment: cmma::Matrix<E>,
22    #[cube(comptime)]
23    layout: MatrixLayout,
24}
25
26#[cube]
27impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
28    for CmmaMatmul<AccTile>
29where
30    CmmaStageReader<AccTile>: CmmaFragmentReader<TileKind = AccTile>,
31{
32    type Config = SharedTileConfig;
33
34    type LhsFragment = Fragment<L>;
35    type RhsFragment = Fragment<R>;
36    type AccFragment = Fragment<A>;
37
38    type LhsTile = Strided;
39    type RhsTile = Strided;
40    type AccTile = AccTile;
41    type OutTile = Strided;
42
43    fn execute(
44        lhs: &Self::LhsFragment,
45        rhs: &Self::RhsFragment,
46        out: &mut Self::AccFragment,
47        #[comptime] _config: Self::Config,
48    ) {
49        cmma::execute::<L, R, A, A>(&lhs.fragment, &rhs.fragment, &out.fragment, &out.fragment);
50    }
51
52    fn allocate_lhs(
53        #[comptime] layout: MatrixLayout,
54        #[comptime] config: Self::Config,
55    ) -> Self::LhsFragment {
56        let size = config.tile_size;
57
58        Fragment::<L> {
59            fragment: unsafe {
60                cmma::Matrix::<L>::uninitialized(
61                    cmma::MatrixIdent::A,
62                    size.m(),
63                    size.n(),
64                    size.k(),
65                    as_cmma_layout(layout),
66                )
67            },
68            layout,
69        }
70    }
71
72    fn allocate_rhs(
73        #[comptime] layout: MatrixLayout,
74        #[comptime] config: Self::Config,
75    ) -> Self::RhsFragment {
76        let size = config.tile_size;
77
78        Fragment::<R> {
79            fragment: unsafe {
80                cmma::Matrix::<R>::uninitialized(
81                    cmma::MatrixIdent::B,
82                    size.m(),
83                    size.n(),
84                    size.k(),
85                    as_cmma_layout(layout),
86                )
87            },
88            layout,
89        }
90    }
91
92    fn allocate_acc(
93        #[comptime] layout: MatrixLayout,
94        #[comptime] config: Self::Config,
95    ) -> Self::AccFragment {
96        let size = config.tile_size;
97
98        Fragment::<A> {
99            fragment: unsafe {
100                cmma::Matrix::<A>::uninitialized(
101                    cmma::MatrixIdent::Accumulator,
102                    size.m(),
103                    size.n(),
104                    size.k(),
105                    cmma::MatrixLayout::Undefined,
106                )
107            },
108            layout,
109        }
110    }
111
112    fn load_lhs<E: Numeric>(
113        tile: &StridedTile<E>,
114        lhs: &mut Self::LhsFragment,
115        #[comptime] _config: Self::Config,
116    ) {
117        CmmaStageReader::<Self::LhsTile>::load_fragment(
118            tile,
119            &mut lhs.fragment,
120            CubeOption::new_None(),
121        );
122    }
123
124    fn load_rhs<E: Numeric>(
125        tile: &StridedTile<E>,
126        rhs: &mut Self::RhsFragment,
127        #[comptime] _config: Self::Config,
128    ) {
129        CmmaStageReader::<Self::RhsTile>::load_fragment(
130            tile,
131            &mut rhs.fragment,
132            CubeOption::new_None(),
133        );
134    }
135
136    fn load_acc<E: Numeric>(
137        tile: &AccTile::Tile<E>,
138        acc: &mut Self::AccFragment,
139        #[comptime] _config: Self::Config,
140    ) {
141        CmmaStageReader::<Self::AccTile>::load_fragment(
142            tile,
143            &mut acc.fragment,
144            CubeOption::new_Some(as_cmma_layout(acc.layout)),
145        );
146    }
147
148    fn write_results<E: Numeric>(
149        tile: &mut StridedTile<E, ReadWrite>,
150        out: &Self::AccFragment,
151        #[comptime] _config: Self::Config,
152    ) {
153        let out = cmma::cast::<A, E>(&out.fragment);
154        CmmaStageWriter::store_fragment(tile, &out);
155    }
156}