use crate::utils::BufferSize;
use std::{alloc, ptr, sync, sync::atomic};
pub type BufferPointer = *mut u8;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct BufPoolCfg {
pub module_id: u8,
pub buffer_size: BufferSize,
pub max_memory: usize,
}
#[derive(Debug)]
pub struct BufPool {
active_allocations: atomic::AtomicUsize,
allocation_cv: sync::Condvar,
allocation_lock: sync::Mutex<()>,
allocated_memory: atomic::AtomicUsize,
cfg: BufPoolCfg,
shutdown_cv: sync::Condvar,
shutdown_lock: sync::Mutex<()>,
}
unsafe impl Send for BufPool {}
unsafe impl Sync for BufPool {}
impl BufPool {
#[inline]
pub fn new(cfg: BufPoolCfg) -> Self {
debug_assert!(
cfg.buffer_size.bytes() < cfg.max_memory,
"MAX_MEMORY should always be larger then the BUFFER_SIZE"
);
Self {
cfg,
active_allocations: atomic::AtomicUsize::new(0),
allocated_memory: atomic::AtomicUsize::new(0),
allocation_cv: sync::Condvar::new(),
allocation_lock: sync::Mutex::new(()),
shutdown_cv: sync::Condvar::new(),
shutdown_lock: sync::Mutex::new(()),
}
}
#[inline(always)]
pub fn allocate(&self, required: usize) -> BufPoolAllocation {
debug_assert!(required > 0, "required buffers must never be 0");
debug_assert!(
required * self.cfg.buffer_size.bytes() <= self.cfg.max_memory,
"Total required bytes must be smaller then the MAX_MEMORY allowed to avoid deadlock"
);
debug_assert!(
required * self.cfg.buffer_size.bytes() <= self.cfg.max_memory,
"Total required bytes must never exceed `u16::MAX` to avoid arithmatic overflows"
);
let required_bytes = self.cfg.buffer_size.bytes() * required;
loop {
let current_bytes = self.allocated_memory.load(atomic::Ordering::Acquire);
if current_bytes + required_bytes > self.cfg.max_memory {
self.backpressure(required_bytes);
continue;
}
match self.allocated_memory.compare_exchange(
current_bytes,
current_bytes + required_bytes,
atomic::Ordering::AcqRel,
atomic::Ordering::Acquire,
) {
Ok(_) => break,
Err(_) => continue,
}
}
let layout = create_layout(required_bytes);
let pointer = allocate_layout(layout);
self.active_allocations.fetch_add(1, atomic::Ordering::Relaxed);
BufPoolAllocation {
layout,
pointer,
required_bytes,
buffers: required,
pool: ptr::NonNull::from(self),
}
}
#[inline]
fn backpressure(&self, required_bytes: usize) {
let mut guard = self.allocation_lock.lock().unwrap_or_else(|e| e.into_inner());
while self.allocated_memory.load(atomic::Ordering::Acquire) + required_bytes > self.cfg.max_memory {
guard = self.allocation_cv.wait(guard).unwrap_or_else(|e| e.into_inner());
}
}
}
impl Drop for BufPool {
fn drop(&mut self) {
let mut guard = self.shutdown_lock.lock().unwrap_or_else(|e| e.into_inner());
while self.active_allocations.load(atomic::Ordering::Acquire) != 0 {
guard = self.shutdown_cv.wait(guard).unwrap_or_else(|e| e.into_inner());
}
}
}
#[derive(Debug)]
pub struct BufPoolAllocation {
buffers: usize,
layout: alloc::Layout,
pointer: ptr::NonNull<u8>,
pool: ptr::NonNull<BufPool>,
required_bytes: usize,
}
unsafe impl Send for BufPoolAllocation {}
impl BufPoolAllocation {
#[inline]
pub const fn first(&self) -> BufferPointer {
self.pointer.as_ptr()
}
#[inline]
pub const fn length(&self) -> usize {
self.buffers
}
#[inline]
pub const fn allocated_bytes(&self) -> usize {
self.required_bytes
}
#[inline]
pub fn iter(&self) -> BufPoolAllocationIter {
let pool = unsafe { self.pool.as_ref() };
BufPoolAllocationIter {
pointer: self.pointer,
buffer_size: pool.cfg.buffer_size.bytes(),
remaining: self.buffers,
}
}
}
impl Drop for BufPoolAllocation {
fn drop(&mut self) {
let pool = unsafe { self.pool.as_ref() };
deallocate_memory(self.pointer, self.layout);
pool.allocated_memory
.fetch_sub(self.required_bytes, atomic::Ordering::Release);
pool.allocation_cv.notify_one();
if pool.active_allocations.fetch_sub(1, atomic::Ordering::Release) == 1 {
pool.shutdown_cv.notify_one();
}
}
}
#[derive(Debug)]
pub struct BufPoolAllocationIter {
pointer: ptr::NonNull<u8>,
buffer_size: usize,
remaining: usize,
}
impl Iterator for BufPoolAllocationIter {
type Item = BufferPointer;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
let curr_ptr = self.pointer;
self.pointer = unsafe { self.pointer.add(self.buffer_size) };
self.remaining -= 1;
Some(curr_ptr.as_ptr())
}
}
#[inline]
fn create_layout(required_bytes: usize) -> alloc::Layout {
match alloc::Layout::array::<u8>(required_bytes) {
Ok(layout) => layout,
Err(e) => panic!("Invalid Layout: {e}"),
}
}
#[inline]
fn allocate_layout(layout: alloc::Layout) -> ptr::NonNull<u8> {
let pointer = unsafe { alloc::alloc(layout) };
match ptr::NonNull::new(pointer) {
Some(p) => p,
None => alloc::handle_alloc_error(layout),
}
}
#[inline]
fn deallocate_memory(pointer: ptr::NonNull<u8>, layout: alloc::Layout) {
unsafe { alloc::dealloc(pointer.as_ptr(), layout) };
}
#[cfg(test)]
mod tests {
use super::*;
const MOD_ID: u8 = 0;
const BUF_SIZE: BufferSize = BufferSize::S32;
#[inline]
fn create_bufpool(max_mem: usize) -> BufPool {
BufPool::new(BufPoolCfg {
buffer_size: BUF_SIZE,
max_memory: max_mem,
module_id: MOD_ID,
})
}
#[test]
#[should_panic]
#[cfg(debug_assertions)]
fn err_new_with_invalid_cfg() {
create_bufpool(BUF_SIZE.bytes() >> 1);
}
#[test]
#[should_panic]
#[cfg(debug_assertions)]
fn err_alloc_zero() {
let bpool = create_bufpool(BUF_SIZE.bytes());
let _ = bpool.allocate(0);
}
#[test]
#[should_panic]
#[cfg(debug_assertions)]
fn err_alloc_more_then_max_memory() {
let bpool = create_bufpool(BUF_SIZE.bytes());
let _ = bpool.allocate(2);
}
#[test]
fn ok_alloc_single() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 2);
let alloc = bpool.allocate(1);
assert_eq!(alloc.buffers, 1);
assert_eq!(alloc.required_bytes, BUF_SIZE.bytes());
}
#[test]
fn ok_alloc_multiple() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x14);
let alloc = bpool.allocate(0x10);
assert_eq!(alloc.buffers, 0x10);
assert_eq!(alloc.required_bytes, BUF_SIZE.bytes() * 0x10);
}
#[test]
fn ok_alloc_max_memory() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x0A);
let alloc = bpool.allocate(0x0A);
assert_eq!(alloc.buffers, 0x0A);
assert_eq!(alloc.required_bytes, BUF_SIZE.bytes() * 0x0A);
}
#[test]
fn ok_alloc_updates_memory_accounting() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x14);
let alloc = bpool.allocate(0x10);
assert_eq!(
bpool.allocated_memory.load(atomic::Ordering::Acquire),
BUF_SIZE.bytes() * 0x10
);
drop(alloc);
assert_eq!(bpool.allocated_memory.load(atomic::Ordering::Acquire), 0);
}
#[test]
fn ok_alloc_updates_active_allocation_tracking() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x2A);
let alloc1 = bpool.allocate(0x10);
let alloc2 = bpool.allocate(0x10);
assert_eq!(bpool.active_allocations.load(atomic::Ordering::Acquire), 2);
let _ = (drop(alloc1), drop(alloc2));
assert_eq!(bpool.active_allocations.load(atomic::Ordering::Acquire), 0);
}
#[test]
fn ok_alloc_decrments_allocated_memory_after_deallocations() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x80);
let allocations: Vec<_> = (0..0x20).map(|_| bpool.allocate(2)).collect();
assert_eq!(bpool.allocated_memory.load(atomic::Ordering::Acquire), 0x20 * 0x40);
drop(allocations);
assert_eq!(bpool.allocated_memory.load(atomic::Ordering::Acquire), 0);
}
#[test]
fn ok_backpressure_blocks_till_memory_is_deallocated() {
let bpool = sync::Arc::new(create_bufpool(BUF_SIZE.bytes() * 2));
let alloc = bpool.allocate(1);
let pool2 = bpool.clone();
let barrier = sync::Arc::new(sync::Barrier::new(2));
let barrier2 = barrier.clone();
let handle = std::thread::spawn(move || {
barrier2.wait();
let start = std::time::Instant::now();
let _alloc = pool2.allocate(2);
start.elapsed()
});
barrier.wait();
std::thread::sleep(std::time::Duration::from_millis(100));
drop(alloc);
let elapsed = handle.join().expect("allocation thread should not panic");
assert!(elapsed >= std::time::Duration::from_millis(100));
}
#[test]
fn ok_concurrent_allocations() {
let pool = sync::Arc::new(create_bufpool(BUF_SIZE.bytes() * 0x1000));
let mut handles = Vec::new();
for _ in 0..0x0A {
let pool = pool.clone();
handles.push(std::thread::spawn(move || {
for _ in 0..0x64 {
drop(pool.allocate(1));
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(pool.allocated_memory.load(atomic::Ordering::Acquire), 0);
assert_eq!(pool.active_allocations.load(atomic::Ordering::Acquire), 0);
}
mod drop {
use super::*;
#[test]
fn ok_partial_drop_updates_accounting() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x0A);
let alloc1 = bpool.allocate(2);
let alloc2 = bpool.allocate(2);
assert_eq!(
bpool.allocated_memory.load(atomic::Ordering::Acquire),
BUF_SIZE.bytes() * 4
);
drop(alloc1);
assert_eq!(
bpool.allocated_memory.load(atomic::Ordering::Acquire),
BUF_SIZE.bytes() * 2
);
drop(alloc2);
assert_eq!(bpool.allocated_memory.load(atomic::Ordering::Acquire), 0);
}
#[test]
fn ok_drop_waits_for_active_allocations() {
let bpool = sync::Arc::new(create_bufpool(BUF_SIZE.bytes() * 0x1A));
let alloc = bpool.allocate(0x10);
let handle = std::thread::spawn(move || {
drop(bpool);
});
std::thread::sleep(std::time::Duration::from_millis(0x64));
assert!(!handle.is_finished());
drop(alloc);
handle.join().unwrap();
}
}
mod memory_tests {
use super::*;
#[test]
fn ok_create_layout() {
let layout = create_layout(0x1000);
assert_eq!(layout.align(), 1);
assert_eq!(layout.size(), 0x1000);
}
#[test]
#[should_panic(expected = "Invalid Layout")]
fn err_create_layout() {
create_layout(usize::MAX);
}
#[test]
fn ok_allocate_layout() {
let layout = create_layout(0x10);
let pointer = allocate_layout(layout);
let raw_ptr = pointer.as_ptr();
assert!(!raw_ptr.is_null());
deallocate_memory(pointer, layout);
}
#[test]
fn ok_allocate_layout_allows_write() {
let layout = create_layout(0x80);
let pointer = allocate_layout(layout);
unsafe {
pointer.as_ptr().write(0x40);
assert_eq!(pointer.as_ptr().read(), 0x40);
}
deallocate_memory(pointer, layout);
}
#[test]
fn ok_allocate_layout_allows_write_to_entire_slice() {
let layout = create_layout(0x200);
let pointer = allocate_layout(layout);
unsafe {
for i in 0..0x200 {
pointer.as_ptr().add(i).write((i % 0xFF) as u8);
}
for i in 0..0x200 {
assert_eq!(pointer.as_ptr().add(i).read(), (i % 0xFF) as u8);
}
}
deallocate_memory(pointer, layout);
}
}
mod alloc_struct {
use super::*;
#[test]
fn ok_first_returns_ptr_to_first_buf_from_alloc() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x20);
let alloc = bpool.allocate(0x10);
assert_eq!(alloc.first(), alloc.pointer.as_ptr());
}
#[test]
fn ok_length_returns_length_of_alloc() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x20);
let alloc = bpool.allocate(0x10);
assert_eq!(alloc.length(), alloc.buffers);
}
#[test]
fn ok_allocated_bytes_return_total_allocated_bytes() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x20);
let alloc = bpool.allocate(0x10);
assert_eq!(alloc.allocated_bytes(), alloc.buffers * BUF_SIZE.bytes());
}
#[test]
fn ok_alloc_can_be_shared_across_threads() {
let pool = sync::Arc::new(create_bufpool(BUF_SIZE.bytes() * 2));
let alloc = pool.allocate(1);
std::thread::spawn(move || {
drop(alloc);
})
.join()
.unwrap();
}
}
mod iterator {
use super::*;
#[test]
fn ok_iter_yeilds_all_buffers() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x0A);
let alloc = bpool.allocate(4);
let ptrs: Vec<_> = alloc.iter().collect();
assert_eq!(ptrs.len(), 4);
assert_eq!(ptrs[1] as usize - ptrs[0] as usize, 0x20);
assert_eq!(ptrs[2] as usize - ptrs[1] as usize, 0x20);
assert_eq!(ptrs[3] as usize - ptrs[2] as usize, 0x20);
}
#[test]
fn ok_iter_yeilds_none_when_exhausted() {
let bpool = create_bufpool(BUF_SIZE.bytes() * 0x0A);
let alloc = bpool.allocate(2);
let mut iter = alloc.iter();
assert!(iter.next().is_some());
assert!(iter.next().is_some());
assert!(iter.next().is_none());
assert!(iter.next().is_none());
}
}
}