cubecl_matmul/components/tile/cmma/
matmul.rs1use std::marker::PhantomData;
2
3use crate::components::tile::{SharedTileConfig, TileMatmul, cmma::reader::CmmaFragmentReader};
4use crate::components::tile::{
5 cmma::reader::CmmaStageReader,
6 io::{Strided, TileKind},
7};
8use crate::components::tile::{cmma::writer::CmmaStageWriter, tile_data::StridedTile};
9use crate::components::{MatrixLayout, as_cmma_layout};
10use cubecl_core as cubecl;
11use cubecl_core::{cmma, prelude::*};
12use cubecl_std::CubeOption;
13
14pub struct CmmaMatmul<Acc: TileKind> {
16 _ty: PhantomData<Acc>,
17}
18
19#[derive(CubeType)]
20pub struct Fragment<E: Numeric> {
21 fragment: cmma::Matrix<E>,
22 #[cube(comptime)]
23 layout: MatrixLayout,
24}
25
26#[cube]
27impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
28 for CmmaMatmul<AccTile>
29where
30 CmmaStageReader<AccTile>: CmmaFragmentReader<TileKind = AccTile>,
31{
32 type Config = SharedTileConfig;
33
34 type LhsFragment = Fragment<L>;
35 type RhsFragment = Fragment<R>;
36 type AccFragment = Fragment<A>;
37
38 type LhsTile = Strided;
39 type RhsTile = Strided;
40 type AccTile = AccTile;
41 type OutTile = Strided;
42
43 fn execute(
44 lhs: &Self::LhsFragment,
45 rhs: &Self::RhsFragment,
46 out: &mut Self::AccFragment,
47 #[comptime] _config: Self::Config,
48 ) {
49 cmma::execute::<L, R, A, A>(&lhs.fragment, &rhs.fragment, &out.fragment, &out.fragment);
50 }
51
52 fn allocate_lhs(
53 #[comptime] layout: MatrixLayout,
54 #[comptime] config: Self::Config,
55 ) -> Self::LhsFragment {
56 let size = config.tile_size;
57
58 Fragment::<L> {
59 fragment: unsafe {
60 cmma::Matrix::<L>::uninitialized(
61 cmma::MatrixIdent::A,
62 size.m(),
63 size.n(),
64 size.k(),
65 as_cmma_layout(layout),
66 )
67 },
68 layout,
69 }
70 }
71
72 fn allocate_rhs(
73 #[comptime] layout: MatrixLayout,
74 #[comptime] config: Self::Config,
75 ) -> Self::RhsFragment {
76 let size = config.tile_size;
77
78 Fragment::<R> {
79 fragment: unsafe {
80 cmma::Matrix::<R>::uninitialized(
81 cmma::MatrixIdent::B,
82 size.m(),
83 size.n(),
84 size.k(),
85 as_cmma_layout(layout),
86 )
87 },
88 layout,
89 }
90 }
91
92 fn allocate_acc(
93 #[comptime] layout: MatrixLayout,
94 #[comptime] config: Self::Config,
95 ) -> Self::AccFragment {
96 let size = config.tile_size;
97
98 Fragment::<A> {
99 fragment: unsafe {
100 cmma::Matrix::<A>::uninitialized(
101 cmma::MatrixIdent::Accumulator,
102 size.m(),
103 size.n(),
104 size.k(),
105 cmma::MatrixLayout::Undefined,
106 )
107 },
108 layout,
109 }
110 }
111
112 fn load_lhs<E: Numeric>(
113 tile: &StridedTile<E>,
114 lhs: &mut Self::LhsFragment,
115 #[comptime] _config: Self::Config,
116 ) {
117 CmmaStageReader::<Self::LhsTile>::load_fragment(
118 tile,
119 &mut lhs.fragment,
120 CubeOption::new_None(),
121 );
122 }
123
124 fn load_rhs<E: Numeric>(
125 tile: &StridedTile<E>,
126 rhs: &mut Self::RhsFragment,
127 #[comptime] _config: Self::Config,
128 ) {
129 CmmaStageReader::<Self::RhsTile>::load_fragment(
130 tile,
131 &mut rhs.fragment,
132 CubeOption::new_None(),
133 );
134 }
135
136 fn load_acc<E: Numeric>(
137 tile: &AccTile::Tile<E>,
138 acc: &mut Self::AccFragment,
139 #[comptime] _config: Self::Config,
140 ) {
141 CmmaStageReader::<Self::AccTile>::load_fragment(
142 tile,
143 &mut acc.fragment,
144 CubeOption::new_Some(as_cmma_layout(acc.layout)),
145 );
146 }
147
148 fn write_results<E: Numeric>(
149 tile: &mut StridedTile<E, ReadWrite>,
150 out: &Self::AccFragment,
151 #[comptime] _config: Self::Config,
152 ) {
153 let out = cmma::cast::<A, E>(&out.fragment);
154 CmmaStageWriter::store_fragment(tile, &out);
155 }
156}