use std::cell::UnsafeCell;
use std::mem::{ManuallyDrop, MaybeUninit};
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
const NONE: usize = usize::MAX;
struct Slot<T> {
value: UnsafeCell<MaybeUninit<T>>,
next: AtomicUsize,
}
unsafe impl<T: Send> Send for Slot<T> {}
unsafe impl<T: Send + Sync> Sync for Slot<T> {}
struct Inner<T> {
slots: Box<[Slot<T>]>,
free_head: AtomicUsize,
free_count: AtomicUsize,
reset: Box<dyn Fn(&mut T) + Send + Sync>,
}
impl<T> Inner<T> {
fn push(&self, idx: usize, mut value: T) {
(self.reset)(&mut value);
unsafe {
(*self.slots[idx].value.get()).write(value);
}
loop {
let head = self.free_head.load(Ordering::Relaxed);
self.slots[idx].next.store(head, Ordering::Relaxed);
match self.free_head.compare_exchange_weak(
head,
idx,
Ordering::Release, Ordering::Relaxed, ) {
Ok(_) => {
self.free_count.fetch_add(1, Ordering::Relaxed);
return;
}
Err(_) => std::hint::spin_loop(),
}
}
}
fn pop(&self) -> Option<usize> {
loop {
let head = self.free_head.load(Ordering::Acquire);
if head == NONE {
return None;
}
let next = self.slots[head].next.load(Ordering::Relaxed);
match self.free_head.compare_exchange_weak(
head,
next,
Ordering::Acquire, Ordering::Acquire, ) {
Ok(_) => {
self.free_count.fetch_sub(1, Ordering::Relaxed);
return Some(head);
}
Err(_) => {
std::hint::spin_loop();
}
}
}
}
unsafe fn read_value(&self, idx: usize) -> T {
unsafe { (*self.slots[idx].value.get()).assume_init_read() }
}
}
impl<T> Drop for Inner<T> {
fn drop(&mut self) {
let mut idx = *self.free_head.get_mut();
while idx != NONE {
unsafe {
(*self.slots[idx].value.get()).assume_init_drop();
}
idx = *self.slots[idx].next.get_mut();
}
}
}
pub struct Pool<T> {
inner: Arc<Inner<T>>,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<T: Send> Send for Pool<T> {}
impl<T> Pool<T> {
pub fn new<I, R>(capacity: usize, mut init: I, reset: R) -> Self
where
I: FnMut() -> T,
R: Fn(&mut T) + Send + Sync + 'static,
{
assert!(capacity > 0, "capacity must be non-zero");
assert!(
capacity < NONE,
"capacity must be less than {}",
NONE
);
let slots: Box<[Slot<T>]> = (0..capacity)
.map(|i| Slot {
value: UnsafeCell::new(MaybeUninit::new(init())),
next: AtomicUsize::new(if i + 1 < capacity { i + 1 } else { NONE }),
})
.collect();
Self {
inner: Arc::new(Inner {
slots,
free_head: AtomicUsize::new(0), free_count: AtomicUsize::new(capacity),
reset: Box::new(reset),
}),
}
}
pub fn try_acquire(&self) -> Option<Pooled<T>> {
self.inner.pop().map(|idx| {
let value = unsafe { self.inner.read_value(idx) };
Pooled {
value: ManuallyDrop::new(value),
idx,
inner: Arc::downgrade(&self.inner),
}
})
}
pub fn available(&self) -> usize {
self.inner.free_count.load(Ordering::Relaxed)
}
}
pub struct Pooled<T> {
value: ManuallyDrop<T>,
idx: usize,
inner: Weak<Inner<T>>,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<T: Send> Send for Pooled<T> {}
unsafe impl<T: Send + Sync> Sync for Pooled<T> {}
impl<T> Deref for Pooled<T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
&self.value
}
}
impl<T> DerefMut for Pooled<T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
&mut self.value
}
}
impl<T> Drop for Pooled<T> {
fn drop(&mut self) {
if let Some(inner) = self.inner.upgrade() {
let value = unsafe { ManuallyDrop::take(&mut self.value) };
inner.push(self.idx, value);
} else {
unsafe { ManuallyDrop::drop(&mut self.value) };
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::thread;
#[test]
fn basic_acquire_release() {
let acquirer = Pool::new(3, || Vec::<u8>::with_capacity(16), Vec::clear);
let mut a = acquirer.try_acquire().unwrap();
a.extend_from_slice(b"hello");
assert_eq!(&*a, b"hello");
let _b = acquirer.try_acquire().unwrap();
let _c = acquirer.try_acquire().unwrap();
assert!(acquirer.try_acquire().is_none());
drop(a);
let d = acquirer.try_acquire().unwrap();
assert!(d.is_empty());
}
#[test]
fn cross_thread_return() {
let acquirer = Pool::new(2, || 42u32, |_| {});
let item = acquirer.try_acquire().unwrap();
assert_eq!(*item, 42);
thread::spawn(move || {
assert_eq!(*item, 42);
drop(item);
})
.join()
.unwrap();
let item2 = acquirer.try_acquire().unwrap();
assert_eq!(*item2, 42);
}
#[test]
fn acquirer_dropped_first() {
let item;
{
let acquirer = Pool::new(1, || String::from("test"), String::clear);
item = acquirer.try_acquire().unwrap();
}
assert_eq!(&*item, "test");
}
#[test]
fn reset_called_on_return() {
let reset_count = Arc::new(AtomicUsize::new(0));
let reset_count_clone = Arc::clone(&reset_count);
let acquirer = Pool::new(
2,
|| 0u32,
move |_| {
reset_count_clone.fetch_add(1, Ordering::Relaxed);
},
);
let a = acquirer.try_acquire().unwrap();
assert_eq!(reset_count.load(Ordering::Relaxed), 0);
drop(a);
assert_eq!(reset_count.load(Ordering::Relaxed), 1);
let b = acquirer.try_acquire().unwrap();
let c = acquirer.try_acquire().unwrap();
drop(b);
drop(c);
assert_eq!(reset_count.load(Ordering::Relaxed), 3);
}
#[test]
fn lifo_ordering() {
let acquirer = Pool::new(3, Vec::<u8>::new, Vec::clear);
let mut guard_a = acquirer.try_acquire().unwrap();
let mut guard_b = acquirer.try_acquire().unwrap();
let mut guard_c = acquirer.try_acquire().unwrap();
guard_a.push(1);
guard_b.push(2);
guard_c.push(3);
drop(guard_a);
drop(guard_b);
drop(guard_c);
let reacquired_1 = acquirer.try_acquire().unwrap();
assert!(reacquired_1.is_empty());
let reacquired_2 = acquirer.try_acquire().unwrap();
assert!(reacquired_2.is_empty());
let reacquired_3 = acquirer.try_acquire().unwrap();
assert!(reacquired_3.is_empty()); }
#[test]
#[should_panic(expected = "capacity must be non-zero")]
fn zero_capacity_panics() {
let _ = Pool::new(0, || (), |()| {});
}
#[test]
fn stress_single_thread() {
let acquirer = Pool::new(100, || Vec::<u8>::with_capacity(64), Vec::clear);
for _ in 0..10_000 {
let mut items: Vec<_> = (0..50).filter_map(|_| acquirer.try_acquire()).collect();
for item in &mut items {
item.extend_from_slice(b"data");
}
drop(items);
}
let count = acquirer.available();
assert_eq!(count, 100);
}
#[test]
fn stress_multi_thread_return() {
let acquirer = Pool::new(
100,
|| AtomicUsize::new(0),
|v| {
v.store(0, Ordering::Relaxed);
},
);
let returned = Arc::new(AtomicUsize::new(0));
thread::scope(|s| {
let (tx, rx) = std::sync::mpsc::channel();
let returned_clone = Arc::clone(&returned);
s.spawn(move || {
while let Ok(item) = rx.recv() {
let _item: Pooled<AtomicUsize> = item;
returned_clone.fetch_add(1, Ordering::Relaxed);
}
});
let mut sent = 0;
while sent < 1000 {
if let Some(item) = acquirer.try_acquire() {
tx.send(item).unwrap();
sent += 1;
} else {
thread::yield_now();
}
}
});
assert_eq!(returned.load(Ordering::Relaxed), 1000);
}
#[test]
fn stress_concurrent_return() {
let acquirer = Pool::new(1000, || 0u64, |_| {});
let items: Vec<_> = (0..1000).filter_map(|_| acquirer.try_acquire()).collect();
assert_eq!(items.len(), 1000);
let items_per_thread = 250;
let mut item_chunks: Vec<Vec<_>> = Vec::new();
let mut iter = items.into_iter();
for _ in 0..4 {
item_chunks.push(iter.by_ref().take(items_per_thread).collect());
}
thread::scope(|s| {
for chunk in item_chunks {
s.spawn(move || {
for item in chunk {
drop(item);
}
});
}
});
let count = acquirer.available();
assert_eq!(count, 1000);
}
}