use cubecl::{prelude::*, std::tensor::layout::Coords2d};
#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
#[allow(clippy::enum_variant_names)]
pub enum GlobalOrder {
#[default]
RowMajor,
ColMajor,
SwizzleRow(u32),
SwizzleCol(u32),
}
impl GlobalOrder {
pub fn canonicalize(self) -> Self {
match self {
GlobalOrder::SwizzleCol(1) => GlobalOrder::ColMajor,
GlobalOrder::SwizzleRow(1) => GlobalOrder::RowMajor,
_ => self,
}
}
}
#[cube]
pub fn swizzle(index: usize, num_steps: usize, #[comptime] step_length: u32) -> Coords2d {
comptime!(assert!(step_length > 0));
let num_elements_per_strip = num_steps * step_length as usize;
let strip_index = (index / num_elements_per_strip) as u32;
let pos_in_strip = (index % num_elements_per_strip) as u32;
let strip_offset = step_length * strip_index;
let abs_step_index = pos_in_strip / step_length;
let abs_pos_in_step = pos_in_strip % step_length;
let strip_direction = strip_index % 2;
let step_direction = abs_step_index % 2;
let step_index = strip_direction * (num_steps as u32 - abs_step_index - 1)
+ (1 - strip_direction) * abs_step_index;
let pos_in_step = if step_length & (step_length - 1) == 0 {
abs_pos_in_step ^ (step_direction * (step_length - 1))
} else {
step_direction * (step_length - abs_pos_in_step - 1)
+ (1 - step_direction) * abs_pos_in_step
};
(step_index, pos_in_step + strip_offset)
}