Skip to main content

cubek_std/cube_count/hypercube/
global_order.rs

1use cubecl::{prelude::*, std::tensor::layout::Coords2d};
2
3#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
4/// Describes the global traversal order as flattened cube position increases.
5///
6/// - `RowMajor`: standard row-first traversal
7/// - `ColMajor`: standard column-first traversal
8/// - `SwizzleCol(w)`: zigzag pattern down columns, with `w`-wide steps
9/// - `SwizzleRow(w)`: zigzag pattern across rows, with `w`-wide steps
10///
11/// Special cases:
12/// - `SwizzleCol(1)` is equivalent to `ColMajor`
13/// - `SwizzleRow(1)` is equivalent to `RowMajor`
14///
15/// Swizzle modes may fail if their `w` does not divide the problem well.
16#[allow(clippy::enum_variant_names)]
17pub enum GlobalOrder {
18    #[default]
19    RowMajor,
20    ColMajor,
21    SwizzleRow(u32),
22    SwizzleCol(u32),
23}
24
25impl GlobalOrder {
26    /// Since they are equivalent but the latter form will skip some calculations,
27    /// - `SwizzleColMajor(1)` becomes `ColMajor`
28    /// - `SwizzleRowMajor(1)` becomes `RowMajor`
29    pub fn canonicalize(self) -> Self {
30        match self {
31            GlobalOrder::SwizzleCol(1) => GlobalOrder::ColMajor,
32            GlobalOrder::SwizzleRow(1) => GlobalOrder::RowMajor,
33            _ => self,
34        }
35    }
36}
37
38#[cube]
39/// Maps a linear `index` to 2D zigzag coordinates `(x, y)` within horizontal or vertical strips.
40///
41/// Each strip is made of `num_steps` steps, each of length `step_length`.
42/// Strips alternate direction: even strips go top-down, odd strips bottom-up.
43/// Steps alternate direction: even steps go left-to-right, odd steps right-to-left.
44///
45/// - Prefer **odd `num_steps`** for smoother transitions between strips.
46/// - Prefer **power-of-two `step_length`** for better performance.
47///
48/// # Parameters
49/// - `index`: linear input index
50/// - `num_steps`: number of snaking steps in a strip
51/// - `step_length`: number of elements in each step (must be > 0)
52///
53/// # Returns
54/// `(x, y)` coordinates after swizzling
55pub fn swizzle(index: usize, num_steps: usize, #[comptime] step_length: u32) -> Coords2d {
56    comptime!(assert!(step_length > 0));
57
58    let num_elements_per_strip = num_steps * step_length as usize;
59    let strip_index = (index / num_elements_per_strip) as u32;
60    let pos_in_strip = (index % num_elements_per_strip) as u32;
61    let strip_offset = step_length * strip_index;
62
63    // Indices without regards to direction
64    let abs_step_index = pos_in_strip / step_length;
65    let abs_pos_in_step = pos_in_strip % step_length;
66
67    // Top-down (0) or Bottom-up (1)
68    let strip_direction = strip_index % 2;
69    // Left-right (0) or Right-left (1)
70    let step_direction = abs_step_index % 2;
71
72    // Update indices with direction
73    let step_index = strip_direction * (num_steps as u32 - abs_step_index - 1)
74        + (1 - strip_direction) * abs_step_index;
75
76    let pos_in_step = if step_length & (step_length - 1) == 0 {
77        abs_pos_in_step ^ (step_direction * (step_length - 1))
78    } else {
79        step_direction * (step_length - abs_pos_in_step - 1)
80            + (1 - step_direction) * abs_pos_in_step
81    };
82
83    (step_index, pos_in_step + strip_offset)
84}