Skip to main content

cubek_std/cube_count/hypercube/cube_count/
plan.rs

1use cubecl::CubeCount;
2
3use crate::cube_count::{CubeCountStrategy, GlobalOrder, HypercubeBlueprint, SmAllocation};
4
5#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
6pub struct CubeCountPlan {
7    pub global_order: GlobalOrder,
8    pub kind: CubeCountPlanKind,
9}
10
11#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
12pub struct Count3d {
13    pub x: u32,
14    pub y: u32,
15    pub z: u32,
16}
17
18impl From<(u32, u32, u32)> for Count3d {
19    fn from(value: (u32, u32, u32)) -> Self {
20        Count3d {
21            x: value.0,
22            y: value.1,
23            z: value.2,
24        }
25    }
26}
27
28impl Count3d {
29    // Use u64 to avoid overflow when multiplying large cube counts.
30    pub(crate) fn total(&self) -> u64 {
31        self.x as u64 * self.y as u64 * self.z as u64
32    }
33}
34
35#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
36pub enum CubeCountPlanKind {
37    FromProblem {
38        problem_count: Count3d,
39    },
40    Sm {
41        cubes_first: bool,
42        num_sms_used: u32,
43        cubes_per_sm: u32,
44        problem_count: Count3d,
45        num_sms: u32,
46        sm_usage: SmAllocation,
47    },
48    Flattened {
49        problem_count: Count3d,
50    },
51    Spread {
52        problem_count: Count3d,
53        spread_count: Count3d,
54    },
55}
56
57impl CubeCountPlan {
58    // Will check if the wanted cube count plan is possible, otherwise will fallback to spread
59    pub fn from_blueprint(
60        blueprint: &HypercubeBlueprint,
61        problem_count: Count3d,
62        max_cube_count: &(u32, u32, u32),
63    ) -> CubeCountPlan {
64        let (max_x, max_y, max_z) = *max_cube_count;
65
66        let plan_kind = match blueprint.cube_count_strategy {
67            CubeCountStrategy::FromProblem => {
68                if problem_count.x > max_x || problem_count.y > max_y || problem_count.z > max_z {
69                    None
70                } else {
71                    Some(CubeCountPlanKind::FromProblem { problem_count })
72                }
73            }
74            CubeCountStrategy::Sm {
75                cubes_first,
76                num_sms,
77                sm_usage,
78            } => {
79                let (num_sms_used, cubes_per_sm) =
80                    sm_usage.allocate(num_sms, problem_count.total() as usize);
81
82                if (cubes_per_sm >= if cubes_first { max_x } else { max_y })
83                    || (num_sms_used >= if cubes_first { max_y } else { max_x })
84                {
85                    None
86                } else {
87                    Some(CubeCountPlanKind::Sm {
88                        cubes_first,
89                        num_sms_used,
90                        cubes_per_sm,
91                        problem_count,
92                        num_sms,
93                        sm_usage,
94                    })
95                }
96            }
97            CubeCountStrategy::Flattened => {
98                if problem_count.total() >= max_x as u64 {
99                    None
100                } else {
101                    Some(CubeCountPlanKind::Flattened { problem_count })
102                }
103            }
104            CubeCountStrategy::Spread => None,
105        };
106
107        // Validate swizzle: fall back to non-swizzled order when the swizzle width
108        // does not evenly divide the problem dimension (m_cubes for Row, n_cubes for Col).
109        // Without this check, the swizzle produces incorrect cube-to-tile mappings.
110        let global_order = match blueprint.global_order {
111            GlobalOrder::SwizzleRow(w) if !problem_count.x.is_multiple_of(w) => {
112                GlobalOrder::RowMajor
113            }
114            GlobalOrder::SwizzleCol(w) if !problem_count.y.is_multiple_of(w) => {
115                GlobalOrder::ColMajor
116            }
117            other => other,
118        }
119        .canonicalize();
120
121        CubeCountPlan {
122            global_order,
123            kind: plan_kind
124                .unwrap_or_else(|| spread_cube_count_plan(problem_count, max_x, max_y, max_z)),
125        }
126    }
127
128    pub fn new_from_problem(target_count: Count3d) -> Self {
129        Self {
130            global_order: Default::default(),
131            kind: CubeCountPlanKind::FromProblem {
132                problem_count: target_count,
133            },
134        }
135    }
136
137    pub fn can_yield_extra_cubes(&self) -> bool {
138        self.kind.can_yield_extra_cubes()
139    }
140
141    pub fn resolve(&self) -> CubeCount {
142        self.kind.resolve()
143    }
144}
145
146impl CubeCountPlanKind {
147    pub fn can_yield_extra_cubes(&self) -> bool {
148        match self {
149            CubeCountPlanKind::FromProblem { .. } | CubeCountPlanKind::Flattened { .. } => false,
150
151            CubeCountPlanKind::Sm {
152                num_sms_used,
153                cubes_per_sm,
154                problem_count,
155                ..
156            } => (num_sms_used * cubes_per_sm) as u64 != problem_count.total(),
157
158            CubeCountPlanKind::Spread {
159                problem_count,
160                spread_count,
161            } => problem_count.total() != spread_count.total(),
162        }
163    }
164
165    fn resolve(&self) -> CubeCount {
166        match self {
167            CubeCountPlanKind::FromProblem { problem_count } => {
168                CubeCount::Static(problem_count.x, problem_count.y, problem_count.z)
169            }
170
171            CubeCountPlanKind::Sm {
172                cubes_first,
173                num_sms_used,
174                cubes_per_sm,
175                ..
176            } => {
177                if *cubes_first {
178                    CubeCount::Static(*cubes_per_sm, *num_sms_used, 1)
179                } else {
180                    CubeCount::Static(*num_sms_used, *cubes_per_sm, 1)
181                }
182            }
183
184            CubeCountPlanKind::Flattened { problem_count } => {
185                CubeCount::Static(problem_count.total() as u32, 1, 1)
186            }
187
188            CubeCountPlanKind::Spread { spread_count, .. } => {
189                CubeCount::Static(spread_count.x, spread_count.y, spread_count.z)
190            }
191        }
192    }
193}
194
195/// Heuristic algorithm to factor the total number of cubes into (x, y, z) dimensions
196/// such that no dimension surpasses its maximum.
197fn spread_cube_count_plan(
198    problem_count: Count3d,
199    max_x: u32,
200    max_y: u32,
201    max_z: u32,
202) -> CubeCountPlanKind {
203    let mut best = None;
204
205    let mut z = max_z;
206    while z >= 1 {
207        let xy_cubes = problem_count.total().div_ceil(z as u64);
208
209        let mut y = max_y;
210        while y >= 1 {
211            let x64 = xy_cubes.div_ceil(y as u64);
212            if x64 <= max_x as u64 {
213                let x = x64 as u32;
214                let volume = x as u64 * y as u64 * z as u64;
215                let score = (volume, std::cmp::Reverse(z), std::cmp::Reverse(y));
216
217                if best.is_none_or(|(_, _, _, _, best_score)| score < best_score) {
218                    best = Some((x, y, z, volume, score));
219                }
220            }
221
222            if y == 1 {
223                break;
224            }
225            y /= 2;
226        }
227
228        if z == 1 {
229            break;
230        }
231        z /= 2;
232    }
233
234    if let Some((x, y, z, _, _)) = best {
235        CubeCountPlanKind::Spread {
236            problem_count,
237            spread_count: Count3d { x, y, z },
238        }
239    } else {
240        panic!("No valid cube spread plan")
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    const MAX_CUBE_COUNT: (u32, u32, u32) = (65535, 65535, 65535);
249
250    #[test]
251    fn swizzle_row_falls_back_when_m_cubes_not_divisible_by_w() {
252        // m_cubes=3 is not divisible by w=4, must fall back to RowMajor
253        let blueprint = HypercubeBlueprint::builder()
254            .global_order(GlobalOrder::SwizzleRow(4))
255            .build();
256        let plan = CubeCountPlan::from_blueprint(
257            &blueprint,
258            Count3d { x: 3, y: 5, z: 1 },
259            &MAX_CUBE_COUNT,
260        );
261        assert_eq!(plan.global_order, GlobalOrder::RowMajor);
262    }
263
264    #[test]
265    fn swizzle_row_kept_when_m_cubes_divisible_by_w() {
266        let blueprint = HypercubeBlueprint::builder()
267            .global_order(GlobalOrder::SwizzleRow(4))
268            .build();
269        let plan = CubeCountPlan::from_blueprint(
270            &blueprint,
271            Count3d { x: 8, y: 5, z: 1 },
272            &MAX_CUBE_COUNT,
273        );
274        assert_eq!(plan.global_order, GlobalOrder::SwizzleRow(4));
275    }
276
277    #[test]
278    fn swizzle_col_falls_back_when_n_cubes_not_divisible_by_w() {
279        let blueprint = HypercubeBlueprint::builder()
280            .global_order(GlobalOrder::SwizzleCol(4))
281            .build();
282        let plan = CubeCountPlan::from_blueprint(
283            &blueprint,
284            Count3d { x: 8, y: 3, z: 1 },
285            &MAX_CUBE_COUNT,
286        );
287        assert_eq!(plan.global_order, GlobalOrder::ColMajor);
288    }
289}