use crate::backend::dtype::Dtype;
use half::f16;
pub enum CpuBuf {
F32(Vec<f32>),
F16(Vec<f16>),
U32(Vec<u32>),
I32(Vec<i32>),
I8(Vec<i8>),
}
impl CpuBuf {
pub fn alloc(dtype: Dtype, n: usize) -> Self {
match dtype {
Dtype::F32 => CpuBuf::F32(vec![0.0; n]),
Dtype::F16 => CpuBuf::F16(vec![f16::ZERO; n]),
Dtype::U32 => CpuBuf::U32(vec![0u32; n]),
Dtype::I32 => CpuBuf::I32(vec![0i32; n]),
Dtype::I8 => CpuBuf::I8(vec![0i8; n]),
}
}
pub fn dtype(&self) -> Dtype {
match self {
CpuBuf::F32(_) => Dtype::F32,
CpuBuf::F16(_) => Dtype::F16,
CpuBuf::U32(_) => Dtype::U32,
CpuBuf::I32(_) => Dtype::I32,
CpuBuf::I8(_) => Dtype::I8,
}
}
pub fn len(&self) -> usize {
match self {
CpuBuf::F32(v) => v.len(),
CpuBuf::F16(v) => v.len(),
CpuBuf::U32(v) => v.len(),
CpuBuf::I32(v) => v.len(),
CpuBuf::I8(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn as_f32(&self) -> &[f32] {
match self {
CpuBuf::F32(v) => v,
_ => panic!("CpuBuf::as_f32 on dtype {}", self.dtype().name()),
}
}
pub fn as_f32_mut(&mut self) -> &mut [f32] {
match self {
CpuBuf::F32(v) => v,
_ => panic!("CpuBuf::as_f32_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_f16(&self) -> &[f16] {
match self {
CpuBuf::F16(v) => v,
_ => panic!("CpuBuf::as_f16 on dtype {}", self.dtype().name()),
}
}
pub fn as_u32(&self) -> &[u32] {
match self {
CpuBuf::U32(v) => v,
_ => panic!("CpuBuf::as_u32 on dtype {}", self.dtype().name()),
}
}
pub fn as_u32_mut(&mut self) -> &mut [u32] {
match self {
CpuBuf::U32(v) => v,
_ => panic!("CpuBuf::as_u32_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_i32(&self) -> &[i32] {
match self {
CpuBuf::I32(v) => v,
_ => panic!("CpuBuf::as_i32 on dtype {}", self.dtype().name()),
}
}
pub fn as_i32_mut(&mut self) -> &mut [i32] {
match self {
CpuBuf::I32(v) => v,
_ => panic!("CpuBuf::as_i32_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_i8(&self) -> &[i8] {
match self {
CpuBuf::I8(v) => v,
_ => panic!("CpuBuf::as_i8 on dtype {}", self.dtype().name()),
}
}
pub fn from_f32(data: Vec<f32>) -> Self {
CpuBuf::F32(data)
}
pub fn from_u32(data: Vec<u32>) -> Self {
CpuBuf::U32(data)
}
pub fn from_i32(data: Vec<i32>) -> Self {
CpuBuf::I32(data)
}
}
#[cfg(feature = "cuda")]
pub enum CudaBuf {
F32(cudarc::driver::CudaSlice<f32>),
F16(cudarc::driver::CudaSlice<f16>),
U32(cudarc::driver::CudaSlice<u32>),
I32(cudarc::driver::CudaSlice<i32>),
I8(cudarc::driver::CudaSlice<i8>),
}
#[cfg(feature = "cuda")]
impl CudaBuf {
pub fn dtype(&self) -> Dtype {
match self {
CudaBuf::F32(_) => Dtype::F32,
CudaBuf::F16(_) => Dtype::F16,
CudaBuf::U32(_) => Dtype::U32,
CudaBuf::I32(_) => Dtype::I32,
CudaBuf::I8(_) => Dtype::I8,
}
}
pub fn len(&self) -> usize {
match self {
CudaBuf::F32(s) => s.len(),
CudaBuf::F16(s) => s.len(),
CudaBuf::U32(s) => s.len(),
CudaBuf::I32(s) => s.len(),
CudaBuf::I8(s) => s.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn as_f16(&self) -> &cudarc::driver::CudaSlice<f16> {
match self {
CudaBuf::F16(s) => s,
_ => panic!("CudaBuf::as_f16 on dtype {}", self.dtype().name()),
}
}
pub fn as_f16_mut(&mut self) -> &mut cudarc::driver::CudaSlice<f16> {
match self {
CudaBuf::F16(s) => s,
_ => panic!("CudaBuf::as_f16_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_f32(&self) -> &cudarc::driver::CudaSlice<f32> {
match self {
CudaBuf::F32(s) => s,
_ => panic!("CudaBuf::as_f32 on dtype {}", self.dtype().name()),
}
}
pub fn as_u32(&self) -> &cudarc::driver::CudaSlice<u32> {
match self {
CudaBuf::U32(s) => s,
_ => panic!("CudaBuf::as_u32 on dtype {}", self.dtype().name()),
}
}
pub fn as_u32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<u32> {
match self {
CudaBuf::U32(s) => s,
_ => panic!("CudaBuf::as_u32_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_i32(&self) -> &cudarc::driver::CudaSlice<i32> {
match self {
CudaBuf::I32(s) => s,
_ => panic!("CudaBuf::as_i32 on dtype {}", self.dtype().name()),
}
}
pub fn as_i8(&self) -> &cudarc::driver::CudaSlice<i8> {
match self {
CudaBuf::I8(s) => s,
_ => panic!("CudaBuf::as_i8 on dtype {}", self.dtype().name()),
}
}
pub fn as_i8_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i8> {
match self {
CudaBuf::I8(s) => s,
_ => panic!("CudaBuf::as_i8_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_f32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<f32> {
match self {
CudaBuf::F32(s) => s,
_ => panic!("CudaBuf::as_f32_mut on dtype {}", self.dtype().name()),
}
}
pub fn as_i32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i32> {
match self {
CudaBuf::I32(s) => s,
_ => panic!("CudaBuf::as_i32_mut on dtype {}", self.dtype().name()),
}
}
pub fn from_f16(s: cudarc::driver::CudaSlice<f16>) -> Self {
CudaBuf::F16(s)
}
pub fn from_f32(s: cudarc::driver::CudaSlice<f32>) -> Self {
CudaBuf::F32(s)
}
pub fn from_u32(s: cudarc::driver::CudaSlice<u32>) -> Self {
CudaBuf::U32(s)
}
pub fn from_i32(s: cudarc::driver::CudaSlice<i32>) -> Self {
CudaBuf::I32(s)
}
pub fn from_i8(s: cudarc::driver::CudaSlice<i8>) -> Self {
CudaBuf::I8(s)
}
pub unsafe fn transmute<T>(&self, len: usize) -> Option<cudarc::driver::CudaView<'_, T>> {
match self {
CudaBuf::F16(s) => unsafe { s.transmute(len) },
CudaBuf::F32(s) => unsafe { s.transmute(len) },
CudaBuf::U32(s) => unsafe { s.transmute(len) },
CudaBuf::I32(s) => unsafe { s.transmute(len) },
CudaBuf::I8(s) => unsafe { s.transmute(len) },
}
}
}
#[cfg(feature = "cuda")]
unsafe impl<'a, 'b: 'a> cudarc::driver::PushKernelArg<&'b CudaBuf>
for cudarc::driver::LaunchArgs<'a>
{
fn arg(&mut self, arg: &'b CudaBuf) -> &mut Self {
match arg {
CudaBuf::F16(s) => self.arg(s),
CudaBuf::F32(s) => self.arg(s),
CudaBuf::U32(s) => self.arg(s),
CudaBuf::I32(s) => self.arg(s),
CudaBuf::I8(s) => self.arg(s),
}
}
}
#[cfg(feature = "cuda")]
unsafe impl<'a, 'b: 'a> cudarc::driver::PushKernelArg<&'b mut CudaBuf>
for cudarc::driver::LaunchArgs<'a>
{
fn arg(&mut self, arg: &'b mut CudaBuf) -> &mut Self {
match arg {
CudaBuf::F16(s) => self.arg(s),
CudaBuf::F32(s) => self.arg(s),
CudaBuf::U32(s) => self.arg(s),
CudaBuf::I32(s) => self.arg(s),
CudaBuf::I8(s) => self.arg(s),
}
}
}