cubecl_matmul/components/tile/mma/
matmul.rs1use std::marker::PhantomData;
2
3use crate::components::MatrixLayout;
4use crate::components::tile::io::{Filled, Strided, TileKind};
5use crate::components::tile::mma::config::MmaMatmulConfig;
6use crate::components::tile::{
7 TileMatmul,
8 mma::{reader::MmaStageReader, writer::MmaStageWriter},
9};
10use crate::components::tile::{mma::reader::MmaFragmentReader, tile_data::StridedTile};
11use cubecl_core::prelude::*;
12use cubecl_core::{self as cubecl, cmma::MmaDefinition, ir::MatrixIdent};
13
14pub struct MmaMatmul<Lhs: TileKind = Strided, Rhs: TileKind = Strided, Acc: TileKind = Filled> {
18 _ty: PhantomData<(Lhs, Rhs, Acc)>,
19}
20
21#[derive(CubeType)]
22pub struct MmaFragment<E: Numeric> {
23 fragment: Array<Line<E>>,
24 #[cube(comptime)]
25 layout: MatrixLayout,
26}
27
28#[cube]
29impl<L: Numeric, R: Numeric, A: Numeric, LhsTile: TileKind, RhsTile: TileKind, AccTile: TileKind>
30 TileMatmul<L, R, A> for MmaMatmul<LhsTile, RhsTile, AccTile>
31where
32 MmaStageReader<LhsTile>: MmaFragmentReader<TileKind = LhsTile>,
33 MmaStageReader<RhsTile>: MmaFragmentReader<TileKind = RhsTile>,
34 MmaStageReader<AccTile>: MmaFragmentReader<TileKind = AccTile>,
35{
36 type Config = MmaMatmulConfig;
37
38 type LhsFragment = MmaFragment<L>;
39 type RhsFragment = MmaFragment<R>;
40 type AccFragment = MmaFragment<A>;
41
42 type LhsTile = LhsTile;
43 type RhsTile = RhsTile;
44 type AccTile = AccTile;
45 type OutTile = Strided;
46
47 fn execute(
48 lhs: &Self::LhsFragment,
49 rhs: &Self::RhsFragment,
50 out: &mut Self::AccFragment,
51 #[comptime] config: Self::Config,
52 ) {
53 let def = mma_definition(config);
54 let out_arr = def.execute(&lhs.fragment, &rhs.fragment, &out.fragment);
55 let num_lines = def.lines_per_lane(MatrixIdent::Accumulator);
56
57 #[unroll]
58 for i in 0..num_lines {
59 out.fragment[i] = out_arr[i];
60 }
61 }
62
63 fn allocate_lhs(
64 #[comptime] layout: MatrixLayout,
65 #[comptime] config: Self::Config,
66 ) -> Self::LhsFragment {
67 let def = mma_definition::<L, R, A>(config);
68 let line_size = def.line_size(MatrixIdent::A);
69 let line_count = def.lines_per_lane(MatrixIdent::A);
70
71 MmaFragment::<L> {
72 fragment: Array::vectorized(line_count, line_size),
73 layout,
74 }
75 }
76
77 fn allocate_rhs(
78 #[comptime] layout: MatrixLayout,
79 #[comptime] config: Self::Config,
80 ) -> Self::RhsFragment {
81 let def = mma_definition::<L, R, A>(config);
82 let line_size = def.line_size(MatrixIdent::B);
83 let line_count = def.lines_per_lane(MatrixIdent::B);
84
85 MmaFragment::<R> {
86 fragment: Array::vectorized(line_count, line_size),
87 layout,
88 }
89 }
90
91 fn allocate_acc(
92 #[comptime] layout: MatrixLayout,
93 #[comptime] config: Self::Config,
94 ) -> Self::AccFragment {
95 let def = mma_definition::<L, R, A>(config);
96 let line_size = def.line_size(MatrixIdent::Accumulator);
97 let line_count = def.lines_per_lane(MatrixIdent::Accumulator);
98
99 MmaFragment::<A> {
100 fragment: Array::vectorized(line_count, line_size),
101 layout,
102 }
103 }
104
105 fn load_lhs<E: Numeric>(
106 tile: &LhsTile::Tile<E>,
107 lhs: &mut Self::LhsFragment,
108 #[comptime] config: Self::Config,
109 ) {
110 MmaStageReader::<Self::LhsTile>::load_fragment(
111 tile,
112 &mut lhs.fragment,
113 mma_definition::<L, R, A>(config),
114 MatrixIdent::A,
115 lhs.layout,
116 config,
117 );
118 }
119
120 fn load_rhs<E: Numeric>(
121 tile: &RhsTile::Tile<E>,
122 rhs: &mut Self::RhsFragment,
123 #[comptime] config: Self::Config,
124 ) {
125 MmaStageReader::<Self::RhsTile>::load_fragment(
126 tile,
127 &mut rhs.fragment,
128 mma_definition::<L, R, A>(config),
129 MatrixIdent::B,
130 rhs.layout,
131 config,
132 );
133 }
134
135 fn load_acc<E: Numeric>(
136 tile: &AccTile::Tile<E>,
137 acc: &mut Self::AccFragment,
138 #[comptime] config: Self::Config,
139 ) {
140 MmaStageReader::<Self::AccTile>::load_fragment(
141 tile,
142 &mut acc.fragment,
143 mma_definition::<L, R, A>(config),
144 MatrixIdent::Accumulator,
145 acc.layout,
146 config,
147 );
148 }
149
150 fn write_results<E: Numeric>(
151 tile: &mut StridedTile<E, ReadWrite>,
152 out: &Self::AccFragment,
153 #[comptime] config: Self::Config,
154 ) {
155 MmaStageWriter::store_fragment(
156 tile,
157 &out.fragment,
158 mma_definition::<L, R, A>(config),
159 MatrixIdent::Accumulator,
160 tile.layout,
161 );
162 }
163}
164
165#[cube]
166pub(super) fn mma_definition<L: Numeric, R: Numeric, A: Numeric>(
167 #[comptime] config: MmaMatmulConfig,
168) -> MmaDefinition<L, R, A> {
169 let size = config.shared.tile_size;
170 MmaDefinition::new(size.m(), size.n(), size.k())
171}