1pub mod ops;
2pub mod scope;
3pub mod variants;
4
5mod strided_tile;
6mod tile_kind;
7
8pub use ops::*;
9pub use scope::{Cube, Plane, Scope, ScopeKind, ScopeMarker, Unit};
10pub use strided_tile::*;
11pub use tile_kind::*;
12pub use variants::bounce_tile::*;
13pub use variants::cmma::*;
14pub use variants::interleaved::*;
15pub use variants::local_tile::*;
16pub use variants::mma::*;
17pub use variants::plane_vec_mat_inner_product::*;
18pub use variants::register::*;
19pub use variants::unit_tile::*;
20
21pub use variants::{cmma, interleaved, mma, plane_vec_mat_inner_product, register};
24
25use cubecl::cmma::Matrix as CubeMatrix;
26use cubecl::prelude::*;
27
28use crate::{MatrixLayout, StageIdent, tile::scope::Scope as TileScope};
29
30#[derive(CubeType)]
31pub enum Tile<N: Numeric, Sc: TileScope, IO: SliceVisibility> {
32 SharedMemory(SharedTile<N, IO>),
33 Cmma(CmmaTile<N>),
34 MmaLhs(MmaLhsTile<N>),
35 MmaRhs(MmaRhsTile<N>),
36 MmaAcc(MmaAccTile<N>),
37 Register(RegisterTile<N>),
38 PlaneVec(PlaneVecTile<N>),
39 Interleaved(InterleavedTile<N>),
40 Unit(UnitTile<N>),
43 Local(LocalTile<N>),
45 Bounce(BounceTile<N>),
49 Broadcasted(Value<N>),
50 None,
51 _Phantom(ScopeMarker<Sc>),
52}
53
54#[derive(CubeType)]
55pub struct CmmaTile<N: Numeric> {
56 pub matrix: CubeMatrix<N>,
57 #[cube(comptime)]
58 pub matrix_layout: MatrixLayout,
59 #[cube(comptime)]
60 pub tile_size: crate::TileSize,
61}
62
63#[derive(CubeType)]
64pub struct MmaLhsTile<N: Numeric> {
65 pub fragment: Array<Vector<N, mma::NL>>,
66 #[cube(comptime)]
67 pub matrix_layout: MatrixLayout,
68 #[cube(comptime)]
69 pub config: MmaMatmul,
70}
71
72#[derive(CubeType)]
73pub struct MmaRhsTile<N: Numeric> {
74 pub fragment: Array<Vector<N, mma::NR>>,
75 #[cube(comptime)]
76 pub matrix_layout: MatrixLayout,
77 #[cube(comptime)]
78 pub config: MmaMatmul,
79}
80
81#[derive(CubeType)]
82pub struct MmaAccTile<N: Numeric> {
83 pub fragment: Array<Vector<N, mma::NA>>,
84 #[cube(comptime)]
85 pub matrix_layout: MatrixLayout,
86 #[cube(comptime)]
87 pub config: MmaMatmul,
88}
89
90#[derive(CubeType)]
91pub struct RegisterTile<N: Numeric> {
92 pub data: Array<N>,
93 #[cube(comptime)]
94 pub matrix_layout: MatrixLayout,
95 #[cube(comptime)]
96 pub config: RegisterMatmul,
97}
98
99#[derive(CubeType)]
100pub struct PlaneVecTile<N: Numeric> {
101 pub data: Array<Vector<N, NPlaneVec>>,
103 #[cube(comptime)]
104 pub matrix_layout: MatrixLayout,
105 #[cube(comptime)]
106 pub config: PlaneVecMatInnerProduct,
107}
108
109#[derive(CubeType)]
111pub struct Value<E: Numeric> {
112 pub val: E,
113}
114
115#[cube]
116impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
117 pub fn mma<L: Numeric, R: Numeric>(
119 &mut self,
120 lhs: &Tile<L, Sc, ReadWrite>,
121 rhs: &Tile<R, Sc, ReadWrite>,
122 ) {
123 match (lhs, rhs, self) {
124 (Tile::Cmma(l), Tile::Cmma(r), Tile::Cmma(a)) => {
125 cmma_execute(&l.matrix, &r.matrix, &mut a.matrix);
126 }
127 (Tile::Cmma(l), Tile::Cmma(r), Tile::Bounce(a)) => {
128 cmma_execute(&l.matrix, &r.matrix, &mut a.cmma.matrix);
129 }
130 (Tile::Bounce(l), Tile::Cmma(r), Tile::Bounce(a)) => {
131 cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.cmma.matrix);
132 }
133 (Tile::Bounce(l), Tile::Cmma(r), Tile::Cmma(a)) => {
134 cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.matrix);
135 }
136 (Tile::MmaLhs(l), Tile::MmaRhs(r), Tile::MmaAcc(a)) => {
137 mma_execute(
138 &l.fragment,
139 &r.fragment,
140 &mut a.fragment,
141 a.matrix_layout,
142 a.config,
143 );
144 }
145 (Tile::Register(l), Tile::Register(r), Tile::Register(a)) => {
146 register_execute(&l.data, &r.data, &mut a.data, a.config);
147 }
148 (Tile::PlaneVec(l), Tile::PlaneVec(r), Tile::PlaneVec(a)) => {
149 planevec_execute(&l.data, &r.data, &mut a.data, a.config);
150 }
151 (Tile::Interleaved(l), Tile::Interleaved(r), Tile::Interleaved(a)) => {
152 interleaved_execute(
153 &l.data,
154 l.matrix_layout,
155 &r.data,
156 r.matrix_layout,
157 &mut a.data,
158 a.matrix_layout,
159 a.config,
160 );
161 }
162 _ => panic!("Unsupported storage combination for mma"),
163 }
164 }
165
166 pub fn copy_from<
173 SE: Numeric,
174 SS: Size,
175 L: Numeric,
176 R: Numeric,
177 A: Numeric,
178 SIO: SliceVisibility,
179 >(
180 &mut self,
181 source: &Tile<SE, Sc, SIO>,
182 #[comptime] ident: StageIdent,
183 ) {
184 match (source, self) {
185 (Tile::SharedMemory(shared), Tile::Cmma(t)) => {
187 let shared = shared.view::<SS>();
188 cmma_load_from_shared::<SE, SS, N, SIO>(
189 &shared,
190 &mut t.matrix,
191 ident,
192 t.matrix_layout,
193 );
194 }
195 (Tile::None, Tile::Cmma(t)) => {
196 cmma_load_zeros::<N>(&mut t.matrix);
197 }
198
199 (Tile::SharedMemory(shared), Tile::Bounce(b)) => {
201 let shared = shared.view::<SS>();
202 cmma_load_from_shared::<SE, SS, N, SIO>(
203 &shared,
204 &mut b.cmma.matrix,
205 ident,
206 b.cmma.matrix_layout,
207 );
208 }
209 (Tile::None, Tile::Bounce(b)) => {
210 cmma_load_zeros::<N>(&mut b.cmma.matrix);
211 }
212
213 (Tile::SharedMemory(shared), Tile::MmaLhs(t)) => {
215 let shared = shared.view::<SS>();
216 mma_load_lhs_from_shared::<SE, SS, N, R, A, SIO>(
217 &shared,
218 &mut t.fragment,
219 t.matrix_layout,
220 t.config,
221 );
222 }
223 (Tile::SharedMemory(shared), Tile::MmaRhs(t)) => {
224 let shared = shared.view::<SS>();
225 mma_load_rhs_from_shared::<SE, SS, N, L, A, SIO>(
226 &shared,
227 &mut t.fragment,
228 t.matrix_layout,
229 t.config,
230 );
231 }
232 (Tile::SharedMemory(shared), Tile::MmaAcc(t)) => {
233 let shared = shared.view::<SS>();
234 mma_load_acc_from_shared::<SE, SS, N, L, R, SIO>(
235 &shared,
236 &mut t.fragment,
237 t.matrix_layout,
238 t.config,
239 );
240 }
241 (Tile::None, Tile::MmaAcc(t)) => {
242 mma_load_acc_zeros::<SE, SS, N, L, R>(&mut t.fragment, t.matrix_layout, t.config);
243 }
244
245 (Tile::SharedMemory(shared), Tile::Register(t)) => {
247 let shared = shared.view::<SS>();
248 register_load_from_shared::<SE, SS, N, SIO>(
249 &shared,
250 &mut t.data,
251 t.matrix_layout,
252 t.config,
253 ident,
254 );
255 }
256 (Tile::None, Tile::Register(t)) => {
257 register_load_zeros::<N>(&mut t.data, t.config, ident);
258 }
259
260 (Tile::SharedMemory(shared), Tile::PlaneVec(t)) => {
262 let shared = shared.view::<SS>();
263 planevec_load_from_shared::<SE, SS, N, SIO>(&shared, &mut t.data, t.config, ident);
264 }
265 (Tile::None, Tile::PlaneVec(t)) => {
266 planevec_load_zeros::<N>(&mut t.data, t.config);
267 }
268
269 (Tile::SharedMemory(shared), Tile::Interleaved(t)) => {
271 let shared = shared.view::<SS>();
272 interleaved_load_from_shared::<SE, SS, N, SIO>(
273 &shared,
274 &mut t.data,
275 t.config,
276 ident,
277 );
278 }
279 (Tile::None, Tile::Interleaved(t)) => {
280 interleaved_load_zeros::<N>(&mut t.data, t.config);
281 }
282
283 (Tile::Cmma(t), Tile::SharedMemory(shared)) => {
285 let mut shared = shared.view::<SS>();
286 cmma_write_to_shared::<N, SS, SE>(&mut shared, &t.matrix);
287 }
288 (Tile::Bounce(b), Tile::SharedMemory(shared)) => {
289 let mut shared = shared.view::<SS>();
290 cmma_write_to_shared::<N, SS, SE>(&mut shared, &b.cmma.matrix);
291 }
292 (Tile::MmaAcc(t), Tile::SharedMemory(shared)) => {
293 let mut shared = shared.view::<SS>();
294 mma_write_to_shared::<N, SS, SE, L, R>(&mut shared, &t.fragment, t.config);
295 }
296 (Tile::Register(t), Tile::SharedMemory(shared)) => {
297 let mut shared = shared.view::<SS>();
298 register_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
299 }
300 (Tile::PlaneVec(t), Tile::SharedMemory(shared)) => {
301 let mut shared = shared.view::<SS>();
302 planevec_write_to_shared::<SE, N, SS>(&mut shared, &t.data, t.config);
303 }
304 (Tile::Interleaved(t), Tile::SharedMemory(shared)) => {
305 let mut shared = shared.view::<SS>();
306 interleaved_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
307 }
308
309 _ => panic!("Unsupported storage pair for copy_from"),
310 }
311 }
312}