use std::alloc::{self, Layout};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use crossbeam_queue::ArrayQueue;
use super::aligned::{AlignedBuffer, RawAllocation};
use crate::Error;
use crate::Result;
#[derive(Clone)]
pub(crate) struct AlignedBufferPool {
inner: Arc<PoolInner>,
}
impl AlignedBufferPool {
pub(crate) fn new(capacity: usize, block_size: usize, block_align: usize) -> Result<Self> {
if block_align == 0 || !block_align.is_power_of_two() {
return Err(Error::AlignmentRequired {
detail: "buffer pool block_align must be a non-zero power of two",
});
}
if block_size == 0 || block_size % block_align != 0 {
return Err(Error::AlignmentRequired {
detail: "buffer pool block_size must be a non-zero multiple of block_align",
});
}
let block_layout = Layout::from_size_align(block_size, block_align).map_err(|_| {
Error::AlignmentRequired {
detail: "buffer pool layout rejected by Layout::from_size_align",
}
})?;
if capacity == 0 {
return Err(Error::AlignmentRequired {
detail: "buffer pool capacity must be > 0",
});
}
Ok(Self {
inner: Arc::new(PoolInner {
free: ArrayQueue::new(capacity),
allocated: AtomicUsize::new(0),
capacity,
block_layout,
wait_lock: Mutex::new(()),
wait_cv: Condvar::new(),
}),
})
}
pub(crate) fn lease(&self) -> AlignedBuffer {
self.inner.lease()
}
#[allow(dead_code)] pub(crate) fn block_size(&self) -> usize {
self.inner.block_layout.size()
}
#[allow(dead_code)] pub(crate) fn block_align(&self) -> usize {
self.inner.block_layout.align()
}
#[allow(dead_code)] pub(crate) fn capacity(&self) -> usize {
self.inner.capacity
}
}
pub(super) struct PoolInner {
free: ArrayQueue<RawAllocation>,
allocated: AtomicUsize,
capacity: usize,
block_layout: Layout,
wait_lock: Mutex<()>,
wait_cv: Condvar,
}
impl PoolInner {
fn lease(self: &Arc<Self>) -> AlignedBuffer {
loop {
if let Some(raw) = self.free.pop() {
return AlignedBuffer::new(raw, Arc::clone(self));
}
let allocated = self.allocated.load(Ordering::Acquire);
if allocated < self.capacity {
if self
.allocated
.compare_exchange(
allocated,
allocated + 1,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
let raw = match self.alloc_fresh() {
Some(r) => r,
None => {
let _ = self.allocated.fetch_sub(1, Ordering::AcqRel);
self.wait_for_return();
continue;
}
};
return AlignedBuffer::new(raw, Arc::clone(self));
}
continue;
}
self.wait_for_return();
}
}
fn wait_for_return(&self) {
let guard = match self.wait_lock.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
if !self.free.is_empty() || self.allocated.load(Ordering::Acquire) < self.capacity {
return;
}
let _woken_guard = self.wait_cv.wait(guard);
}
fn alloc_fresh(&self) -> Option<RawAllocation> {
let ptr = unsafe { alloc::alloc_zeroed(self.block_layout) };
let nn = NonNull::new(ptr)?;
Some(RawAllocation {
ptr: nn,
layout: self.block_layout,
})
}
pub(super) fn return_allocation(&self, raw: RawAllocation) {
if let Err(raw) = self.free.push(raw) {
unsafe {
alloc::dealloc(raw.ptr.as_ptr(), raw.layout);
}
let _ = self.allocated.fetch_sub(1, Ordering::AcqRel);
}
self.wait_cv.notify_one();
}
}
impl Drop for PoolInner {
fn drop(&mut self) {
while let Some(raw) = self.free.pop() {
unsafe {
alloc::dealloc(raw.ptr.as_ptr(), raw.layout);
}
}
}
}
const _: () = {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
fn check() {
assert_send::<AlignedBufferPool>();
assert_sync::<AlignedBufferPool>();
assert_send::<AlignedBuffer>();
}
let _ = check;
};
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_pool_new_validates_alignment_power_of_two() {
let r = AlignedBufferPool::new(4, 4096, 7);
assert!(r.is_err(), "non-power-of-two alignment must error");
}
#[test]
fn test_pool_new_validates_size_multiple_of_align() {
let r = AlignedBufferPool::new(4, 4097, 512);
assert!(r.is_err(), "size not a multiple of align must error");
}
#[test]
fn test_pool_new_validates_zero_capacity() {
let r = AlignedBufferPool::new(0, 4096, 512);
assert!(r.is_err(), "zero capacity must error");
}
#[test]
fn test_pool_new_validates_zero_block_size() {
let r = AlignedBufferPool::new(4, 0, 512);
assert!(r.is_err(), "zero block size must error");
}
#[test]
fn test_pool_new_succeeds_with_valid_inputs() {
let p = AlignedBufferPool::new(64, 4096, 512).expect("valid pool");
assert_eq!(p.capacity(), 64);
assert_eq!(p.block_size(), 4096);
assert_eq!(p.block_align(), 512);
}
#[test]
fn test_lease_returns_aligned_buffer_with_correct_size() {
let p = AlignedBufferPool::new(4, 4096, 512).expect("pool");
let b = p.lease();
assert_eq!(b.len(), 4096);
assert_eq!(b.align(), 512);
assert_eq!((b.as_ptr() as usize) % 512, 0);
}
#[test]
fn test_lease_returns_zero_initialised_buffer() {
let p = AlignedBufferPool::new(4, 4096, 512).expect("pool");
let b = p.lease();
assert!(b.as_slice().iter().all(|&x| x == 0));
}
#[test]
fn test_drop_returns_buffer_to_pool() {
let p = AlignedBufferPool::new(2, 4096, 512).expect("pool");
let b1 = p.lease();
let b2 = p.lease();
let p1 = b1.as_ptr() as usize;
let p2 = b2.as_ptr() as usize;
drop(b1);
drop(b2);
let b3 = p.lease();
let b4 = p.lease();
let p3 = b3.as_ptr() as usize;
let p4 = b4.as_ptr() as usize;
let returned = [p3, p4];
assert!(returned.contains(&p1), "buf 1 should be reused");
assert!(returned.contains(&p2), "buf 2 should be reused");
}
#[test]
fn test_concurrent_lease_and_return() {
let pool = AlignedBufferPool::new(4, 4096, 512).expect("pool");
let pool_arc = Arc::new(pool);
let n_threads = 8;
let leases_per_thread = 16;
let mut handles = Vec::new();
for _ in 0..n_threads {
let p = Arc::clone(&pool_arc);
handles.push(std::thread::spawn(move || {
for _ in 0..leases_per_thread {
let mut b = p.lease();
b.as_mut_slice()[0] = 0x42;
}
}));
}
for h in handles {
h.join().expect("thread");
}
let _ = pool_arc.lease();
}
#[test]
fn test_lazy_allocation_no_buffers_until_first_lease() {
let p = AlignedBufferPool::new(64, 4096, 512).expect("pool");
assert_eq!(p.inner.allocated.load(Ordering::Acquire), 0);
let _b = p.lease();
assert_eq!(p.inner.allocated.load(Ordering::Acquire), 1);
}
#[test]
fn test_lease_blocks_when_capacity_exhausted_then_succeeds_after_return() {
let pool = AlignedBufferPool::new(1, 4096, 512).expect("pool");
let pool_arc = Arc::new(pool);
let b1 = pool_arc.lease();
let p1_ptr = b1.as_ptr() as usize;
let p_clone = Arc::clone(&pool_arc);
let handle = std::thread::spawn(move || {
let b2 = p_clone.lease();
b2.as_ptr() as usize
});
std::thread::sleep(std::time::Duration::from_millis(50));
drop(b1);
let p2_ptr = handle.join().expect("thread");
assert_eq!(p1_ptr, p2_ptr);
}
}