use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};
pub const MAX_DATAGRAM: usize = 2048;
pub trait BufferPool: Send + Sync + fmt::Debug {
fn slab_len(&self) -> usize;
fn take(&self) -> Box<[u8]>;
fn give(&self, buf: Box<[u8]>);
}
#[derive(Debug, Clone, Copy)]
pub struct NoPool {
slab_len: usize,
}
impl NoPool {
#[must_use]
pub const fn new(slab_len: usize) -> Self {
Self { slab_len }
}
}
impl Default for NoPool {
fn default() -> Self {
Self::new(MAX_DATAGRAM)
}
}
impl BufferPool for NoPool {
fn slab_len(&self) -> usize {
self.slab_len
}
fn take(&self) -> Box<[u8]> {
vec![0u8; self.slab_len].into_boxed_slice()
}
fn give(&self, _buf: Box<[u8]>) {}
}
pub struct SlabPool {
slab_len: usize,
max_idle: usize,
free: Mutex<Vec<Box<[u8]>>>,
outstanding: std::sync::atomic::AtomicUsize,
}
impl SlabPool {
#[must_use]
pub fn new(slab_len: usize, max_idle: usize) -> Arc<Self> {
Arc::new(Self {
slab_len,
max_idle,
free: Mutex::new(Vec::with_capacity(max_idle)),
outstanding: std::sync::atomic::AtomicUsize::new(0),
})
}
#[must_use]
pub fn for_wireguard() -> Arc<Self> {
Self::new(MAX_DATAGRAM, 256)
}
pub fn prefill(&self, n: usize) {
let n = n.min(self.max_idle);
if let Ok(mut free) = self.free.lock() {
while free.len() < n {
free.push(vec![0u8; self.slab_len].into_boxed_slice());
}
}
}
#[must_use]
pub fn get(self: &Arc<Self>) -> PooledBuf {
let storage = BufferPool::take(self.as_ref());
PooledBuf {
storage: Some(storage),
len: 0,
pool: Arc::clone(self),
}
}
#[must_use]
pub fn outstanding(&self) -> usize {
self.outstanding.load(std::sync::atomic::Ordering::Relaxed)
}
#[must_use]
pub fn idle(&self) -> usize {
self.free.lock().map(|f| f.len()).unwrap_or(0)
}
}
impl BufferPool for SlabPool {
fn slab_len(&self) -> usize {
self.slab_len
}
fn take(&self) -> Box<[u8]> {
self.outstanding
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if let Ok(mut free) = self.free.lock() {
if let Some(buf) = free.pop() {
return buf;
}
}
vec![0u8; self.slab_len].into_boxed_slice()
}
fn give(&self, buf: Box<[u8]>) {
self.outstanding
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if buf.len() != self.slab_len {
return; }
if let Ok(mut free) = self.free.lock() {
if free.len() < self.max_idle {
free.push(buf);
}
}
}
}
impl fmt::Debug for SlabPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SlabPool")
.field("slab_len", &self.slab_len)
.field("max_idle", &self.max_idle)
.field("idle", &self.idle())
.field("outstanding", &self.outstanding())
.finish()
}
}
pub struct PooledBuf {
storage: Option<Box<[u8]>>,
len: usize,
pool: Arc<SlabPool>,
}
impl PooledBuf {
#[must_use]
pub fn spare_mut(&mut self) -> &mut [u8] {
self.storage.as_deref_mut().unwrap_or(&mut [])
}
pub fn set_len(&mut self, len: usize) {
let cap = self.storage.as_deref().map_or(0, <[u8]>::len);
self.len = len.min(cap);
}
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn capacity(&self) -> usize {
self.storage.as_deref().map_or(0, <[u8]>::len)
}
#[must_use]
pub fn copy_from(pool: &Arc<SlabPool>, data: &[u8]) -> Self {
let mut b = pool.get();
let n = data.len().min(b.capacity());
if let Some(dst) = b.spare_mut().get_mut(..n) {
dst.copy_from_slice(data.get(..n).unwrap_or(&[]));
}
b.set_len(n);
b
}
}
impl Deref for PooledBuf {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.storage
.as_deref()
.and_then(|s| s.get(..self.len))
.unwrap_or(&[])
}
}
impl DerefMut for PooledBuf {
fn deref_mut(&mut self) -> &mut [u8] {
self.storage
.as_deref_mut()
.and_then(|s| s.get_mut(..self.len))
.unwrap_or(&mut [])
}
}
impl Drop for PooledBuf {
fn drop(&mut self) {
if let Some(storage) = self.storage.take() {
self.pool.give(storage);
}
}
}
impl fmt::Debug for PooledBuf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PooledBuf")
.field("len", &self.len)
.field("capacity", &self.capacity())
.finish()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
use super::*;
#[test]
fn slab_pool_recycles() {
let pool = SlabPool::new(128, 4);
assert_eq!(pool.idle(), 0);
let bufs: Vec<PooledBuf> = (0..4).map(|_| pool.get()).collect();
let addrs: Vec<usize> = bufs
.iter()
.map(|b| b.storage.as_ref().unwrap().as_ptr() as usize)
.collect();
assert_eq!(pool.outstanding(), 4);
drop(bufs);
assert_eq!(pool.idle(), 4);
assert_eq!(pool.outstanding(), 0);
let mut held = Vec::new();
for expected in addrs.iter().rev() {
let b = pool.get();
assert_eq!(b.storage.as_ref().unwrap().as_ptr() as usize, *expected);
held.push(b);
}
assert_eq!(pool.idle(), 0);
}
#[test]
fn pool_is_bounded() {
let pool = SlabPool::new(64, 2);
let bufs: Vec<_> = (0..10).map(|_| pool.get()).collect();
assert_eq!(pool.outstanding(), 10);
drop(bufs);
assert_eq!(pool.idle(), 2, "freelist must not exceed max_idle");
assert_eq!(pool.outstanding(), 0);
}
#[test]
fn pooled_buf_len_and_deref() {
let pool = SlabPool::new(32, 4);
let mut b = pool.get();
assert_eq!(b.len(), 0);
assert_eq!(b.capacity(), 32);
b.spare_mut()[..5].copy_from_slice(b"hello");
b.set_len(5);
assert_eq!(&*b, b"hello");
b.set_len(1000);
assert_eq!(b.len(), 32, "clamped to capacity");
}
#[test]
fn copy_from_helper() {
let pool = SlabPool::new(32, 4);
let b = PooledBuf::copy_from(&pool, b"wireguard");
assert_eq!(&*b, b"wireguard");
}
#[test]
fn stable_address_across_moves() {
let pool = SlabPool::new(64, 1);
let b = pool.get();
let addr = b.storage.as_ref().unwrap().as_ptr() as usize;
let b2 = b; assert_eq!(b2.storage.as_ref().unwrap().as_ptr() as usize, addr);
let boxed = Box::new(b2); assert_eq!(boxed.storage.as_ref().unwrap().as_ptr() as usize, addr);
}
}