cubecl_matmul/components/batch/partitioned_matmul/hypercube/
base.rs1use cubecl_core::CubeCount;
2
3use crate::components::{
4 MatmulProblem, MatmulSetupError, TilingScheme,
5 batch::partitioned_matmul::hypercube::{
6 cube_count_plan::{CubeCountPlan, CubeCountPlanConfig, CubeCountPlanSelection},
7 global_order::{GlobalOrder, GlobalOrderSelection},
8 },
9};
10
11#[derive(Debug, Clone)]
12pub struct HypercubeSelection {
15 pub cube_span: CubeSpan,
16 pub global_order: GlobalOrder,
17 pub cube_count_plan_selection: CubeCountPlanSelection,
18}
19
20pub struct HypercubeSelectionBuilder<'a> {
22 tiling_scheme: &'a TilingScheme,
23 global_order: GlobalOrderSelection,
24 cube_count_plan_config: Option<CubeCountPlanSelection>,
25}
26
27#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
28pub struct HypercubeConfig {
32 pub cube_span: CubeSpan,
33 pub global_order: GlobalOrder,
34 pub cube_count_plan_config: CubeCountPlanConfig,
35}
36
37#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
38pub struct CubeSpan {
40 pub m: u32,
41 pub n: u32,
42 pub batch: u32,
43}
44
45impl HypercubeSelection {
46 pub fn builder<'a>(tiling_scheme: &'a TilingScheme) -> HypercubeSelectionBuilder<'a> {
48 HypercubeSelectionBuilder::new(tiling_scheme)
49 }
50
51 pub(crate) fn to_hypercube_config(
52 &self,
53 problem: &MatmulProblem,
54 max_cube_count: CubeCount,
55 ) -> HypercubeConfig {
56 let cube_count_plan = CubeCountPlan::from_selection(self, problem, max_cube_count);
57 let cube_count_plan_config = CubeCountPlanConfig::from_cube_count_plan(cube_count_plan);
58
59 HypercubeConfig {
60 cube_span: self.cube_span,
61 global_order: self.global_order,
62 cube_count_plan_config,
63 }
64 }
65}
66
67impl HypercubeConfig {
68 pub fn validate(&self, problem: &MatmulProblem) -> Result<(), MatmulSetupError> {
71 let m_cubes = (problem.m as u32).div_ceil(self.cube_span.m);
72 let n_cubes = (problem.n as u32).div_ceil(self.cube_span.n);
73
74 use GlobalOrder::*;
75
76 match self.global_order {
77 RowMajor | ColMajor => Ok(()),
78
79 SwizzleRowMajor(w) if m_cubes % w != 0 => {
80 Err(MatmulSetupError::InvalidConfig(Box::new(format!(
81 "In swizzle row major, number of cubes in m {m_cubes:?} must be divisible by swizzle step length {w:?}."
82 ))))
83 }
84
85 SwizzleColMajor(w) if n_cubes % w != 0 => {
86 Err(MatmulSetupError::InvalidConfig(Box::new(format!(
87 "In swizzle col major, number of cubes in n {n_cubes:?} must be divisible by swizzle step length {w:?}."
88 ))))
89 }
90
91 _ => Ok(()),
92 }
93 }
94}
95
96impl<'a> HypercubeSelectionBuilder<'a> {
97 fn new(tiling_scheme: &'a TilingScheme) -> Self {
98 Self {
99 tiling_scheme,
100 global_order: GlobalOrderSelection::default(),
101 cube_count_plan_config: None,
102 }
103 }
104
105 pub fn global_order(mut self, global_order: GlobalOrderSelection) -> Self {
107 self.global_order = global_order;
108 self
109 }
110
111 pub fn cube_count_plan(mut self, cube_count_plan_config: CubeCountPlanSelection) -> Self {
113 self.cube_count_plan_config = Some(cube_count_plan_config);
114 self
115 }
116
117 pub fn build(self) -> HypercubeSelection {
119 let cube_span = CubeSpan {
120 m: self.tiling_scheme.elements_in_global_partition_m(),
121 n: self.tiling_scheme.elements_in_global_partition_n(),
122 batch: self.tiling_scheme.global_partition_size.batches,
123 };
124
125 let global_order = self.global_order.into_order(&cube_span);
126 let cube_pos_strategy = self.cube_count_plan_config.unwrap_or_default();
127
128 HypercubeSelection {
129 cube_span,
130 global_order,
131 cube_count_plan_selection: cube_pos_strategy,
132 }
133 }
134}
135
136impl HypercubeConfig {
137 pub fn cube_count_plan(
139 &self,
140 problem: &MatmulProblem,
141 max_cube_count: CubeCount,
142 ) -> CubeCountPlan {
143 CubeCountPlan::from_config(self, problem, max_cube_count)
144 }
145}