use std::{
cell::UnsafeCell,
collections::VecDeque,
mem::transmute,
process::abort,
ptr::NonNull,
sync::atomic::{AtomicUsize, Ordering::*},
};
static PAUSED_COUNT: AtomicUsize = AtomicUsize::new(0);
pub unsafe fn add<T>(ptr: NonNull<T>, dropper: unsafe fn(NonNull<T>)) {
LOCAL_DELETION.with(|queue| {
queue.add(Garbage {
ptr: NonNull::new_unchecked(ptr.as_ptr() as *mut u8),
dropper: transmute(dropper),
});
if PAUSED_COUNT.load(Acquire) == 0 {
queue.delete();
}
})
}
pub fn try_force() -> bool {
LOCAL_DELETION.with(|queue| {
let success = PAUSED_COUNT.load(Acquire) == 0;
if success {
queue.delete();
}
success
})
}
#[inline]
pub fn pause<F, T>(exec: F) -> T
where
F: FnOnce() -> T,
{
let paused = Pause::new();
let res = exec();
drop(paused);
res
}
struct Pause;
impl Pause {
pub fn new() -> Self {
if PAUSED_COUNT.fetch_add(1, Acquire) == usize::max_value() {
abort();
}
Pause
}
}
impl Drop for Pause {
fn drop(&mut self) {
PAUSED_COUNT.fetch_sub(1, Release);
}
}
struct Garbage {
ptr: NonNull<u8>,
dropper: unsafe fn(NonNull<u8>),
}
struct GarbageQueue {
inner: UnsafeCell<VecDeque<Garbage>>,
}
impl GarbageQueue {
fn new() -> Self {
Self { inner: UnsafeCell::new(VecDeque::with_capacity(16)) }
}
fn add(&self, garbage: Garbage) {
unsafe { &mut *self.inner.get() }.push_back(garbage);
}
fn delete(&self) {
let deque = unsafe { &mut *self.inner.get() };
while let Some(garbage) = deque.pop_front() {
unsafe {
(garbage.dropper)(garbage.ptr);
}
}
}
}
impl Drop for GarbageQueue {
fn drop(&mut self) {
while PAUSED_COUNT.load(Acquire) != 0 {}
self.delete();
}
}
thread_local! {
static LOCAL_DELETION: GarbageQueue = GarbageQueue::new();
}
#[cfg(test)]
mod test {
use super::*;
use alloc::*;
use std::thread;
#[test]
fn try_force_succeeds_in_single_threaded() {
assert!(try_force());
const COUNT: usize = 16;
let mut allocs = Vec::with_capacity(COUNT);
for i in 0 .. COUNT {
allocs.push(unsafe { alloc(i) });
}
pause(|| ());
for ptr in allocs {
unsafe {
add(ptr, dealloc);
}
}
assert!(try_force());
}
#[test]
fn count_is_gt_0_when_pausing() {
const NTHREADS: usize = 20;
let mut threads = Vec::with_capacity(NTHREADS);
for _ in 0 .. NTHREADS {
threads.push(thread::spawn(|| {
pause(|| {
assert!(PAUSED_COUNT.load(SeqCst) > 0);
})
}));
}
for thread in threads {
thread.join().expect("sub-thread panicked");
}
}
}