1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::MatmulProblem;
5use crate::components::batch::partitioned_matmul::hypercube::global_order::{GlobalOrder, swizzle};
6use crate::components::batch::partitioned_matmul::hypercube::sm_allocation::SmAllocation;
7use crate::components::batch::{HypercubeConfig, HypercubeSelection};
8
9#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
10pub enum CubeCountPlanSelection {
13 #[default]
14 FromProblem,
16
17 Sm {
20 cubes_first: bool,
21 num_sms: u32,
22 sm_usage: SmAllocation,
23 },
24
25 Flattened,
27
28 Spread,
30}
31
32#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
33pub enum CubeCountPlan {
39 FromProblem {
40 m_cubes: u32,
41 n_cubes: u32,
42 batch_cubes: u32,
43 },
44 Sm {
45 cubes_first: bool,
46 num_sms_used: u32,
47 cubes_per_sm: u32,
48 m_cubes: u32,
49 n_cubes: u32,
50 batch_cubes: u32,
51 num_sms: u32,
52 sm_usage: SmAllocation,
53 },
54 Flattened {
55 m_cubes: u32,
56 n_cubes: u32,
57 batch_cubes: u32,
58 },
59 Spread {
60 m_cubes: u32,
61 n_cubes: u32,
62 batch_cubes: u32,
63 x: u32,
64 y: u32,
65 z: u32,
66 },
67}
68
69impl CubeCountPlan {
70 pub fn can_yield_extra_cubes(&self) -> bool {
72 match self {
73 CubeCountPlan::FromProblem { .. } | CubeCountPlan::Flattened { .. } => false,
74 CubeCountPlan::Sm {
75 num_sms_used,
76 cubes_per_sm,
77 m_cubes,
78 n_cubes,
79 batch_cubes,
80 ..
81 } => num_sms_used * cubes_per_sm != m_cubes * n_cubes * batch_cubes,
82 CubeCountPlan::Spread {
83 m_cubes,
84 n_cubes,
85 batch_cubes,
86 x,
87 y,
88 z,
89 } => m_cubes * n_cubes * batch_cubes != x * y * z,
90 }
91 }
92}
93
94#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
95pub enum CubeCountPlanConfig {
99 FromProblem,
100
101 Sm {
102 cubes_first: bool,
103 num_sms: u32,
104 sm_usage: SmAllocation,
105 can_yield_extra_cubes: bool,
106 },
107
108 Flattened,
109
110 Spread {
111 can_yield_extra_cubes: bool,
112 },
113}
114
115#[derive(CubeType, CubeLaunch)]
116pub enum CubeCountInput {
120 FromProblem,
121 SmFirst {
122 m_cubes: u32,
123 n_cubes: u32,
124 batch_cubes: u32,
125 },
126 CubeFirst {
127 m_cubes: u32,
128 n_cubes: u32,
129 batch_cubes: u32,
130 },
131 Flattened {
132 m_cubes: u32,
133 n_cubes: u32,
134 },
135 Spread {
136 m_cubes: u32,
137 n_cubes: u32,
138 batch_cubes: u32,
139 },
140}
141
142impl CubeCountPlan {
143 pub fn from_selection(
145 selection: &HypercubeSelection,
146 problem: &MatmulProblem,
147 max_cube_count: CubeCount,
148 ) -> CubeCountPlan {
149 let (max_x, max_y, max_z) = match max_cube_count {
150 CubeCount::Static(x, y, z) => (x, y, z),
151 CubeCount::Dynamic(_) => panic!("Dynamic cube count not supported for cube count plan"),
152 };
153
154 let m_cubes = (problem.m as u32).div_ceil(selection.cube_span.m);
155 let n_cubes = (problem.n as u32).div_ceil(selection.cube_span.n);
156 let batch_cubes = (problem.num_batches() as u32).div_ceil(selection.cube_span.batch);
157
158 let plan = match selection.cube_count_plan_selection {
159 CubeCountPlanSelection::FromProblem => {
160 if m_cubes > max_x || n_cubes > max_y || batch_cubes > max_z {
161 None
162 } else {
163 Some(CubeCountPlan::FromProblem {
164 m_cubes,
165 n_cubes,
166 batch_cubes,
167 })
168 }
169 }
170 CubeCountPlanSelection::Sm {
171 cubes_first,
172 num_sms,
173 sm_usage,
174 } => {
175 let (num_sms_used, cubes_per_sm) =
176 sm_usage.allocate(num_sms, m_cubes * n_cubes * batch_cubes);
177
178 if (cubes_per_sm >= if cubes_first { max_x } else { max_y })
179 || (num_sms_used >= if cubes_first { max_y } else { max_x })
180 {
181 None
182 } else {
183 Some(CubeCountPlan::Sm {
184 cubes_first,
185 num_sms_used,
186 cubes_per_sm,
187 m_cubes,
188 n_cubes,
189 batch_cubes,
190 num_sms,
191 sm_usage,
192 })
193 }
194 }
195 CubeCountPlanSelection::Flattened => {
196 if m_cubes * n_cubes * batch_cubes >= max_x {
197 None
198 } else {
199 Some(CubeCountPlan::Flattened {
200 m_cubes,
201 n_cubes,
202 batch_cubes,
203 })
204 }
205 }
206 CubeCountPlanSelection::Spread => None,
207 };
208
209 plan.unwrap_or_else(|| {
210 spread_cube_count_plan(m_cubes, n_cubes, batch_cubes, max_x, max_y, max_z)
211 })
212 }
213
214 pub fn from_config(
218 config: &HypercubeConfig,
219 problem: &MatmulProblem,
220 max_cube_count: CubeCount,
221 ) -> CubeCountPlan {
222 let (max_x, max_y, max_z) = match max_cube_count {
223 CubeCount::Static(x, y, z) => (x, y, z),
224 CubeCount::Dynamic(_) => panic!("Dynamic cube count not supported for cube count plan"),
225 };
226
227 let m_cubes = (problem.m as u32).div_ceil(config.cube_span.m);
228 let n_cubes = (problem.n as u32).div_ceil(config.cube_span.n);
229 let batch_cubes = (problem.num_batches() as u32).div_ceil(config.cube_span.batch);
230
231 match config.cube_count_plan_config {
232 CubeCountPlanConfig::FromProblem => CubeCountPlan::FromProblem {
233 m_cubes,
234 n_cubes,
235 batch_cubes,
236 },
237 CubeCountPlanConfig::Sm {
238 cubes_first,
239 num_sms,
240 sm_usage,
241 ..
242 } => {
243 let (num_sms_used, cubes_per_sm) =
244 sm_usage.allocate(num_sms, m_cubes * n_cubes * batch_cubes);
245 CubeCountPlan::Sm {
246 cubes_first,
247 num_sms_used,
248 cubes_per_sm,
249 m_cubes,
250 n_cubes,
251 batch_cubes,
252 num_sms,
253 sm_usage,
254 }
255 }
256 CubeCountPlanConfig::Flattened => CubeCountPlan::Flattened {
257 m_cubes,
258 n_cubes,
259 batch_cubes,
260 },
261 CubeCountPlanConfig::Spread { .. } => {
262 spread_cube_count_plan(m_cubes, n_cubes, batch_cubes, max_x, max_y, max_z)
263 }
264 }
265 }
266}
267
268impl CubeCountPlanConfig {
269 pub fn can_yield_extra_cubes(&self) -> bool {
271 match self {
272 CubeCountPlanConfig::FromProblem | CubeCountPlanConfig::Flattened => false,
273 CubeCountPlanConfig::Sm {
274 can_yield_extra_cubes,
275 ..
276 } => *can_yield_extra_cubes,
277 CubeCountPlanConfig::Spread {
278 can_yield_extra_cubes,
279 } => *can_yield_extra_cubes,
280 }
281 }
282
283 pub(crate) fn from_cube_count_plan(cube_count_plan: CubeCountPlan) -> CubeCountPlanConfig {
284 match cube_count_plan {
285 CubeCountPlan::FromProblem { .. } => CubeCountPlanConfig::FromProblem,
286 CubeCountPlan::Sm {
287 cubes_first,
288 num_sms,
289 sm_usage,
290 ..
291 } => CubeCountPlanConfig::Sm {
292 cubes_first,
293 num_sms,
294 sm_usage,
295 can_yield_extra_cubes: cube_count_plan.can_yield_extra_cubes(),
296 },
297 CubeCountPlan::Flattened { .. } => CubeCountPlanConfig::Flattened,
298 CubeCountPlan::Spread { .. } => CubeCountPlanConfig::Spread {
299 can_yield_extra_cubes: cube_count_plan.can_yield_extra_cubes(),
300 },
301 }
302 }
303}
304
305pub(crate) fn spread_cube_count_plan(
308 m_cubes: u32,
309 n_cubes: u32,
310 batch_cubes: u32,
311 max_x: u32,
312 max_y: u32,
313 max_z: u32,
314) -> CubeCountPlan {
315 let total_cubes = m_cubes * n_cubes * batch_cubes;
316
317 let mut best = None;
318
319 let mut z = max_z;
320 while z >= 1 {
321 let xy_cubes = total_cubes.div_ceil(z);
322
323 let mut y = max_y;
324 while y >= 1 {
325 let x = xy_cubes.div_ceil(y);
326 if x <= max_x {
327 let volume = x * y * z;
328 let score = (volume, std::cmp::Reverse(z), std::cmp::Reverse(y));
329
330 if best.is_none_or(|(_, _, _, _, best_score)| score < best_score) {
331 best = Some((x, y, z, volume, score));
332 }
333 }
334
335 if y == 1 {
336 break;
337 }
338 y /= 2;
339 }
340
341 if z == 1 {
342 break;
343 }
344 z /= 2;
345 }
346
347 if let Some((x, y, z, _, _)) = best {
348 CubeCountPlan::Spread {
349 m_cubes,
350 n_cubes,
351 batch_cubes,
352 x,
353 y,
354 z,
355 }
356 } else {
357 panic!("No valid cube spread plan")
358 }
359}
360
361impl CubeCountPlan {
362 pub fn resolve(&self) -> CubeCount {
364 match self {
365 CubeCountPlan::FromProblem {
366 m_cubes,
367 n_cubes,
368 batch_cubes,
369 } => CubeCount::Static(*m_cubes, *n_cubes, *batch_cubes),
370 CubeCountPlan::Sm {
371 cubes_first,
372 num_sms_used,
373 cubes_per_sm,
374 ..
375 } => match cubes_first {
376 true => CubeCount::Static(*cubes_per_sm, *num_sms_used, 1),
377 false => CubeCount::Static(*num_sms_used, *cubes_per_sm, 1),
378 },
379 CubeCountPlan::Flattened {
380 m_cubes,
381 n_cubes,
382 batch_cubes,
383 } => CubeCount::Static(*m_cubes * *n_cubes * *batch_cubes, 1, 1),
384 CubeCountPlan::Spread { x, y, z, .. } => CubeCount::Static(*x, *y, *z),
385 }
386 }
387
388 pub fn as_args<'a, R: Runtime>(&self) -> CubeCountInputArgs<'a, R> {
390 match self {
391 CubeCountPlan::FromProblem { .. } => CubeCountInputArgs::FromProblem,
392 CubeCountPlan::Sm {
393 cubes_first,
394 m_cubes,
395 n_cubes,
396 batch_cubes,
397 ..
398 } => match cubes_first {
399 true => CubeCountInputArgs::CubeFirst {
400 m_cubes: ScalarArg::new(*m_cubes),
401 n_cubes: ScalarArg::new(*n_cubes),
402 batch_cubes: ScalarArg::new(*batch_cubes),
403 },
404 false => CubeCountInputArgs::SmFirst {
405 m_cubes: ScalarArg::new(*m_cubes),
406 n_cubes: ScalarArg::new(*n_cubes),
407 batch_cubes: ScalarArg::new(*batch_cubes),
408 },
409 },
410 CubeCountPlan::Flattened {
411 m_cubes, n_cubes, ..
412 } => CubeCountInputArgs::Flattened {
413 m_cubes: ScalarArg::new(*m_cubes),
414 n_cubes: ScalarArg::new(*n_cubes),
415 },
416 CubeCountPlan::Spread {
417 m_cubes,
418 n_cubes,
419 batch_cubes,
420 ..
421 } => CubeCountInputArgs::Spread {
422 m_cubes: ScalarArg::new(*m_cubes),
423 n_cubes: ScalarArg::new(*n_cubes),
424 batch_cubes: ScalarArg::new(*batch_cubes),
425 },
426 }
427 }
428}
429
430#[cube]
431impl CubeCountInput {
432 pub fn num_valid_cubes(&self) -> u32 {
434 match self {
435 CubeCountInput::FromProblem | CubeCountInput::Flattened { .. } => {
436 panic!("Shouldn't need to be called because the cube count should always be exact")
437 }
438 CubeCountInput::SmFirst {
439 m_cubes,
440 n_cubes,
441 batch_cubes,
442 } => *m_cubes * *n_cubes * *batch_cubes,
443 CubeCountInput::CubeFirst {
444 m_cubes,
445 n_cubes,
446 batch_cubes,
447 } => *m_cubes * *n_cubes * *batch_cubes,
448 CubeCountInput::Spread {
449 m_cubes,
450 n_cubes,
451 batch_cubes,
452 } => *m_cubes * *n_cubes * *batch_cubes,
453 }
454 }
455
456 pub fn cube_pos_to_tensor_pos(&self, #[comptime] global_order: GlobalOrder) -> (u32, u32, u32) {
458 match self {
459 CubeCountInput::FromProblem => (CUBE_POS_X, CUBE_POS_Y, CUBE_POS_Z),
460 CubeCountInput::SmFirst {
461 m_cubes, n_cubes, ..
462 } => self.absolute_index_to_m_n_batch(CUBE_POS, *m_cubes, *n_cubes, global_order),
463 CubeCountInput::CubeFirst {
464 m_cubes, n_cubes, ..
465 } => self.absolute_index_to_m_n_batch(
466 CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X,
467 *m_cubes,
468 *n_cubes,
469 global_order,
470 ),
471 CubeCountInput::Flattened { m_cubes, n_cubes } => {
472 self.absolute_index_to_m_n_batch(CUBE_POS_X, *m_cubes, *n_cubes, global_order)
473 }
474 CubeCountInput::Spread {
475 m_cubes, n_cubes, ..
476 } => self.absolute_index_to_m_n_batch(CUBE_POS, *m_cubes, *n_cubes, global_order),
477 }
478 }
479
480 fn absolute_index_to_m_n_batch(
481 &self,
482 absolute_index: u32,
483 m_cubes: u32,
484 n_cubes: u32,
485 #[comptime] global_order: GlobalOrder,
486 ) -> (u32, u32, u32) {
487 let batch_stride = m_cubes * n_cubes;
488 let batch_pos = absolute_index / batch_stride;
489 let matrix_pos = absolute_index % batch_stride;
490
491 let (m_pos, n_pos) = match comptime!(global_order) {
492 GlobalOrder::RowMajor => (matrix_pos / n_cubes, matrix_pos % n_cubes),
493 GlobalOrder::ColMajor => (matrix_pos % m_cubes, matrix_pos / m_cubes),
494 GlobalOrder::SwizzleRowMajor(w) => {
495 let (x, y) = swizzle(matrix_pos, n_cubes, w);
496 (y, x)
497 }
498 GlobalOrder::SwizzleColMajor(w) => swizzle(matrix_pos, m_cubes, w),
499 };
500
501 (m_pos, n_pos, batch_pos)
502 }
503}