use std::hint::spin_loop;
use std::sync::atomic::{AtomicUsize, Ordering};
const WRITE_BIT: usize = 1usize << (usize::BITS - 1);
const COUNT_MASK: usize = WRITE_BIT - 1;
#[derive(Debug)]
pub(crate) struct MaintenanceGate {
state: AtomicUsize,
}
impl Default for MaintenanceGate {
fn default() -> Self {
Self::new()
}
}
impl MaintenanceGate {
#[must_use]
pub(crate) const fn new() -> Self {
Self {
state: AtomicUsize::new(0),
}
}
pub(crate) fn enter_shared(&self) -> MaintenanceReadGuard<'_> {
loop {
let cur = self.state.load(Ordering::Acquire);
if cur & WRITE_BIT != 0 || cur & COUNT_MASK == COUNT_MASK {
spin_loop();
continue;
}
if self
.state
.compare_exchange_weak(cur, cur + 1, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return MaintenanceReadGuard { gate: self };
}
}
}
pub(crate) fn enter_exclusive(&self) -> MaintenanceWriteGuard<'_> {
loop {
let cur = self.state.load(Ordering::Acquire);
if cur & WRITE_BIT != 0 {
spin_loop();
continue;
}
if self
.state
.compare_exchange_weak(cur, cur | WRITE_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
while self.state.load(Ordering::Acquire) & COUNT_MASK != 0 {
spin_loop();
}
return MaintenanceWriteGuard { gate: self };
}
}
}
fn leave_shared(&self) {
self.state.fetch_sub(1, Ordering::Release);
}
fn leave_exclusive(&self) {
self.state.store(0, Ordering::Release);
}
#[cfg(test)]
fn writer_pending_for_test(&self) -> bool {
self.state.load(Ordering::Acquire) & WRITE_BIT != 0
}
}
#[derive(Debug)]
pub(crate) struct MaintenanceReadGuard<'a> {
gate: &'a MaintenanceGate,
}
impl Drop for MaintenanceReadGuard<'_> {
fn drop(&mut self) {
self.gate.leave_shared();
}
}
#[derive(Debug)]
pub(crate) struct MaintenanceWriteGuard<'a> {
gate: &'a MaintenanceGate,
}
impl Drop for MaintenanceWriteGuard<'_> {
fn drop(&mut self) {
self.gate.leave_exclusive();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc::sync_channel;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn exclusive_waits_for_shared_guard() {
let gate = Arc::new(MaintenanceGate::new());
let shared = gate.enter_shared();
let worker_gate = Arc::clone(&gate);
let (started_tx, started_rx) = sync_channel(0);
let (done_tx, done_rx) = sync_channel(0);
let handle = thread::spawn(move || {
started_tx.send(()).unwrap();
let _exclusive = worker_gate.enter_exclusive();
done_tx.send(()).unwrap();
});
started_rx.recv().unwrap();
assert!(done_rx.recv_timeout(Duration::from_millis(50)).is_err());
drop(shared);
done_rx.recv_timeout(Duration::from_secs(1)).unwrap();
handle.join().unwrap();
}
#[test]
fn pending_exclusive_blocks_new_shared_entries() {
let gate = Arc::new(MaintenanceGate::new());
let shared = gate.enter_shared();
let exclusive_gate = Arc::clone(&gate);
let (exclusive_started_tx, exclusive_started_rx) = sync_channel(0);
let (release_tx, release_rx) = sync_channel(0);
let exclusive = thread::spawn(move || {
exclusive_started_tx.send(()).unwrap();
let _exclusive = exclusive_gate.enter_exclusive();
release_rx.recv().unwrap();
});
exclusive_started_rx.recv().unwrap();
while !gate.writer_pending_for_test() {
spin_loop();
}
let shared_gate = Arc::clone(&gate);
let (shared_done_tx, shared_done_rx) = sync_channel(0);
let shared_waiter = thread::spawn(move || {
let _shared = shared_gate.enter_shared();
shared_done_tx.send(()).unwrap();
});
drop(shared);
assert!(
shared_done_rx
.recv_timeout(Duration::from_millis(50))
.is_err(),
"new shared entrant must wait behind pending exclusive"
);
release_tx.send(()).unwrap();
exclusive.join().unwrap();
shared_done_rx.recv_timeout(Duration::from_secs(1)).unwrap();
shared_waiter.join().unwrap();
}
}