use std::collections::HashMap;
use std::ops::Deref;
use std::sync::{Arc, LazyLock, Mutex, Weak};
use vyre::error::{Error, Result};
const MAX_BUFFERS_PER_CLASS: usize = 8;
#[derive(Clone, Hash, PartialEq, Eq)]
struct BufferKey {
device: wgpu::Device,
size_class: u64,
usage_bits: u32,
}
#[derive(Default)]
struct BufferPoolInner {
buffers: Mutex<HashMap<BufferKey, Vec<wgpu::Buffer>>>,
}
#[derive(Clone, Default)]
pub struct BufferPool {
inner: Arc<BufferPoolInner>,
}
pub struct PooledBuffer {
key: BufferKey,
buffer: Option<wgpu::Buffer>,
pool: Weak<BufferPoolInner>,
}
impl BufferPool {
#[must_use]
#[inline]
pub fn global() -> &'static Self {
static POOL: LazyLock<BufferPool> = LazyLock::new(BufferPool::new);
&POOL
}
#[must_use]
#[inline]
pub fn new() -> Self {
Self {
inner: Arc::new(BufferPoolInner::default()),
}
}
#[inline]
pub fn acquire(
&self,
device: &wgpu::Device,
label: &str,
size: u64,
usage: wgpu::BufferUsages,
) -> Result<PooledBuffer> {
let key = BufferKey {
device: device.clone(),
size_class: size_class(size),
usage_bits: usage.bits(),
};
let mut buffers = self.inner.buffers.lock().map_err(|source| Error::Gpu {
message: format!(
"GPU buffer pool mutex poisoned: {source}. Fix: restart the process or avoid panicking while holding the buffer pool lock."
),
})?;
let buffer = buffers
.entry(key.clone())
.or_default()
.pop()
.unwrap_or_else(|| {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: key.size_class,
usage,
mapped_at_creation: false,
})
});
Ok(PooledBuffer {
key,
buffer: Some(buffer),
pool: Arc::downgrade(&self.inner),
})
}
#[inline]
pub fn release(&self, buffer: PooledBuffer) {
drop(buffer);
}
#[inline]
pub fn with_buffer<R>(
&self,
device: &wgpu::Device,
label: &str,
size: u64,
usage: wgpu::BufferUsages,
f: impl FnOnce(&wgpu::Buffer) -> R,
) -> Result<R> {
let buffer = self.acquire(device, label, size, usage)?;
let result = f(&buffer);
self.release(buffer);
Ok(result)
}
}
impl PooledBuffer {
#[must_use]
#[inline]
pub fn size(&self) -> u64 {
self.key.size_class
}
#[must_use]
#[inline]
pub fn buffer(&self) -> &wgpu::Buffer {
self.buffer
.as_ref()
.expect("pooled buffer missing inner buffer. Fix: do not use PooledBuffer after drop.")
}
}
impl Deref for PooledBuffer {
type Target = wgpu::Buffer;
fn deref(&self) -> &Self::Target {
self.buffer()
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
let Some(buffer) = self.buffer.take() else {
return;
};
let Some(pool) = self.pool.upgrade() else {
return;
};
let Ok(mut buffers) = pool.buffers.lock() else {
return;
};
let class = buffers.entry(self.key.clone()).or_default();
if class.len() < MAX_BUFFERS_PER_CLASS {
class.push(buffer);
}
}
}
fn size_class(size: u64) -> u64 {
size.max(wgpu::COPY_BUFFER_ALIGNMENT)
.next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
}
#[cfg(test)]
mod tests {
use super::size_class;
#[test]
fn size_class_is_copy_aligned_and_nonzero() {
assert_eq!(size_class(0), wgpu::COPY_BUFFER_ALIGNMENT);
assert_eq!(size_class(1), wgpu::COPY_BUFFER_ALIGNMENT);
assert_eq!(
size_class(wgpu::COPY_BUFFER_ALIGNMENT + 1),
(wgpu::COPY_BUFFER_ALIGNMENT + 1).next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
);
}
}