Skip to main content

cubek_std/cube_count/hypercube/
cube_mapping.rs

1use cubecl::prelude::*;
2
3use crate::cube_count::{CubeCountPlan, CubeCountPlanKind, GlobalOrder, swizzle};
4
5#[derive(CubeType, CubeLaunch)]
6/// Runtime-side counterpart of [CubeCountPlan]: given the cube position,
7/// resolves the conceptual `(x, y, z)` coordinates in problem space.
8///
9/// Each operation is responsible for mapping the returned generic `(x, y, z)`
10/// tuple to its own domain axes (e.g. matmul interprets them as `(m, n, batch)`,
11/// gemv as `(matrix_axis, _, batch)`, attention as `(seq_q, batch_heads, _)`).
12pub struct CubeMapping {
13    strategy: CubeMappingStrategy,
14    #[cube(comptime)]
15    pub can_yield_extra_cubes: bool,
16    #[cube(comptime)]
17    global_order: GlobalOrder,
18}
19
20#[derive(CubeType, CubeLaunch)]
21/// [CubeCountPlanKind] stripped of non-essential runtime information.
22///
23/// Given as runtime input to kernels.
24#[allow(unused)] // Constructed via CubeMappingStrategyArgs only
25pub(crate) enum CubeMappingStrategy {
26    FromProblem,
27    SmFirst {
28        x_cubes: u32,
29        y_cubes: u32,
30        z_cubes: u32,
31    },
32    CubeFirst {
33        x_cubes: u32,
34        y_cubes: u32,
35        z_cubes: u32,
36    },
37    Flattened {
38        x_cubes: u32,
39        y_cubes: u32,
40    },
41    Spread {
42        x_cubes: u32,
43        y_cubes: u32,
44        z_cubes: u32,
45    },
46}
47
48#[cube]
49impl CubeMapping {
50    /// Returns the number of valid cubes (problem-space volume).
51    pub fn num_valid_cubes(&self) -> usize {
52        match &self.strategy {
53            CubeMappingStrategy::FromProblem | CubeMappingStrategy::Flattened { .. } => {
54                panic!("Shouldn't need to be called because the cube count should always be exact")
55            }
56            CubeMappingStrategy::SmFirst {
57                x_cubes,
58                y_cubes,
59                z_cubes,
60            }
61            | CubeMappingStrategy::CubeFirst {
62                x_cubes,
63                y_cubes,
64                z_cubes,
65            }
66            | CubeMappingStrategy::Spread {
67                x_cubes,
68                y_cubes,
69                z_cubes,
70            } => *x_cubes as usize * *y_cubes as usize * *z_cubes as usize,
71        }
72    }
73
74    /// Given a cube position, returns the generic problem-space coordinates `(x, y, z)`.
75    ///
76    /// Consumers assign meaning to `x/y/z` (matmul: `m/n/batch`, gemv: `matrix/_/batch`, etc.).
77    pub fn cube_pos_to_xyz(&self) -> (u32, u32, u32) {
78        match &self.strategy {
79            CubeMappingStrategy::FromProblem => (CUBE_POS_X, CUBE_POS_Y, CUBE_POS_Z),
80
81            CubeMappingStrategy::SmFirst {
82                x_cubes, y_cubes, ..
83            } => {
84                self.strategy
85                    .absolute_index_to_xyz(CUBE_POS, *x_cubes, *y_cubes, self.global_order)
86            }
87
88            CubeMappingStrategy::CubeFirst {
89                x_cubes, y_cubes, ..
90            } => self.strategy.absolute_index_to_xyz(
91                CUBE_POS_Y as usize * CUBE_COUNT_X as usize + CUBE_POS_X as usize,
92                *x_cubes,
93                *y_cubes,
94                self.global_order,
95            ),
96
97            CubeMappingStrategy::Flattened { x_cubes, y_cubes } => self
98                .strategy
99                .absolute_index_to_xyz(CUBE_POS_X as usize, *x_cubes, *y_cubes, self.global_order),
100
101            CubeMappingStrategy::Spread {
102                x_cubes, y_cubes, ..
103            } => {
104                self.strategy
105                    .absolute_index_to_xyz(CUBE_POS, *x_cubes, *y_cubes, self.global_order)
106            }
107        }
108    }
109}
110
111#[cube]
112impl CubeMappingStrategy {
113    fn absolute_index_to_xyz(
114        &self,
115        absolute_index: usize,
116        x_cubes: u32,
117        y_cubes: u32,
118        #[comptime] global_order: GlobalOrder,
119    ) -> (u32, u32, u32) {
120        let z_stride = (x_cubes * y_cubes) as usize;
121        let z_pos = absolute_index / z_stride;
122        let xy_pos = absolute_index % z_stride;
123
124        let (x_pos, y_pos) = match comptime!(global_order) {
125            GlobalOrder::RowMajor => ((xy_pos / y_cubes as usize) as u32, xy_pos as u32 % y_cubes),
126            GlobalOrder::ColMajor => (xy_pos as u32 % x_cubes, (xy_pos / x_cubes as usize) as u32),
127            GlobalOrder::SwizzleRow(w) => {
128                let (x, y) = swizzle(xy_pos, y_cubes as usize, w);
129                (y, x)
130            }
131            GlobalOrder::SwizzleCol(w) => swizzle(xy_pos, x_cubes as usize, w),
132        };
133
134        (x_pos, y_pos, z_pos as u32)
135    }
136}
137
138/// Build a [CubeMappingLaunch] from a resolved [CubeCountPlan].
139pub fn cube_mapping_launch<R: Runtime>(cube_count_plan: &CubeCountPlan) -> CubeMappingLaunch<R> {
140    CubeMappingLaunch::new(
141        mapping_strategy(&cube_count_plan.kind),
142        cube_count_plan.kind.can_yield_extra_cubes(),
143        cube_count_plan.global_order,
144    )
145}
146
147fn mapping_strategy<R: Runtime>(
148    cube_count_plan_kind: &CubeCountPlanKind,
149) -> CubeMappingStrategyArgs<R> {
150    match cube_count_plan_kind {
151        CubeCountPlanKind::FromProblem { .. } => CubeMappingStrategyArgs::FromProblem,
152
153        CubeCountPlanKind::Sm {
154            cubes_first,
155            problem_count,
156            ..
157        } => {
158            if *cubes_first {
159                CubeMappingStrategyArgs::CubeFirst {
160                    x_cubes: problem_count.x,
161                    y_cubes: problem_count.y,
162                    z_cubes: problem_count.z,
163                }
164            } else {
165                CubeMappingStrategyArgs::SmFirst {
166                    x_cubes: problem_count.x,
167                    y_cubes: problem_count.y,
168                    z_cubes: problem_count.z,
169                }
170            }
171        }
172
173        CubeCountPlanKind::Flattened { problem_count, .. } => CubeMappingStrategyArgs::Flattened {
174            x_cubes: problem_count.x,
175            y_cubes: problem_count.y,
176        },
177
178        CubeCountPlanKind::Spread { problem_count, .. } => CubeMappingStrategyArgs::Spread {
179            x_cubes: problem_count.x,
180            y_cubes: problem_count.y,
181            z_cubes: problem_count.z,
182        },
183    }
184}