use cubecl::prelude::*;
use cubecl_core::{self as cubecl};
#[derive(CubeType, CubeLaunch, Clone, Copy)]
pub struct Swizzle {
#[cube(comptime)]
yyy_mask: u32,
#[cube(comptime)]
shift: u32,
#[cube(comptime)]
invert_shift: bool,
#[cube(comptime)]
repeats_after: u32,
}
#[cube]
impl Swizzle {
pub fn new(#[comptime] bits: u32, #[comptime] base: u32, #[comptime] shift: i32) -> Self {
let invert_shift = shift < 0;
let mask = (1u32 << bits) - 1;
let yyy_mask = comptime![mask << (base + Ord::max(shift, 0) as u32)];
let repeats_after = comptime![if bits > 0 {
1u32 << (base + bits + Ord::max(shift, 0) as u32)
} else {
1u32 << base
}];
Swizzle {
yyy_mask,
shift: comptime![shift.unsigned_abs()],
invert_shift,
repeats_after,
}
}
pub fn none() -> Self {
Swizzle {
yyy_mask: 0u32,
shift: 0u32,
invert_shift: false,
repeats_after: 1u32,
}
}
pub fn apply(&self, offset: u32, #[comptime] type_size: usize) -> u32 {
if comptime![self.yyy_mask == 0] {
offset
} else {
let offset_bytes = offset * type_size as u32;
let offset_masked = offset_bytes & self.yyy_mask;
let offset_shifted =
shift_right(offset_masked, self.shift, comptime![self.invert_shift]);
let offset_bytes = offset_bytes ^ offset_shifted;
offset_bytes / type_size as u32
}
}
pub fn repeats_after(&self) -> comptime_type!(u32) {
self.repeats_after
}
}
#[cube]
fn shift_right(value: u32, shift: u32, #[comptime] invert: bool) -> u32 {
if invert {
value << shift
} else {
value >> shift
}
}