use core::marker::PhantomData;
use cubecl::prelude::*;
use cubecl_core::{self as cubecl, ir::VectorSize, unexpanded};
#[derive(CubeType)]
pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
slice: Slice<S>,
#[cube(comptime)]
vector_size: VectorSize,
#[cube(comptime)]
load_many: Option<usize>,
#[cube(comptime)]
_phantom: PhantomData<T>,
}
#[cube]
impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
pub fn new(slice: Slice<S>) -> ReinterpretSlice<S, T> {
let in_vector_size = slice.vector_size();
let source_size = S::Scalar::type_size();
let target_size = T::Scalar::type_size();
let (optimized_vector_size, load_many) = comptime!(optimize_vector_size(
source_size,
in_vector_size,
target_size
));
match comptime!(optimized_vector_size) {
Some(vector_size) => {
let size!(N2) = vector_size;
let slice = slice.into_vectorized().with_vector_size::<N2>();
ReinterpretSlice::<S, T> {
slice: unsafe { slice.downcast_unchecked() },
vector_size,
load_many,
_phantom: PhantomData,
}
}
None => ReinterpretSlice::<S, T> {
slice,
vector_size: in_vector_size,
load_many,
_phantom: PhantomData,
},
}
}
pub fn read(&self, index: usize) -> T {
let size!(N) = self.vector_size;
let slice = self.slice.into_vectorized().with_vector_size::<N>();
match comptime!(self.load_many) {
Some(amount) => {
let first = index * amount;
let size!(N2) = comptime!(amount * self.vector_size);
let mut vector = Vector::<S::Scalar, N2>::empty();
#[unroll]
for k in 0..amount {
let elem = slice[first + k];
#[unroll]
for j in 0..self.vector_size {
vector[k * self.vector_size + j] = elem[j];
}
}
T::reinterpret(vector)
}
None => T::reinterpret(slice[index]),
}
}
}
#[derive(CubeType)]
pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
slice: SliceMut<S>,
#[cube(comptime)]
vector_size: VectorSize,
#[cube(comptime)]
load_many: Option<usize>,
#[cube(comptime)]
_phantom: PhantomData<T>,
}
#[cube]
impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
pub fn new(slice: SliceMut<S>) -> ReinterpretSliceMut<S, T> {
let in_vector_size = slice.vector_size();
let source_size = S::Scalar::type_size();
let target_size = T::Scalar::type_size();
let (optimized_vector_size, load_many) = comptime!(optimize_vector_size(
source_size,
in_vector_size,
target_size
));
match comptime!(optimized_vector_size) {
Some(vector_size) => {
let size!(N2) = vector_size;
let slice = slice.into_vectorized().with_vector_size::<N2>();
ReinterpretSliceMut::<S, T> {
slice: unsafe { slice.downcast_unchecked() },
vector_size,
load_many,
_phantom: PhantomData,
}
}
None => ReinterpretSliceMut::<S, T> {
slice,
vector_size: in_vector_size,
load_many,
_phantom: PhantomData,
},
}
}
pub fn read(&self, index: usize) -> T {
let size!(N) = self.vector_size;
let slice = self.slice.into_vectorized().with_vector_size::<N>();
match comptime!(self.load_many) {
Some(amount) => {
let first = index * amount;
let size!(N2) = comptime!(amount * self.vector_size);
let mut vector = Vector::<S::Scalar, N2>::empty();
#[unroll]
for k in 0..amount {
let elem = slice[first + k];
#[unroll]
for j in 0..self.vector_size {
vector[k * self.vector_size + j] = elem[j];
}
}
T::reinterpret(vector)
}
None => T::reinterpret(slice[index]),
}
}
pub fn write(&mut self, index: usize, value: T) {
let size!(N) = self.vector_size;
let mut slice = self.slice.into_vectorized().with_vector_size::<N>();
let size!(N1) = S::reinterpret_vectorization::<T>();
let reinterpreted = Vector::<S::Scalar, N1>::reinterpret(value);
match comptime!(self.load_many) {
Some(amount) => {
let first = index * amount;
let vector_size = comptime!(reinterpreted.size() / amount);
#[unroll]
for k in 0..amount {
let mut vector = Vector::empty();
#[unroll]
for j in 0..vector_size {
vector[j] = reinterpreted[k * vector_size + j];
}
slice[first + k] = vector;
}
}
None => slice[index] = Vector::cast_from(reinterpreted),
}
}
}
fn optimize_vector_size(
source_size: usize,
vector_size: VectorSize,
target_size: usize,
) -> (Option<usize>, Option<usize>) {
let vector_source_size = source_size * vector_size;
match vector_source_size.cmp(&target_size) {
core::cmp::Ordering::Less => {
if !target_size.is_multiple_of(vector_source_size) {
panic!("incompatible number of bytes");
}
let ratio = target_size / vector_source_size;
(None, Some(ratio))
}
core::cmp::Ordering::Greater => {
if !vector_source_size.is_multiple_of(target_size) {
panic!("incompatible number of bytes");
}
let ratio = vector_source_size / target_size;
(Some(vector_size / ratio), None)
}
core::cmp::Ordering::Equal => (None, None),
}
}
pub fn size_of<S: CubePrimitive>() -> u32 {
unexpanded!()
}
pub mod size_of {
use super::*;
#[allow(unused, clippy::all)]
pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
S::as_type(context).size() as u32
}
}