use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct GpuError {
pub backend: BackendKind,
pub message: String,
}
impl GpuError {
pub fn new(backend: BackendKind, msg: impl Into<String>) -> Self {
Self { backend, message: msg.into() }
}
pub fn cuda(msg: impl Into<String>) -> Self {
Self::new(BackendKind::Cuda, msg)
}
pub fn opencl(msg: impl Into<String>) -> Self {
Self::new(BackendKind::OpenCL, msg)
}
pub fn rocm(msg: impl Into<String>) -> Self {
Self::new(BackendKind::Rocm, msg)
}
}
impl fmt::Display for GpuError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{:?}] {}", self.backend, self.message)
}
}
impl std::error::Error for GpuError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BackendKind {
Cuda,
OpenCL,
Rocm,
Stub,
}
impl fmt::Display for BackendKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BackendKind::Cuda => write!(f, "CUDA"),
BackendKind::OpenCL => write!(f, "OpenCL"),
BackendKind::Rocm => write!(f, "ROCm/HIP"),
BackendKind::Stub => write!(f, "Stub"),
}
}
}
pub trait BackendBuffer: Send + Sync + fmt::Debug {
fn size_bytes(&self) -> usize;
fn device_ptr(&self) -> *const std::ffi::c_void; fn as_any(&self) -> &dyn std::any::Any;
}
pub type DeviceBuffer = Arc<dyn BackendBuffer>;
pub trait BackendStream: Send + Sync + fmt::Debug {
fn id(&self) -> u64;
fn synchronize(&self) -> Result<(), GpuError>;
fn as_any(&self) -> &dyn std::any::Any;
}
pub type DeviceStream = Arc<dyn BackendStream>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferDir {
HostToDevice,
DeviceToHost,
DeviceToDevice,
}
pub trait ComputeBackend: Send + Sync + fmt::Debug + 'static {
fn kind(&self) -> BackendKind;
fn name(&self) -> &str; fn device_id(&self) -> usize;
fn alloc_bytes(&self, size: usize) -> Result<DeviceBuffer, GpuError>;
fn htod_sync(
&self,
src: *const std::ffi::c_void,
src_bytes: usize,
dst: &DeviceBuffer,
) -> Result<(), GpuError>;
fn dtoh_sync(
&self,
src: &DeviceBuffer,
dst: *mut std::ffi::c_void,
dst_bytes: usize,
) -> Result<(), GpuError>;
unsafe fn htod_async(
&self,
src: *const std::ffi::c_void,
src_bytes: usize,
dst: &DeviceBuffer,
stream: &DeviceStream,
) -> Result<(), GpuError>;
unsafe fn dtoh_async(
&self,
src: &DeviceBuffer,
dst: *mut std::ffi::c_void,
dst_bytes: usize,
stream: &DeviceStream,
) -> Result<(), GpuError>;
fn create_stream(&self) -> Result<DeviceStream, GpuError>;
fn synchronize_device(&self) -> Result<(), GpuError>;
fn memory_info(&self) -> Result<(usize, usize), GpuError>;
}
pub fn probe_backend(
device_id: usize,
preferred: Option<BackendKind>,
) -> Result<Arc<dyn ComputeBackend>, GpuError> {
let order: Vec<BackendKind> = match preferred {
Some(k) => vec![k],
None => vec![
BackendKind::Cuda,
BackendKind::Rocm,
BackendKind::OpenCL,
BackendKind::Stub,
],
};
for kind in &order {
match try_init_backend(*kind, device_id) {
Ok(b) => {
log::info!("GPU backend selected: {} on device {}", b.name(), device_id);
return Ok(b);
}
Err(e) => {
log::debug!("Backend {:?} unavailable: {}", kind, e);
}
}
}
Err(GpuError::new(
BackendKind::Stub,
format!("No compute backend available for device {}", device_id),
))
}
fn try_init_backend(
kind: BackendKind,
device_id: usize,
) -> Result<Arc<dyn ComputeBackend>, GpuError> {
match kind {
#[cfg(feature = "gpu")]
BackendKind::Cuda => {
use crate::gpu_cuda_backend::CudaBackend;
let b = CudaBackend::new(device_id)?;
Ok(Arc::new(b) as Arc<dyn ComputeBackend>)
}
#[cfg(not(feature = "gpu"))]
BackendKind::Cuda => Err(GpuError::cuda("CUDA feature not compiled")),
#[cfg(feature = "opencl")]
BackendKind::OpenCL => {
use crate::gpu_opencl::OpenCLBackend;
let b = OpenCLBackend::new(device_id)?;
Ok(Arc::new(b) as Arc<dyn ComputeBackend>)
}
#[cfg(not(feature = "opencl"))]
BackendKind::OpenCL => Err(GpuError::opencl("OpenCL feature not compiled")),
#[cfg(feature = "rocm")]
BackendKind::Rocm => {
use crate::gpu_rocm::RocmBackend;
let b = RocmBackend::new(device_id)?;
Ok(Arc::new(b) as Arc<dyn ComputeBackend>)
}
#[cfg(not(feature = "rocm"))]
BackendKind::Rocm => Err(GpuError::rocm("ROCm feature not compiled")),
BackendKind::Stub => {
use self::stub::StubBackend;
Ok(Arc::new(StubBackend::new(device_id)) as Arc<dyn ComputeBackend>)
}
}
}
pub mod stub {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
static STREAM_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug)]
pub struct StubBuffer {
data: Vec<u8>,
}
impl BackendBuffer for StubBuffer {
fn size_bytes(&self) -> usize { self.data.len() }
fn device_ptr(&self) -> *const std::ffi::c_void {
self.data.as_ptr() as *const _
}
fn as_any(&self) -> &dyn std::any::Any { self }
}
#[derive(Debug)]
pub struct StubStream { pub id: u64 }
impl BackendStream for StubStream {
fn id(&self) -> u64 { self.id }
fn synchronize(&self) -> Result<(), GpuError> { Ok(()) }
fn as_any(&self) -> &dyn std::any::Any { self }
}
#[derive(Debug)]
pub struct StubBackend { device_id: usize }
impl StubBackend {
pub fn new(device_id: usize) -> Self { Self { device_id } }
}
impl ComputeBackend for StubBackend {
fn kind(&self) -> BackendKind { BackendKind::Stub }
fn name(&self) -> &str { "Stub (no hardware)" }
fn device_id(&self) -> usize { self.device_id }
fn alloc_bytes(&self, size: usize) -> Result<DeviceBuffer, GpuError> {
Ok(Arc::new(StubBuffer { data: vec![0u8; size] }))
}
fn htod_sync(
&self,
src: *const std::ffi::c_void,
src_bytes: usize,
dst: &DeviceBuffer,
) -> Result<(), GpuError> {
let _ = (src, src_bytes, dst);
Ok(())
}
fn dtoh_sync(
&self,
src: &DeviceBuffer,
dst: *mut std::ffi::c_void,
dst_bytes: usize,
) -> Result<(), GpuError> {
let _ = (src, dst, dst_bytes);
Ok(())
}
unsafe fn htod_async(
&self,
src: *const std::ffi::c_void,
src_bytes: usize,
dst: &DeviceBuffer,
_stream: &DeviceStream,
) -> Result<(), GpuError> {
self.htod_sync(src, src_bytes, dst)
}
unsafe fn dtoh_async(
&self,
src: &DeviceBuffer,
dst: *mut std::ffi::c_void,
dst_bytes: usize,
_stream: &DeviceStream,
) -> Result<(), GpuError> {
self.dtoh_sync(src, dst, dst_bytes)
}
fn create_stream(&self) -> Result<DeviceStream, GpuError> {
let id = STREAM_ID.fetch_add(1, Ordering::Relaxed);
Ok(Arc::new(StubStream { id }))
}
fn synchronize_device(&self) -> Result<(), GpuError> { Ok(()) }
fn memory_info(&self) -> Result<(usize, usize), GpuError> {
Ok((8 * 1024 * 1024 * 1024, 8 * 1024 * 1024 * 1024)) }
}
}