use std::cell::RefCell;
use std::collections::HashMap;
pub struct MemoryPool {
buffers: HashMap<usize, Vec<Vec<f32>>>,
max_per_class: usize,
budget: usize,
current_usage: usize,
}
impl MemoryPool {
pub fn new(budget: usize) -> Self {
Self {
buffers: HashMap::new(),
max_per_class: 4,
budget,
current_usage: 0,
}
}
pub fn default_budget() -> Self {
Self::new(256 * 1024 * 1024)
}
fn round_up_pow2(n: usize) -> usize {
if n == 0 {
return 1;
}
let mut v = n - 1;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v |= v >> 32;
v + 1
}
pub fn acquire(&mut self, len: usize) -> Vec<f32> {
let capacity = Self::round_up_pow2(len);
if let Some(buffers) = self.buffers.get_mut(&capacity) {
if let Some(mut buf) = buffers.pop() {
self.current_usage -= buf.capacity() * std::mem::size_of::<f32>();
buf.clear();
buf.resize(len, 0.0);
return buf;
}
}
vec![0.0f32; len]
}
pub fn release(&mut self, buf: Vec<f32>) {
let capacity = buf.capacity();
let size = capacity * std::mem::size_of::<f32>();
if self.current_usage + size > self.budget {
return;
}
let bin = Self::round_up_pow2(capacity);
let buffers = self.buffers.entry(bin).or_default();
if buffers.len() >= self.max_per_class {
return;
}
self.current_usage += size;
buffers.push(buf);
}
pub fn clear(&mut self) {
self.buffers.clear();
self.current_usage = 0;
}
pub fn usage(&self) -> usize {
self.current_usage
}
pub fn buffer_count(&self) -> usize {
self.buffers.values().map(|v| v.len()).sum()
}
}
impl Default for MemoryPool {
fn default() -> Self {
Self::default_budget()
}
}
thread_local! {
static POOL: RefCell<MemoryPool> = RefCell::new(MemoryPool::default_budget());
}
#[allow(clippy::option_if_let_else)]
pub fn acquire_buffer(len: usize) -> Vec<f32> {
POOL.with(|pool| {
match pool.try_borrow_mut() {
Ok(mut p) => p.acquire(len),
Err(_) => vec![0.0f32; len],
}
})
}
pub fn release_buffer(buf: Vec<f32>) {
POOL.with(|pool| {
if let Ok(mut p) = pool.try_borrow_mut() {
p.release(buf);
}
});
}
pub fn clear_pool() {
POOL.with(|pool| {
if let Ok(mut p) = pool.try_borrow_mut() {
p.clear();
}
});
}
pub fn pool_usage() -> usize {
POOL.with(|pool| pool.borrow().usage())
}
pub struct PooledBuffer {
buffer: Option<Vec<f32>>,
}
impl PooledBuffer {
pub fn new(len: usize) -> Self {
Self {
buffer: Some(acquire_buffer(len)),
}
}
#[allow(clippy::expect_used)]
pub fn as_slice(&self) -> &[f32] {
self.buffer.as_ref().expect("buffer already taken")
}
#[allow(clippy::expect_used)]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
self.buffer.as_mut().expect("buffer already taken")
}
#[allow(clippy::expect_used)]
pub fn take(mut self) -> Vec<f32> {
self.buffer.take().expect("buffer already taken")
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
if let Some(buf) = self.buffer.take() {
release_buffer(buf);
}
}
}
impl std::ops::Deref for PooledBuffer {
type Target = [f32];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl std::ops::DerefMut for PooledBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_round_up_pow2() {
assert_eq!(MemoryPool::round_up_pow2(0), 1);
assert_eq!(MemoryPool::round_up_pow2(1), 1);
assert_eq!(MemoryPool::round_up_pow2(2), 2);
assert_eq!(MemoryPool::round_up_pow2(3), 4);
assert_eq!(MemoryPool::round_up_pow2(5), 8);
assert_eq!(MemoryPool::round_up_pow2(1000), 1024);
}
#[test]
fn test_pool_reuse() {
let mut pool = MemoryPool::new(1024 * 1024);
let buf = pool.acquire(100);
assert_eq!(buf.len(), 100);
let first_ptr = buf.as_ptr();
pool.release(buf);
let buf2 = pool.acquire(100);
assert_eq!(buf2.len(), 100);
assert_eq!(buf2.as_ptr(), first_ptr); }
#[test]
fn test_pooled_buffer_guard() {
clear_pool();
assert_eq!(pool_usage(), 0);
{
let mut buf = PooledBuffer::new(1000);
buf[0] = 1.0;
buf[999] = 2.0;
}
assert!(pool_usage() > 0);
clear_pool();
assert_eq!(pool_usage(), 0);
}
}