use crate::device::DeviceCapabilities;
pub mod cpu;
pub mod cuda;
#[cfg(feature = "cuda")]
pub mod cuda_kernels;
pub mod cuda_pool;
#[cfg(feature = "cudnn")]
pub mod cudnn_ops;
#[cfg(feature = "vulkan")]
pub mod vulkan;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(feature = "wgpu")]
pub mod wgpu_backend;
pub mod gpu_tests;
pub use cpu::CpuBackend;
pub use cuda::CudaBackend;
#[cfg(feature = "vulkan")]
pub use vulkan::VulkanBackend;
#[cfg(feature = "metal")]
pub use metal::MetalBackend;
#[cfg(feature = "wgpu")]
pub use wgpu_backend::WgpuBackend;
pub trait Backend: Send + Sync {
fn name(&self) -> &'static str;
fn is_available(&self) -> bool;
fn capabilities(&self) -> DeviceCapabilities;
fn allocate(&self, size: usize) -> *mut u8;
fn deallocate(&self, ptr: *mut u8, size: usize);
fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize);
fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
fn synchronize(&self);
}
#[derive(Debug)]
pub struct GpuMemory {
ptr: *mut u8,
size: usize,
device_index: usize,
backend_type: BackendType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendType {
Cpu,
#[cfg(feature = "cuda")]
Cuda,
#[cfg(feature = "vulkan")]
Vulkan,
#[cfg(feature = "metal")]
Metal,
#[cfg(feature = "wgpu")]
Wgpu,
}
impl GpuMemory {
pub fn new(ptr: *mut u8, size: usize, device_index: usize, backend_type: BackendType) -> Self {
Self {
ptr,
size,
device_index,
backend_type,
}
}
#[must_use]
pub fn ptr(&self) -> *mut u8 {
self.ptr
}
#[must_use]
pub fn size(&self) -> usize {
self.size
}
#[must_use]
pub fn device_index(&self) -> usize {
self.device_index
}
#[must_use]
pub fn backend_type(&self) -> BackendType {
self.backend_type
}
}
#[derive(Debug)]
pub struct GpuStream {
handle: usize,
device_index: usize,
backend_type: BackendType,
}
impl GpuStream {
#[must_use]
pub fn new(handle: usize, device_index: usize, backend_type: BackendType) -> Self {
Self {
handle,
device_index,
backend_type,
}
}
#[must_use]
pub fn handle(&self) -> usize {
self.handle
}
#[must_use]
pub fn device_index(&self) -> usize {
self.device_index
}
pub fn synchronize(&self) {
match self.backend_type {
BackendType::Cpu => {} #[cfg(feature = "cuda")]
BackendType::Cuda => cuda::stream_synchronize(self.handle),
#[cfg(feature = "vulkan")]
BackendType::Vulkan => vulkan::queue_wait_idle(self.handle),
#[cfg(feature = "metal")]
BackendType::Metal => metal::command_buffer_wait(self.handle),
#[cfg(feature = "wgpu")]
BackendType::Wgpu => wgpu_backend::queue_submit(self.handle),
}
}
}
#[must_use]
pub fn best_available_backend() -> BackendType {
#[cfg(feature = "cuda")]
if cuda::is_available() {
return BackendType::Cuda;
}
#[cfg(feature = "metal")]
if metal::is_available() {
return BackendType::Metal;
}
#[cfg(feature = "vulkan")]
if vulkan::is_available() {
return BackendType::Vulkan;
}
#[cfg(feature = "wgpu")]
if wgpu_backend::is_available() {
return BackendType::Wgpu;
}
BackendType::Cpu
}
#[must_use]
pub fn gpu_count() -> usize {
#[allow(unused_mut)]
let mut count = 0_usize;
#[cfg(feature = "cuda")]
{
count += cuda::device_count();
}
#[cfg(feature = "vulkan")]
{
count += vulkan::device_count();
}
#[cfg(feature = "metal")]
{
count += metal::device_count();
}
#[cfg(feature = "wgpu")]
{
count += wgpu_backend::device_count();
}
count
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_memory_creation() {
let mem = GpuMemory::new(std::ptr::null_mut(), 1024, 0, BackendType::Cpu);
assert_eq!(mem.size(), 1024);
assert_eq!(mem.device_index(), 0);
assert_eq!(mem.backend_type(), BackendType::Cpu);
assert!(mem.ptr().is_null());
}
#[test]
fn test_gpu_memory_nonzero_ptr() {
let mut data = vec![0u8; 256];
let ptr = data.as_mut_ptr();
let mem = GpuMemory::new(ptr, 256, 0, BackendType::Cpu);
assert_eq!(mem.ptr(), ptr);
assert_eq!(mem.size(), 256);
}
#[test]
fn test_gpu_stream_creation() {
let stream = GpuStream::new(42, 0, BackendType::Cpu);
assert_eq!(stream.handle(), 42);
assert_eq!(stream.device_index(), 0);
}
#[test]
fn test_gpu_stream_cpu_sync() {
let stream = GpuStream::new(0, 0, BackendType::Cpu);
stream.synchronize();
}
#[test]
fn test_backend_type_equality() {
assert_eq!(BackendType::Cpu, BackendType::Cpu);
#[cfg(feature = "cuda")]
assert_ne!(BackendType::Cpu, BackendType::Cuda);
}
#[test]
fn test_best_available_backend() {
let best = best_available_backend();
let _ = best;
}
#[test]
fn test_gpu_count() {
let count = gpu_count();
assert!(count < 1000, "Sanity check: unreasonable GPU count");
}
#[test]
fn test_cpu_backend_is_available() {
let cpu = CpuBackend::new();
assert!(cpu.is_available());
assert_eq!(cpu.name(), "cpu");
}
#[test]
fn test_cpu_backend_allocate_deallocate() {
let cpu = CpuBackend::new();
let ptr = cpu.allocate(256);
assert!(!ptr.is_null());
cpu.deallocate(ptr, 256);
}
#[test]
fn test_cpu_backend_zero_alloc() {
let cpu = CpuBackend::new();
let ptr = cpu.allocate(0);
assert!(ptr.is_null());
}
#[test]
fn test_cpu_backend_copy_round_trip() {
let cpu = CpuBackend::new();
let src: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let dst_ptr = cpu.allocate(16);
cpu.copy_to_device(dst_ptr, src.as_ptr().cast::<u8>(), 16);
let mut result = [0.0f32; 4];
cpu.copy_to_host(result.as_mut_ptr().cast::<u8>(), dst_ptr.cast_const(), 16);
assert_eq!(result, [1.0, 2.0, 3.0, 4.0]);
cpu.deallocate(dst_ptr, 16);
}
#[test]
fn test_cpu_backend_capabilities() {
let cpu = CpuBackend::new();
let caps = cpu.capabilities();
assert!(caps.supports_f16);
assert!(caps.supports_f64);
assert!(caps.total_memory > 0);
}
}