cubecl_matmul/components/
size.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5pub enum MatmulDim {
7 M,
9 N,
11 K,
13}
14
15macro_rules! define_3d_size_base {
16 ($name:ident, $ty:ty) => {
17 #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
18 pub struct $name {
19 pub m: $ty,
20 pub n: $ty,
21 pub k: $ty,
22 }
23
24 impl $name {
25 pub fn new(m: u32, n: u32, k: u32) -> Self {
26 $name {
27 m: <$ty>::try_from(m).unwrap(),
28 n: <$ty>::try_from(n).unwrap(),
29 k: <$ty>::try_from(k).unwrap(),
30 }
31 }
32
33 pub fn get(&self, dim: MatmulDim) -> u32 {
34 (match dim {
35 MatmulDim::M => self.m,
36 MatmulDim::N => self.n,
37 MatmulDim::K => self.k,
38 }) as u32
39 }
40
41 pub fn m(&self) -> u32 {
42 self.get(MatmulDim::M)
43 }
44
45 pub fn n(&self) -> u32 {
46 self.get(MatmulDim::N)
47 }
48
49 pub fn k(&self) -> u32 {
50 self.get(MatmulDim::K)
51 }
52
53 pub fn mn(&self) -> u32 {
54 self.get(MatmulDim::M) * self.get(MatmulDim::N)
55 }
56
57 pub fn mk(&self) -> u32 {
58 self.get(MatmulDim::M) * self.get(MatmulDim::K)
59 }
60
61 pub fn nk(&self) -> u32 {
62 self.get(MatmulDim::N) * self.get(MatmulDim::K)
63 }
64
65 pub fn mnk(&self) -> u32 {
66 self.get(MatmulDim::M) * self.get(MatmulDim::N) * self.get(MatmulDim::K)
67 }
68 }
69 };
70}
71
72macro_rules! impl_from_tuple {
73 ($name:ident, $ty_struct:ty, $ty_tuple:ty) => {
74 impl From<($ty_tuple, $ty_tuple, $ty_tuple)> for $name {
75 fn from(value: ($ty_tuple, $ty_tuple, $ty_tuple)) -> Self {
76 Self {
77 m: value.0 as $ty_struct,
78 n: value.1 as $ty_struct,
79 k: value.2 as $ty_struct,
80 }
81 }
82 }
83
84 impl From<$name> for ($ty_tuple, $ty_tuple, $ty_tuple) {
85 fn from(value: $name) -> Self {
86 (
87 value.m as $ty_tuple,
88 value.n as $ty_tuple,
89 value.k as $ty_tuple,
90 )
91 }
92 }
93 };
94}
95
96define_3d_size_base!(TileSize, u8);
98impl_from_tuple!(TileSize, u8, u8);
99impl_from_tuple!(TileSize, u8, u32);
100impl_from_tuple!(TileSize, u8, i32);
101impl_from_tuple!(TileSize, u8, u16);
102impl_from_tuple!(TileSize, u8, usize);
103
104define_3d_size_base!(PartitionSize, u8);
106impl_from_tuple!(PartitionSize, u8, u8);
107impl_from_tuple!(PartitionSize, u8, u32);
108impl_from_tuple!(PartitionSize, u8, i32);
109impl_from_tuple!(PartitionSize, u8, u16);
110impl_from_tuple!(PartitionSize, u8, usize);
111
112define_3d_size_base!(StageSize, u8);
114impl_from_tuple!(StageSize, u8, u8);
115impl_from_tuple!(StageSize, u8, u32);
116impl_from_tuple!(StageSize, u8, i32);
117impl_from_tuple!(StageSize, u8, u16);
118impl_from_tuple!(StageSize, u8, usize);
119
120define_3d_size_base!(MatmulProblemSize, u32);
122impl_from_tuple!(MatmulProblemSize, u32, u8);
123impl_from_tuple!(MatmulProblemSize, u32, u32);
124impl_from_tuple!(MatmulProblemSize, u32, i32);
125impl_from_tuple!(MatmulProblemSize, u32, u16);
126impl_from_tuple!(MatmulProblemSize, u32, usize);
127
128#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
129pub struct GlobalPartitionSize {
131 pub m: u32,
132 pub n: u32,
133 pub batches: u32,
134}
135
136impl GlobalPartitionSize {
137 pub fn new(m: u32, n: u32, batches: u32) -> Self {
138 Self { m, n, batches }
139 }
140}