cubecl_matmul/components/tile/mma/
matmul.rs1use std::marker::PhantomData;
2
3use crate::components::tile::io::{Strided, TileKind};
4use crate::components::tile::{
5 TileConfig, TileMatmul,
6 mma::{reader::MmaStageReader, writer::MmaStageWriter},
7};
8use crate::components::tile::{mma::reader::MmaFragmentReader, tile_data::StridedTile};
9use crate::components::{StageIdent, tile::mma::config::MmaMatmulConfig};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, cmma::MmaDefinition, ir::MatrixIdent};
12
13pub struct MmaMatmul<Acc: TileKind> {
17 _ty: PhantomData<Acc>,
18}
19
20#[cube]
21impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
22 for MmaMatmul<AccTile>
23where
24 MmaStageReader<AccTile>: MmaFragmentReader<TileKind = AccTile>,
25{
26 type Config = MmaMatmulConfig;
27 type LhsFragment = Sequence<Line<L>>;
28 type RhsFragment = Sequence<Line<R>>;
29 type AccFragment = Sequence<Line<A>>;
30
31 type LhsTile = Strided;
32 type RhsTile = Strided;
33 type AccTile = AccTile;
34 type OutTile = Strided;
35
36 fn execute(
37 lhs: &Self::LhsFragment,
38 rhs: &Self::RhsFragment,
39 out: &mut Self::AccFragment,
40 #[comptime] config: Self::Config,
41 ) {
42 let def = mma_definition(config);
43 let out_arr = def.execute(lhs, rhs, out);
44 let num_lines = def.lines_per_lane(MatrixIdent::Accumulator);
45
46 #[unroll]
47 for i in 0..num_lines {
48 *out.index_mut(i) = out_arr[i];
49 }
50 }
51
52 fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment {
53 let def = mma_definition::<L, R, A>(config);
54 let line_size = def.line_size(MatrixIdent::A);
55 let mut frag = Sequence::new();
56 #[unroll]
57 for _ in 0..def.lines_per_lane(MatrixIdent::A) {
58 #[allow(unused_mut)]
60 let mut reg = Line::empty(line_size);
61 frag.push(reg);
62 }
63 frag
64 }
65
66 fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment {
67 let def = mma_definition::<L, R, A>(config);
68 let line_size = def.line_size(MatrixIdent::B);
69 let mut frag = Sequence::new();
70 #[unroll]
71 for _ in 0..def.lines_per_lane(MatrixIdent::B) {
72 #[allow(unused_mut)]
74 let mut reg = Line::empty(line_size);
75 frag.push(reg);
76 }
77 frag
78 }
79
80 fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {
81 let def = mma_definition::<L, R, A>(config);
82 let line_size = def.line_size(MatrixIdent::Accumulator);
83 let mut frag = Sequence::new();
84 #[unroll]
85 for _ in 0..def.lines_per_lane(MatrixIdent::Accumulator) {
86 #[allow(unused_mut)]
88 let mut reg = Line::empty(line_size);
89 frag.push(reg);
90 }
91 frag
92 }
93
94 fn load_lhs<E: Numeric>(
95 tile: &StridedTile<E>,
96 lhs: &mut Self::LhsFragment,
97 #[comptime] config: Self::Config,
98 ) {
99 MmaStageReader::<Self::LhsTile>::load_fragment(
100 tile,
101 lhs,
102 mma_definition::<L, R, A>(config),
103 MatrixIdent::A,
104 config.matrix_layout(StageIdent::Lhs),
105 );
106 }
107
108 fn load_rhs<E: Numeric>(
109 tile: &StridedTile<E>,
110 rhs: &mut Self::RhsFragment,
111 #[comptime] config: Self::Config,
112 ) {
113 MmaStageReader::<Self::LhsTile>::load_fragment(
114 tile,
115 rhs,
116 mma_definition::<L, R, A>(config),
117 MatrixIdent::B,
118 config.matrix_layout(StageIdent::Rhs),
119 );
120 }
121
122 fn load_acc<E: Numeric>(
123 tile: &AccTile::Tile<E>,
124 acc: &mut Self::AccFragment,
125 #[comptime] config: Self::Config,
126 ) {
127 MmaStageReader::<Self::AccTile>::load_fragment(
128 tile,
129 acc,
130 mma_definition::<L, R, A>(config),
131 MatrixIdent::Accumulator,
132 config.matrix_layout(StageIdent::Acc),
133 );
134 }
135
136 fn write_results<E: Numeric>(
137 tile: &mut StridedTile<E, ReadWrite>,
138 out: &Self::AccFragment,
139 #[comptime] config: Self::Config,
140 ) {
141 MmaStageWriter::store_fragment(
142 tile,
143 out,
144 mma_definition::<L, R, A>(config),
145 MatrixIdent::Accumulator,
146 config.matrix_layout(StageIdent::Out),
147 );
148 }
149}
150
151#[cube]
152pub(super) fn mma_definition<L: Numeric, R: Numeric, A: Numeric>(
153 #[comptime] config: MmaMatmulConfig,
154) -> MmaDefinition<L, R, A> {
155 let size = config.tile_size();
156 MmaDefinition::new(size.m(), size.n(), size.k())
157}