use std::mem;
use std::{cell::RefCell, rc::Rc};
use crate::DType;
const MAX_BUFFERS_SIZE: usize = 4 * 1024 * 1024 * 1024;
const TRIM_THRESHOLD: usize = MAX_BUFFERS_SIZE / 2;
#[derive(Debug, Clone)]
pub struct PoolMetrics {
pub current_size: usize,
pub hits: usize,
pub misses: usize,
pub drops: usize,
}
#[derive(Debug)]
pub struct BufferPool<T> {
pool: Vec<Vec<T>>,
pub metrics: PoolMetrics,
}
pub type SharedPool<T> = Rc<RefCell<BufferPool<T>>>;
#[derive(Debug)]
pub struct PooledBuffer<T: DType> {
buf: Vec<T>,
pool: Option<SharedPool<T>>,
}
impl<T: DType> PooledBuffer<T> {
pub fn new(buf: Vec<T>, pool: SharedPool<T>) -> Self {
PooledBuffer {
buf,
pool: Some(pool),
}
}
pub fn into_inner(mut self) -> Vec<T> {
let buf = std::mem::take(&mut self.buf);
self.pool = None;
buf
}
}
impl<T: DType> std::ops::Deref for PooledBuffer<T> {
type Target = Vec<T>;
fn deref(&self) -> &Vec<T> {
&self.buf
}
}
impl<T: DType> std::ops::DerefMut for PooledBuffer<T> {
fn deref_mut(&mut self) -> &mut Vec<T> {
&mut self.buf
}
}
impl<T: DType> Drop for PooledBuffer<T> {
fn drop(&mut self) {
if let Some(pool) = self.pool.take() {
let buf = std::mem::take(&mut self.buf);
pool.borrow_mut().recycle_buffer(buf);
}
}
}
impl<T: DType> BufferPool<T> {
pub fn new() -> Self {
BufferPool {
pool: Vec::new(),
metrics: PoolMetrics {
current_size: 0,
hits: 0,
misses: 0,
drops: 0,
},
}
}
pub fn get_empty_buffer(&mut self, capacity: usize) -> Vec<T> {
let mut smallest_found_buf = None;
for i in 0..self.pool.len() {
let this_buf_len = self.pool[i].len();
let found_smaller_buf = smallest_found_buf
.is_some_and(|found: usize| self.pool[found].len() < this_buf_len);
let found_useable_buf = smallest_found_buf.is_none() && this_buf_len >= capacity;
if found_smaller_buf || found_useable_buf {
smallest_found_buf = Some(i);
}
}
if let Some(idx) = smallest_found_buf {
self.metrics.hits += 1;
let mut buf = self.pool.swap_remove(idx);
let buf_capacity = buf.capacity();
self.metrics.current_size = self
.metrics
.current_size
.saturating_sub(buf_capacity * mem::size_of::<T>());
buf.clear();
buf.reserve(capacity);
debug_assert_eq!(
self.metrics.current_size,
self.pool
.iter()
.map(|b| b.capacity() * size_of::<T>())
.sum()
);
buf
} else {
self.metrics.misses += 1;
Vec::with_capacity(capacity)
}
}
pub fn get_buffer(&mut self, capacity: usize) -> Vec<T> {
let mut buf = self.get_empty_buffer(capacity);
unsafe {
buf.set_len(capacity);
}
buf
}
pub fn recycle_buffer(&mut self, buf: Vec<T>) {
let buffer_bytes = buf.capacity() * mem::size_of::<T>();
if self.metrics.current_size + buffer_bytes <= MAX_BUFFERS_SIZE {
self.metrics.current_size += buffer_bytes;
self.pool.push(buf);
debug_assert_eq!(
self.metrics.current_size,
self.pool
.iter()
.map(|b| b.capacity() * size_of::<T>())
.sum()
);
self.trim_excess();
} else {
self.metrics.drops += 1;
}
}
fn trim_excess(&mut self) {
while self.metrics.current_size > TRIM_THRESHOLD {
let mut max_idx = 0;
let mut max_bytes = 0;
for (i, buf) in self.pool.iter().enumerate() {
let bbytes = buf.capacity() * mem::size_of::<T>();
if bbytes > max_bytes {
max_bytes = bbytes;
max_idx = i;
}
}
self.pool.swap_remove(max_idx);
self.metrics.current_size = self.metrics.current_size.saturating_sub(max_bytes);
}
debug_assert_eq!(
self.metrics.current_size,
self.pool
.iter()
.map(|b| b.capacity() * size_of::<T>())
.sum()
);
}
#[allow(unused)]
pub fn metrics(&self) -> PoolMetrics {
self.metrics.clone()
}
}