1use super::error::TilingError;
4use super::geometry::TcbGeometry;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TilingConfig {
13 pub name: String,
15 pub macro_tile: TcbGeometry,
17 pub midi_tile: TcbGeometry,
19 pub micro_tile: TcbGeometry,
21 pub backend: TilingBackend,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27pub enum TilingBackend {
28 CpuAvx2,
30 CpuAvx512,
32 CpuNeon,
34 Gpu,
36 Scalar,
38}
39
40impl TilingConfig {
41 #[must_use]
45 pub fn gpu_q4k_matvec() -> Self {
46 Self {
47 name: "Q4K_MatVec_GPU".into(),
48 macro_tile: TcbGeometry::with_alignment(1, 4096, 256, 64),
49 midi_tile: TcbGeometry::with_alignment(1, 256, 256, 64),
50 micro_tile: TcbGeometry::with_alignment(1, 32, 256, 64),
51 backend: TilingBackend::Gpu,
52 }
53 }
54
55 #[must_use]
59 pub fn gpu_q4k_matmul() -> Self {
60 Self {
61 name: "Q4K_MatMul_GPU".into(),
62 macro_tile: TcbGeometry::with_alignment(128, 128, 256, 64),
63 midi_tile: TcbGeometry::with_alignment(32, 32, 256, 64),
64 micro_tile: TcbGeometry::with_alignment(8, 8, 256, 64),
65 backend: TilingBackend::Gpu,
66 }
67 }
68
69 #[must_use]
71 pub fn gpu_softmax() -> Self {
72 Self {
73 name: "Softmax_GPU".into(),
74 macro_tile: TcbGeometry::with_alignment(1, 32000, 1, 64),
75 midi_tile: TcbGeometry::with_alignment(1, 1024, 1, 64),
76 micro_tile: TcbGeometry::with_alignment(1, 32, 1, 64),
77 backend: TilingBackend::Gpu,
78 }
79 }
80
81 #[must_use]
88 pub fn cpu_avx512_matmul() -> Self {
89 Self {
90 name: "MatMul_AVX512".into(),
91 macro_tile: TcbGeometry::with_alignment(512, 512, 512, 64),
92 midi_tile: TcbGeometry::with_alignment(128, 128, 128, 64),
93 micro_tile: TcbGeometry::with_alignment(4, 16, 128, 64),
95 backend: TilingBackend::CpuAvx512,
96 }
97 }
98
99 #[must_use]
107 pub fn cpu_avx512_q4k_matvec() -> Self {
108 Self {
109 name: "Q4K_MatVec_AVX512".into(),
110 macro_tile: TcbGeometry::with_alignment(4096, 1, 4096, 64),
112 midi_tile: TcbGeometry::with_alignment(64, 1, 256, 64),
115 micro_tile: TcbGeometry::with_alignment(4, 1, 256, 64),
117 backend: TilingBackend::CpuAvx512,
118 }
119 }
120
121 #[must_use]
129 pub fn cpu_avx512_vnni_q4k_q8k() -> Self {
130 Self {
131 name: "Q4K_Q8K_VNNI".into(),
132 macro_tile: TcbGeometry::with_alignment(4096, 1, 4096, 64),
133 midi_tile: TcbGeometry::with_alignment(64, 1, 256, 64),
134 micro_tile: TcbGeometry::with_alignment(4, 1, 256, 64),
136 backend: TilingBackend::CpuAvx512,
137 }
138 }
139
140 #[must_use]
142 pub fn cpu_avx2_matmul() -> Self {
143 Self {
144 name: "MatMul_AVX2".into(),
145 macro_tile: TcbGeometry::with_alignment(256, 256, 256, 32),
146 midi_tile: TcbGeometry::with_alignment(64, 64, 64, 32),
147 micro_tile: TcbGeometry::with_alignment(4, 8, 64, 32),
149 backend: TilingBackend::CpuAvx2,
150 }
151 }
152
153 #[must_use]
155 pub fn cpu_avx2_q4k_matvec() -> Self {
156 Self {
157 name: "Q4K_MatVec_AVX2".into(),
158 macro_tile: TcbGeometry::with_alignment(4096, 1, 4096, 32),
160 midi_tile: TcbGeometry::with_alignment(64, 1, 256, 32),
161 micro_tile: TcbGeometry::with_alignment(4, 1, 256, 32),
163 backend: TilingBackend::CpuAvx2,
164 }
165 }
166
167 #[must_use]
169 pub fn cpu_rmsnorm() -> Self {
170 Self {
171 name: "RMSNorm_CPU".into(),
172 macro_tile: TcbGeometry::with_alignment(1, 4096, 1, 32),
173 midi_tile: TcbGeometry::with_alignment(1, 256, 1, 32),
174 micro_tile: TcbGeometry::with_alignment(1, 16, 1, 32),
175 backend: TilingBackend::CpuAvx512,
176 }
177 }
178
179 pub fn validate(&self) -> Result<(), TilingError> {
181 if self.midi_tile.m > self.macro_tile.m
183 || self.midi_tile.n > self.macro_tile.n
184 || self.midi_tile.k > self.macro_tile.k
185 {
186 return Err(TilingError::InvalidHierarchy {
187 reason: "Midi-tile larger than macro-tile".into(),
188 });
189 }
190
191 if self.micro_tile.m > self.midi_tile.m
192 || self.micro_tile.n > self.midi_tile.n
193 || self.micro_tile.k > self.midi_tile.k
194 {
195 return Err(TilingError::InvalidHierarchy {
196 reason: "Micro-tile larger than midi-tile".into(),
197 });
198 }
199
200 if self.macro_tile.m % self.midi_tile.m != 0 {
202 return Err(TilingError::DivisibilityError {
203 level: "macro/midi",
204 dimension: "M",
205 larger: self.macro_tile.m,
206 smaller: self.midi_tile.m,
207 });
208 }
209
210 if self.midi_tile.m % self.micro_tile.m != 0 {
211 return Err(TilingError::DivisibilityError {
212 level: "midi/micro",
213 dimension: "M",
214 larger: self.midi_tile.m,
215 smaller: self.micro_tile.m,
216 });
217 }
218
219 Ok(())
220 }
221
222 #[must_use]
224 pub fn num_macro_tiles(&self, m: u32, n: u32) -> u32 {
225 let m_tiles = (m + self.macro_tile.m - 1) / self.macro_tile.m;
226 let n_tiles = (n + self.macro_tile.n - 1) / self.macro_tile.n;
227 m_tiles * n_tiles
228 }
229
230 #[must_use]
232 pub fn midi_tiles_per_macro(&self) -> u32 {
233 let m_tiles = self.macro_tile.m / self.midi_tile.m;
234 let n_tiles = self.macro_tile.n / self.midi_tile.n;
235 m_tiles * n_tiles
236 }
237
238 #[must_use]
240 pub fn micro_tiles_per_midi(&self) -> u32 {
241 let m_tiles = self.midi_tile.m / self.micro_tile.m;
242 let n_tiles = self.midi_tile.n / self.micro_tile.n;
243 m_tiles * n_tiles
244 }
245}