Skip to main content

cubek_std/tile/mma/
config.rs

1use cubecl::ir::{DeviceProperties, MatrixIdent, StorageType};
2
3#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
4pub struct MmaIOConfig {
5    pub lhs_load_method: LoadMethod,
6    pub rhs_load_method: LoadMethod,
7    pub acc_load_method: LoadMethod,
8    pub store_method: StoreMethod,
9}
10
11#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
12pub enum LoadMethod {
13    Manual,
14    LoadMatrix,
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum StoreMethod {
19    Manual,
20    StoreMatrix,
21}
22
23impl MmaIOConfig {
24    pub fn new(
25        device_props: &DeviceProperties,
26        lhs_stage: StorageType,
27        rhs_stage: StorageType,
28        acc_stage: StorageType,
29    ) -> Self {
30        Self {
31            lhs_load_method: load_method(device_props, lhs_stage),
32            rhs_load_method: load_method(device_props, rhs_stage),
33            acc_load_method: load_method(device_props, acc_stage),
34            store_method: store_method(device_props, acc_stage),
35        }
36    }
37
38    pub fn load_method(&self, ident: MatrixIdent) -> LoadMethod {
39        match ident {
40            MatrixIdent::A => self.lhs_load_method,
41            MatrixIdent::B => self.rhs_load_method,
42            MatrixIdent::Accumulator => self.acc_load_method,
43        }
44    }
45
46    pub fn store_method(&self) -> StoreMethod {
47        self.store_method
48    }
49}
50
51fn load_method(device_props: &DeviceProperties, dtype: StorageType) -> LoadMethod {
52    if !matches!(dtype, StorageType::Packed(_, _))
53        && device_props.features.matmul.ldmatrix.contains(&dtype)
54    {
55        LoadMethod::LoadMatrix
56    } else {
57        LoadMethod::Manual
58    }
59}
60
61fn store_method(device_props: &DeviceProperties, dtype: StorageType) -> StoreMethod {
62    if !matches!(dtype, StorageType::Packed(_, _))
63        && device_props.features.matmul.stmatrix.contains(&dtype)
64    {
65        StoreMethod::StoreMatrix
66    } else {
67        StoreMethod::Manual
68    }
69}