cubecl_matmul/components/tile/accelerated/
matmul.rs

1use std::marker::PhantomData;
2
3use crate::components::tile::{TileConfig, TileMatmul, accelerated::reader::CmmaFragmentReader};
4use crate::components::tile::{accelerated::writer::CmmaStageWriter, tile_data::StridedTile};
5use crate::components::tile::{
6    accelerated::{config::AcceleratedConfig, reader::CmmaStageReader},
7    io::{Strided, TileKind},
8};
9use crate::components::{StageIdent, as_cmma_layout};
10use cubecl_core as cubecl;
11use cubecl_core::{cmma, prelude::*};
12use cubecl_std::CubeOption;
13
14/// Uses one plane to perform a small matmul using accelerated instructions.
15pub struct AcceleratedMatmul<Acc: TileKind> {
16    _ty: PhantomData<Acc>,
17}
18
19#[cube]
20impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
21    for AcceleratedMatmul<AccTile>
22where
23    CmmaStageReader<AccTile>: CmmaFragmentReader<TileKind = AccTile>,
24{
25    type Config = AcceleratedConfig;
26    type LhsFragment = cmma::Matrix<L>;
27    type RhsFragment = cmma::Matrix<R>;
28    type AccFragment = cmma::Matrix<A>;
29
30    type LhsTile = Strided;
31    type RhsTile = Strided;
32    type AccTile = AccTile;
33    type OutTile = Strided;
34
35    fn execute(
36        lhs: &Self::LhsFragment,
37        rhs: &Self::RhsFragment,
38        out: &mut Self::AccFragment,
39        #[comptime] _config: Self::Config,
40    ) {
41        cmma::execute::<L, R, A, A>(lhs, rhs, out, out);
42    }
43
44    fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment {
45        let size = config.tile_size();
46        let layout = config.matrix_layout(StageIdent::Lhs);
47        unsafe {
48            cmma::Matrix::<L>::uninitialized(
49                cmma::MatrixIdent::A,
50                size.m(),
51                size.n(),
52                size.k(),
53                as_cmma_layout(layout),
54            )
55        }
56    }
57
58    fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment {
59        let size = config.tile_size();
60        let layout = config.matrix_layout(StageIdent::Rhs);
61        unsafe {
62            cmma::Matrix::<R>::uninitialized(
63                cmma::MatrixIdent::B,
64                size.m(),
65                size.n(),
66                size.k(),
67                as_cmma_layout(layout),
68            )
69        }
70    }
71
72    fn load_lhs<E: Numeric>(
73        tile: &StridedTile<E>,
74        lhs: &mut Self::LhsFragment,
75        #[comptime] _config: Self::Config,
76    ) {
77        CmmaStageReader::<Self::LhsTile>::load_fragment(tile, lhs, CubeOption::new_None());
78    }
79
80    fn load_rhs<E: Numeric>(
81        tile: &StridedTile<E>,
82        rhs: &mut Self::RhsFragment,
83        #[comptime] _config: Self::Config,
84    ) {
85        CmmaStageReader::<Self::RhsTile>::load_fragment(tile, rhs, CubeOption::new_None());
86    }
87
88    fn load_acc<E: Numeric>(
89        tile: &AccTile::Tile<E>,
90        acc: &mut Self::AccFragment,
91        #[comptime] config: Self::Config,
92    ) {
93        let layout = comptime!(as_cmma_layout(config.matrix_layout(StageIdent::Acc)));
94        CmmaStageReader::<Self::AccTile>::load_fragment(tile, acc, CubeOption::new_Some(layout));
95    }
96
97    fn write_results<E: Numeric>(
98        tile: &mut StridedTile<E, ReadWrite>,
99        out: &Self::AccFragment,
100        #[comptime] _config: Self::Config,
101    ) {
102        let out = cmma::cast::<A, E>(out);
103        CmmaStageWriter::store_fragment(tile, &out);
104    }
105
106    fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {
107        let size = config.tile_size();
108        unsafe {
109            cmma::Matrix::<A>::uninitialized(
110                cmma::MatrixIdent::Accumulator,
111                size.m(),
112                size.n(),
113                size.k(),
114                cmma::MatrixLayout::Undefined,
115            )
116        }
117    }
118}