cubecl_matmul/components/batch/partitioned_matmul/hypercube/
cube_count_plan.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::MatmulProblem;
5use crate::components::batch::partitioned_matmul::hypercube::global_order::{GlobalOrder, swizzle};
6use crate::components::batch::partitioned_matmul::hypercube::sm_allocation::SmAllocation;
7use crate::components::batch::{HypercubeConfig, HypercubeSelection};
8
9#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
10/// Front-facing configuration when crafting a MatmulSelection
11/// Allows choosing a strategy before knowing actual values
12pub enum CubeCountPlanSelection {
13    #[default]
14    /// X: num cubes in m, Y: num cubes in n, Z: num cubes in batch
15    FromProblem,
16
17    /// If not cubes_first: X: num SMs, Y: num cubes per SM
18    /// If cubes_first: X: num cubes per SM, Y: num SMs
19    Sm {
20        cubes_first: bool,
21        num_sms: u32,
22        sm_usage: SmAllocation,
23    },
24
25    /// X: total cubes flattened (num SMs * num cubes per SM)
26    Flattened,
27
28    /// Heuristically find a balance for X, Y, Z that respects hardware limits
29    Spread,
30}
31
32#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
33/// Informations necessary in the computation of the CubeCount.
34/// Because this struct depends on the problem size, it is simplified into
35/// [CubeCountPlanConfig] to be injected as comptime in the kernel.
36///
37/// Refer to [CubeCountPlanSelection] for more details
38pub enum CubeCountPlan {
39    FromProblem {
40        m_cubes: u32,
41        n_cubes: u32,
42        batch_cubes: u32,
43    },
44    Sm {
45        cubes_first: bool,
46        num_sms_used: u32,
47        cubes_per_sm: u32,
48        m_cubes: u32,
49        n_cubes: u32,
50        batch_cubes: u32,
51        num_sms: u32,
52        sm_usage: SmAllocation,
53    },
54    Flattened {
55        m_cubes: u32,
56        n_cubes: u32,
57        batch_cubes: u32,
58    },
59    Spread {
60        m_cubes: u32,
61        n_cubes: u32,
62        batch_cubes: u32,
63        x: u32,
64        y: u32,
65        z: u32,
66    },
67}
68
69impl CubeCountPlan {
70    /// Whether the CubeCount will have more cubes than strictly necessary.
71    pub fn can_yield_extra_cubes(&self) -> bool {
72        match self {
73            CubeCountPlan::FromProblem { .. } | CubeCountPlan::Flattened { .. } => false,
74            CubeCountPlan::Sm {
75                num_sms_used,
76                cubes_per_sm,
77                m_cubes,
78                n_cubes,
79                batch_cubes,
80                ..
81            } => num_sms_used * cubes_per_sm != m_cubes * n_cubes * batch_cubes,
82            CubeCountPlan::Spread {
83                m_cubes,
84                n_cubes,
85                batch_cubes,
86                x,
87                y,
88                z,
89            } => m_cubes * n_cubes * batch_cubes != x * y * z,
90        }
91    }
92}
93
94#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
95/// Config derived from CubeCountPlan to be used comptime in kernels
96///
97/// Refer to [CubeCountPlanSelection] for more details
98pub enum CubeCountPlanConfig {
99    FromProblem,
100
101    Sm {
102        cubes_first: bool,
103        num_sms: u32,
104        sm_usage: SmAllocation,
105        can_yield_extra_cubes: bool,
106    },
107
108    Flattened,
109
110    Spread {
111        can_yield_extra_cubes: bool,
112    },
113}
114
115#[derive(CubeType, CubeLaunch)]
116/// CubeCountPlan stripped of non-essential runtime information
117///
118/// This enum is given as runtime input to the matmul
119pub enum CubeCountInput {
120    FromProblem,
121    SmFirst {
122        m_cubes: u32,
123        n_cubes: u32,
124        batch_cubes: u32,
125    },
126    CubeFirst {
127        m_cubes: u32,
128        n_cubes: u32,
129        batch_cubes: u32,
130    },
131    Flattened {
132        m_cubes: u32,
133        n_cubes: u32,
134    },
135    Spread {
136        m_cubes: u32,
137        n_cubes: u32,
138        batch_cubes: u32,
139    },
140}
141
142impl CubeCountPlan {
143    // Will check if the wanted cube count plan is possible, otherwise will fallback to spread
144    pub fn from_selection(
145        selection: &HypercubeSelection,
146        problem: &MatmulProblem,
147        max_cube_count: CubeCount,
148    ) -> CubeCountPlan {
149        let (max_x, max_y, max_z) = match max_cube_count {
150            CubeCount::Static(x, y, z) => (x, y, z),
151            CubeCount::Dynamic(_) => panic!("Dynamic cube count not supported for cube count plan"),
152        };
153
154        let m_cubes = (problem.m as u32).div_ceil(selection.cube_span.m);
155        let n_cubes = (problem.n as u32).div_ceil(selection.cube_span.n);
156        let batch_cubes = (problem.num_batches() as u32).div_ceil(selection.cube_span.batch);
157
158        let plan = match selection.cube_count_plan_selection {
159            CubeCountPlanSelection::FromProblem => {
160                if m_cubes > max_x || n_cubes > max_y || batch_cubes > max_z {
161                    None
162                } else {
163                    Some(CubeCountPlan::FromProblem {
164                        m_cubes,
165                        n_cubes,
166                        batch_cubes,
167                    })
168                }
169            }
170            CubeCountPlanSelection::Sm {
171                cubes_first,
172                num_sms,
173                sm_usage,
174            } => {
175                let (num_sms_used, cubes_per_sm) =
176                    sm_usage.allocate(num_sms, m_cubes * n_cubes * batch_cubes);
177
178                if (cubes_per_sm >= if cubes_first { max_x } else { max_y })
179                    || (num_sms_used >= if cubes_first { max_y } else { max_x })
180                {
181                    None
182                } else {
183                    Some(CubeCountPlan::Sm {
184                        cubes_first,
185                        num_sms_used,
186                        cubes_per_sm,
187                        m_cubes,
188                        n_cubes,
189                        batch_cubes,
190                        num_sms,
191                        sm_usage,
192                    })
193                }
194            }
195            CubeCountPlanSelection::Flattened => {
196                if m_cubes * n_cubes * batch_cubes >= max_x {
197                    None
198                } else {
199                    Some(CubeCountPlan::Flattened {
200                        m_cubes,
201                        n_cubes,
202                        batch_cubes,
203                    })
204                }
205            }
206            CubeCountPlanSelection::Spread => None,
207        };
208
209        plan.unwrap_or_else(|| {
210            spread_cube_count_plan(m_cubes, n_cubes, batch_cubes, max_x, max_y, max_z)
211        })
212    }
213
214    /// Because we don't want to store the CubeCountPlan values in config, we have to recompute it
215    ///
216    /// Assumes the hypercube config is valid
217    pub fn from_config(
218        config: &HypercubeConfig,
219        problem: &MatmulProblem,
220        max_cube_count: CubeCount,
221    ) -> CubeCountPlan {
222        let (max_x, max_y, max_z) = match max_cube_count {
223            CubeCount::Static(x, y, z) => (x, y, z),
224            CubeCount::Dynamic(_) => panic!("Dynamic cube count not supported for cube count plan"),
225        };
226
227        let m_cubes = (problem.m as u32).div_ceil(config.cube_span.m);
228        let n_cubes = (problem.n as u32).div_ceil(config.cube_span.n);
229        let batch_cubes = (problem.num_batches() as u32).div_ceil(config.cube_span.batch);
230
231        match config.cube_count_plan_config {
232            CubeCountPlanConfig::FromProblem => CubeCountPlan::FromProblem {
233                m_cubes,
234                n_cubes,
235                batch_cubes,
236            },
237            CubeCountPlanConfig::Sm {
238                cubes_first,
239                num_sms,
240                sm_usage,
241                ..
242            } => {
243                let (num_sms_used, cubes_per_sm) =
244                    sm_usage.allocate(num_sms, m_cubes * n_cubes * batch_cubes);
245                CubeCountPlan::Sm {
246                    cubes_first,
247                    num_sms_used,
248                    cubes_per_sm,
249                    m_cubes,
250                    n_cubes,
251                    batch_cubes,
252                    num_sms,
253                    sm_usage,
254                }
255            }
256            CubeCountPlanConfig::Flattened => CubeCountPlan::Flattened {
257                m_cubes,
258                n_cubes,
259                batch_cubes,
260            },
261            CubeCountPlanConfig::Spread { .. } => {
262                spread_cube_count_plan(m_cubes, n_cubes, batch_cubes, max_x, max_y, max_z)
263            }
264        }
265    }
266}
267
268impl CubeCountPlanConfig {
269    /// Whether the CubeCount will have more cubes than strictly necessary.
270    pub fn can_yield_extra_cubes(&self) -> bool {
271        match self {
272            CubeCountPlanConfig::FromProblem | CubeCountPlanConfig::Flattened => false,
273            CubeCountPlanConfig::Sm {
274                can_yield_extra_cubes,
275                ..
276            } => *can_yield_extra_cubes,
277            CubeCountPlanConfig::Spread {
278                can_yield_extra_cubes,
279            } => *can_yield_extra_cubes,
280        }
281    }
282
283    pub(crate) fn from_cube_count_plan(cube_count_plan: CubeCountPlan) -> CubeCountPlanConfig {
284        match cube_count_plan {
285            CubeCountPlan::FromProblem { .. } => CubeCountPlanConfig::FromProblem,
286            CubeCountPlan::Sm {
287                cubes_first,
288                num_sms,
289                sm_usage,
290                ..
291            } => CubeCountPlanConfig::Sm {
292                cubes_first,
293                num_sms,
294                sm_usage,
295                can_yield_extra_cubes: cube_count_plan.can_yield_extra_cubes(),
296            },
297            CubeCountPlan::Flattened { .. } => CubeCountPlanConfig::Flattened,
298            CubeCountPlan::Spread { .. } => CubeCountPlanConfig::Spread {
299                can_yield_extra_cubes: cube_count_plan.can_yield_extra_cubes(),
300            },
301        }
302    }
303}
304
305/// Heuristic algorithm to factor the total number of cubes into (x, y, z) dimensions
306/// such that no dimension surpasses its maximum.
307pub(crate) fn spread_cube_count_plan(
308    m_cubes: u32,
309    n_cubes: u32,
310    batch_cubes: u32,
311    max_x: u32,
312    max_y: u32,
313    max_z: u32,
314) -> CubeCountPlan {
315    let total_cubes = m_cubes * n_cubes * batch_cubes;
316
317    let mut best = None;
318
319    let mut z = max_z;
320    while z >= 1 {
321        let xy_cubes = total_cubes.div_ceil(z);
322
323        let mut y = max_y;
324        while y >= 1 {
325            let x = xy_cubes.div_ceil(y);
326            if x <= max_x {
327                let volume = x * y * z;
328                let score = (volume, std::cmp::Reverse(z), std::cmp::Reverse(y));
329
330                if best.is_none_or(|(_, _, _, _, best_score)| score < best_score) {
331                    best = Some((x, y, z, volume, score));
332                }
333            }
334
335            if y == 1 {
336                break;
337            }
338            y /= 2;
339        }
340
341        if z == 1 {
342            break;
343        }
344        z /= 2;
345    }
346
347    if let Some((x, y, z, _, _)) = best {
348        CubeCountPlan::Spread {
349            m_cubes,
350            n_cubes,
351            batch_cubes,
352            x,
353            y,
354            z,
355        }
356    } else {
357        panic!("No valid cube spread plan")
358    }
359}
360
361impl CubeCountPlan {
362    // Resolves the cube count plan into a concrete [`CubeCount`].
363    pub fn resolve(&self) -> CubeCount {
364        match self {
365            CubeCountPlan::FromProblem {
366                m_cubes,
367                n_cubes,
368                batch_cubes,
369            } => CubeCount::Static(*m_cubes, *n_cubes, *batch_cubes),
370            CubeCountPlan::Sm {
371                cubes_first,
372                num_sms_used,
373                cubes_per_sm,
374                ..
375            } => match cubes_first {
376                true => CubeCount::Static(*cubes_per_sm, *num_sms_used, 1),
377                false => CubeCount::Static(*num_sms_used, *cubes_per_sm, 1),
378            },
379            CubeCountPlan::Flattened {
380                m_cubes,
381                n_cubes,
382                batch_cubes,
383            } => CubeCount::Static(*m_cubes * *n_cubes * *batch_cubes, 1, 1),
384            CubeCountPlan::Spread { x, y, z, .. } => CubeCount::Static(*x, *y, *z),
385        }
386    }
387
388    /// Make a CubeCountInput from CubeCountPlan
389    pub fn as_args<'a, R: Runtime>(&self) -> CubeCountInputArgs<'a, R> {
390        match self {
391            CubeCountPlan::FromProblem { .. } => CubeCountInputArgs::FromProblem,
392            CubeCountPlan::Sm {
393                cubes_first,
394                m_cubes,
395                n_cubes,
396                batch_cubes,
397                ..
398            } => match cubes_first {
399                true => CubeCountInputArgs::CubeFirst {
400                    m_cubes: ScalarArg::new(*m_cubes),
401                    n_cubes: ScalarArg::new(*n_cubes),
402                    batch_cubes: ScalarArg::new(*batch_cubes),
403                },
404                false => CubeCountInputArgs::SmFirst {
405                    m_cubes: ScalarArg::new(*m_cubes),
406                    n_cubes: ScalarArg::new(*n_cubes),
407                    batch_cubes: ScalarArg::new(*batch_cubes),
408                },
409            },
410            CubeCountPlan::Flattened {
411                m_cubes, n_cubes, ..
412            } => CubeCountInputArgs::Flattened {
413                m_cubes: ScalarArg::new(*m_cubes),
414                n_cubes: ScalarArg::new(*n_cubes),
415            },
416            CubeCountPlan::Spread {
417                m_cubes,
418                n_cubes,
419                batch_cubes,
420                ..
421            } => CubeCountInputArgs::Spread {
422                m_cubes: ScalarArg::new(*m_cubes),
423                n_cubes: ScalarArg::new(*n_cubes),
424                batch_cubes: ScalarArg::new(*batch_cubes),
425            },
426        }
427    }
428}
429
430#[cube]
431impl CubeCountInput {
432    /// Returns the number of valid cubes
433    pub fn num_valid_cubes(&self) -> u32 {
434        match self {
435            CubeCountInput::FromProblem | CubeCountInput::Flattened { .. } => {
436                panic!("Shouldn't need to be called because the cube count should always be exact")
437            }
438            CubeCountInput::SmFirst {
439                m_cubes,
440                n_cubes,
441                batch_cubes,
442            } => *m_cubes * *n_cubes * *batch_cubes,
443            CubeCountInput::CubeFirst {
444                m_cubes,
445                n_cubes,
446                batch_cubes,
447            } => *m_cubes * *n_cubes * *batch_cubes,
448            CubeCountInput::Spread {
449                m_cubes,
450                n_cubes,
451                batch_cubes,
452            } => *m_cubes * *n_cubes * *batch_cubes,
453        }
454    }
455
456    /// Given a cube position (SM ID, local index), returns the tensor coordinates (m, n, batch).
457    pub fn cube_pos_to_tensor_pos(&self, #[comptime] global_order: GlobalOrder) -> (u32, u32, u32) {
458        match self {
459            CubeCountInput::FromProblem => (CUBE_POS_X, CUBE_POS_Y, CUBE_POS_Z),
460            CubeCountInput::SmFirst {
461                m_cubes, n_cubes, ..
462            } => self.absolute_index_to_m_n_batch(CUBE_POS, *m_cubes, *n_cubes, global_order),
463            CubeCountInput::CubeFirst {
464                m_cubes, n_cubes, ..
465            } => self.absolute_index_to_m_n_batch(
466                CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X,
467                *m_cubes,
468                *n_cubes,
469                global_order,
470            ),
471            CubeCountInput::Flattened { m_cubes, n_cubes } => {
472                self.absolute_index_to_m_n_batch(CUBE_POS_X, *m_cubes, *n_cubes, global_order)
473            }
474            CubeCountInput::Spread {
475                m_cubes, n_cubes, ..
476            } => self.absolute_index_to_m_n_batch(CUBE_POS, *m_cubes, *n_cubes, global_order),
477        }
478    }
479
480    fn absolute_index_to_m_n_batch(
481        &self,
482        absolute_index: u32,
483        m_cubes: u32,
484        n_cubes: u32,
485        #[comptime] global_order: GlobalOrder,
486    ) -> (u32, u32, u32) {
487        let batch_stride = m_cubes * n_cubes;
488        let batch_pos = absolute_index / batch_stride;
489        let matrix_pos = absolute_index % batch_stride;
490
491        let (m_pos, n_pos) = match comptime!(global_order) {
492            GlobalOrder::RowMajor => (matrix_pos / n_cubes, matrix_pos % n_cubes),
493            GlobalOrder::ColMajor => (matrix_pos % m_cubes, matrix_pos / m_cubes),
494            GlobalOrder::SwizzleRowMajor(w) => {
495                let (x, y) = swizzle(matrix_pos, n_cubes, w);
496                (y, x)
497            }
498            GlobalOrder::SwizzleColMajor(w) => swizzle(matrix_pos, m_cubes, w),
499        };
500
501        (m_pos, n_pos, batch_pos)
502    }
503}