Skip to main content

cubek_std/tile/
mod.rs

1pub mod ops;
2pub mod scope;
3pub mod variants;
4
5mod strided_tile;
6mod tile_kind;
7
8pub use ops::*;
9pub use scope::{Cube, Plane, Scope, ScopeKind, ScopeMarker, Unit};
10pub use strided_tile::*;
11pub use tile_kind::*;
12pub use variants::bounce_tile::*;
13pub use variants::cmma::*;
14pub use variants::interleaved::*;
15pub use variants::local_tile::*;
16pub use variants::mma::*;
17pub use variants::plane_vec_mat_inner_product::*;
18pub use variants::register::*;
19pub use variants::unit_tile::*;
20
21// Re-export the variant modules at their old paths so existing consumers keep
22// working (e.g. `cubek_std::tile::cmma`, `cubek_std::tile::mma`).
23pub use variants::{cmma, interleaved, mma, plane_vec_mat_inner_product, register};
24
25use cubecl::cmma::Matrix as CubeMatrix;
26use cubecl::prelude::*;
27
28use crate::{MatrixLayout, StageIdent, tile::scope::Scope as TileScope};
29
30#[derive(CubeType)]
31pub enum Tile<N: Numeric, Sc: TileScope, IO: SliceVisibility> {
32    SharedMemory(SharedTile<N, IO>),
33    Cmma(CmmaTile<N>),
34    MmaLhs(MmaLhsTile<N>),
35    MmaRhs(MmaRhsTile<N>),
36    MmaAcc(MmaAccTile<N>),
37    Register(RegisterTile<N>),
38    PlaneVec(PlaneVecTile<N>),
39    Interleaved(InterleavedTile<N>),
40    /// Each unit holds a full row-major copy of the tile in registers.
41    /// Only valid when `Sc = Unit`.
42    Unit(UnitTile<N>),
43    /// The tile is fragmented across plane units. Only valid when `Sc = Plane`.
44    Local(LocalTile<N>),
45    /// Bundles a cmma fragment, an smem scratch slice, and a `LocalTile` view.
46    /// From the caller's perspective it is a single tile; the smem round-trip
47    /// is internal to ops dispatch. Only valid when `Sc = Plane`.
48    Bounce(BounceTile<N>),
49    Broadcasted(Value<N>),
50    None,
51    _Phantom(ScopeMarker<Sc>),
52}
53
54#[derive(CubeType)]
55pub struct CmmaTile<N: Numeric> {
56    pub matrix: CubeMatrix<N>,
57    #[cube(comptime)]
58    pub matrix_layout: MatrixLayout,
59    #[cube(comptime)]
60    pub tile_size: crate::TileSize,
61}
62
63#[derive(CubeType)]
64pub struct MmaLhsTile<N: Numeric> {
65    pub fragment: Array<Vector<N, mma::NL>>,
66    #[cube(comptime)]
67    pub matrix_layout: MatrixLayout,
68    #[cube(comptime)]
69    pub config: MmaMatmul,
70}
71
72#[derive(CubeType)]
73pub struct MmaRhsTile<N: Numeric> {
74    pub fragment: Array<Vector<N, mma::NR>>,
75    #[cube(comptime)]
76    pub matrix_layout: MatrixLayout,
77    #[cube(comptime)]
78    pub config: MmaMatmul,
79}
80
81#[derive(CubeType)]
82pub struct MmaAccTile<N: Numeric> {
83    pub fragment: Array<Vector<N, mma::NA>>,
84    #[cube(comptime)]
85    pub matrix_layout: MatrixLayout,
86    #[cube(comptime)]
87    pub config: MmaMatmul,
88}
89
90#[derive(CubeType)]
91pub struct RegisterTile<N: Numeric> {
92    pub data: Array<N>,
93    #[cube(comptime)]
94    pub matrix_layout: MatrixLayout,
95    #[cube(comptime)]
96    pub config: RegisterMatmul,
97}
98
99#[derive(CubeType)]
100pub struct PlaneVecTile<N: Numeric> {
101    // Fragment inner size is `NPlaneVec` (= reduce_vector_size).
102    pub data: Array<Vector<N, NPlaneVec>>,
103    #[cube(comptime)]
104    pub matrix_layout: MatrixLayout,
105    #[cube(comptime)]
106    pub config: PlaneVecMatInnerProduct,
107}
108
109/// Wrapper over val to make enum work
110#[derive(CubeType)]
111pub struct Value<E: Numeric> {
112    pub val: E,
113}
114
115#[cube]
116impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
117    /// Executes `lhs · rhs`, accumulating the result into `self`.
118    pub fn mma<L: Numeric, R: Numeric>(
119        &mut self,
120        lhs: &Tile<L, Sc, ReadWrite>,
121        rhs: &Tile<R, Sc, ReadWrite>,
122    ) {
123        match (lhs, rhs, self) {
124            (Tile::Cmma(l), Tile::Cmma(r), Tile::Cmma(a)) => {
125                cmma_execute(&l.matrix, &r.matrix, &mut a.matrix);
126            }
127            (Tile::Cmma(l), Tile::Cmma(r), Tile::Bounce(a)) => {
128                cmma_execute(&l.matrix, &r.matrix, &mut a.cmma.matrix);
129            }
130            (Tile::Bounce(l), Tile::Cmma(r), Tile::Bounce(a)) => {
131                cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.cmma.matrix);
132            }
133            (Tile::Bounce(l), Tile::Cmma(r), Tile::Cmma(a)) => {
134                cmma_execute(&l.cmma.matrix, &r.matrix, &mut a.matrix);
135            }
136            (Tile::MmaLhs(l), Tile::MmaRhs(r), Tile::MmaAcc(a)) => {
137                mma_execute(
138                    &l.fragment,
139                    &r.fragment,
140                    &mut a.fragment,
141                    a.matrix_layout,
142                    a.config,
143                );
144            }
145            (Tile::Register(l), Tile::Register(r), Tile::Register(a)) => {
146                register_execute(&l.data, &r.data, &mut a.data, a.config);
147            }
148            (Tile::PlaneVec(l), Tile::PlaneVec(r), Tile::PlaneVec(a)) => {
149                planevec_execute(&l.data, &r.data, &mut a.data, a.config);
150            }
151            (Tile::Interleaved(l), Tile::Interleaved(r), Tile::Interleaved(a)) => {
152                interleaved_execute(
153                    &l.data,
154                    l.matrix_layout,
155                    &r.data,
156                    r.matrix_layout,
157                    &mut a.data,
158                    a.matrix_layout,
159                    a.config,
160                );
161            }
162            _ => panic!("Unsupported storage combination for mma"),
163        }
164    }
165
166    /// Copies data from `source` into `self`.
167    ///
168    /// `SS` is the vector size of the shared memory tile involved in the copy
169    /// (whether that's the source on a load, or the destination on a write).
170    /// `L`/`R`/`A` are the matrix-level numeric types needed by the MMA
171    /// readers/writers — they are unused on non-MMA paths.
172    pub fn copy_from<
173        SE: Numeric,
174        SS: Size,
175        L: Numeric,
176        R: Numeric,
177        A: Numeric,
178        SIO: SliceVisibility,
179    >(
180        &mut self,
181        source: &Tile<SE, Sc, SIO>,
182        #[comptime] ident: StageIdent,
183    ) {
184        match (source, self) {
185            // --- Cmma loads ---
186            (Tile::SharedMemory(shared), Tile::Cmma(t)) => {
187                let shared = shared.view::<SS>();
188                cmma_load_from_shared::<SE, SS, N, SIO>(
189                    &shared,
190                    &mut t.matrix,
191                    ident,
192                    t.matrix_layout,
193                );
194            }
195            (Tile::None, Tile::Cmma(t)) => {
196                cmma_load_zeros::<N>(&mut t.matrix);
197            }
198
199            // --- Bounce loads (delegate to inner cmma) ---
200            (Tile::SharedMemory(shared), Tile::Bounce(b)) => {
201                let shared = shared.view::<SS>();
202                cmma_load_from_shared::<SE, SS, N, SIO>(
203                    &shared,
204                    &mut b.cmma.matrix,
205                    ident,
206                    b.cmma.matrix_layout,
207                );
208            }
209            (Tile::None, Tile::Bounce(b)) => {
210                cmma_load_zeros::<N>(&mut b.cmma.matrix);
211            }
212
213            // --- Mma loads ---
214            (Tile::SharedMemory(shared), Tile::MmaLhs(t)) => {
215                let shared = shared.view::<SS>();
216                mma_load_lhs_from_shared::<SE, SS, N, R, A, SIO>(
217                    &shared,
218                    &mut t.fragment,
219                    t.matrix_layout,
220                    t.config,
221                );
222            }
223            (Tile::SharedMemory(shared), Tile::MmaRhs(t)) => {
224                let shared = shared.view::<SS>();
225                mma_load_rhs_from_shared::<SE, SS, N, L, A, SIO>(
226                    &shared,
227                    &mut t.fragment,
228                    t.matrix_layout,
229                    t.config,
230                );
231            }
232            (Tile::SharedMemory(shared), Tile::MmaAcc(t)) => {
233                let shared = shared.view::<SS>();
234                mma_load_acc_from_shared::<SE, SS, N, L, R, SIO>(
235                    &shared,
236                    &mut t.fragment,
237                    t.matrix_layout,
238                    t.config,
239                );
240            }
241            (Tile::None, Tile::MmaAcc(t)) => {
242                mma_load_acc_zeros::<SE, SS, N, L, R>(&mut t.fragment, t.matrix_layout, t.config);
243            }
244
245            // --- Register loads ---
246            (Tile::SharedMemory(shared), Tile::Register(t)) => {
247                let shared = shared.view::<SS>();
248                register_load_from_shared::<SE, SS, N, SIO>(
249                    &shared,
250                    &mut t.data,
251                    t.matrix_layout,
252                    t.config,
253                    ident,
254                );
255            }
256            (Tile::None, Tile::Register(t)) => {
257                register_load_zeros::<N>(&mut t.data, t.config, ident);
258            }
259
260            // --- PlaneVec loads ---
261            (Tile::SharedMemory(shared), Tile::PlaneVec(t)) => {
262                let shared = shared.view::<SS>();
263                planevec_load_from_shared::<SE, SS, N, SIO>(&shared, &mut t.data, t.config, ident);
264            }
265            (Tile::None, Tile::PlaneVec(t)) => {
266                planevec_load_zeros::<N>(&mut t.data, t.config);
267            }
268
269            // --- Interleaved loads ---
270            (Tile::SharedMemory(shared), Tile::Interleaved(t)) => {
271                let shared = shared.view::<SS>();
272                interleaved_load_from_shared::<SE, SS, N, SIO>(
273                    &shared,
274                    &mut t.data,
275                    t.config,
276                    ident,
277                );
278            }
279            (Tile::None, Tile::Interleaved(t)) => {
280                interleaved_load_zeros::<N>(&mut t.data, t.config);
281            }
282
283            // --- Writes: shared memory copies from a compute container ---
284            (Tile::Cmma(t), Tile::SharedMemory(shared)) => {
285                let mut shared = shared.view::<SS>();
286                cmma_write_to_shared::<N, SS, SE>(&mut shared, &t.matrix);
287            }
288            (Tile::Bounce(b), Tile::SharedMemory(shared)) => {
289                let mut shared = shared.view::<SS>();
290                cmma_write_to_shared::<N, SS, SE>(&mut shared, &b.cmma.matrix);
291            }
292            (Tile::MmaAcc(t), Tile::SharedMemory(shared)) => {
293                let mut shared = shared.view::<SS>();
294                mma_write_to_shared::<N, SS, SE, L, R>(&mut shared, &t.fragment, t.config);
295            }
296            (Tile::Register(t), Tile::SharedMemory(shared)) => {
297                let mut shared = shared.view::<SS>();
298                register_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
299            }
300            (Tile::PlaneVec(t), Tile::SharedMemory(shared)) => {
301                let mut shared = shared.view::<SS>();
302                planevec_write_to_shared::<SE, N, SS>(&mut shared, &t.data, t.config);
303            }
304            (Tile::Interleaved(t), Tile::SharedMemory(shared)) => {
305                let mut shared = shared.view::<SS>();
306                interleaved_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
307            }
308
309            _ => panic!("Unsupported storage pair for copy_from"),
310        }
311    }
312}