use std::sync::Mutex;
use crate::error::Result;
use crate::metal::{GpuBuffer, MetalContext, PAGE_SIZE};
#[derive(Debug, Default)]
pub struct ScratchPool {
free: Mutex<Vec<GpuBuffer>>,
}
impl ScratchPool {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn checkout(&self, ctx: &MetalContext, bytes: usize) -> Result<GpuBuffer> {
let capacity = bytes.next_multiple_of(PAGE_SIZE).max(PAGE_SIZE);
{
let mut free = self.free.lock().expect("pool lock never poisoned");
let mut best: Option<(usize, usize)> = None; for (i, buf) in free.iter().enumerate() {
let cap = buf.capacity();
if cap >= capacity && best.is_none_or(|(_, c)| cap < c) {
best = Some((i, cap));
}
}
if let Some((i, _)) = best {
let mut buf = free.swap_remove(i);
buf.set_len(bytes);
return Ok(buf);
}
}
let mut buf = GpuBuffer::alloc(ctx, capacity)?;
buf.set_len(bytes);
Ok(buf)
}
pub fn put_back(&self, buf: GpuBuffer) {
self.free
.lock()
.expect("pool lock never poisoned")
.push(buf);
}
#[must_use]
pub fn free_len(&self) -> usize {
self.free.lock().expect("pool lock never poisoned").len()
}
pub fn poison_free_buffers(&self, byte: u8) {
let mut free = self.free.lock().expect("pool lock never poisoned");
for buf in free.iter_mut() {
let len = buf.len();
buf.set_len(buf.capacity());
buf.contents_mut().fill(byte);
buf.set_len(len);
}
}
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum Alloc<'a> {
Direct,
Pool(&'a ScratchPool),
}
impl Alloc<'_> {
pub(crate) fn buffer(&self, ctx: &MetalContext, bytes: usize) -> Result<GpuBuffer> {
match self {
Alloc::Direct => GpuBuffer::alloc(ctx, bytes),
Alloc::Pool(pool) => pool.checkout(ctx, bytes),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ctx_or_skip(test: &str) -> Option<MetalContext> {
match MetalContext::new() {
Ok(ctx) => Some(ctx),
Err(err) => {
if std::env::var_os("METAL_JSON_REQUIRE_GPU").is_some_and(|v| v == "1") {
panic!("METAL_JSON_REQUIRE_GPU=1 but no usable Metal device: {err}");
}
eprintln!("SKIP {test}: no usable Metal device here ({err})");
None
}
}
}
#[test]
fn checkout_reuses_returned_capacity_best_fit() {
let Some(ctx) = ctx_or_skip("checkout_reuses_returned_capacity_best_fit") else {
return;
};
let pool = ScratchPool::new();
let small = pool.checkout(&ctx, 100).unwrap();
let big = pool.checkout(&ctx, PAGE_SIZE + 1).unwrap();
assert_eq!(small.len(), 100);
assert_eq!(small.capacity(), PAGE_SIZE);
assert_eq!(big.capacity(), 2 * PAGE_SIZE);
pool.put_back(small);
pool.put_back(big);
assert_eq!(pool.free_len(), 2);
let again = pool.checkout(&ctx, 50).unwrap();
assert_eq!(again.len(), 50);
assert_eq!(again.capacity(), PAGE_SIZE);
assert_eq!(pool.free_len(), 1);
let two = pool.checkout(&ctx, PAGE_SIZE + 5).unwrap();
assert_eq!(two.capacity(), 2 * PAGE_SIZE);
assert_eq!(pool.free_len(), 0);
}
#[test]
fn poison_survives_checkout() {
let Some(ctx) = ctx_or_skip("poison_survives_checkout") else {
return;
};
let pool = ScratchPool::new();
let buf = pool.checkout(&ctx, 64).unwrap();
pool.put_back(buf);
pool.poison_free_buffers(0xDB);
let buf = pool.checkout(&ctx, 32).unwrap();
assert!(buf.contents().iter().all(|&b| b == 0xDB));
}
#[test]
fn pool_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ScratchPool>();
assert_send_sync::<std::sync::Arc<ScratchPool>>();
}
}