use std::cell::UnsafeCell;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicU32, Ordering};
use crate::error::{Error, Result};
use super::ring::MpmcRing;
use super::{DEFAULT_POOL_CAPACITY, MAX_PACKET_SIZE};
#[repr(C, align(64))]
pub struct PacketBuffer {
data: [u8; MAX_PACKET_SIZE],
len: u32,
index: u32,
refcount: AtomicU32,
}
impl PacketBuffer {
#[inline]
#[allow(clippy::large_stack_arrays)] const fn new(index: u32) -> Self {
Self {
data: [0; MAX_PACKET_SIZE],
len: 0,
index,
refcount: AtomicU32::new(0),
}
}
#[inline]
#[must_use]
pub const fn index(&self) -> u32 {
self.index
}
#[inline]
#[must_use]
pub const fn len(&self) -> usize {
self.len as usize
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn set_len(&mut self, len: usize) {
self.len = len.min(MAX_PACKET_SIZE) as u32;
}
#[inline]
#[must_use]
pub fn as_slice(&self) -> &[u8] {
&self.data[..self.len as usize]
}
#[inline]
#[must_use]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data[..self.len as usize]
}
#[inline]
#[must_use]
pub fn as_full_slice(&self) -> &[u8] {
&self.data
}
#[inline]
#[must_use]
pub fn as_full_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
#[inline]
#[must_use]
pub fn as_ptr(&self) -> *const u8 {
self.data.as_ptr()
}
#[inline]
#[must_use]
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.data.as_mut_ptr()
}
#[inline]
pub fn add_ref(&self) {
self.refcount.fetch_add(1, Ordering::AcqRel);
}
#[inline]
pub fn release(&self) -> bool {
self.refcount.fetch_sub(1, Ordering::AcqRel) == 1
}
#[inline]
#[must_use]
pub fn refcount(&self) -> u32 {
self.refcount.load(Ordering::Acquire)
}
#[inline]
pub fn reset(&mut self) {
self.len = 0;
self.refcount.store(0, Ordering::Release);
}
pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
if data.len() > MAX_PACKET_SIZE {
return Err(Error::PacketPool(format!(
"data too large: {} > {}",
data.len(),
MAX_PACKET_SIZE
)));
}
self.data[..data.len()].copy_from_slice(data);
self.len = data.len() as u32;
Ok(())
}
}
#[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for PacketBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketBuffer")
.field("len", &self.len)
.field("index", &self.index)
.field("refcount", &self.refcount.load(Ordering::Relaxed))
.finish()
}
}
pub struct PacketRef<'pool> {
pool: &'pool PacketPool,
idx: u32,
}
impl PacketRef<'_> {
#[inline]
#[must_use]
pub fn index(&self) -> u32 {
self.idx
}
#[inline]
#[must_use]
pub fn into_index(self) -> u32 {
let md = ManuallyDrop::new(self);
md.idx
}
}
impl Deref for PacketRef<'_> {
type Target = PacketBuffer;
#[inline]
fn deref(&self) -> &PacketBuffer {
unsafe { &*self.pool.buffers[self.idx as usize].get() }
}
}
impl DerefMut for PacketRef<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut PacketBuffer {
unsafe { &mut *self.pool.buffers[self.idx as usize].get() }
}
}
impl Drop for PacketRef<'_> {
fn drop(&mut self) {
unsafe { self.pool.free_by_index(self.idx) };
}
}
impl std::fmt::Debug for PacketRef<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketRef")
.field("idx", &self.idx)
.field("buffer", &**self)
.finish()
}
}
pub struct PacketPool {
buffers: Box<[UnsafeCell<PacketBuffer>]>,
free_indices: MpmcRing<u32>,
capacity: usize,
}
impl PacketPool {
pub fn new(capacity: usize) -> Result<Self> {
let capacity = capacity.max(1);
let buffers: Vec<UnsafeCell<PacketBuffer>> = (0..capacity)
.map(|i| UnsafeCell::new(PacketBuffer::new(i as u32)))
.collect();
let free_indices = MpmcRing::new(capacity);
for i in 0..capacity {
let _ = free_indices.enqueue(i as u32);
}
Ok(Self {
buffers: buffers.into_boxed_slice(),
free_indices,
capacity,
})
}
pub fn with_default_capacity() -> Result<Self> {
Self::new(DEFAULT_POOL_CAPACITY)
}
#[inline]
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
#[inline]
#[must_use]
pub fn free_count(&self) -> usize {
self.free_indices.len()
}
#[inline]
#[must_use]
pub fn allocated_count(&self) -> usize {
self.capacity - self.free_indices.len()
}
pub fn alloc(&self) -> Option<PacketRef<'_>> {
let idx = self.free_indices.dequeue()?;
unsafe {
(*self.buffers[idx as usize].get())
.refcount
.store(1, Ordering::Release);
}
Some(PacketRef { pool: self, idx })
}
pub fn alloc_index(&self) -> Option<u32> {
let idx = self.free_indices.dequeue()?;
unsafe {
(*self.buffers[idx as usize].get())
.refcount
.store(1, Ordering::Release);
}
Some(idx)
}
pub fn alloc_with_data(&self, data: &[u8]) -> Result<PacketRef<'_>> {
let mut pkt = self
.alloc()
.ok_or_else(|| Error::PacketPool("pool exhausted".to_string()))?;
pkt.copy_from_slice(data)?;
Ok(pkt)
}
pub(crate) unsafe fn free(&self, buffer: &mut PacketBuffer) {
let idx = buffer.index;
debug_assert!((idx as usize) < self.capacity);
buffer.reset();
while self.free_indices.enqueue(idx).is_err() {
std::hint::spin_loop();
}
}
pub unsafe fn free_by_index(&self, idx: u32) {
debug_assert!((idx as usize) < self.capacity);
let buffer = unsafe { &mut *self.buffers[idx as usize].get() };
unsafe { self.free(buffer) };
}
#[must_use]
pub unsafe fn get(&self, idx: u32) -> &PacketBuffer {
debug_assert!((idx as usize) < self.capacity);
unsafe { &*self.buffers[idx as usize].get() }
}
#[must_use]
#[allow(clippy::mut_from_ref)] pub unsafe fn get_mut(&self, idx: u32) -> &mut PacketBuffer {
debug_assert!((idx as usize) < self.capacity);
unsafe { &mut *self.buffers[idx as usize].get() }
}
pub fn alloc_batch_indices(&self, out: &mut [u32]) -> usize {
let mut count = 0;
for slot in out.iter_mut() {
if let Some(idx) = self.alloc_index() {
*slot = idx;
count += 1;
} else {
break;
}
}
count
}
}
#[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for PacketPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PacketPool")
.field("capacity", &self.capacity)
.field("free_count", &self.free_count())
.field("allocated_count", &self.allocated_count())
.finish()
}
}
unsafe impl Send for PacketPool {}
unsafe impl Sync for PacketPool {}
#[allow(dead_code)]
const _ASSERT_PACKET_REF_SEND: () = {
const fn assert_send<T: Send>() {}
assert_send::<PacketRef<'_>>();
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_size() {
assert_eq!(std::mem::align_of::<PacketBuffer>(), 64);
}
#[test]
fn test_pool_creation() {
let pool = PacketPool::new(100).unwrap();
assert_eq!(pool.capacity(), 100);
assert_eq!(pool.free_count(), 100);
assert_eq!(pool.allocated_count(), 0);
}
#[test]
fn test_pool_alloc_drop() {
let pool = PacketPool::new(10).unwrap();
let buf = pool.alloc().unwrap();
assert_eq!(buf.refcount(), 1);
assert_eq!(pool.free_count(), 9);
assert_eq!(pool.allocated_count(), 1);
drop(buf);
assert_eq!(pool.free_count(), 10);
assert_eq!(pool.allocated_count(), 0);
let buf2 = pool.alloc().unwrap();
assert!(buf2.index() < 10);
}
#[test]
fn test_packet_ref_into_index() {
let pool = PacketPool::new(4).unwrap();
let buf = pool.alloc().unwrap();
let idx = buf.index();
let extracted = buf.into_index();
assert_eq!(extracted, idx);
assert_eq!(pool.free_count(), 3);
unsafe { pool.free_by_index(extracted) };
assert_eq!(pool.free_count(), 4);
}
#[test]
fn test_pool_exhaustion() {
let pool = PacketPool::new(2).unwrap();
let _buf1 = pool.alloc().unwrap();
let _buf2 = pool.alloc().unwrap();
assert!(pool.alloc().is_none());
assert_eq!(pool.free_count(), 0);
}
#[test]
fn test_buffer_copy() {
let pool = PacketPool::new(1).unwrap();
let mut buf = pool.alloc().unwrap();
let data = [1u8, 2, 3, 4, 5];
buf.copy_from_slice(&data).unwrap();
assert_eq!(buf.len(), 5);
assert_eq!(buf.as_slice(), &data);
}
#[test]
fn test_alloc_with_data() {
let pool = PacketPool::new(1).unwrap();
let data = [0xAB; 100];
let buf = pool.alloc_with_data(&data).unwrap();
assert_eq!(buf.len(), 100);
assert_eq!(buf.as_slice(), &data);
}
#[test]
fn test_batch_alloc_indices() {
let pool = PacketPool::new(5).unwrap();
let mut indices = [0u32; 10];
let count = pool.alloc_batch_indices(&mut indices);
assert_eq!(count, 5);
assert_eq!(pool.free_count(), 0);
for idx in &indices[..5] {
assert!(*idx < 5);
}
}
#[test]
fn test_concurrent_alloc_drop() {
use std::sync::Arc;
let pool = Arc::new(PacketPool::new(64).unwrap());
let iterations = 1000;
let threads = 4;
let handles: Vec<_> = (0..threads)
.map(|_| {
let pool = Arc::clone(&pool);
std::thread::spawn(move || {
for _ in 0..iterations {
if let Some(mut pkt) = pool.alloc() {
pkt.set_len(4);
pkt.as_full_mut_slice()[..4].copy_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
assert_eq!(pkt.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
}
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(pool.free_count(), 64);
assert_eq!(pool.allocated_count(), 0);
}
#[test]
fn test_concurrent_alloc_into_index_free() {
use std::sync::Arc;
let pool = Arc::new(PacketPool::new(32).unwrap());
let iterations = 500;
let threads = 4;
let handles: Vec<_> = (0..threads)
.map(|_| {
let pool = Arc::clone(&pool);
std::thread::spawn(move || {
for _ in 0..iterations {
if let Some(pkt) = pool.alloc() {
let idx = pkt.into_index();
assert!(idx < 32);
unsafe { pool.free_by_index(idx) };
}
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(pool.free_count(), 32);
}
}