use parking_lot::Mutex;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use crate::EvictionPolicy;
use crate::config::PoolConfig;
use crate::tls::{SHARD_AFFINITY, TLS_CACHE, TLS_LIMIT};
use crate::utils::pin_buffer;
#[derive(Debug)]
pub(crate) struct BufferEntry {
pub(crate) buffer: Vec<u8>,
pub(crate) access_count: u8,
}
impl BufferEntry {
#[inline]
fn new(buffer: Vec<u8>) -> Self {
Self {
buffer,
access_count: 1, }
}
#[inline]
fn mark_accessed(&mut self) {
self.access_count = self.access_count.saturating_add(1);
}
#[inline]
fn decay(&mut self) {
self.access_count = self.access_count.saturating_sub(1);
}
#[inline]
fn capacity(&self) -> usize {
self.buffer.capacity()
}
}
#[derive(Debug)]
pub(crate) struct Shard {
pub(crate) buffers: Mutex<Vec<BufferEntry>>,
pub(crate) count: AtomicUsize,
}
impl Shard {
fn new() -> Self {
Self {
buffers: Mutex::new(Vec::new()),
count: AtomicUsize::new(0),
}
}
}
#[derive(Clone, Debug)]
pub struct BufferPool {
pub(crate) shards: Arc<Vec<Shard>>, pub(crate) config: PoolConfig,
shard_mask: usize,
}
impl BufferPool {
#[inline]
pub fn new() -> Self {
crate::Builder::default().build()
}
#[inline]
pub fn builder() -> crate::Builder {
crate::Builder::default()
}
pub(crate) fn with_config(config: PoolConfig) -> Self {
let shards: Vec<Shard> = (0..config.num_shards).map(|_| Shard::new()).collect();
TLS_LIMIT.with(|limit| {
limit.set(config.tls_cache_size);
});
Self {
shards: Arc::new(shards),
shard_mask: config.num_shards - 1,
config,
}
}
#[inline(always)]
fn get_shard_index(&self) -> usize {
SHARD_AFFINITY.with(|affinity| {
if let Some(idx) = affinity.get() {
if idx < self.shards.len() {
return idx;
}
}
let mut hasher = std::collections::hash_map::DefaultHasher::new();
thread::current().id().hash(&mut hasher);
let idx = (hasher.finish() as usize) & self.shard_mask;
affinity.set(Some(idx));
idx
})
}
#[inline]
#[must_use]
pub fn get(&self, size: usize) -> crate::buffer::PooledBuffer {
let tls_hit = TLS_CACHE.with(|tls| {
let mut cache = tls.borrow_mut();
if let Some(entry) = cache.buffers.last()
&& entry.capacity() >= size
{
let mut entry = cache.buffers.pop().unwrap();
entry.mark_accessed(); let mut buf = entry.buffer;
unsafe {
buf.set_len(size);
}
return Some(buf);
}
if let Some(idx) = cache.buffers.iter().position(|e| e.capacity() >= size) {
let mut entry = cache.buffers.swap_remove(idx);
entry.mark_accessed(); let mut buf = entry.buffer;
unsafe {
buf.set_len(size);
}
return Some(buf);
}
None
});
if let Some(buf) = tls_hit {
return crate::buffer::PooledBuffer::new(buf, self.clone());
}
let shard_idx = self.get_shard_index();
debug_assert!(shard_idx < self.shards.len(), "Shard index out of bounds");
let shard = &self.shards[shard_idx];
let mut buffers = shard.buffers.lock();
if let Some(entry) = buffers.last()
&& entry.capacity() >= size
{
let mut entry = buffers.pop().unwrap();
shard.count.fetch_sub(1, Ordering::Relaxed);
entry.mark_accessed(); let mut buffer = entry.buffer;
unsafe {
buffer.set_len(size);
}
return crate::buffer::PooledBuffer::new(buffer, self.clone());
}
let vec = if let Some(idx) = buffers.iter().position(|e| e.capacity() >= size) {
let mut entry = buffers.swap_remove(idx);
shard.count.fetch_sub(1, Ordering::Relaxed);
entry.mark_accessed(); let mut buffer = entry.buffer;
unsafe {
buffer.set_len(size);
}
buffer
} else {
drop(buffers);
let mut v = Vec::with_capacity(size);
#[allow(clippy::uninit_vec)]
unsafe {
v.set_len(size);
}
v
};
crate::buffer::PooledBuffer::new(vec, self.clone())
}
#[inline]
pub(crate) fn put(&self, mut buffer: Vec<u8>) {
buffer.clear();
let limit = TLS_LIMIT.with(|limit| {
let current = limit.get();
if current == 0 {
limit.set(self.config.tls_cache_size);
self.config.tls_cache_size
} else {
current
}
});
let buffer = TLS_CACHE.with(|tls| {
let mut cache = tls.borrow_mut();
if cache.buffers.len() < limit {
cache.buffers.push(BufferEntry::new(buffer)); None } else {
Some(buffer) }
});
let Some(buffer) = buffer else {
return;
};
if buffer.capacity() < self.config.min_buffer_size {
return; }
if self.config.pinned_memory {
pin_buffer(&buffer);
}
let shard_idx = self.get_shard_index();
debug_assert!(shard_idx < self.shards.len(), "Shard index out of bounds in put");
let shard = &self.shards[shard_idx];
let mut buffers = shard.buffers.lock();
if buffers.len() < self.config.max_buffers_per_shard {
buffers.push(BufferEntry::new(buffer)); shard.count.fetch_add(1, Ordering::Relaxed);
} else {
if self.config.eviction_policy == EvictionPolicy::ClockPro {
Self::evict_and_insert(&mut buffers, buffer);
}
}
}
fn evict_and_insert(shard: &mut [BufferEntry], new_buffer: Vec<u8>) {
if shard.is_empty() {
return;
}
let mut min_idx = 0;
let mut min_count = u8::MAX;
for (idx, entry) in shard.iter_mut().enumerate() {
entry.decay(); if entry.access_count < min_count {
min_count = entry.access_count;
min_idx = idx;
}
}
shard[min_idx] = BufferEntry::new(new_buffer);
}
pub fn preallocate(&self, count: usize, size: usize) {
let per_shard = count.div_ceil(self.config.num_shards);
let total_to_allocate = per_shard * self.config.num_shards;
let mut all_buffers: Vec<BufferEntry> = Vec::with_capacity(total_to_allocate);
for _ in 0..total_to_allocate {
let mut buf = Vec::with_capacity(size.max(self.config.min_buffer_size));
if self.config.pinned_memory {
unsafe {
buf.set_len(buf.capacity());
}
pin_buffer(&buf);
buf.clear();
}
all_buffers.push(BufferEntry::new(buf)); }
for shard_idx in (0..self.config.num_shards).rev() {
debug_assert!(
shard_idx < self.shards.len(),
"Shard index out of bounds in preallocate"
);
let shard = &self.shards[shard_idx];
let mut buffers = shard.buffers.lock();
buffers.reserve(per_shard);
let start = shard_idx * per_shard;
let end = start + per_shard;
for entry in all_buffers.drain(start..end) {
buffers.push(entry);
}
shard.count.fetch_add(per_shard, Ordering::Relaxed);
}
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.count.load(Ordering::Relaxed)).sum()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.count.load(Ordering::Relaxed) == 0)
}
pub fn clear(&self) {
for shard in self.shards.iter() {
let mut buffers = shard.buffers.lock();
let count = buffers.len();
buffers.clear();
shard.count.fetch_sub(count, Ordering::Relaxed);
}
}
pub fn shrink_to_fit(&self) {
for shard in self.shards.iter() {
shard.buffers.lock().shrink_to_fit();
}
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}