use std::collections::HashMap;
use parking_lot::Mutex;
use sapient_core::buffer::BufferHandle;
use sapient_core::DType;
struct PoolEntry {
handle: BufferHandle,
last_used: std::time::Instant,
capacity: usize,
}
pub struct PoolAllocator {
inner: Mutex<PoolInner>,
}
struct PoolInner {
free: HashMap<usize, Vec<PoolEntry>>,
used_bytes: usize,
capacity: usize,
}
impl PoolAllocator {
pub fn new(capacity_bytes: usize) -> Self {
Self {
inner: Mutex::new(PoolInner {
free: HashMap::new(),
used_bytes: 0,
capacity: capacity_bytes,
}),
}
}
pub fn acquire(&self, numel: usize, dtype: DType) -> Option<BufferHandle> {
let byte_size = dtype.byte_count(numel);
let mut inner = self.inner.lock();
if let Some(entries) = inner.free.get_mut(&byte_size) {
if let Some(entry) = entries.pop() {
inner.used_bytes = inner.used_bytes.saturating_sub(entry.capacity);
return Some(entry.handle);
}
}
None
}
pub fn release(&self, handle: BufferHandle, numel: usize, dtype: DType) {
let byte_size = dtype.byte_count(numel);
let mut inner = self.inner.lock();
while inner.used_bytes + byte_size > inner.capacity {
if !Self::evict_lru(&mut inner) {
break;
}
}
if inner.used_bytes + byte_size <= inner.capacity {
inner.used_bytes += byte_size;
inner.free.entry(byte_size).or_default().push(PoolEntry {
handle,
last_used: std::time::Instant::now(),
capacity: byte_size,
});
}
}
fn evict_lru(inner: &mut PoolInner) -> bool {
let mut oldest_key: Option<usize> = None;
let mut oldest_time = std::time::Instant::now();
for (&key, entries) in &inner.free {
for entry in entries {
if entry.last_used < oldest_time {
oldest_time = entry.last_used;
oldest_key = Some(key);
}
}
}
if let Some(key) = oldest_key {
if let Some(entries) = inner.free.get_mut(&key) {
if let Some(entry) = entries.pop() {
inner.used_bytes = inner.used_bytes.saturating_sub(entry.capacity);
return true;
}
}
}
false
}
pub fn used_bytes(&self) -> usize {
self.inner.lock().used_bytes
}
pub fn capacity(&self) -> usize {
self.inner.lock().capacity
}
}
impl std::fmt::Debug for PoolAllocator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.lock();
f.debug_struct("PoolAllocator")
.field("used_bytes", &inner.used_bytes)
.field("capacity", &inner.capacity)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use sapient_core::buffer::CpuBuffer;
#[test]
fn acquire_release_cycle() {
let pool = PoolAllocator::new(1024 * 1024);
assert!(pool.acquire(16, DType::F32).is_none());
let buf = BufferHandle::new(CpuBuffer::zeros(16, DType::F32).unwrap());
pool.release(buf, 16, DType::F32);
assert_eq!(pool.used_bytes(), 64);
let h = pool.acquire(16, DType::F32);
assert!(h.is_some());
assert_eq!(pool.used_bytes(), 0);
}
}