vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! Reusable GPU buffer pool keyed by device, size class, and usage flags.

use std::collections::HashMap;
use std::ops::Deref;
use std::sync::{Arc, LazyLock, Mutex, Weak};
use vyre::error::{Error, Result};

const MAX_BUFFERS_PER_CLASS: usize = 8;

#[derive(Clone, Hash, PartialEq, Eq)]
struct BufferKey {
    device: wgpu::Device,
    size_class: u64,
    usage_bits: u32,
}

#[derive(Default)]
struct BufferPoolInner {
    buffers: Mutex<HashMap<BufferKey, Vec<wgpu::Buffer>>>,
}

/// Device-aware reusable GPU buffer pool.
#[derive(Clone, Default)]
pub struct BufferPool {
    inner: Arc<BufferPoolInner>,
}

/// Buffer handle that returns to its originating [`BufferPool`] on drop.
pub struct PooledBuffer {
    key: BufferKey,
    buffer: Option<wgpu::Buffer>,
    pool: Weak<BufferPoolInner>,
}

impl BufferPool {
    /// Return the process-wide buffer pool.
    #[must_use]
    #[inline]
    pub fn global() -> &'static Self {
        static POOL: LazyLock<BufferPool> = LazyLock::new(BufferPool::new);
        &POOL
    }

    /// Create an empty buffer pool.
    #[must_use]
    #[inline]
    pub fn new() -> Self {
        Self {
            inner: Arc::new(BufferPoolInner::default()),
        }
    }

    /// Acquire a reusable buffer with at least `size` bytes and exactly `usage`.
    ///
    /// # Errors
    ///
    /// Returns [`Error::Gpu`] when pool metadata cannot be locked.
    #[inline]
    pub fn acquire(
        &self,
        device: &wgpu::Device,
        label: &str,
        size: u64,
        usage: wgpu::BufferUsages,
    ) -> Result<PooledBuffer> {
        let key = BufferKey {
            device: device.clone(),
            size_class: size_class(size),
            usage_bits: usage.bits(),
        };
        let mut buffers = self.inner.buffers.lock().map_err(|source| Error::Gpu {
            message: format!(
                "GPU buffer pool mutex poisoned: {source}. Fix: restart the process or avoid panicking while holding the buffer pool lock."
            ),
        })?;
        let buffer = buffers
            .entry(key.clone())
            .or_default()
            .pop()
            .unwrap_or_else(|| {
                device.create_buffer(&wgpu::BufferDescriptor {
                    label: Some(label),
                    size: key.size_class,
                    usage,
                    mapped_at_creation: false,
                })
            });
        Ok(PooledBuffer {
            key,
            buffer: Some(buffer),
            pool: Arc::downgrade(&self.inner),
        })
    }

    /// Release a buffer to the pool immediately.
    #[inline]
    pub fn release(&self, buffer: PooledBuffer) {
        drop(buffer);
    }

    /// Acquire a buffer for the duration of `f` and release it afterward.
    ///
    /// # Errors
    ///
    /// Returns [`Error::Gpu`] when buffer acquisition fails.
    #[inline]
    pub fn with_buffer<R>(
        &self,
        device: &wgpu::Device,
        label: &str,
        size: u64,
        usage: wgpu::BufferUsages,
        f: impl FnOnce(&wgpu::Buffer) -> R,
    ) -> Result<R> {
        let buffer = self.acquire(device, label, size, usage)?;
        let result = f(&buffer);
        self.release(buffer);
        Ok(result)
    }
}

impl PooledBuffer {
    /// Return the size-class allocation backing this pooled buffer.
    #[must_use]
    #[inline]
    pub fn size(&self) -> u64 {
        self.key.size_class
    }

    /// Return the inner `wgpu::Buffer`.
    #[must_use]
    #[inline]
    pub fn buffer(&self) -> &wgpu::Buffer {
        self.buffer
            .as_ref()
            .expect("pooled buffer missing inner buffer. Fix: do not use PooledBuffer after drop.")
    }
}

impl Deref for PooledBuffer {
    type Target = wgpu::Buffer;

    fn deref(&self) -> &Self::Target {
        self.buffer()
    }
}

impl Drop for PooledBuffer {
    fn drop(&mut self) {
        let Some(buffer) = self.buffer.take() else {
            return;
        };
        let Some(pool) = self.pool.upgrade() else {
            return;
        };
        let Ok(mut buffers) = pool.buffers.lock() else {
            return;
        };
        let class = buffers.entry(self.key.clone()).or_default();
        if class.len() < MAX_BUFFERS_PER_CLASS {
            class.push(buffer);
        }
    }
}

fn size_class(size: u64) -> u64 {
    size.max(wgpu::COPY_BUFFER_ALIGNMENT)
        .next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
}

#[cfg(test)]
mod tests {
    use super::size_class;

    #[test]
    fn size_class_is_copy_aligned_and_nonzero() {
        assert_eq!(size_class(0), wgpu::COPY_BUFFER_ALIGNMENT);
        assert_eq!(size_class(1), wgpu::COPY_BUFFER_ALIGNMENT);
        assert_eq!(
            size_class(wgpu::COPY_BUFFER_ALIGNMENT + 1),
            (wgpu::COPY_BUFFER_ALIGNMENT + 1).next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
        );
    }
}