cubek_std/cube_count/hypercube/global_order.rs
1use cubecl::{prelude::*, std::tensor::layout::Coords2d};
2
3#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
4/// Describes the global traversal order as flattened cube position increases.
5///
6/// - `RowMajor`: standard row-first traversal
7/// - `ColMajor`: standard column-first traversal
8/// - `SwizzleCol(w)`: zigzag pattern down columns, with `w`-wide steps
9/// - `SwizzleRow(w)`: zigzag pattern across rows, with `w`-wide steps
10///
11/// Special cases:
12/// - `SwizzleCol(1)` is equivalent to `ColMajor`
13/// - `SwizzleRow(1)` is equivalent to `RowMajor`
14///
15/// Swizzle modes may fail if their `w` does not divide the problem well.
16#[allow(clippy::enum_variant_names)]
17pub enum GlobalOrder {
18 #[default]
19 RowMajor,
20 ColMajor,
21 SwizzleRow(u32),
22 SwizzleCol(u32),
23}
24
25impl GlobalOrder {
26 /// Since they are equivalent but the latter form will skip some calculations,
27 /// - `SwizzleColMajor(1)` becomes `ColMajor`
28 /// - `SwizzleRowMajor(1)` becomes `RowMajor`
29 pub fn canonicalize(self) -> Self {
30 match self {
31 GlobalOrder::SwizzleCol(1) => GlobalOrder::ColMajor,
32 GlobalOrder::SwizzleRow(1) => GlobalOrder::RowMajor,
33 _ => self,
34 }
35 }
36}
37
38#[cube]
39/// Maps a linear `index` to 2D zigzag coordinates `(x, y)` within horizontal or vertical strips.
40///
41/// Each strip is made of `num_steps` steps, each of length `step_length`.
42/// Strips alternate direction: even strips go top-down, odd strips bottom-up.
43/// Steps alternate direction: even steps go left-to-right, odd steps right-to-left.
44///
45/// - Prefer **odd `num_steps`** for smoother transitions between strips.
46/// - Prefer **power-of-two `step_length`** for better performance.
47///
48/// # Parameters
49/// - `index`: linear input index
50/// - `num_steps`: number of snaking steps in a strip
51/// - `step_length`: number of elements in each step (must be > 0)
52///
53/// # Returns
54/// `(x, y)` coordinates after swizzling
55pub fn swizzle(index: usize, num_steps: usize, #[comptime] step_length: u32) -> Coords2d {
56 comptime!(assert!(step_length > 0));
57
58 let num_elements_per_strip = num_steps * step_length as usize;
59 let strip_index = (index / num_elements_per_strip) as u32;
60 let pos_in_strip = (index % num_elements_per_strip) as u32;
61 let strip_offset = step_length * strip_index;
62
63 // Indices without regards to direction
64 let abs_step_index = pos_in_strip / step_length;
65 let abs_pos_in_step = pos_in_strip % step_length;
66
67 // Top-down (0) or Bottom-up (1)
68 let strip_direction = strip_index % 2;
69 // Left-right (0) or Right-left (1)
70 let step_direction = abs_step_index % 2;
71
72 // Update indices with direction
73 let step_index = strip_direction * (num_steps as u32 - abs_step_index - 1)
74 + (1 - strip_direction) * abs_step_index;
75
76 let pos_in_step = if step_length & (step_length - 1) == 0 {
77 abs_pos_in_step ^ (step_direction * (step_length - 1))
78 } else {
79 step_direction * (step_length - abs_pos_in_step - 1)
80 + (1 - step_direction) * abs_pos_in_step
81 };
82
83 (step_index, pos_in_step + strip_offset)
84}