Skip to main content

cubek_std/tile/data/
mma.rs

1use cubecl::{
2    cmma::MmaDefinition,
3    define_size,
4    ir::{DeviceProperties, MatrixIdent, StorageType},
5    prelude::*,
6};
7
8use crate::{
9    MatrixLayout, SwizzleModes, TileSize,
10    tile::{Tile, TileScope},
11};
12
13// Fragment inner vector sizes for the three MMA roles. Bound at allocation time
14// via `mma_register_vector_sizes` to match the hardware's `def.vector_size(...)`
15// for each role — these are independent of the outer Tile enum's stage vector `V`.
16define_size!(pub NL);
17define_size!(pub NR);
18define_size!(pub NA);
19
20/// Single MMA tile carrier. The role (Lhs / Rhs / Acc) lives inside
21/// [`MmaFragment`] because each role's fragment uses a different inner vector
22/// size (`NL` / `NR` / `NA`); the outer carrier holds the shared comptime
23/// metadata.
24#[derive(CubeType)]
25pub struct MmaTile<N: Numeric> {
26    pub fragment: MmaFragment<N>,
27    #[cube(comptime)]
28    pub matrix_layout: MatrixLayout,
29    #[cube(comptime)]
30    pub config: MmaMatmul,
31}
32
33#[derive(CubeType)]
34pub enum MmaFragment<N: Numeric> {
35    Lhs(Array<Vector<N, NL>>),
36    Rhs(Array<Vector<N, NR>>),
37    Acc(Array<Vector<N, NA>>),
38}
39
40#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
41pub struct MmaMatmul {
42    pub tile_size: TileSize,
43    pub plane_dim: u32,
44    pub swizzle_modes: SwizzleModes,
45    pub mma_io_config: MmaIOConfig,
46}
47
48#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
49pub struct MmaIOConfig {
50    pub lhs_load_method: LoadMethod,
51    pub rhs_load_method: LoadMethod,
52    pub acc_load_method: LoadMethod,
53    pub store_method: StoreMethod,
54}
55
56#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
57pub enum LoadMethod {
58    Manual,
59    LoadMatrix,
60}
61
62#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
63pub enum StoreMethod {
64    Manual,
65    StoreMatrix,
66}
67
68impl MmaIOConfig {
69    pub fn new(
70        device_props: &DeviceProperties,
71        lhs_stage: StorageType,
72        rhs_stage: StorageType,
73        acc_stage: StorageType,
74    ) -> Self {
75        Self {
76            lhs_load_method: load_method(device_props, lhs_stage),
77            rhs_load_method: load_method(device_props, rhs_stage),
78            acc_load_method: load_method(device_props, acc_stage),
79            store_method: store_method(device_props, acc_stage),
80        }
81    }
82
83    pub fn load_method(&self, ident: MatrixIdent) -> LoadMethod {
84        match ident {
85            MatrixIdent::A => self.lhs_load_method,
86            MatrixIdent::B => self.rhs_load_method,
87            MatrixIdent::Accumulator => self.acc_load_method,
88        }
89    }
90
91    pub fn store_method(&self) -> StoreMethod {
92        self.store_method
93    }
94}
95
96fn load_method(device_props: &DeviceProperties, dtype: StorageType) -> LoadMethod {
97    if !matches!(dtype, StorageType::Packed(_, _))
98        && device_props.features.matmul.ldmatrix.contains(&dtype)
99    {
100        LoadMethod::LoadMatrix
101    } else {
102        LoadMethod::Manual
103    }
104}
105
106fn store_method(device_props: &DeviceProperties, dtype: StorageType) -> StoreMethod {
107    if !matches!(dtype, StorageType::Packed(_, _))
108        && device_props.features.matmul.stmatrix.contains(&dtype)
109    {
110        StoreMethod::StoreMatrix
111    } else {
112        StoreMethod::Manual
113    }
114}
115
116#[cube]
117fn make_mma_definition<L: Numeric, R: Numeric, A: Numeric>(
118    #[comptime] config: MmaMatmul,
119) -> MmaDefinition<L, R, A> {
120    MmaDefinition::new(
121        config.tile_size.m() as usize,
122        config.tile_size.n() as usize,
123        config.tile_size.k() as usize,
124    )
125}
126
127#[cube]
128#[allow(unused_variables)]
129pub fn mma_register_vector_sizes<L: Numeric, R: Numeric, A: Numeric>(def: MmaDefinition<L, R, A>) {
130    let vector_size_a = def.vector_size(MatrixIdent::A);
131    let vector_size_b = def.vector_size(MatrixIdent::B);
132    let vector_size_acc = def.vector_size(MatrixIdent::Accumulator);
133    intrinsic!(|scope| {
134        scope.register_size::<NL>(vector_size_a);
135        scope.register_size::<NR>(vector_size_b);
136        scope.register_size::<NA>(vector_size_acc);
137    });
138}
139
140#[cube]
141pub fn mma_allocate_lhs<L: Numeric, R: Numeric, A: Numeric, Sc: TileScope>(
142    #[comptime] layout: MatrixLayout,
143    #[comptime] config: MmaMatmul,
144) -> Tile<L, Sc, ReadWrite> {
145    let def = make_mma_definition::<L, R, A>(config);
146    mma_register_vector_sizes(def);
147    let vector_count = def.vectors_per_lane(MatrixIdent::A);
148
149    Tile::new_Mma(MmaTile::<L> {
150        fragment: MmaFragment::new_Lhs(Array::new(vector_count)),
151        matrix_layout: layout,
152        config,
153    })
154}
155
156#[cube]
157pub fn mma_allocate_rhs<R: Numeric, L: Numeric, A: Numeric, Sc: TileScope>(
158    #[comptime] layout: MatrixLayout,
159    #[comptime] config: MmaMatmul,
160) -> Tile<R, Sc, ReadWrite> {
161    let def = make_mma_definition::<L, R, A>(config);
162    mma_register_vector_sizes(def);
163    let vector_count = def.vectors_per_lane(MatrixIdent::B);
164
165    Tile::new_Mma(MmaTile::<R> {
166        fragment: MmaFragment::new_Rhs(Array::new(vector_count)),
167        matrix_layout: layout,
168        config,
169    })
170}
171
172#[cube]
173pub fn mma_allocate_acc<A: Numeric, L: Numeric, R: Numeric, Sc: TileScope>(
174    #[comptime] layout: MatrixLayout,
175    #[comptime] config: MmaMatmul,
176) -> Tile<A, Sc, ReadWrite> {
177    let def = make_mma_definition::<L, R, A>(config);
178    mma_register_vector_sizes(def);
179    let vector_count = def.vectors_per_lane(MatrixIdent::Accumulator);
180
181    Tile::new_Mma(MmaTile::<A> {
182        fragment: MmaFragment::new_Acc(Array::new(vector_count)),
183        matrix_layout: layout,
184        config,
185    })
186}