Skip to main content

cubek_std/tile/data/
cmma.rs

1use cubecl::{
2    cmma::Matrix as CubeMatrix,
3    cmma::{self},
4    prelude::*,
5};
6
7use crate::{
8    MatrixLayout, SwizzleModes, TileSize, as_cmma_layout,
9    tile::{Tile, TileScope},
10};
11
12#[derive(CubeType)]
13pub struct CmmaTile<N: Numeric> {
14    pub matrix: CubeMatrix<N>,
15    #[cube(comptime)]
16    pub matrix_layout: MatrixLayout,
17    #[cube(comptime)]
18    pub tile_size: TileSize,
19}
20
21#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
22pub struct CmmaMatmul {
23    pub tile_size: TileSize,
24    pub plane_dim: u32,
25    pub swizzle_modes: SwizzleModes,
26}
27
28impl CmmaMatmul {
29    pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
30        Self {
31            tile_size,
32            plane_dim,
33            swizzle_modes,
34        }
35    }
36}
37
38#[cube]
39pub fn cmma_allocate_lhs<L: Numeric, Sc: TileScope>(
40    #[comptime] layout: MatrixLayout,
41    #[comptime] tile_size: TileSize,
42) -> Tile<L, Sc, ReadWrite> {
43    let fragment = unsafe {
44        cmma::Matrix::<L>::uninitialized(
45            cmma::MatrixIdent::A,
46            tile_size.m as usize,
47            tile_size.n as usize,
48            tile_size.k as usize,
49            as_cmma_layout(layout),
50        )
51    };
52    Tile::new_Cmma(CmmaTile::<L> {
53        matrix: fragment,
54        matrix_layout: layout,
55        tile_size,
56    })
57}
58
59#[cube]
60pub fn cmma_allocate_rhs<R: Numeric, Sc: TileScope>(
61    #[comptime] layout: MatrixLayout,
62    #[comptime] tile_size: TileSize,
63) -> Tile<R, Sc, ReadWrite> {
64    let fragment = unsafe {
65        cmma::Matrix::<R>::uninitialized(
66            cmma::MatrixIdent::B,
67            tile_size.m as usize,
68            tile_size.n as usize,
69            tile_size.k as usize,
70            as_cmma_layout(layout),
71        )
72    };
73    Tile::new_Cmma(CmmaTile::<R> {
74        matrix: fragment,
75        matrix_layout: layout,
76        tile_size,
77    })
78}
79
80#[cube]
81pub fn cmma_allocate_acc<A: Numeric, Sc: TileScope>(
82    #[comptime] layout: MatrixLayout,
83    #[comptime] tile_size: TileSize,
84) -> Tile<A, Sc, ReadWrite> {
85    let fragment = unsafe {
86        cmma::Matrix::<A>::uninitialized(
87            cmma::MatrixIdent::Accumulator,
88            tile_size.m as usize,
89            tile_size.n as usize,
90            tile_size.k as usize,
91            cmma::MatrixLayout::Undefined,
92        )
93    };
94    Tile::new_Cmma(CmmaTile::<A> {
95        matrix: fragment,
96        matrix_layout: layout,
97        tile_size,
98    })
99}