#[cfg(feature = "vulkan")]
pub mod vulkan;
#[cfg(feature = "cuda")]
pub mod cuda;
use crate::error::Result;
pub trait Backend: Sized {
type Buffer: BackendBufferOps;
type Pipeline;
fn create() -> Result<Self>;
fn upload(&self, data: &[u8]) -> Result<Self::Buffer>;
fn alloc(&self, size: u64) -> Result<Self::Buffer>;
fn dispatch(
&self,
spirv: &[u32],
entry_point: &str,
buffers: &[&Self::Buffer],
workgroups: [u32; 3],
push_constants: Option<&[u8]>,
) -> Result<()>;
fn create_pipeline(
&self,
spirv: &[u32],
entry_point: &str,
binding_count: usize,
push_constant_size: u32,
) -> Result<Self::Pipeline>;
fn dispatch_pipeline(
&self,
pipeline: &Self::Pipeline,
buffers: &[&Self::Buffer],
workgroups: [u32; 3],
push_constants: Option<&[u8]>,
) -> Result<()>;
fn device_name(&self) -> &str;
fn device_memory(&self) -> u64;
fn subgroup_size(&self) -> u32;
fn copy_buffer(&self, src: &Self::Buffer, size: u64) -> Result<Self::Buffer>;
}
pub trait BackendBufferOps {
fn read_back(&self) -> Result<Vec<u8>>;
#[allow(dead_code)]
fn byte_size(&self) -> u64;
}
pub enum BackendBuffer {
#[cfg(feature = "vulkan")]
Vulkan(vulkan::VulkanBuffer),
#[cfg(feature = "cuda")]
Cuda(cuda::CudaBuffer),
}
pub enum BackendKernel {
#[cfg(feature = "vulkan")]
Vulkan(vulkan::VulkanKernel),
#[cfg(feature = "cuda")]
Cuda(cuda::CudaKernel),
}
impl BackendBufferOps for BackendBuffer {
fn read_back(&self) -> Result<Vec<u8>> {
match self {
#[cfg(feature = "vulkan")]
Self::Vulkan(b) => b.read_back(),
#[cfg(feature = "cuda")]
Self::Cuda(b) => b.read_back(),
}
}
fn byte_size(&self) -> u64 {
match self {
#[cfg(feature = "vulkan")]
Self::Vulkan(b) => b.byte_size(),
#[cfg(feature = "cuda")]
Self::Cuda(b) => b.byte_size(),
}
}
}