use core::mem;
#[spirv(buffer_load_intrinsic)]
#[inline(never)]
#[spirv_std_macros::gpu_only]
unsafe fn buffer_load_intrinsic<T>(
buffer: &[u32],
offset: u32,
) -> T {
unsafe {
buffer
.as_ptr()
.cast::<u8>()
.add(offset as usize)
.cast::<T>()
.read()
}
}
#[spirv(buffer_store_intrinsic)]
#[inline(never)]
#[spirv_std_macros::gpu_only]
unsafe fn buffer_store_intrinsic<T>(
buffer: &mut [u32],
offset: u32,
value: T,
) {
unsafe {
buffer
.as_mut_ptr()
.cast::<u8>()
.add(offset as usize)
.cast::<T>()
.write(value);
}
}
#[repr(transparent)]
pub struct ByteAddressableBuffer<T> {
pub data: T,
}
fn bounds_check<T>(data: &[u32], byte_index: u32) {
let sizeof = mem::size_of::<T>() as u32;
if !byte_index.is_multiple_of(4) {
panic!("`byte_index` should be a multiple of 4");
}
let last_byte = byte_index + sizeof;
let len = data.len() as u32 * 4;
if byte_index + sizeof > len {
panic!(
"index out of bounds: the len is {} but loading {} bytes at `byte_index` {} reads until {} (exclusive)",
len, sizeof, byte_index, last_byte,
);
}
}
impl<'a> ByteAddressableBuffer<&'a [u32]> {
#[inline]
pub fn from_slice(data: &'a [u32]) -> Self {
Self { data }
}
pub unsafe fn load<T>(&self, byte_index: u32) -> T {
bounds_check::<T>(self.data, byte_index);
unsafe { buffer_load_intrinsic(self.data, byte_index) }
}
pub unsafe fn load_unchecked<T>(&self, byte_index: u32) -> T {
unsafe { buffer_load_intrinsic(self.data, byte_index) }
}
}
impl<'a> ByteAddressableBuffer<&'a mut [u32]> {
#[inline]
pub fn from_mut_slice(data: &'a mut [u32]) -> Self {
Self { data }
}
#[inline]
pub fn as_ref(&self) -> ByteAddressableBuffer<&[u32]> {
ByteAddressableBuffer { data: self.data }
}
#[inline]
pub unsafe fn load<T>(&self, byte_index: u32) -> T {
unsafe { self.as_ref().load(byte_index) }
}
#[inline]
pub unsafe fn load_unchecked<T>(&self, byte_index: u32) -> T {
unsafe { self.as_ref().load_unchecked(byte_index) }
}
pub unsafe fn store<T>(&mut self, byte_index: u32, value: T) {
bounds_check::<T>(self.data, byte_index);
unsafe {
buffer_store_intrinsic(self.data, byte_index, value);
}
}
pub unsafe fn store_unchecked<T>(&mut self, byte_index: u32, value: T) {
unsafe {
buffer_store_intrinsic(self.data, byte_index, value);
}
}
}