use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct PooledBuffer {
data: Vec<u8>,
capacity: usize,
}
impl PooledBuffer {
pub fn new(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
capacity,
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn resize(&mut self, new_size: usize, value: u8) {
self.data.resize(new_size, value);
}
}
pub struct MemoryPool {
pools: Arc<Mutex<HashMap<usize, Vec<PooledBuffer>>>>,
stats: Arc<Mutex<PoolStats>>,
config: PoolConfig,
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_buffers_per_size: usize,
pub max_total_buffers: usize,
pub round_sizes: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_buffers_per_size: 16,
max_total_buffers: 256,
round_sizes: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub total_gets: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub total_returns: u64,
pub buffers_in_pool: usize,
pub bytes_in_pool: usize,
}
impl PoolStats {
pub fn hit_rate(&self) -> f64 {
if self.total_gets == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_gets as f64
}
}
pub fn miss_rate(&self) -> f64 {
1.0 - self.hit_rate()
}
}
impl MemoryPool {
pub fn new() -> Self {
Self::with_config(PoolConfig::default())
}
pub fn with_config(config: PoolConfig) -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(PoolStats::default())),
config,
}
}
pub fn get(&self, size: usize) -> PooledBuffer {
let size_class = self.size_class(size);
let mut pools = self.pools.lock().unwrap();
let mut stats = self.stats.lock().unwrap();
stats.total_gets += 1;
if let Some(pool) = pools.get_mut(&size_class) {
if let Some(buffer) = pool.pop() {
stats.cache_hits += 1;
stats.buffers_in_pool -= 1;
stats.bytes_in_pool -= buffer.capacity();
return buffer;
}
}
stats.cache_misses += 1;
PooledBuffer::new(size_class)
}
pub fn return_buffer(&self, buffer: PooledBuffer) {
let size_class = buffer.capacity();
let mut pools = self.pools.lock().unwrap();
let mut stats = self.stats.lock().unwrap();
stats.total_returns += 1;
let pool = pools.entry(size_class).or_insert_with(Vec::new);
if pool.len() < self.config.max_buffers_per_size
&& stats.buffers_in_pool < self.config.max_total_buffers
{
stats.buffers_in_pool += 1;
stats.bytes_in_pool += buffer.capacity();
pool.push(buffer);
}
}
pub fn stats(&self) -> PoolStats {
self.stats.lock().unwrap().clone()
}
pub fn clear(&self) {
let mut pools = self.pools.lock().unwrap();
let mut stats = self.stats.lock().unwrap();
pools.clear();
stats.buffers_in_pool = 0;
stats.bytes_in_pool = 0;
}
fn size_class(&self, size: usize) -> usize {
if self.config.round_sizes {
size.next_power_of_two()
} else {
size
}
}
}
impl Default for MemoryPool {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_POOL: once_cell::sync::Lazy<MemoryPool> =
once_cell::sync::Lazy::new(|| MemoryPool::new());
pub fn global_pool() -> &'static MemoryPool {
&GLOBAL_POOL
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_creation() {
let pool = MemoryPool::new();
let stats = pool.stats();
assert_eq!(stats.total_gets, 0);
assert_eq!(stats.cache_hits, 0);
}
#[test]
fn test_get_and_return() {
let pool = MemoryPool::new();
let buf = pool.get(1024);
assert_eq!(buf.capacity(), 1024);
let stats = pool.stats();
assert_eq!(stats.total_gets, 1);
assert_eq!(stats.cache_misses, 1);
assert_eq!(stats.cache_hits, 0);
pool.return_buffer(buf);
let stats = pool.stats();
assert_eq!(stats.total_returns, 1);
assert_eq!(stats.buffers_in_pool, 1);
let buf2 = pool.get(1024);
assert_eq!(buf2.capacity(), 1024);
let stats = pool.stats();
assert_eq!(stats.total_gets, 2);
assert_eq!(stats.cache_hits, 1);
assert_eq!(stats.hit_rate(), 0.5);
}
#[test]
fn test_size_rounding() {
let pool = MemoryPool::new();
let buf = pool.get(1000);
assert_eq!(buf.capacity(), 1024);
}
#[test]
fn test_pool_limit() {
let config = PoolConfig {
max_buffers_per_size: 2,
..Default::default()
};
let pool = MemoryPool::with_config(config);
pool.return_buffer(PooledBuffer::new(1024));
pool.return_buffer(PooledBuffer::new(1024));
pool.return_buffer(PooledBuffer::new(1024));
let stats = pool.stats();
assert_eq!(stats.buffers_in_pool, 2);
}
#[test]
fn test_multiple_sizes() {
let pool = MemoryPool::new();
let buf1 = pool.get(1024);
let buf2 = pool.get(2048);
let buf3 = pool.get(4096);
pool.return_buffer(buf1);
pool.return_buffer(buf2);
pool.return_buffer(buf3);
let stats = pool.stats();
assert_eq!(stats.buffers_in_pool, 3);
assert_eq!(stats.bytes_in_pool, 1024 + 2048 + 4096);
}
#[test]
fn test_clear() {
let pool = MemoryPool::new();
pool.return_buffer(PooledBuffer::new(1024));
pool.return_buffer(PooledBuffer::new(2048));
assert_eq!(pool.stats().buffers_in_pool, 2);
pool.clear();
assert_eq!(pool.stats().buffers_in_pool, 0);
assert_eq!(pool.stats().bytes_in_pool, 0);
}
#[test]
fn test_global_pool() {
let pool1 = global_pool();
let pool2 = global_pool();
assert!(std::ptr::eq(pool1, pool2));
}
}