use cudarc::driver::CudaSlice;
use crate::device::KaioDevice;
use crate::error::Result;
#[repr(transparent)]
pub struct GpuBuffer<T> {
inner: CudaSlice<T>,
}
impl<T> GpuBuffer<T> {
pub fn from_cuda_slice(inner: CudaSlice<T>) -> Self {
Self { inner }
}
pub fn into_cuda_slice(self) -> CudaSlice<T> {
self.inner
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.len() == 0
}
pub fn inner(&self) -> &CudaSlice<T> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut CudaSlice<T> {
&mut self.inner
}
}
impl<T: cudarc::driver::DeviceRepr + Default + Clone + Unpin> GpuBuffer<T> {
pub fn to_host(&self, device: &KaioDevice) -> Result<Vec<T>> {
Ok(device.stream().clone_dtoh(&self.inner)?)
}
}
#[cfg(test)]
mod repr_soundness {
use super::GpuBuffer;
use cudarc::driver::CudaSlice;
use half::f16;
use static_assertions::{assert_eq_align, assert_eq_size};
assert_eq_size!(GpuBuffer<f32>, CudaSlice<f32>);
assert_eq_align!(GpuBuffer<f32>, CudaSlice<f32>);
assert_eq_size!(GpuBuffer<f16>, CudaSlice<f16>);
assert_eq_align!(GpuBuffer<f16>, CudaSlice<f16>);
assert_eq_size!(GpuBuffer<i8>, CudaSlice<i8>);
assert_eq_align!(GpuBuffer<i8>, CudaSlice<i8>);
assert_eq_size!(GpuBuffer<u32>, CudaSlice<u32>);
assert_eq_align!(GpuBuffer<u32>, CudaSlice<u32>);
}