1use cubecl::{
2 cmma::MmaDefinition,
3 define_size,
4 ir::{DeviceProperties, MatrixIdent, StorageType},
5 prelude::*,
6};
7
8use crate::{
9 MatrixLayout, SwizzleModes, TileSize,
10 tile::{Tile, TileScope},
11};
12
13define_size!(pub NL);
17define_size!(pub NR);
18define_size!(pub NA);
19
20#[derive(CubeType)]
25pub struct MmaTile<N: Numeric> {
26 pub fragment: MmaFragment<N>,
27 #[cube(comptime)]
28 pub matrix_layout: MatrixLayout,
29 #[cube(comptime)]
30 pub config: MmaMatmul,
31}
32
33#[derive(CubeType)]
34pub enum MmaFragment<N: Numeric> {
35 Lhs(Array<Vector<N, NL>>),
36 Rhs(Array<Vector<N, NR>>),
37 Acc(Array<Vector<N, NA>>),
38}
39
40#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
41pub struct MmaMatmul {
42 pub tile_size: TileSize,
43 pub plane_dim: u32,
44 pub swizzle_modes: SwizzleModes,
45 pub mma_io_config: MmaIOConfig,
46}
47
48#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
49pub struct MmaIOConfig {
50 pub lhs_load_method: LoadMethod,
51 pub rhs_load_method: LoadMethod,
52 pub acc_load_method: LoadMethod,
53 pub store_method: StoreMethod,
54}
55
56#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
57pub enum LoadMethod {
58 Manual,
59 LoadMatrix,
60}
61
62#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
63pub enum StoreMethod {
64 Manual,
65 StoreMatrix,
66}
67
68impl MmaIOConfig {
69 pub fn new(
70 device_props: &DeviceProperties,
71 lhs_stage: StorageType,
72 rhs_stage: StorageType,
73 acc_stage: StorageType,
74 ) -> Self {
75 Self {
76 lhs_load_method: load_method(device_props, lhs_stage),
77 rhs_load_method: load_method(device_props, rhs_stage),
78 acc_load_method: load_method(device_props, acc_stage),
79 store_method: store_method(device_props, acc_stage),
80 }
81 }
82
83 pub fn load_method(&self, ident: MatrixIdent) -> LoadMethod {
84 match ident {
85 MatrixIdent::A => self.lhs_load_method,
86 MatrixIdent::B => self.rhs_load_method,
87 MatrixIdent::Accumulator => self.acc_load_method,
88 }
89 }
90
91 pub fn store_method(&self) -> StoreMethod {
92 self.store_method
93 }
94}
95
96fn load_method(device_props: &DeviceProperties, dtype: StorageType) -> LoadMethod {
97 if !matches!(dtype, StorageType::Packed(_, _))
98 && device_props.features.matmul.ldmatrix.contains(&dtype)
99 {
100 LoadMethod::LoadMatrix
101 } else {
102 LoadMethod::Manual
103 }
104}
105
106fn store_method(device_props: &DeviceProperties, dtype: StorageType) -> StoreMethod {
107 if !matches!(dtype, StorageType::Packed(_, _))
108 && device_props.features.matmul.stmatrix.contains(&dtype)
109 {
110 StoreMethod::StoreMatrix
111 } else {
112 StoreMethod::Manual
113 }
114}
115
116#[cube]
117fn make_mma_definition<L: Numeric, R: Numeric, A: Numeric>(
118 #[comptime] config: MmaMatmul,
119) -> MmaDefinition<L, R, A> {
120 MmaDefinition::new(
121 config.tile_size.m() as usize,
122 config.tile_size.n() as usize,
123 config.tile_size.k() as usize,
124 )
125}
126
127#[cube]
128#[allow(unused_variables)]
129pub fn mma_register_vector_sizes<L: Numeric, R: Numeric, A: Numeric>(def: MmaDefinition<L, R, A>) {
130 let vector_size_a = def.vector_size(MatrixIdent::A);
131 let vector_size_b = def.vector_size(MatrixIdent::B);
132 let vector_size_acc = def.vector_size(MatrixIdent::Accumulator);
133 intrinsic!(|scope| {
134 scope.register_size::<NL>(vector_size_a);
135 scope.register_size::<NR>(vector_size_b);
136 scope.register_size::<NA>(vector_size_acc);
137 });
138}
139
140#[cube]
141pub fn mma_allocate_lhs<L: Numeric, R: Numeric, A: Numeric, Sc: TileScope>(
142 #[comptime] layout: MatrixLayout,
143 #[comptime] config: MmaMatmul,
144) -> Tile<L, Sc, ReadWrite> {
145 let def = make_mma_definition::<L, R, A>(config);
146 mma_register_vector_sizes(def);
147 let vector_count = def.vectors_per_lane(MatrixIdent::A);
148
149 Tile::new_Mma(MmaTile::<L> {
150 fragment: MmaFragment::new_Lhs(Array::new(vector_count)),
151 matrix_layout: layout,
152 config,
153 })
154}
155
156#[cube]
157pub fn mma_allocate_rhs<R: Numeric, L: Numeric, A: Numeric, Sc: TileScope>(
158 #[comptime] layout: MatrixLayout,
159 #[comptime] config: MmaMatmul,
160) -> Tile<R, Sc, ReadWrite> {
161 let def = make_mma_definition::<L, R, A>(config);
162 mma_register_vector_sizes(def);
163 let vector_count = def.vectors_per_lane(MatrixIdent::B);
164
165 Tile::new_Mma(MmaTile::<R> {
166 fragment: MmaFragment::new_Rhs(Array::new(vector_count)),
167 matrix_layout: layout,
168 config,
169 })
170}
171
172#[cube]
173pub fn mma_allocate_acc<A: Numeric, L: Numeric, R: Numeric, Sc: TileScope>(
174 #[comptime] layout: MatrixLayout,
175 #[comptime] config: MmaMatmul,
176) -> Tile<A, Sc, ReadWrite> {
177 let def = make_mma_definition::<L, R, A>(config);
178 mma_register_vector_sizes(def);
179 let vector_count = def.vectors_per_lane(MatrixIdent::Accumulator);
180
181 Tile::new_Mma(MmaTile::<A> {
182 fragment: MmaFragment::new_Acc(Array::new(vector_count)),
183 matrix_layout: layout,
184 config,
185 })
186}