cubecl_matmul/components/
size.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5/// Matrix dimension specifier for matmul operations.
6pub enum MatmulDim {
7    /// Rows of the output matrix.
8    M,
9    /// Columns of the output matrix.
10    N,
11    /// Reduction dimension.
12    K,
13}
14
15macro_rules! define_3d_size_base {
16    ($name:ident, $ty:ty) => {
17        #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
18        pub struct $name {
19            pub m: $ty,
20            pub n: $ty,
21            pub k: $ty,
22        }
23
24        impl $name {
25            pub fn new(m: u32, n: u32, k: u32) -> Self {
26                $name {
27                    m: <$ty>::try_from(m).unwrap(),
28                    n: <$ty>::try_from(n).unwrap(),
29                    k: <$ty>::try_from(k).unwrap(),
30                }
31            }
32
33            pub fn get(&self, dim: MatmulDim) -> u32 {
34                (match dim {
35                    MatmulDim::M => self.m,
36                    MatmulDim::N => self.n,
37                    MatmulDim::K => self.k,
38                }) as u32
39            }
40
41            pub fn m(&self) -> u32 {
42                self.get(MatmulDim::M)
43            }
44
45            pub fn n(&self) -> u32 {
46                self.get(MatmulDim::N)
47            }
48
49            pub fn k(&self) -> u32 {
50                self.get(MatmulDim::K)
51            }
52
53            pub fn mn(&self) -> u32 {
54                self.get(MatmulDim::M) * self.get(MatmulDim::N)
55            }
56
57            pub fn mk(&self) -> u32 {
58                self.get(MatmulDim::M) * self.get(MatmulDim::K)
59            }
60
61            pub fn nk(&self) -> u32 {
62                self.get(MatmulDim::N) * self.get(MatmulDim::K)
63            }
64
65            pub fn mnk(&self) -> u32 {
66                self.get(MatmulDim::M) * self.get(MatmulDim::N) * self.get(MatmulDim::K)
67            }
68        }
69    };
70}
71
72macro_rules! impl_from_tuple {
73    ($name:ident, $ty_struct:ty, $ty_tuple:ty) => {
74        impl From<($ty_tuple, $ty_tuple, $ty_tuple)> for $name {
75            fn from(value: ($ty_tuple, $ty_tuple, $ty_tuple)) -> Self {
76                Self {
77                    m: value.0 as $ty_struct,
78                    n: value.1 as $ty_struct,
79                    k: value.2 as $ty_struct,
80                }
81            }
82        }
83
84        impl From<$name> for ($ty_tuple, $ty_tuple, $ty_tuple) {
85            fn from(value: $name) -> Self {
86                (
87                    value.m as $ty_tuple,
88                    value.n as $ty_tuple,
89                    value.k as $ty_tuple,
90                )
91            }
92        }
93    };
94}
95
96// Number of elements in a tile
97define_3d_size_base!(TileSize, u8);
98impl_from_tuple!(TileSize, u8, u8);
99impl_from_tuple!(TileSize, u8, u32);
100impl_from_tuple!(TileSize, u8, i32);
101impl_from_tuple!(TileSize, u8, u16);
102impl_from_tuple!(TileSize, u8, usize);
103
104// Number of tiles in a stage partition
105define_3d_size_base!(PartitionSize, u8);
106impl_from_tuple!(PartitionSize, u8, u8);
107impl_from_tuple!(PartitionSize, u8, u32);
108impl_from_tuple!(PartitionSize, u8, i32);
109impl_from_tuple!(PartitionSize, u8, u16);
110impl_from_tuple!(PartitionSize, u8, usize);
111
112// Number of partitions in a stage
113define_3d_size_base!(StageSize, u8);
114impl_from_tuple!(StageSize, u8, u8);
115impl_from_tuple!(StageSize, u8, u32);
116impl_from_tuple!(StageSize, u8, i32);
117impl_from_tuple!(StageSize, u8, u16);
118impl_from_tuple!(StageSize, u8, usize);
119
120// Shapes m,n,k of the problem
121define_3d_size_base!(MatmulProblemSize, u32);
122impl_from_tuple!(MatmulProblemSize, u32, u8);
123impl_from_tuple!(MatmulProblemSize, u32, u32);
124impl_from_tuple!(MatmulProblemSize, u32, i32);
125impl_from_tuple!(MatmulProblemSize, u32, u16);
126impl_from_tuple!(MatmulProblemSize, u32, usize);
127
128#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
129/// Number of global matmul blocks computed by a single cube.
130pub struct GlobalPartitionSize {
131    pub m: u32,
132    pub n: u32,
133    pub batches: u32,
134}
135
136impl GlobalPartitionSize {
137    pub fn new(m: u32, n: u32, batches: u32) -> Self {
138        Self { m, n, batches }
139    }
140}