cubek-std 0.2.0-pre.5

CubeK: Standard Library
Documentation
use cubecl::{prelude::*, std::tensor::layout::Coords2d};

#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
/// Describes the global traversal order as flattened cube position increases.
///
/// - `RowMajor`: standard row-first traversal
/// - `ColMajor`: standard column-first traversal
/// - `SwizzleCol(w)`: zigzag pattern down columns, with `w`-wide steps
/// - `SwizzleRow(w)`: zigzag pattern across rows, with `w`-wide steps
///
/// Special cases:
/// - `SwizzleCol(1)` is equivalent to `ColMajor`
/// - `SwizzleRow(1)` is equivalent to `RowMajor`
///
/// Swizzle modes may fail if their `w` does not divide the problem well.
#[allow(clippy::enum_variant_names)]
pub enum GlobalOrder {
    #[default]
    RowMajor,
    ColMajor,
    SwizzleRow(u32),
    SwizzleCol(u32),
}

impl GlobalOrder {
    /// Since they are equivalent but the latter form will skip some calculations,
    /// - `SwizzleColMajor(1)` becomes `ColMajor`
    /// - `SwizzleRowMajor(1)` becomes `RowMajor`
    pub fn canonicalize(self) -> Self {
        match self {
            GlobalOrder::SwizzleCol(1) => GlobalOrder::ColMajor,
            GlobalOrder::SwizzleRow(1) => GlobalOrder::RowMajor,
            _ => self,
        }
    }
}

#[cube]
/// Maps a linear `index` to 2D zigzag coordinates `(x, y)` within horizontal or vertical strips.
///
/// Each strip is made of `num_steps` steps, each of length `step_length`.
/// Strips alternate direction: even strips go top-down, odd strips bottom-up.
/// Steps alternate direction: even steps go left-to-right, odd steps right-to-left.
///
/// - Prefer **odd `num_steps`** for smoother transitions between strips.
/// - Prefer **power-of-two `step_length`** for better performance.
///
/// # Parameters
/// - `index`: linear input index
/// - `num_steps`: number of snaking steps in a strip
/// - `step_length`: number of elements in each step (must be > 0)
///
/// # Returns
/// `(x, y)` coordinates after swizzling
pub fn swizzle(index: usize, num_steps: usize, #[comptime] step_length: u32) -> Coords2d {
    comptime!(assert!(step_length > 0));

    let num_elements_per_strip = num_steps * step_length as usize;
    let strip_index = (index / num_elements_per_strip) as u32;
    let pos_in_strip = (index % num_elements_per_strip) as u32;
    let strip_offset = step_length * strip_index;

    // Indices without regards to direction
    let abs_step_index = pos_in_strip / step_length;
    let abs_pos_in_step = pos_in_strip % step_length;

    // Top-down (0) or Bottom-up (1)
    let strip_direction = strip_index % 2;
    // Left-right (0) or Right-left (1)
    let step_direction = abs_step_index % 2;

    // Update indices with direction
    let step_index = strip_direction * (num_steps as u32 - abs_step_index - 1)
        + (1 - strip_direction) * abs_step_index;

    let pos_in_step = if step_length & (step_length - 1) == 0 {
        abs_pos_in_step ^ (step_direction * (step_length - 1))
    } else {
        step_direction * (step_length - abs_pos_in_step - 1)
            + (1 - step_direction) * abs_pos_in_step
    };

    (step_index, pos_in_step + strip_offset)
}