cubek_std/tile/data/
cmma.rs1use cubecl::{
2 cmma::Matrix as CubeMatrix,
3 cmma::{self},
4 prelude::*,
5};
6
7use crate::{
8 MatrixLayout, SwizzleModes, TileSize, as_cmma_layout,
9 tile::{Tile, TileScope},
10};
11
12#[derive(CubeType)]
13pub struct CmmaTile<N: Numeric> {
14 pub matrix: CubeMatrix<N>,
15 #[cube(comptime)]
16 pub matrix_layout: MatrixLayout,
17 #[cube(comptime)]
18 pub tile_size: TileSize,
19}
20
21#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
22pub struct CmmaMatmul {
23 pub tile_size: TileSize,
24 pub plane_dim: u32,
25 pub swizzle_modes: SwizzleModes,
26}
27
28impl CmmaMatmul {
29 pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
30 Self {
31 tile_size,
32 plane_dim,
33 swizzle_modes,
34 }
35 }
36}
37
38#[cube]
39pub fn cmma_allocate_lhs<L: Numeric, Sc: TileScope>(
40 #[comptime] layout: MatrixLayout,
41 #[comptime] tile_size: TileSize,
42) -> Tile<L, Sc, ReadWrite> {
43 let fragment = unsafe {
44 cmma::Matrix::<L>::uninitialized(
45 cmma::MatrixIdent::A,
46 tile_size.m as usize,
47 tile_size.n as usize,
48 tile_size.k as usize,
49 as_cmma_layout(layout),
50 )
51 };
52 Tile::new_Cmma(CmmaTile::<L> {
53 matrix: fragment,
54 matrix_layout: layout,
55 tile_size,
56 })
57}
58
59#[cube]
60pub fn cmma_allocate_rhs<R: Numeric, Sc: TileScope>(
61 #[comptime] layout: MatrixLayout,
62 #[comptime] tile_size: TileSize,
63) -> Tile<R, Sc, ReadWrite> {
64 let fragment = unsafe {
65 cmma::Matrix::<R>::uninitialized(
66 cmma::MatrixIdent::B,
67 tile_size.m as usize,
68 tile_size.n as usize,
69 tile_size.k as usize,
70 as_cmma_layout(layout),
71 )
72 };
73 Tile::new_Cmma(CmmaTile::<R> {
74 matrix: fragment,
75 matrix_layout: layout,
76 tile_size,
77 })
78}
79
80#[cube]
81pub fn cmma_allocate_acc<A: Numeric, Sc: TileScope>(
82 #[comptime] layout: MatrixLayout,
83 #[comptime] tile_size: TileSize,
84) -> Tile<A, Sc, ReadWrite> {
85 let fragment = unsafe {
86 cmma::Matrix::<A>::uninitialized(
87 cmma::MatrixIdent::Accumulator,
88 tile_size.m as usize,
89 tile_size.n as usize,
90 tile_size.k as usize,
91 cmma::MatrixLayout::Undefined,
92 )
93 };
94 Tile::new_Cmma(CmmaTile::<A> {
95 matrix: fragment,
96 matrix_layout: layout,
97 tile_size,
98 })
99}