use std::sync::{Arc, Mutex, OnceLock};
const SMALL_BUFFER_SIZE: usize = 1024;
const MEDIUM_BUFFER_SIZE: usize = 8192;
const LARGE_BUFFER_SIZE: usize = 65536;
const MAX_POOL_SIZE: usize = 16;
pub struct PooledBuffer {
buffer: Vec<u8>,
pool: Arc<Mutex<Vec<Vec<u8>>>>,
}
impl PooledBuffer {
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.buffer
}
pub fn as_slice(&self) -> &[u8] {
&self.buffer
}
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn clear(&mut self) {
self.buffer.clear();
}
pub fn resize(&mut self, new_len: usize, value: u8) {
self.buffer.resize(new_len, value);
}
pub fn as_mut_vec(&mut self) -> &mut Vec<u8> {
&mut self.buffer
}
pub fn as_vec(&self) -> &Vec<u8> {
&self.buffer
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
self.buffer.clear();
if let Ok(mut pool) = self.pool.lock() {
if pool.len() < MAX_POOL_SIZE {
pool.push(std::mem::take(&mut self.buffer));
}
}
}
}
pub struct BufferPool {
small_buffers: Arc<Mutex<Vec<Vec<u8>>>>,
medium_buffers: Arc<Mutex<Vec<Vec<u8>>>>,
large_buffers: Arc<Mutex<Vec<Vec<u8>>>>,
}
impl BufferPool {
pub fn new() -> Self {
Self {
small_buffers: Arc::new(Mutex::new(Vec::new())),
medium_buffers: Arc::new(Mutex::new(Vec::new())),
large_buffers: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn get_small_buffer(&self) -> PooledBuffer {
self.get_buffer_from_pool(&self.small_buffers, SMALL_BUFFER_SIZE)
}
pub fn get_medium_buffer(&self) -> PooledBuffer {
self.get_buffer_from_pool(&self.medium_buffers, MEDIUM_BUFFER_SIZE)
}
pub fn get_large_buffer(&self) -> PooledBuffer {
self.get_buffer_from_pool(&self.large_buffers, LARGE_BUFFER_SIZE)
}
pub fn get_buffer_with_capacity(&self, capacity: usize) -> PooledBuffer {
if capacity <= SMALL_BUFFER_SIZE {
self.get_small_buffer()
} else if capacity <= MEDIUM_BUFFER_SIZE {
self.get_medium_buffer()
} else {
self.get_large_buffer()
}
}
fn get_buffer_from_pool(
&self,
pool: &Arc<Mutex<Vec<Vec<u8>>>>,
default_capacity: usize,
) -> PooledBuffer {
let buffer = if let Ok(mut pool_guard) = pool.lock() {
pool_guard
.pop()
.unwrap_or_else(|| Vec::with_capacity(default_capacity))
} else {
Vec::with_capacity(default_capacity)
};
PooledBuffer {
buffer,
pool: Arc::clone(pool),
}
}
pub fn stats(&self) -> BufferPoolStats {
let small_count = self.small_buffers.lock().map(|p| p.len()).unwrap_or(0);
let medium_count = self.medium_buffers.lock().map(|p| p.len()).unwrap_or(0);
let large_count = self.large_buffers.lock().map(|p| p.len()).unwrap_or(0);
BufferPoolStats {
small_buffers_pooled: small_count,
medium_buffers_pooled: medium_count,
large_buffers_pooled: large_count,
}
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BufferPoolStats {
pub small_buffers_pooled: usize,
pub medium_buffers_pooled: usize,
pub large_buffers_pooled: usize,
}
static GLOBAL_BUFFER_POOL: OnceLock<BufferPool> = OnceLock::new();
pub fn global_buffer_pool() -> &'static BufferPool {
GLOBAL_BUFFER_POOL.get_or_init(BufferPool::new)
}
pub mod global {
use super::*;
pub fn get_small_buffer() -> PooledBuffer {
global_buffer_pool().get_small_buffer()
}
pub fn get_medium_buffer() -> PooledBuffer {
global_buffer_pool().get_medium_buffer()
}
pub fn get_large_buffer() -> PooledBuffer {
global_buffer_pool().get_large_buffer()
}
pub fn get_buffer_with_capacity(capacity: usize) -> PooledBuffer {
global_buffer_pool().get_buffer_with_capacity(capacity)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_pool_basic() {
let pool = BufferPool::new();
{
let mut buffer = pool.get_small_buffer();
buffer.as_mut_vec().extend_from_slice(b"hello");
assert_eq!(buffer.len(), 5);
assert_eq!(buffer.as_slice(), b"hello");
}
let buffer2 = pool.get_small_buffer();
assert_eq!(buffer2.len(), 0); assert!(buffer2.capacity() >= SMALL_BUFFER_SIZE);
}
#[test]
fn test_buffer_pool_stats() {
let pool = BufferPool::new();
let stats = pool.stats();
assert_eq!(stats.small_buffers_pooled, 0);
{
let _buffer = pool.get_small_buffer();
}
let stats = pool.stats();
assert_eq!(stats.small_buffers_pooled, 1);
}
#[test]
fn test_global_buffer_pool() {
let buffer1 = global::get_small_buffer();
let buffer2 = global::get_medium_buffer();
let buffer3 = global::get_large_buffer();
assert!(buffer1.capacity() >= SMALL_BUFFER_SIZE);
assert!(buffer2.capacity() >= MEDIUM_BUFFER_SIZE);
assert!(buffer3.capacity() >= LARGE_BUFFER_SIZE);
}
}