use crate::backend::BackendBuffer;
use crate::buffer::GpuBuf;
use crate::dispatch;
use crate::error::{GpuError, Result};
use crate::kernel::Kernel;
use crate::ticket::Ticket;
pub struct Batch {
inner: BatchInner,
}
enum BatchInner {
#[cfg(feature = "vulkan")]
Vulkan(crate::backend::vulkan::VulkanBatch),
#[cfg(feature = "cuda")]
Cuda(crate::backend::cuda::CudaBatch),
}
impl Batch {
#[cfg(feature = "vulkan")]
pub(crate) const fn new_vulkan(vk_batch: crate::backend::vulkan::VulkanBatch) -> Self {
Self {
inner: BatchInner::Vulkan(vk_batch),
}
}
#[cfg(feature = "cuda")]
pub(crate) const fn new_cuda(cuda_batch: crate::backend::cuda::CudaBatch) -> Self {
Self {
inner: BatchInner::Cuda(cuda_batch),
}
}
pub fn run(
&mut self,
kernel: &Kernel,
buffers: &[&dyn GpuBuf],
invocations: u32,
) -> Result<&mut Self> {
let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
self.run_configured(kernel, buffers, workgroups, None)
}
pub fn run_with_push_constants(
&mut self,
kernel: &Kernel,
buffers: &[&dyn GpuBuf],
invocations: u32,
push_constants: &[u8],
) -> Result<&mut Self> {
let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
self.run_configured(kernel, buffers, workgroups, Some(push_constants))
}
pub fn run_configured(
&mut self,
kernel: &Kernel,
buffers: &[&dyn GpuBuf],
workgroups: [u32; 3],
push_constants: Option<&[u8]>,
) -> Result<&mut Self> {
let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
if kernel.binding_count != backend_bufs.len() {
return Err(GpuError::BindingMismatch {
expected: kernel.binding_count,
got: backend_bufs.len(),
});
}
match &mut self.inner {
#[cfg(feature = "vulkan")]
BatchInner::Vulkan(vk_batch) => {
#[allow(irrefutable_let_patterns)]
let crate::backend::BackendKernel::Vulkan(vk_kernel) = &kernel.inner
else {
return Err(GpuError::BackendUnavailable(
"kernel was not compiled for Vulkan".into(),
));
};
let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = backend_bufs
.iter()
.map(|buf| match buf {
BackendBuffer::Vulkan(vb) => Ok(vb),
#[cfg(feature = "cuda")]
_ => Err(GpuError::BackendUnavailable(
"buffer/backend mismatch: expected Vulkan buffer".into(),
)),
})
.collect::<Result<Vec<_>>>()?;
vk_batch.record_dispatch(vk_kernel, &vk_bufs, workgroups, push_constants)?;
}
#[cfg(feature = "cuda")]
BatchInner::Cuda(cuda_batch) => {
let crate::backend::BackendKernel::Cuda(cuda_kernel) = &kernel.inner else {
return Err(GpuError::BackendUnavailable(
"kernel was not compiled for CUDA".into(),
));
};
let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = backend_bufs
.iter()
.map(|buf| match buf {
BackendBuffer::Cuda(cb) => Ok(cb),
#[cfg(feature = "vulkan")]
_ => Err(GpuError::BackendUnavailable(
"buffer/backend mismatch: expected CUDA buffer".into(),
)),
})
.collect::<Result<Vec<_>>>()?;
cuda_batch.record_dispatch(cuda_kernel, &cuda_bufs, workgroups, push_constants)?;
}
}
Ok(self)
}
pub fn barrier(&mut self) -> &mut Self {
match &mut self.inner {
#[cfg(feature = "vulkan")]
BatchInner::Vulkan(vk_batch) => vk_batch.record_barrier(),
#[cfg(feature = "cuda")]
BatchInner::Cuda(cuda_batch) => cuda_batch.record_barrier(),
}
self
}
pub fn submit(self) -> Result<()> {
self.submit_async()?.wait()
}
pub fn submit_async(self) -> Result<Ticket> {
match self.inner {
#[cfg(feature = "vulkan")]
BatchInner::Vulkan(vk_batch) => Ok(Ticket::new_vulkan(vk_batch.submit_async()?)),
#[cfg(feature = "cuda")]
BatchInner::Cuda(cuda_batch) => Ok(Ticket::new_cuda(cuda_batch.submit_async()?)),
}
}
}