use cubecl_core::intrinsic;
use cubecl_core::ir::{IndexAssignOperator, IndexOperator, Instruction, Operator};
use cubecl_core::{self as cubecl, prelude::*};
#[cube]
pub trait UnalignedVector<E: Scalar, N: Size>: CubeType + Sized {
fn unaligned_vector_read(&self, index: usize) -> Vector<E, N>;
fn unaligned_vector_write(&mut self, index: usize, value: Vector<E, N>);
}
macro_rules! impl_unaligned_vector {
($type:ident) => {
paste::paste! {
type [<$type Expand>]<E> = NativeExpand<$type<E>>;
}
#[cube]
impl<E: Scalar, N: Size> UnalignedVector<E, N> for $type<E> {
fn unaligned_vector_read(&self, index: usize) -> Vector<E, N> {
unaligned_vector_read::<$type<E>, E, N>(self, index)
}
fn unaligned_vector_write(&mut self, index: usize, value: Vector<E, N>) {
unaligned_vector_write::<$type<E>, E, N>(self, index, value)
}
}
};
}
impl_unaligned_vector!(Array);
impl_unaligned_vector!(Tensor);
impl_unaligned_vector!(SharedMemory);
#[cube]
#[allow(unused_variables)]
fn unaligned_vector_read<T: CubeType<ExpandType = NativeExpand<T>>, E: Scalar, N: Size>(
this: &T,
index: usize,
) -> Vector<E, N> {
intrinsic!(|scope| {
if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) {
todo!("Unaligned reads are only allowed on scalar arrays for now");
}
let vector_size = N::__expand_value(scope);
let out = scope.create_local(this.expand.ty.with_vector_size(vector_size));
scope.register(Instruction::new(
Operator::UncheckedIndex(IndexOperator {
list: *this.expand,
index: index.expand.consume(),
vector_size: 0,
unroll_factor: 1,
}),
*out,
));
out.into()
})
}
#[cube]
#[allow(unused_variables)]
fn unaligned_vector_write<T: CubeType<ExpandType = NativeExpand<T>>, E: Scalar, N: Size>(
this: &mut T,
index: usize,
value: Vector<E, N>,
) {
intrinsic!(|scope| {
if !matches!(this.expand.ty, cubecl::ir::Type::Scalar(_)) {
todo!("Unaligned reads are only allowed on scalar arrays for now");
}
let vector_size = N::__expand_value(scope);
scope.register(Instruction::new(
Operator::UncheckedIndexAssign(IndexAssignOperator {
index: index.expand.consume(),
value: value.expand.consume(),
vector_size: 0,
unroll_factor: 1,
}),
*this.expand,
));
})
}