use crate::error::RusTorchError;
use num_traits::Float;
use std::sync::Arc;
#[cfg(feature = "cuda")]
use cudarc::driver::{CudaDevice as CudarcDevice, CudaSlice, DeviceSlice};
#[cfg(feature = "metal")]
use metal::{Buffer, Device as MetalDeviceType};
#[cfg(feature = "opencl")]
use opencl3::{
context::Context as CLContext,
memory::{Buffer as CLBuffer, ClMem},
};
pub enum GpuBuffer<T> {
#[cfg(feature = "cuda")]
Cuda {
data: Arc<CudaSlice<T>>,
device: Arc<CudarcDevice>,
},
#[cfg(feature = "metal")]
Metal {
buffer: Arc<Buffer>,
device: Arc<MetalDeviceType>,
},
#[cfg(feature = "opencl")]
OpenCL {
buffer: Arc<CLBuffer<T>>,
context: Arc<CLContext>,
},
Cpu(Arc<Vec<T>>),
}
impl<T> GpuBuffer<T> {
pub fn is_cpu(&self) -> bool {
matches!(self, GpuBuffer::Cpu(_))
}
pub fn is_cuda(&self) -> bool {
#[cfg(feature = "cuda")]
{
matches!(self, GpuBuffer::Cuda { .. })
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
pub fn is_metal(&self) -> bool {
#[cfg(feature = "metal")]
{
matches!(self, GpuBuffer::Metal { .. })
}
#[cfg(not(feature = "metal"))]
{
false
}
}
pub fn is_opencl(&self) -> bool {
#[cfg(feature = "opencl")]
{
matches!(self, GpuBuffer::OpenCL { .. })
}
#[cfg(not(feature = "opencl"))]
{
false
}
}
pub fn len(&self) -> usize {
match self {
GpuBuffer::Cpu(data) => data.len(),
#[cfg(feature = "cuda")]
GpuBuffer::Cuda { data, .. } => data.num_bytes() / std::mem::size_of::<T>(),
#[cfg(feature = "metal")]
GpuBuffer::Metal { buffer, .. } => buffer.length() as usize / std::mem::size_of::<T>(),
#[cfg(feature = "opencl")]
GpuBuffer::OpenCL { buffer, .. } => {
buffer.size().unwrap_or(0) / std::mem::size_of::<T>()
}
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: Clone> Clone for GpuBuffer<T> {
fn clone(&self) -> Self {
match self {
GpuBuffer::Cpu(data) => GpuBuffer::Cpu(data.clone()),
#[cfg(feature = "cuda")]
GpuBuffer::Cuda { data, device } => GpuBuffer::Cuda {
data: data.clone(),
device: device.clone(),
},
#[cfg(feature = "metal")]
GpuBuffer::Metal { buffer, device } => GpuBuffer::Metal {
buffer: buffer.clone(),
device: device.clone(),
},
#[cfg(feature = "opencl")]
GpuBuffer::OpenCL { buffer, context } => GpuBuffer::OpenCL {
buffer: buffer.clone(),
context: context.clone(),
},
}
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for GpuBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GpuBuffer::Cpu(data) => f
.debug_struct("GpuBuffer::Cpu")
.field("len", &data.len())
.finish(),
#[cfg(feature = "cuda")]
GpuBuffer::Cuda { data, device } => f
.debug_struct("GpuBuffer::Cuda")
.field("len", &(data.num_bytes() / std::mem::size_of::<T>()))
.field("device", &"CUDA")
.finish(),
#[cfg(feature = "metal")]
GpuBuffer::Metal { buffer, .. } => f
.debug_struct("GpuBuffer::Metal")
.field(
"len",
&(buffer.length() as usize / std::mem::size_of::<T>()),
)
.finish(),
#[cfg(feature = "opencl")]
GpuBuffer::OpenCL { buffer, .. } => f
.debug_struct("GpuBuffer::OpenCL")
.field(
"len",
&(buffer.size().unwrap_or(0) / std::mem::size_of::<T>()),
)
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_buffer() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let buffer = GpuBuffer::Cpu(Arc::new(data));
assert!(buffer.is_cpu());
assert!(!buffer.is_cuda());
assert!(!buffer.is_metal());
assert!(!buffer.is_opencl());
assert_eq!(buffer.len(), 4);
assert!(!buffer.is_empty());
}
#[test]
fn test_empty_buffer() {
let buffer = GpuBuffer::Cpu(Arc::new(Vec::<f32>::new()));
assert!(buffer.is_empty());
assert_eq!(buffer.len(), 0);
}
}