Skip to main content

cubek_std/
size.rs

1use cubecl::prelude::*;
2
3#[derive(Debug, Clone, Copy)]
4/// Matrix dimension specifier for matmul operations.
5pub enum MatmulDim {
6    /// Rows of the output matrix.
7    M,
8    /// Columns of the output matrix.
9    N,
10    /// Reduction dimension.
11    K,
12}
13
14macro_rules! define_3d_size_base {
15    ($name:ident, $ty:ty) => {
16        #[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
17        pub struct $name {
18            pub m: $ty,
19            pub n: $ty,
20            pub k: $ty,
21        }
22
23        impl $name {
24            pub fn new(m: u32, n: u32, k: u32) -> Self {
25                $name {
26                    m: <$ty>::try_from(m).unwrap(),
27                    n: <$ty>::try_from(n).unwrap(),
28                    k: <$ty>::try_from(k).unwrap(),
29                }
30            }
31
32            pub fn get(&self, dim: MatmulDim) -> u32 {
33                (match dim {
34                    MatmulDim::M => self.m,
35                    MatmulDim::N => self.n,
36                    MatmulDim::K => self.k,
37                }) as u32
38            }
39
40            pub fn m(&self) -> u32 {
41                self.get(MatmulDim::M)
42            }
43
44            pub fn n(&self) -> u32 {
45                self.get(MatmulDim::N)
46            }
47
48            pub fn k(&self) -> u32 {
49                self.get(MatmulDim::K)
50            }
51
52            pub fn mn(&self) -> u32 {
53                self.get(MatmulDim::M) * self.get(MatmulDim::N)
54            }
55
56            pub fn mk(&self) -> u32 {
57                self.get(MatmulDim::M) * self.get(MatmulDim::K)
58            }
59
60            pub fn nk(&self) -> u32 {
61                self.get(MatmulDim::N) * self.get(MatmulDim::K)
62            }
63
64            pub fn mnk(&self) -> u32 {
65                self.get(MatmulDim::M) * self.get(MatmulDim::N) * self.get(MatmulDim::K)
66            }
67        }
68    };
69}
70
71macro_rules! impl_from_tuple {
72    ($name:ident, $ty_struct:ty, $ty_tuple:ty) => {
73        impl From<($ty_tuple, $ty_tuple, $ty_tuple)> for $name {
74            fn from(value: ($ty_tuple, $ty_tuple, $ty_tuple)) -> Self {
75                Self {
76                    m: value.0 as $ty_struct,
77                    n: value.1 as $ty_struct,
78                    k: value.2 as $ty_struct,
79                }
80            }
81        }
82
83        impl From<$name> for ($ty_tuple, $ty_tuple, $ty_tuple) {
84            fn from(value: $name) -> Self {
85                (
86                    value.m as $ty_tuple,
87                    value.n as $ty_tuple,
88                    value.k as $ty_tuple,
89                )
90            }
91        }
92    };
93}
94
95// Number of elements in a tile
96define_3d_size_base!(TileSize, u32);
97impl_from_tuple!(TileSize, u32, u8);
98impl_from_tuple!(TileSize, u32, u32);
99impl_from_tuple!(TileSize, u32, i32);
100impl_from_tuple!(TileSize, u32, u16);
101impl_from_tuple!(TileSize, u32, usize);
102
103// Number of tiles in a stage partition
104define_3d_size_base!(PartitionSize, u8);
105impl_from_tuple!(PartitionSize, u8, u8);
106impl_from_tuple!(PartitionSize, u8, u32);
107impl_from_tuple!(PartitionSize, u8, i32);
108impl_from_tuple!(PartitionSize, u8, u16);
109impl_from_tuple!(PartitionSize, u8, usize);
110
111// Number of partitions in a stage
112define_3d_size_base!(StageSize, u8);
113impl_from_tuple!(StageSize, u8, u8);
114impl_from_tuple!(StageSize, u8, u32);
115impl_from_tuple!(StageSize, u8, i32);
116impl_from_tuple!(StageSize, u8, u16);
117impl_from_tuple!(StageSize, u8, usize);
118
119// Shapes m,n,k of the problem
120define_3d_size_base!(MatmulProblemSize, u32);
121impl_from_tuple!(MatmulProblemSize, u32, u8);
122impl_from_tuple!(MatmulProblemSize, u32, u32);
123impl_from_tuple!(MatmulProblemSize, u32, i32);
124impl_from_tuple!(MatmulProblemSize, u32, u16);
125impl_from_tuple!(MatmulProblemSize, u32, usize);
126
127#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
128/// Number of global matmul blocks computed by a single cube.
129pub struct GlobalPartitionSize {
130    pub m: u32,
131    pub n: u32,
132    pub batches: u32,
133}
134
135impl GlobalPartitionSize {
136    pub fn new(m: u32, n: u32, batches: u32) -> Self {
137        Self { m, n, batches }
138    }
139}