use std::collections::HashMap;
use std::sync::Arc;
use cudarc::driver::{CudaSlice, CudaStream};
use parking_lot::Mutex;
use crate::error::GpuError;
pub const DEFAULT_POOL_CAPACITY_PER_CLASS: usize = 4;
pub fn size_class(bytes: usize) -> usize {
bytes.max(1024).next_power_of_two()
}
struct WorkspacePoolInner {
free: HashMap<usize, Vec<Arc<CudaSlice<u8>>>>,
per_class_capacity: usize,
bytes_pooled: usize,
}
#[derive(Clone)]
pub struct WorkspacePool {
inner: Arc<Mutex<WorkspacePoolInner>>,
}
impl WorkspacePool {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_POOL_CAPACITY_PER_CLASS)
}
pub fn with_capacity(per_class: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(WorkspacePoolInner {
free: HashMap::new(),
per_class_capacity: per_class.max(1),
bytes_pooled: 0,
})),
}
}
pub fn acquire(
&self,
stream: &Arc<CudaStream>,
requested_bytes: usize,
) -> Result<WorkspaceLease, GpuError> {
let class = size_class(requested_bytes.max(1));
let pooled = {
let mut g = self.inner.lock();
if let Some(bucket) = g.free.get_mut(&class) {
let popped = bucket.pop();
if let Some(ref s) = popped {
g.bytes_pooled = g.bytes_pooled.saturating_sub(s.len());
}
popped
} else {
None
}
};
let slab = match pooled {
Some(s) => s,
None => {
let s = unsafe { stream.alloc::<u8>(class) }.map_err(|e| {
GpuError::OutOfMemory(format!("cublaslt workspace alloc {class}B: {e}"))
})?;
Arc::new(s)
}
};
Ok(WorkspaceLease {
slab: Some(slab),
class,
pool: self.inner.clone(),
})
}
pub fn pooled_slabs(&self) -> usize {
let g = self.inner.lock();
g.free.values().map(|v| v.len()).sum()
}
pub fn pooled_bytes(&self) -> usize {
self.inner.lock().bytes_pooled
}
}
impl Default for WorkspacePool {
fn default() -> Self {
Self::new()
}
}
pub struct WorkspaceLease {
slab: Option<Arc<CudaSlice<u8>>>,
class: usize,
pool: Arc<Mutex<WorkspacePoolInner>>,
}
impl WorkspaceLease {
pub fn slice(&self) -> &Arc<CudaSlice<u8>> {
self.slab
.as_ref()
.expect("WorkspaceLease::slice after Drop")
}
pub fn size(&self) -> usize {
self.class
}
}
impl Drop for WorkspaceLease {
fn drop(&mut self) {
let Some(slab) = self.slab.take() else {
return;
};
let mut g = self.pool.lock();
let cap = g.per_class_capacity;
let bucket = g.free.entry(self.class).or_default();
if bucket.len() < cap {
let bytes = slab.len();
bucket.push(slab);
g.bytes_pooled = g.bytes_pooled.saturating_add(bytes);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn seed_free_slot(pool: &WorkspacePool, class: usize, slab_bytes: usize) {
let mut g = pool.inner.lock();
g.bytes_pooled = g.bytes_pooled.saturating_add(slab_bytes);
g.free.entry(class).or_default();
}
#[test]
fn size_class_rounds_up() {
assert_eq!(size_class(1), 1024);
assert_eq!(size_class(1024), 1024);
assert_eq!(size_class(1025), 2048);
assert_eq!(size_class(4 * 1024 * 1024), 4 * 1024 * 1024);
assert_eq!(size_class(4 * 1024 * 1024 + 1), 8 * 1024 * 1024);
assert_eq!(size_class(0), 1024);
}
#[test]
fn workspace_pool_recycles() {
let pool = WorkspacePool::with_capacity(2);
assert_eq!(pool.pooled_slabs(), 0);
seed_free_slot(&pool, size_class(4 * 1024 * 1024), 4 * 1024 * 1024);
assert_eq!(pool.pooled_bytes(), 4 * 1024 * 1024);
seed_free_slot(&pool, size_class(33_554_432), 33_554_432);
assert_eq!(pool.pooled_bytes(), 4 * 1024 * 1024 + 33_554_432);
assert!(pool
.inner
.lock()
.free
.contains_key(&size_class(4 * 1024 * 1024)));
assert!(pool.inner.lock().free.contains_key(&size_class(33_554_432)));
}
#[test]
fn pool_capacity_clamps_to_one() {
let pool = WorkspacePool::with_capacity(0);
assert_eq!(pool.inner.lock().per_class_capacity, 1);
}
}