scry-gpu 0.1.0

Lightweight GPU compute — dispatch shaders without the graphics baggage
Documentation
// SPDX-License-Identifier: MIT OR Apache-2.0
//! Batched dispatch — multiple dispatches in a single GPU submission.
//!
//! A [`Batch`] records multiple kernel dispatches into one command buffer,
//! then submits them all with a single fence wait. This eliminates the
//! per-dispatch synchronization overhead that dominates bandwidth-bound
//! workloads.
//!
//! # Example
//!
//! ```ignore
//! let mut batch = gpu.batch()?;
//! batch.run(&kernel, &[&input, &pass1], n)?;
//! batch.barrier();  // ensure pass1 finishes before pass2 reads it
//! batch.run(&kernel, &[&pass1, &pass2], pass1_n)?;
//! batch.submit()?;
//! ```

use crate::backend::BackendBuffer;
use crate::buffer::GpuBuf;
use crate::dispatch;
use crate::error::{GpuError, Result};
use crate::kernel::Kernel;
use crate::ticket::Ticket;

/// A batch of dispatches recorded into a single command buffer.
///
/// Created via [`Device::batch`](crate::Device::batch).
/// Use [`barrier`](Batch::barrier) between dispatches that have data
/// dependencies (where one dispatch reads from another's output).
pub struct Batch {
    inner: BatchInner,
}

enum BatchInner {
    #[cfg(feature = "vulkan")]
    Vulkan(crate::backend::vulkan::VulkanBatch),
    #[cfg(feature = "cuda")]
    Cuda(crate::backend::cuda::CudaBatch),
}

impl Batch {
    #[cfg(feature = "vulkan")]
    pub(crate) const fn new_vulkan(vk_batch: crate::backend::vulkan::VulkanBatch) -> Self {
        Self {
            inner: BatchInner::Vulkan(vk_batch),
        }
    }

    #[cfg(feature = "cuda")]
    pub(crate) const fn new_cuda(cuda_batch: crate::backend::cuda::CudaBatch) -> Self {
        Self {
            inner: BatchInner::Cuda(cuda_batch),
        }
    }

    /// Record a kernel dispatch with auto-calculated workgroups.
    pub fn run(
        &mut self,
        kernel: &Kernel,
        buffers: &[&dyn GpuBuf],
        invocations: u32,
    ) -> Result<&mut Self> {
        let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
        self.run_configured(kernel, buffers, workgroups, None)
    }

    /// Record a kernel dispatch with push constants.
    pub fn run_with_push_constants(
        &mut self,
        kernel: &Kernel,
        buffers: &[&dyn GpuBuf],
        invocations: u32,
        push_constants: &[u8],
    ) -> Result<&mut Self> {
        let workgroups = dispatch::calc_dispatch(invocations, kernel.workgroup_size);
        self.run_configured(kernel, buffers, workgroups, Some(push_constants))
    }

    /// Record a kernel dispatch with explicit workgroups and optional push constants.
    pub fn run_configured(
        &mut self,
        kernel: &Kernel,
        buffers: &[&dyn GpuBuf],
        workgroups: [u32; 3],
        push_constants: Option<&[u8]>,
    ) -> Result<&mut Self> {
        let backend_bufs: Vec<&BackendBuffer> = buffers.iter().map(|b| b.raw()).collect();
        if kernel.binding_count != backend_bufs.len() {
            return Err(GpuError::BindingMismatch {
                expected: kernel.binding_count,
                got: backend_bufs.len(),
            });
        }

        match &mut self.inner {
            #[cfg(feature = "vulkan")]
            BatchInner::Vulkan(vk_batch) => {
                #[allow(irrefutable_let_patterns)]
                let crate::backend::BackendKernel::Vulkan(vk_kernel) = &kernel.inner
                else {
                    return Err(GpuError::BackendUnavailable(
                        "kernel was not compiled for Vulkan".into(),
                    ));
                };
                let vk_bufs: Vec<&crate::backend::vulkan::VulkanBuffer> = backend_bufs
                    .iter()
                    .map(|buf| match buf {
                        BackendBuffer::Vulkan(vb) => Ok(vb),
                        #[cfg(feature = "cuda")]
                        _ => Err(GpuError::BackendUnavailable(
                            "buffer/backend mismatch: expected Vulkan buffer".into(),
                        )),
                    })
                    .collect::<Result<Vec<_>>>()?;
                vk_batch.record_dispatch(vk_kernel, &vk_bufs, workgroups, push_constants)?;
            }
            #[cfg(feature = "cuda")]
            BatchInner::Cuda(cuda_batch) => {
                let crate::backend::BackendKernel::Cuda(cuda_kernel) = &kernel.inner else {
                    return Err(GpuError::BackendUnavailable(
                        "kernel was not compiled for CUDA".into(),
                    ));
                };
                let cuda_bufs: Vec<&crate::backend::cuda::CudaBuffer> = backend_bufs
                    .iter()
                    .map(|buf| match buf {
                        BackendBuffer::Cuda(cb) => Ok(cb),
                        #[cfg(feature = "vulkan")]
                        _ => Err(GpuError::BackendUnavailable(
                            "buffer/backend mismatch: expected CUDA buffer".into(),
                        )),
                    })
                    .collect::<Result<Vec<_>>>()?;
                cuda_batch.record_dispatch(cuda_kernel, &cuda_bufs, workgroups, push_constants)?;
            }
        }

        Ok(self)
    }

    /// Insert a compute-to-compute barrier.
    ///
    /// Use this between dispatches where a later dispatch reads from an
    /// earlier dispatch's output buffer. Without a barrier, the GPU may
    /// execute dispatches out of order or overlap writes with reads.
    pub fn barrier(&mut self) -> &mut Self {
        match &mut self.inner {
            #[cfg(feature = "vulkan")]
            BatchInner::Vulkan(vk_batch) => vk_batch.record_barrier(),
            #[cfg(feature = "cuda")]
            BatchInner::Cuda(cuda_batch) => cuda_batch.record_barrier(),
        }
        self
    }

    /// Submit all recorded dispatches and wait for completion.
    ///
    /// All dispatches execute in a single command buffer with one fence wait,
    /// eliminating per-dispatch synchronization overhead.
    ///
    /// Equivalent to `self.submit_async()?.wait()`.
    pub fn submit(self) -> Result<()> {
        self.submit_async()?.wait()
    }

    /// Submit all recorded dispatches and return a [`Ticket`] for
    /// non-blocking completion tracking.
    ///
    /// The GPU work is queued immediately. Use [`Ticket::wait`] to block
    /// until completion, or [`Ticket::is_ready`] to poll.
    ///
    /// # Example
    ///
    /// ```ignore
    /// let mut batch = gpu.batch()?;
    /// batch.run(&kernel, &[&input, &output], n)?;
    /// let ticket = batch.submit_async()?;
    ///
    /// // ... CPU work while GPU runs ...
    ///
    /// ticket.wait()?;
    /// let result: Vec<f32> = output.download()?;
    /// ```
    pub fn submit_async(self) -> Result<Ticket> {
        match self.inner {
            #[cfg(feature = "vulkan")]
            BatchInner::Vulkan(vk_batch) => Ok(Ticket::new_vulkan(vk_batch.submit_async()?)),
            #[cfg(feature = "cuda")]
            BatchInner::Cuda(cuda_batch) => Ok(Ticket::new_cuda(cuda_batch.submit_async()?)),
        }
    }
}