use std::any::Any;
use thiserror::Error;
use crate::device::DeviceType;
use crate::gpu::device::GpuDevice;
use crate::gpu::kernel::{KernelConfig, KernelLaunchResult};
use crate::gpu::memory::GpuBuffer;
#[derive(Debug, Error)]
pub enum GpuError {
#[error("GPU not available: {reason}")]
NotAvailable { reason: String },
#[error("GPU device {index} not found")]
DeviceNotFound { index: usize },
#[error("GPU out of memory: requested {requested} bytes, available {available}")]
OutOfMemory { requested: usize, available: usize },
#[error("Kernel launch failed: {0}")]
KernelLaunchFailed(String),
#[error("Data transfer failed: {0}")]
TransferFailed(String),
#[error("GPU synchronization failed: {0}")]
SyncFailed(String),
}
pub trait GpuBackend: Send {
fn allocate_raw(&mut self, byte_count: usize) -> Result<Box<dyn Any + Send>, GpuError>;
fn deallocate_raw(&mut self, handle: Box<dyn Any + Send>) -> Result<(), GpuError>;
fn transfer_to_device_raw(&mut self, bytes: &[u8]) -> Result<Box<dyn Any + Send>, GpuError>;
fn transfer_to_host_raw(
&self,
handle: &dyn Any,
byte_count: usize,
) -> Result<Vec<u8>, GpuError>;
fn launch_kernel(&mut self, config: &KernelConfig) -> Result<KernelLaunchResult, GpuError>;
fn synchronize(&mut self) -> Result<(), GpuError>;
fn device_info(&self) -> &GpuDevice;
fn is_available() -> bool
where
Self: Sized;
}
pub fn allocate<T: Send + 'static>(
backend: &mut dyn GpuBackend,
count: usize,
) -> Result<GpuBuffer<T>, GpuError> {
let byte_count = count
.checked_mul(std::mem::size_of::<T>())
.ok_or(GpuError::OutOfMemory {
requested: usize::MAX,
available: 0,
})?;
backend.allocate_raw(byte_count)?;
Ok(GpuBuffer::new(backend.device_info().index, byte_count))
}
pub fn transfer_to_device<T: Copy + Send + 'static>(
backend: &mut dyn GpuBackend,
host: &[T],
) -> Result<GpuBuffer<T>, GpuError> {
let byte_len = std::mem::size_of_val(host);
let bytes: &[u8] = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, byte_len) };
backend.transfer_to_device_raw(bytes)?;
Ok(GpuBuffer::new(backend.device_info().index, bytes.len()))
}
pub fn transfer_to_host<T: Copy + Send + 'static>(
backend: &dyn GpuBackend,
_buf: &GpuBuffer<T>,
len: usize,
handle: &dyn Any,
) -> Result<Vec<T>, GpuError> {
let byte_count = len
.checked_mul(std::mem::size_of::<T>())
.ok_or(GpuError::OutOfMemory {
requested: usize::MAX,
available: 0,
})?;
let raw_bytes = backend.transfer_to_host_raw(handle, byte_count)?;
let elem_size = std::mem::size_of::<T>();
if elem_size == 0 {
return Ok(vec![]);
}
let mut result: Vec<T> = Vec::with_capacity(raw_bytes.len() / elem_size);
for chunk in raw_bytes.chunks_exact(elem_size) {
let mut arr = vec![0u8; elem_size];
arr.copy_from_slice(chunk);
let value = unsafe { std::ptr::read(arr.as_ptr() as *const T) };
result.push(value);
}
Ok(result)
}
pub struct CudaStub {
pseudo_device: GpuDevice,
}
impl CudaStub {
pub fn new(device_index: usize) -> Self {
let mut device = Self::pseudo_device();
device.index = device_index;
Self {
pseudo_device: device,
}
}
pub fn pseudo_device() -> GpuDevice {
GpuDevice {
index: 0,
device_type: DeviceType::Cuda,
name: "CUDA Stub (not available)".to_string(),
total_memory_bytes: 0,
free_memory_bytes: 0,
compute_capability: None,
supports_fp16: false,
supports_bf16: false,
}
}
}
impl GpuBackend for CudaStub {
fn allocate_raw(&mut self, _byte_count: usize) -> Result<Box<dyn Any + Send>, GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
fn deallocate_raw(&mut self, _handle: Box<dyn Any + Send>) -> Result<(), GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
fn transfer_to_device_raw(&mut self, _bytes: &[u8]) -> Result<Box<dyn Any + Send>, GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
fn transfer_to_host_raw(
&self,
_handle: &dyn Any,
_byte_count: usize,
) -> Result<Vec<u8>, GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
fn launch_kernel(&mut self, _config: &KernelConfig) -> Result<KernelLaunchResult, GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
fn synchronize(&mut self) -> Result<(), GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
fn device_info(&self) -> &GpuDevice {
&self.pseudo_device
}
fn is_available() -> bool {
false
}
}
pub fn create_gpu_backend(device_index: usize) -> Result<Box<dyn GpuBackend>, GpuError> {
Ok(Box::new(CudaStub::new(device_index)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::DeviceType;
use crate::gpu::kernel::KernelConfig;
#[test]
fn test_cuda_stub_creation() {
let stub = CudaStub::new(0);
let info = stub.device_info();
assert_eq!(info.index, 0);
assert_eq!(info.device_type, DeviceType::Cuda);
assert!(!info.name.is_empty());
}
#[test]
fn test_cuda_stub_is_not_available() {
assert!(!CudaStub::is_available());
}
#[test]
fn test_allocate_returns_error() {
let mut stub = CudaStub::new(0);
let result = stub.allocate_raw(1024);
assert!(result.is_err());
match result {
Err(GpuError::NotAvailable { .. }) => {}
other => panic!("Expected NotAvailable, got {:?}", other),
}
}
#[test]
fn test_transfer_to_device_returns_error() {
let mut stub = CudaStub::new(0);
let data = vec![1.0_f32, 2.0, 3.0];
let result = transfer_to_device(&mut stub, &data);
assert!(result.is_err());
match result {
Err(GpuError::NotAvailable { .. }) => {}
other => panic!("Expected NotAvailable, got {:?}", other),
}
}
#[test]
fn test_synchronize_returns_error() {
let mut stub = CudaStub::new(0);
let result = stub.synchronize();
assert!(result.is_err());
match result {
Err(GpuError::NotAvailable { .. }) => {}
other => panic!("Expected NotAvailable, got {:?}", other),
}
}
#[test]
fn test_error_messages_contain_reason() {
let err = GpuError::NotAvailable {
reason: "no CUDA driver found".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("no CUDA driver found"));
let oom = GpuError::OutOfMemory {
requested: 1024,
available: 512,
};
let oom_msg = oom.to_string();
assert!(oom_msg.contains("1024"));
assert!(oom_msg.contains("512"));
}
#[test]
fn test_create_gpu_backend_fails_gracefully() {
let backend_result = create_gpu_backend(0);
assert!(backend_result.is_ok());
let mut backend = match backend_result {
Ok(b) => b,
Err(e) => panic!("create_gpu_backend should not fail: {e}"),
};
let config = KernelConfig::new("noop");
let launch_result = backend.launch_kernel(&config);
assert!(launch_result.is_err());
}
}