use std::collections::HashMap;
use std::sync::atomic::{AtomicPtr, AtomicU8, AtomicUsize, Ordering};
use std::sync::Mutex;
use std::thread::{self, Thread};
#[derive(Default)]
pub struct Parker {
pending: AtomicUsize,
state: Mutex<State>,
}
#[derive(Default)]
struct State {
count: u64,
threads: HashMap<usize, HashMap<u64, Thread>>,
}
impl Parker {
pub fn park<T>(&self, atomic: &impl Atomic<T>, should_park: impl Fn(T) -> bool) {
let key = atomic as *const _ as usize;
loop {
self.pending.fetch_add(1, Ordering::SeqCst);
let id = {
let state = &mut *self.state.lock().unwrap();
state.count += 1;
let threads = state.threads.entry(key).or_default();
threads.insert(state.count, thread::current());
state.count
};
if !should_park(atomic.load(Ordering::SeqCst)) {
let thread = {
let mut state = self.state.lock().unwrap();
state
.threads
.get_mut(&key)
.and_then(|threads| threads.remove(&id))
};
if thread.is_some() {
self.pending.fetch_sub(1, Ordering::Relaxed);
}
return;
}
loop {
thread::park();
let mut state = self.state.lock().unwrap();
if !state
.threads
.get_mut(&key)
.is_some_and(|threads| threads.contains_key(&id))
{
break;
}
}
if !should_park(atomic.load(Ordering::Acquire)) {
return;
}
}
}
pub fn unpark<T>(&self, atomic: &impl Atomic<T>) {
let key = atomic as *const _ as usize;
if self.pending.load(Ordering::SeqCst) == 0 {
return;
}
let threads = {
let mut state = self.state.lock().unwrap();
state.threads.remove(&key)
};
if let Some(threads) = threads {
self.pending.fetch_sub(threads.len(), Ordering::Relaxed);
for (_, thread) in threads {
thread.unpark();
}
}
}
}
pub trait Atomic<T> {
fn load(&self, ordering: Ordering) -> T;
}
impl<T> Atomic<*mut T> for AtomicPtr<T> {
fn load(&self, ordering: Ordering) -> *mut T {
self.load(ordering)
}
}
impl Atomic<u8> for AtomicU8 {
fn load(&self, ordering: Ordering) -> u8 {
self.load(ordering)
}
}