1use cubecl::prelude::*;
2
3#[derive(Debug, Clone, Copy)]
4pub enum MatmulDim {
6 M,
8 N,
10 K,
12}
13
14macro_rules! define_3d_size_base {
15 ($name:ident, $ty:ty) => {
16 #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
17 pub struct $name {
18 pub m: $ty,
19 pub n: $ty,
20 pub k: $ty,
21 }
22
23 impl $name {
24 pub fn new(m: u32, n: u32, k: u32) -> Self {
25 $name {
26 m: <$ty>::try_from(m).unwrap(),
27 n: <$ty>::try_from(n).unwrap(),
28 k: <$ty>::try_from(k).unwrap(),
29 }
30 }
31
32 pub fn get(&self, dim: MatmulDim) -> u32 {
33 (match dim {
34 MatmulDim::M => self.m,
35 MatmulDim::N => self.n,
36 MatmulDim::K => self.k,
37 }) as u32
38 }
39
40 pub fn m(&self) -> u32 {
41 self.get(MatmulDim::M)
42 }
43
44 pub fn n(&self) -> u32 {
45 self.get(MatmulDim::N)
46 }
47
48 pub fn k(&self) -> u32 {
49 self.get(MatmulDim::K)
50 }
51
52 pub fn mn(&self) -> u32 {
53 self.get(MatmulDim::M) * self.get(MatmulDim::N)
54 }
55
56 pub fn mk(&self) -> u32 {
57 self.get(MatmulDim::M) * self.get(MatmulDim::K)
58 }
59
60 pub fn nk(&self) -> u32 {
61 self.get(MatmulDim::N) * self.get(MatmulDim::K)
62 }
63
64 pub fn mnk(&self) -> u32 {
65 self.get(MatmulDim::M) * self.get(MatmulDim::N) * self.get(MatmulDim::K)
66 }
67 }
68 };
69}
70
71macro_rules! impl_from_tuple {
72 ($name:ident, $ty_struct:ty, $ty_tuple:ty) => {
73 impl From<($ty_tuple, $ty_tuple, $ty_tuple)> for $name {
74 fn from(value: ($ty_tuple, $ty_tuple, $ty_tuple)) -> Self {
75 Self {
76 m: value.0 as $ty_struct,
77 n: value.1 as $ty_struct,
78 k: value.2 as $ty_struct,
79 }
80 }
81 }
82
83 impl From<$name> for ($ty_tuple, $ty_tuple, $ty_tuple) {
84 fn from(value: $name) -> Self {
85 (
86 value.m as $ty_tuple,
87 value.n as $ty_tuple,
88 value.k as $ty_tuple,
89 )
90 }
91 }
92 };
93}
94
95define_3d_size_base!(TileSize, u32);
97impl_from_tuple!(TileSize, u32, u8);
98impl_from_tuple!(TileSize, u32, u32);
99impl_from_tuple!(TileSize, u32, i32);
100impl_from_tuple!(TileSize, u32, u16);
101impl_from_tuple!(TileSize, u32, usize);
102
103define_3d_size_base!(PartitionSize, u8);
105impl_from_tuple!(PartitionSize, u8, u8);
106impl_from_tuple!(PartitionSize, u8, u32);
107impl_from_tuple!(PartitionSize, u8, i32);
108impl_from_tuple!(PartitionSize, u8, u16);
109impl_from_tuple!(PartitionSize, u8, usize);
110
111define_3d_size_base!(StageSize, u8);
113impl_from_tuple!(StageSize, u8, u8);
114impl_from_tuple!(StageSize, u8, u32);
115impl_from_tuple!(StageSize, u8, i32);
116impl_from_tuple!(StageSize, u8, u16);
117impl_from_tuple!(StageSize, u8, usize);
118
119define_3d_size_base!(MatmulProblemSize, u32);
121impl_from_tuple!(MatmulProblemSize, u32, u8);
122impl_from_tuple!(MatmulProblemSize, u32, u32);
123impl_from_tuple!(MatmulProblemSize, u32, i32);
124impl_from_tuple!(MatmulProblemSize, u32, u16);
125impl_from_tuple!(MatmulProblemSize, u32, usize);
126
127#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
128pub struct GlobalPartitionSize {
130 pub m: u32,
131 pub n: u32,
132 pub batches: u32,
133}
134
135impl GlobalPartitionSize {
136 pub fn new(m: u32, n: u32, batches: u32) -> Self {
137 Self { m, n, batches }
138 }
139}