cubecl_matmul/components/batch/partitioned_matmul/hypercube/
base.rs

1use cubecl_core::CubeCount;
2
3use crate::components::{
4    MatmulProblem, MatmulSetupError, TilingScheme,
5    batch::partitioned_matmul::hypercube::{
6        cube_count_plan::{CubeCountPlan, CubeCountPlanConfig, CubeCountPlanSelection},
7        global_order::{GlobalOrder, GlobalOrderSelection},
8    },
9};
10
11#[derive(Debug, Clone)]
12/// Determines how to launch the hypercube, i.e. anything
13/// relevant to CubeCount and where a Cube at a cube position should work
14pub struct HypercubeSelection {
15    pub cube_span: CubeSpan,
16    pub global_order: GlobalOrder,
17    pub cube_count_plan_selection: CubeCountPlanSelection,
18}
19
20/// Builder for creating a [HypercubeSelection]
21pub struct HypercubeSelectionBuilder<'a> {
22    tiling_scheme: &'a TilingScheme,
23    global_order: GlobalOrderSelection,
24    cube_count_plan_config: Option<CubeCountPlanSelection>,
25}
26
27#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
28/// Determines how to launch the hypercube, i.e. anything
29/// relevant to CubeCount and where a Cube at a cube position should work
30/// Similar to [HyperCubeSelection] but injected in kernel as comptime struct
31pub struct HypercubeConfig {
32    pub cube_span: CubeSpan,
33    pub global_order: GlobalOrder,
34    pub cube_count_plan_config: CubeCountPlanConfig,
35}
36
37#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
38// Number of elements each cube covers in the tensors
39pub struct CubeSpan {
40    pub m: u32,
41    pub n: u32,
42    pub batch: u32,
43}
44
45impl HypercubeSelection {
46    /// Create a builder for HypercubeSelection
47    pub fn builder<'a>(tiling_scheme: &'a TilingScheme) -> HypercubeSelectionBuilder<'a> {
48        HypercubeSelectionBuilder::new(tiling_scheme)
49    }
50
51    pub(crate) fn to_hypercube_config(
52        &self,
53        problem: &MatmulProblem,
54        max_cube_count: CubeCount,
55    ) -> HypercubeConfig {
56        let cube_count_plan = CubeCountPlan::from_selection(self, problem, max_cube_count);
57        let cube_count_plan_config = CubeCountPlanConfig::from_cube_count_plan(cube_count_plan);
58
59        HypercubeConfig {
60            cube_span: self.cube_span,
61            global_order: self.global_order,
62            cube_count_plan_config,
63        }
64    }
65}
66
67impl HypercubeConfig {
68    /// Returns an error if:
69    /// - The global order is swizzle but its assumptions are not met
70    pub fn validate(&self, problem: &MatmulProblem) -> Result<(), MatmulSetupError> {
71        let m_cubes = (problem.m as u32).div_ceil(self.cube_span.m);
72        let n_cubes = (problem.n as u32).div_ceil(self.cube_span.n);
73
74        use GlobalOrder::*;
75
76        match self.global_order {
77            RowMajor | ColMajor => Ok(()),
78
79            SwizzleRowMajor(w) if m_cubes % w != 0 => {
80                Err(MatmulSetupError::InvalidConfig(Box::new(format!(
81                    "In swizzle row major, number of cubes in m {m_cubes:?} must be divisible by swizzle step length {w:?}."
82                ))))
83            }
84
85            SwizzleColMajor(w) if n_cubes % w != 0 => {
86                Err(MatmulSetupError::InvalidConfig(Box::new(format!(
87                    "In swizzle col major, number of cubes in n {n_cubes:?} must be divisible by swizzle step length {w:?}."
88                ))))
89            }
90
91            _ => Ok(()),
92        }
93    }
94}
95
96impl<'a> HypercubeSelectionBuilder<'a> {
97    fn new(tiling_scheme: &'a TilingScheme) -> Self {
98        Self {
99            tiling_scheme,
100            global_order: GlobalOrderSelection::default(),
101            cube_count_plan_config: None,
102        }
103    }
104
105    /// Set the [GlobalOrderSelection]
106    pub fn global_order(mut self, global_order: GlobalOrderSelection) -> Self {
107        self.global_order = global_order;
108        self
109    }
110
111    /// Set the [CubeCountPlanSelection]
112    pub fn cube_count_plan(mut self, cube_count_plan_config: CubeCountPlanSelection) -> Self {
113        self.cube_count_plan_config = Some(cube_count_plan_config);
114        self
115    }
116
117    /// Build the HypercubeSelection
118    pub fn build(self) -> HypercubeSelection {
119        let cube_span = CubeSpan {
120            m: self.tiling_scheme.elements_in_global_partition_m(),
121            n: self.tiling_scheme.elements_in_global_partition_n(),
122            batch: self.tiling_scheme.global_partition_size.batches,
123        };
124
125        let global_order = self.global_order.into_order(&cube_span);
126        let cube_pos_strategy = self.cube_count_plan_config.unwrap_or_default();
127
128        HypercubeSelection {
129            cube_span,
130            global_order,
131            cube_count_plan_selection: cube_pos_strategy,
132        }
133    }
134}
135
136impl HypercubeConfig {
137    /// Make a CubeCountPlan from the problem, constrained to not exceed the maximal cube count
138    pub fn cube_count_plan(
139        &self,
140        problem: &MatmulProblem,
141        max_cube_count: CubeCount,
142    ) -> CubeCountPlan {
143        CubeCountPlan::from_config(self, problem, max_cube_count)
144    }
145}