cubek_std/tile/mma/
config.rs1use cubecl::ir::{DeviceProperties, MatrixIdent, StorageType};
2
3#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
4pub struct MmaIOConfig {
5 pub lhs_load_method: LoadMethod,
6 pub rhs_load_method: LoadMethod,
7 pub acc_load_method: LoadMethod,
8 pub store_method: StoreMethod,
9}
10
11#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
12pub enum LoadMethod {
13 Manual,
14 LoadMatrix,
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum StoreMethod {
19 Manual,
20 StoreMatrix,
21}
22
23impl MmaIOConfig {
24 pub fn new(
25 device_props: &DeviceProperties,
26 lhs_stage: StorageType,
27 rhs_stage: StorageType,
28 acc_stage: StorageType,
29 ) -> Self {
30 Self {
31 lhs_load_method: load_method(device_props, lhs_stage),
32 rhs_load_method: load_method(device_props, rhs_stage),
33 acc_load_method: load_method(device_props, acc_stage),
34 store_method: store_method(device_props, acc_stage),
35 }
36 }
37
38 pub fn load_method(&self, ident: MatrixIdent) -> LoadMethod {
39 match ident {
40 MatrixIdent::A => self.lhs_load_method,
41 MatrixIdent::B => self.rhs_load_method,
42 MatrixIdent::Accumulator => self.acc_load_method,
43 }
44 }
45
46 pub fn store_method(&self) -> StoreMethod {
47 self.store_method
48 }
49}
50
51fn load_method(device_props: &DeviceProperties, dtype: StorageType) -> LoadMethod {
52 if !matches!(dtype, StorageType::Packed(_, _))
53 && device_props.features.matmul.ldmatrix.contains(&dtype)
54 {
55 LoadMethod::LoadMatrix
56 } else {
57 LoadMethod::Manual
58 }
59}
60
61fn store_method(device_props: &DeviceProperties, dtype: StorageType) -> StoreMethod {
62 if !matches!(dtype, StorageType::Packed(_, _))
63 && device_props.features.matmul.stmatrix.contains(&dtype)
64 {
65 StoreMethod::StoreMatrix
66 } else {
67 StoreMethod::Manual
68 }
69}