use crate::device::DeviceCapabilities;
pub mod cpu;
pub mod cuda;
#[cfg(feature = "cuda")]
pub mod cuda_kernels;
pub mod cuda_pool;
#[cfg(feature = "cudnn")]
pub mod cudnn_ops;
#[cfg(feature = "vulkan")]
pub mod vulkan;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(feature = "wgpu")]
pub mod wgpu_backend;
pub mod gpu_tests;
pub use cpu::CpuBackend;
pub use cuda::CudaBackend;
#[cfg(feature = "vulkan")]
pub use vulkan::VulkanBackend;
#[cfg(feature = "metal")]
pub use metal::MetalBackend;
#[cfg(feature = "wgpu")]
pub use wgpu_backend::WgpuBackend;
pub trait Backend: Send + Sync {
fn name(&self) -> &'static str;
fn is_available(&self) -> bool;
fn capabilities(&self) -> DeviceCapabilities;
fn allocate(&self, size: usize) -> *mut u8;
fn deallocate(&self, ptr: *mut u8, size: usize);
fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize);
fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
fn synchronize(&self);
}
#[derive(Debug)]
pub struct GpuMemory {
ptr: *mut u8,
size: usize,
device_index: usize,
backend_type: BackendType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendType {
Cpu,
#[cfg(feature = "cuda")]
Cuda,
#[cfg(feature = "vulkan")]
Vulkan,
#[cfg(feature = "metal")]
Metal,
#[cfg(feature = "wgpu")]
Wgpu,
}
impl GpuMemory {
pub fn new(ptr: *mut u8, size: usize, device_index: usize, backend_type: BackendType) -> Self {
Self {
ptr,
size,
device_index,
backend_type,
}
}
#[must_use]
pub fn ptr(&self) -> *mut u8 {
self.ptr
}
#[must_use]
pub fn size(&self) -> usize {
self.size
}
#[must_use]
pub fn device_index(&self) -> usize {
self.device_index
}
#[must_use]
pub fn backend_type(&self) -> BackendType {
self.backend_type
}
}
#[derive(Debug)]
pub struct GpuStream {
handle: usize,
device_index: usize,
backend_type: BackendType,
}
impl GpuStream {
#[must_use]
pub fn new(handle: usize, device_index: usize, backend_type: BackendType) -> Self {
Self {
handle,
device_index,
backend_type,
}
}
#[must_use]
pub fn handle(&self) -> usize {
self.handle
}
#[must_use]
pub fn device_index(&self) -> usize {
self.device_index
}
pub fn synchronize(&self) {
match self.backend_type {
BackendType::Cpu => {} #[cfg(feature = "cuda")]
BackendType::Cuda => cuda::stream_synchronize(self.handle),
#[cfg(feature = "vulkan")]
BackendType::Vulkan => vulkan::queue_wait_idle(self.handle),
#[cfg(feature = "metal")]
BackendType::Metal => metal::command_buffer_wait(self.handle),
#[cfg(feature = "wgpu")]
BackendType::Wgpu => wgpu_backend::queue_submit(self.handle),
}
}
}
#[must_use]
pub fn best_available_backend() -> BackendType {
#[cfg(feature = "cuda")]
if cuda::is_available() {
return BackendType::Cuda;
}
#[cfg(feature = "metal")]
if metal::is_available() {
return BackendType::Metal;
}
#[cfg(feature = "vulkan")]
if vulkan::is_available() {
return BackendType::Vulkan;
}
#[cfg(feature = "wgpu")]
if wgpu_backend::is_available() {
return BackendType::Wgpu;
}
BackendType::Cpu
}
#[must_use]
pub fn gpu_count() -> usize {
#[allow(unused_mut)]
let mut count = 0_usize;
#[cfg(feature = "cuda")]
{
count += cuda::device_count();
}
#[cfg(feature = "vulkan")]
{
count += vulkan::device_count();
}
#[cfg(feature = "metal")]
{
count += metal::device_count();
}
#[cfg(feature = "wgpu")]
{
count += wgpu_backend::device_count();
}
count
}