cubek_std/cube_count/hypercube/cube_count/
plan.rs1use cubecl::CubeCount;
2
3use crate::cube_count::{CubeCountStrategy, GlobalOrder, HypercubeBlueprint, SmAllocation};
4
5#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
6pub struct CubeCountPlan {
7 pub global_order: GlobalOrder,
8 pub kind: CubeCountPlanKind,
9}
10
11#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
12pub struct Count3d {
13 pub x: u32,
14 pub y: u32,
15 pub z: u32,
16}
17
18impl From<(u32, u32, u32)> for Count3d {
19 fn from(value: (u32, u32, u32)) -> Self {
20 Count3d {
21 x: value.0,
22 y: value.1,
23 z: value.2,
24 }
25 }
26}
27
28impl Count3d {
29 pub(crate) fn total(&self) -> u64 {
31 self.x as u64 * self.y as u64 * self.z as u64
32 }
33}
34
35#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
36pub enum CubeCountPlanKind {
37 FromProblem {
38 problem_count: Count3d,
39 },
40 Sm {
41 cubes_first: bool,
42 num_sms_used: u32,
43 cubes_per_sm: u32,
44 problem_count: Count3d,
45 num_sms: u32,
46 sm_usage: SmAllocation,
47 },
48 Flattened {
49 problem_count: Count3d,
50 },
51 Spread {
52 problem_count: Count3d,
53 spread_count: Count3d,
54 },
55}
56
57impl CubeCountPlan {
58 pub fn from_blueprint(
60 blueprint: &HypercubeBlueprint,
61 problem_count: Count3d,
62 max_cube_count: &(u32, u32, u32),
63 ) -> CubeCountPlan {
64 let (max_x, max_y, max_z) = *max_cube_count;
65
66 let plan_kind = match blueprint.cube_count_strategy {
67 CubeCountStrategy::FromProblem => {
68 if problem_count.x > max_x || problem_count.y > max_y || problem_count.z > max_z {
69 None
70 } else {
71 Some(CubeCountPlanKind::FromProblem { problem_count })
72 }
73 }
74 CubeCountStrategy::Sm {
75 cubes_first,
76 num_sms,
77 sm_usage,
78 } => {
79 let (num_sms_used, cubes_per_sm) =
80 sm_usage.allocate(num_sms, problem_count.total() as usize);
81
82 if (cubes_per_sm >= if cubes_first { max_x } else { max_y })
83 || (num_sms_used >= if cubes_first { max_y } else { max_x })
84 {
85 None
86 } else {
87 Some(CubeCountPlanKind::Sm {
88 cubes_first,
89 num_sms_used,
90 cubes_per_sm,
91 problem_count,
92 num_sms,
93 sm_usage,
94 })
95 }
96 }
97 CubeCountStrategy::Flattened => {
98 if problem_count.total() >= max_x as u64 {
99 None
100 } else {
101 Some(CubeCountPlanKind::Flattened { problem_count })
102 }
103 }
104 CubeCountStrategy::Spread => None,
105 };
106
107 let global_order = match blueprint.global_order {
111 GlobalOrder::SwizzleRow(w) if !problem_count.x.is_multiple_of(w) => {
112 GlobalOrder::RowMajor
113 }
114 GlobalOrder::SwizzleCol(w) if !problem_count.y.is_multiple_of(w) => {
115 GlobalOrder::ColMajor
116 }
117 other => other,
118 }
119 .canonicalize();
120
121 CubeCountPlan {
122 global_order,
123 kind: plan_kind
124 .unwrap_or_else(|| spread_cube_count_plan(problem_count, max_x, max_y, max_z)),
125 }
126 }
127
128 pub fn new_from_problem(target_count: Count3d) -> Self {
129 Self {
130 global_order: Default::default(),
131 kind: CubeCountPlanKind::FromProblem {
132 problem_count: target_count,
133 },
134 }
135 }
136
137 pub fn can_yield_extra_cubes(&self) -> bool {
138 self.kind.can_yield_extra_cubes()
139 }
140
141 pub fn resolve(&self) -> CubeCount {
142 self.kind.resolve()
143 }
144}
145
146impl CubeCountPlanKind {
147 pub fn can_yield_extra_cubes(&self) -> bool {
148 match self {
149 CubeCountPlanKind::FromProblem { .. } | CubeCountPlanKind::Flattened { .. } => false,
150
151 CubeCountPlanKind::Sm {
152 num_sms_used,
153 cubes_per_sm,
154 problem_count,
155 ..
156 } => (num_sms_used * cubes_per_sm) as u64 != problem_count.total(),
157
158 CubeCountPlanKind::Spread {
159 problem_count,
160 spread_count,
161 } => problem_count.total() != spread_count.total(),
162 }
163 }
164
165 fn resolve(&self) -> CubeCount {
166 match self {
167 CubeCountPlanKind::FromProblem { problem_count } => {
168 CubeCount::Static(problem_count.x, problem_count.y, problem_count.z)
169 }
170
171 CubeCountPlanKind::Sm {
172 cubes_first,
173 num_sms_used,
174 cubes_per_sm,
175 ..
176 } => {
177 if *cubes_first {
178 CubeCount::Static(*cubes_per_sm, *num_sms_used, 1)
179 } else {
180 CubeCount::Static(*num_sms_used, *cubes_per_sm, 1)
181 }
182 }
183
184 CubeCountPlanKind::Flattened { problem_count } => {
185 CubeCount::Static(problem_count.total() as u32, 1, 1)
186 }
187
188 CubeCountPlanKind::Spread { spread_count, .. } => {
189 CubeCount::Static(spread_count.x, spread_count.y, spread_count.z)
190 }
191 }
192 }
193}
194
195fn spread_cube_count_plan(
198 problem_count: Count3d,
199 max_x: u32,
200 max_y: u32,
201 max_z: u32,
202) -> CubeCountPlanKind {
203 let mut best = None;
204
205 let mut z = max_z;
206 while z >= 1 {
207 let xy_cubes = problem_count.total().div_ceil(z as u64);
208
209 let mut y = max_y;
210 while y >= 1 {
211 let x64 = xy_cubes.div_ceil(y as u64);
212 if x64 <= max_x as u64 {
213 let x = x64 as u32;
214 let volume = x as u64 * y as u64 * z as u64;
215 let score = (volume, std::cmp::Reverse(z), std::cmp::Reverse(y));
216
217 if best.is_none_or(|(_, _, _, _, best_score)| score < best_score) {
218 best = Some((x, y, z, volume, score));
219 }
220 }
221
222 if y == 1 {
223 break;
224 }
225 y /= 2;
226 }
227
228 if z == 1 {
229 break;
230 }
231 z /= 2;
232 }
233
234 if let Some((x, y, z, _, _)) = best {
235 CubeCountPlanKind::Spread {
236 problem_count,
237 spread_count: Count3d { x, y, z },
238 }
239 } else {
240 panic!("No valid cube spread plan")
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 const MAX_CUBE_COUNT: (u32, u32, u32) = (65535, 65535, 65535);
249
250 #[test]
251 fn swizzle_row_falls_back_when_m_cubes_not_divisible_by_w() {
252 let blueprint = HypercubeBlueprint::builder()
254 .global_order(GlobalOrder::SwizzleRow(4))
255 .build();
256 let plan = CubeCountPlan::from_blueprint(
257 &blueprint,
258 Count3d { x: 3, y: 5, z: 1 },
259 &MAX_CUBE_COUNT,
260 );
261 assert_eq!(plan.global_order, GlobalOrder::RowMajor);
262 }
263
264 #[test]
265 fn swizzle_row_kept_when_m_cubes_divisible_by_w() {
266 let blueprint = HypercubeBlueprint::builder()
267 .global_order(GlobalOrder::SwizzleRow(4))
268 .build();
269 let plan = CubeCountPlan::from_blueprint(
270 &blueprint,
271 Count3d { x: 8, y: 5, z: 1 },
272 &MAX_CUBE_COUNT,
273 );
274 assert_eq!(plan.global_order, GlobalOrder::SwizzleRow(4));
275 }
276
277 #[test]
278 fn swizzle_col_falls_back_when_n_cubes_not_divisible_by_w() {
279 let blueprint = HypercubeBlueprint::builder()
280 .global_order(GlobalOrder::SwizzleCol(4))
281 .build();
282 let plan = CubeCountPlan::from_blueprint(
283 &blueprint,
284 Count3d { x: 8, y: 3, z: 1 },
285 &MAX_CUBE_COUNT,
286 );
287 assert_eq!(plan.global_order, GlobalOrder::ColMajor);
288 }
289}