use bytes::BytesMut;
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use crate::packet::Packet;
#[derive(Clone)]
pub struct PacketBufPool<const N: usize = 4096> {
queue: Arc<Mutex<VecDeque<BytesMut>>>,
capacity: usize,
}
impl<const N: usize> PacketBufPool<N> {
pub fn new(capacity: usize) -> Self {
let mut queue = VecDeque::with_capacity(capacity);
for _ in 0..capacity {
queue.push_back(BytesMut::zeroed(N).split_to(0));
}
PacketBufPool {
queue: Arc::new(Mutex::new(queue)),
capacity,
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn re_use(&self) -> Option<Packet<[u8]>> {
while let Some(mut pointer_to_start_of_allocation) =
{ self.queue.lock().unwrap().pop_front() }
{
debug_assert_eq!(pointer_to_start_of_allocation.len(), 0);
if pointer_to_start_of_allocation.try_reclaim(N) {
let mut buf = pointer_to_start_of_allocation.split_off(0);
debug_assert!(buf.capacity() >= N);
unsafe { buf.set_len(N) };
let return_to_pool = ReturnToPool {
pointer_to_start_of_allocation: Some(pointer_to_start_of_allocation),
queue: self.queue.clone(),
};
return Some(Packet::new_from_pool(return_to_pool, buf));
} else {
continue;
}
}
None
}
pub fn get(&self) -> Packet<[u8]> {
if let Some(packet) = self.re_use() {
return packet;
}
let mut buf = BytesMut::zeroed(N);
let pointer_to_start_of_allocation = buf.split_to(0);
debug_assert_eq!(pointer_to_start_of_allocation.len(), 0);
debug_assert_eq!(buf.len(), N);
let return_to_pool = ReturnToPool {
pointer_to_start_of_allocation: Some(pointer_to_start_of_allocation),
queue: self.queue.clone(),
};
Packet::new_from_pool(return_to_pool, buf)
}
}
pub struct ReturnToPool {
pointer_to_start_of_allocation: Option<BytesMut>,
queue: Arc<Mutex<VecDeque<BytesMut>>>,
}
impl Drop for ReturnToPool {
fn drop(&mut self) {
let p = self.pointer_to_start_of_allocation.take().unwrap();
let mut queue_g = self.queue.lock().unwrap();
if queue_g.len() < queue_g.capacity() {
queue_g.push_back(p);
}
}
}
#[cfg(test)]
mod tests {
use std::{hint::black_box, thread};
use super::PacketBufPool;
#[test]
fn pool_prealloc() {
const N: usize = 1024;
let buffer_count = 10;
let pool = PacketBufPool::<N>::new(10);
let mut packets = vec![];
for _ in 0..buffer_count {
let packet = pool.re_use().expect("10 buffers was pre-allocated");
assert_eq!(packet.buf().len(), N);
packets.push(packet); }
assert!(
pool.re_use().is_none(),
"pool is empty and a new packet must be allocated"
);
}
#[test]
fn pool_buffer_recycle() {
let pool = PacketBufPool::<4096>::new(1);
for i in 0..10 {
let mut packet1 = black_box(pool.get());
let packet1_addr = packet1.buf().as_ptr();
let data = format!("Hello there. x{i}\nGeneral Kenobi! You are a bold one.");
let data = data.as_bytes();
packet1.truncate(data.len());
packet1.copy_from_slice(data);
thread::spawn(move || drop(packet1)).join().unwrap();
let packet2 = black_box(pool.get());
let packet2_addr = packet2.buf().as_ptr();
let packet3 = black_box(pool.get());
let packet3_addr = packet3.buf().as_ptr();
assert!(
packet2.starts_with(data),
"old data should remain in the recycled buffer",
);
assert!(
!packet3.starts_with(data),
"old data should not exist in the new buffer",
);
assert_eq!(packet1_addr, packet2_addr);
assert_ne!(packet1_addr, packet3_addr);
drop((packet2, packet3));
}
}
}