use parking_lot::Mutex;
use std::collections::VecDeque;
use std::sync::Arc;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct BufferPool {
buffers: Arc<Mutex<VecDeque<Vec<u8>>>>,
buffer_size: usize,
max_pool_size: usize,
}
impl BufferPool {
pub fn new(buffer_size: usize, max_pool_size: usize) -> Self {
Self {
buffers: Arc::new(Mutex::new(VecDeque::with_capacity(max_pool_size))),
buffer_size,
max_pool_size,
}
}
pub fn get(&self) -> PooledBuffer {
let mut buffer = {
let mut buffers = self.buffers.lock();
buffers.pop_front().unwrap_or_else(|| {
debug!("Creating new buffer of size {}", self.buffer_size);
Vec::with_capacity(self.buffer_size)
})
};
buffer.clear();
if buffer.capacity() < self.buffer_size {
buffer.reserve(self.buffer_size - buffer.capacity());
}
PooledBuffer {
buffer,
pool: Arc::downgrade(&self.buffers),
max_pool_size: self.max_pool_size,
}
}
pub fn size(&self) -> usize {
self.buffers.lock().len()
}
pub fn warm_up(&self, count: usize) {
let mut buffers = self.buffers.lock();
let current_size = buffers.len();
let to_create = (count.saturating_sub(current_size)).min(self.max_pool_size - current_size);
for _ in 0..to_create {
buffers.push_back(Vec::with_capacity(self.buffer_size));
}
debug!("Buffer pool warmed up with {} buffers", to_create);
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new(16384, 64) }
}
pub struct PooledBuffer {
buffer: Vec<u8>,
pool: std::sync::Weak<Mutex<VecDeque<Vec<u8>>>>,
max_pool_size: usize,
}
impl PooledBuffer {
pub fn as_mut_vec(&mut self) -> &mut Vec<u8> {
&mut self.buffer
}
pub fn as_vec(&self) -> &Vec<u8> {
&self.buffer
}
pub fn as_slice(&self) -> &[u8] {
&self.buffer
}
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn clear(&mut self) {
self.buffer.clear();
}
pub fn extend_from_slice(&mut self, other: &[u8]) {
self.buffer.extend_from_slice(other);
}
pub fn push(&mut self, byte: u8) {
self.buffer.push(byte);
}
pub fn reserve(&mut self, additional: usize) {
self.buffer.reserve(additional);
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
if let Some(pool) = self.pool.upgrade() {
let mut buffers = pool.lock();
if buffers.len() < self.max_pool_size && self.buffer.capacity() >= 1024 {
let mut returned_buffer = std::mem::take(&mut self.buffer);
returned_buffer.clear();
buffers.push_back(returned_buffer);
debug!("Buffer returned to pool");
}
}
}
}
impl std::ops::Deref for PooledBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.buffer
}
}
impl std::ops::DerefMut for PooledBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buffer
}
}
pub struct GlobalBufferPools {
small: BufferPool,
medium: BufferPool,
large: BufferPool,
extra_large: BufferPool,
}
impl GlobalBufferPools {
pub fn new() -> Self {
Self {
small: BufferPool::new(2048, 32), medium: BufferPool::new(16384, 64), large: BufferPool::new(131072, 16), extra_large: BufferPool::new(1048576, 8), }
}
pub fn get_buffer(&self, expected_size: usize) -> PooledBuffer {
if expected_size <= 2048 {
self.small.get()
} else if expected_size <= 16384 {
self.medium.get()
} else if expected_size <= 131072 {
self.large.get()
} else {
self.extra_large.get()
}
}
pub fn get_small(&self) -> PooledBuffer {
self.small.get()
}
pub fn get_medium(&self) -> PooledBuffer {
self.medium.get()
}
pub fn get_large(&self) -> PooledBuffer {
self.large.get()
}
pub fn get_extra_large(&self) -> PooledBuffer {
self.extra_large.get()
}
pub fn warm_up(&self) {
self.small.warm_up(16); self.medium.warm_up(32);
self.large.warm_up(8);
self.extra_large.warm_up(4);
}
pub fn stats(&self) -> BufferPoolStats {
BufferPoolStats {
small_pool_size: self.small.size(),
medium_pool_size: self.medium.size(),
large_pool_size: self.large.size(),
extra_large_pool_size: self.extra_large.size(),
}
}
}
impl Default for GlobalBufferPools {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BufferPoolStats {
pub small_pool_size: usize,
pub medium_pool_size: usize,
pub large_pool_size: usize,
pub extra_large_pool_size: usize,
}
use std::sync::OnceLock;
static GLOBAL_POOLS: OnceLock<GlobalBufferPools> = OnceLock::new();
pub fn global_pools() -> &'static GlobalBufferPools {
GLOBAL_POOLS.get_or_init(|| {
let pools = GlobalBufferPools::new();
pools.warm_up();
pools
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_pool_basic() {
let pool = BufferPool::new(1024, 4);
{
let mut buf1 = pool.get();
buf1.extend_from_slice(b"hello");
assert_eq!(buf1.len(), 5);
assert!(buf1.capacity() >= 1024);
}
assert_eq!(pool.size(), 1);
let mut buf2 = pool.get();
assert_eq!(buf2.len(), 0); buf2.extend_from_slice(b"world");
assert_eq!(buf2.as_slice(), b"world");
}
#[test]
fn test_buffer_pool_max_size() {
let pool = BufferPool::new(1024, 2);
let _buf1 = pool.get();
let _buf2 = pool.get();
let _buf3 = pool.get();
drop(_buf1);
drop(_buf2);
drop(_buf3);
assert_eq!(pool.size(), 2);
}
#[test]
fn test_global_pools() {
let pools = global_pools();
let mut small_buf = pools.get_small();
small_buf.extend_from_slice(b"small");
let mut medium_buf = pools.get_medium();
medium_buf.extend_from_slice(b"medium data");
let mut large_buf = pools.get_large();
large_buf.extend_from_slice(b"large data payload");
assert_eq!(small_buf.as_slice(), b"small");
assert_eq!(medium_buf.as_slice(), b"medium data");
assert_eq!(large_buf.as_slice(), b"large data payload");
assert!(small_buf.capacity() >= 1024);
assert!(medium_buf.capacity() >= 8192);
assert!(large_buf.capacity() >= 65536);
}
#[test]
fn test_buffer_selection() {
let pools = GlobalBufferPools::new();
let buf1 = pools.get_buffer(512); let buf2 = pools.get_buffer(4096); let buf3 = pools.get_buffer(32768);
assert!(buf1.capacity() >= 1024);
assert!(buf2.capacity() >= 8192);
assert!(buf3.capacity() >= 65536);
}
}