use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use parking_lot::{Condvar, Mutex};
#[derive(Default)]
pub struct ParkGroup {
inner: Arc<ParkGroupInner>,
}
#[derive(Default)]
struct ParkGroupInner {
condvar: Condvar,
notifications: Mutex<(u32, bool)>,
state: AtomicU64,
num_workers: AtomicU32,
}
const IDLE_UNIT: u64 = 1;
const ACTIVE_RECRUITER_BIT: u64 = 1 << 32;
const PREPARING_TO_PARK_BIT: u64 = 1 << 33;
const VERSION_UNIT: u64 = 1 << 34;
fn state_num_idle(state: u64) -> u32 {
state as u32
}
fn state_version(state: u64) -> u32 {
(state >> 34) as u32
}
pub struct ParkGroupWorker {
inner: Arc<ParkGroupInner>,
recruiter: bool,
version: u32,
}
impl ParkGroup {
pub fn new() -> Self {
Self::default()
}
pub fn new_worker(&self) -> ParkGroupWorker {
self.inner
.num_workers
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |w| w.checked_add(1))
.expect("can't have more than 2^32 - 1 workers");
ParkGroupWorker {
version: 0,
inner: Arc::clone(&self.inner),
recruiter: false,
}
}
pub fn unpark_one(&self) {
self.inner.unpark_one();
}
}
impl ParkGroupWorker {
pub fn prepare_park(&mut self) -> ParkAttempt<'_> {
let mut state = self.inner.state.load(Ordering::SeqCst);
self.version = state_version(state);
while state & PREPARING_TO_PARK_BIT == 0 && state_version(state) == self.version {
let new_state = state | PREPARING_TO_PARK_BIT | ACTIVE_RECRUITER_BIT;
match self.inner.state.compare_exchange_weak(
state,
new_state,
Ordering::Relaxed,
Ordering::SeqCst,
) {
Ok(s) => {
if s & ACTIVE_RECRUITER_BIT == 0 {
self.recruiter = true;
}
break;
},
Err(s) => state = s,
}
}
ParkAttempt { worker: self }
}
pub fn recruit_next(&mut self) {
if !self.recruiter {
return;
}
let mut recruit_next = false;
let _ = self
.inner
.state
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
debug_assert!(state & ACTIVE_RECRUITER_BIT != 0);
recruit_next = state_num_idle(state) > 0;
let bit = if recruit_next {
IDLE_UNIT
} else {
ACTIVE_RECRUITER_BIT
};
Some(state - bit)
});
if recruit_next {
self.inner.unpark_one_slow_as_recruiter();
}
self.recruiter = false;
}
}
pub struct ParkAttempt<'a> {
worker: &'a mut ParkGroupWorker,
}
impl ParkAttempt<'_> {
pub fn park(mut self) {
let state = &self.worker.inner.state;
let update = state.fetch_update(Ordering::Relaxed, Ordering::SeqCst, |state| {
if state_version(state) != self.worker.version {
None
} else if self.worker.recruiter {
Some(state + IDLE_UNIT - ACTIVE_RECRUITER_BIT)
} else {
Some(state + IDLE_UNIT)
}
});
if update.is_ok() {
self.park_slow()
}
}
#[cold]
fn park_slow(&mut self) {
let condvar = &self.worker.inner.condvar;
let mut notifications = self.worker.inner.notifications.lock();
condvar.wait_while(&mut notifications, |n| n.0 == 0);
self.worker.recruiter = notifications.1;
notifications.0 -= 1;
notifications.1 = false;
}
}
impl ParkGroupInner {
fn unpark_one(&self) {
let mut should_unpark = false;
let _ = self
.state
.fetch_update(Ordering::Release, Ordering::SeqCst, |state| {
should_unpark = state_num_idle(state) > 0 && state & ACTIVE_RECRUITER_BIT == 0;
if should_unpark {
Some(state - IDLE_UNIT + ACTIVE_RECRUITER_BIT)
} else if state & PREPARING_TO_PARK_BIT == PREPARING_TO_PARK_BIT {
Some(state.wrapping_add(VERSION_UNIT) & !PREPARING_TO_PARK_BIT)
} else {
None
}
});
if should_unpark {
self.unpark_one_slow_as_recruiter();
}
}
#[cold]
fn unpark_one_slow_as_recruiter(&self) {
let mut notifications = self.notifications.lock();
notifications.0 += 1;
notifications.1 = true;
self.condvar.notify_one();
}
}