cubek_std/cube_count/hypercube/
cube_mapping.rs1use cubecl::prelude::*;
2
3use crate::cube_count::{CubeCountPlan, CubeCountPlanKind, GlobalOrder, swizzle};
4
5#[derive(CubeType, CubeLaunch)]
6pub struct CubeMapping {
13 strategy: CubeMappingStrategy,
14 #[cube(comptime)]
15 pub can_yield_extra_cubes: bool,
16 #[cube(comptime)]
17 global_order: GlobalOrder,
18}
19
20#[derive(CubeType, CubeLaunch)]
21#[allow(unused)] pub(crate) enum CubeMappingStrategy {
26 FromProblem,
27 SmFirst {
28 x_cubes: u32,
29 y_cubes: u32,
30 z_cubes: u32,
31 },
32 CubeFirst {
33 x_cubes: u32,
34 y_cubes: u32,
35 z_cubes: u32,
36 },
37 Flattened {
38 x_cubes: u32,
39 y_cubes: u32,
40 },
41 Spread {
42 x_cubes: u32,
43 y_cubes: u32,
44 z_cubes: u32,
45 },
46}
47
48#[cube]
49impl CubeMapping {
50 pub fn num_valid_cubes(&self) -> usize {
52 match &self.strategy {
53 CubeMappingStrategy::FromProblem | CubeMappingStrategy::Flattened { .. } => {
54 panic!("Shouldn't need to be called because the cube count should always be exact")
55 }
56 CubeMappingStrategy::SmFirst {
57 x_cubes,
58 y_cubes,
59 z_cubes,
60 }
61 | CubeMappingStrategy::CubeFirst {
62 x_cubes,
63 y_cubes,
64 z_cubes,
65 }
66 | CubeMappingStrategy::Spread {
67 x_cubes,
68 y_cubes,
69 z_cubes,
70 } => *x_cubes as usize * *y_cubes as usize * *z_cubes as usize,
71 }
72 }
73
74 pub fn cube_pos_to_xyz(&self) -> (u32, u32, u32) {
78 match &self.strategy {
79 CubeMappingStrategy::FromProblem => (CUBE_POS_X, CUBE_POS_Y, CUBE_POS_Z),
80
81 CubeMappingStrategy::SmFirst {
82 x_cubes, y_cubes, ..
83 } => {
84 self.strategy
85 .absolute_index_to_xyz(CUBE_POS, *x_cubes, *y_cubes, self.global_order)
86 }
87
88 CubeMappingStrategy::CubeFirst {
89 x_cubes, y_cubes, ..
90 } => self.strategy.absolute_index_to_xyz(
91 CUBE_POS_Y as usize * CUBE_COUNT_X as usize + CUBE_POS_X as usize,
92 *x_cubes,
93 *y_cubes,
94 self.global_order,
95 ),
96
97 CubeMappingStrategy::Flattened { x_cubes, y_cubes } => self
98 .strategy
99 .absolute_index_to_xyz(CUBE_POS_X as usize, *x_cubes, *y_cubes, self.global_order),
100
101 CubeMappingStrategy::Spread {
102 x_cubes, y_cubes, ..
103 } => {
104 self.strategy
105 .absolute_index_to_xyz(CUBE_POS, *x_cubes, *y_cubes, self.global_order)
106 }
107 }
108 }
109}
110
111#[cube]
112impl CubeMappingStrategy {
113 fn absolute_index_to_xyz(
114 &self,
115 absolute_index: usize,
116 x_cubes: u32,
117 y_cubes: u32,
118 #[comptime] global_order: GlobalOrder,
119 ) -> (u32, u32, u32) {
120 let z_stride = (x_cubes * y_cubes) as usize;
121 let z_pos = absolute_index / z_stride;
122 let xy_pos = absolute_index % z_stride;
123
124 let (x_pos, y_pos) = match comptime!(global_order) {
125 GlobalOrder::RowMajor => ((xy_pos / y_cubes as usize) as u32, xy_pos as u32 % y_cubes),
126 GlobalOrder::ColMajor => (xy_pos as u32 % x_cubes, (xy_pos / x_cubes as usize) as u32),
127 GlobalOrder::SwizzleRow(w) => {
128 let (x, y) = swizzle(xy_pos, y_cubes as usize, w);
129 (y, x)
130 }
131 GlobalOrder::SwizzleCol(w) => swizzle(xy_pos, x_cubes as usize, w),
132 };
133
134 (x_pos, y_pos, z_pos as u32)
135 }
136}
137
138pub fn cube_mapping_launch<R: Runtime>(cube_count_plan: &CubeCountPlan) -> CubeMappingLaunch<R> {
140 CubeMappingLaunch::new(
141 mapping_strategy(&cube_count_plan.kind),
142 cube_count_plan.kind.can_yield_extra_cubes(),
143 cube_count_plan.global_order,
144 )
145}
146
147fn mapping_strategy<R: Runtime>(
148 cube_count_plan_kind: &CubeCountPlanKind,
149) -> CubeMappingStrategyArgs<R> {
150 match cube_count_plan_kind {
151 CubeCountPlanKind::FromProblem { .. } => CubeMappingStrategyArgs::FromProblem,
152
153 CubeCountPlanKind::Sm {
154 cubes_first,
155 problem_count,
156 ..
157 } => {
158 if *cubes_first {
159 CubeMappingStrategyArgs::CubeFirst {
160 x_cubes: problem_count.x,
161 y_cubes: problem_count.y,
162 z_cubes: problem_count.z,
163 }
164 } else {
165 CubeMappingStrategyArgs::SmFirst {
166 x_cubes: problem_count.x,
167 y_cubes: problem_count.y,
168 z_cubes: problem_count.z,
169 }
170 }
171 }
172
173 CubeCountPlanKind::Flattened { problem_count, .. } => CubeMappingStrategyArgs::Flattened {
174 x_cubes: problem_count.x,
175 y_cubes: problem_count.y,
176 },
177
178 CubeCountPlanKind::Spread { problem_count, .. } => CubeMappingStrategyArgs::Spread {
179 x_cubes: problem_count.x,
180 y_cubes: problem_count.y,
181 z_cubes: problem_count.z,
182 },
183 }
184}