use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
pub(super) struct OnDemandGateGuard<'a> {
gate: &'a Arc<AtomicBool>,
}
impl<'a> OnDemandGateGuard<'a> {
pub(super) fn try_acquire(gate: &'a Arc<AtomicBool>) -> Option<Self> {
if gate.swap(true, Ordering::AcqRel) {
None
} else {
Some(Self { gate })
}
}
}
impl Drop for OnDemandGateGuard<'_> {
fn drop(&mut self) {
self.gate.store(false, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_acquire_on_free_gate_succeeds_and_holds() {
let gate = Arc::new(AtomicBool::new(false));
let guard = OnDemandGateGuard::try_acquire(&gate).expect("free gate");
assert!(gate.load(Ordering::Acquire), "gate held after acquire");
drop(guard);
}
#[test]
fn try_acquire_on_held_gate_returns_none() {
let gate = Arc::new(AtomicBool::new(true));
let opt = OnDemandGateGuard::try_acquire(&gate);
assert!(opt.is_none(), "held gate must reject");
assert!(gate.load(Ordering::Acquire), "gate still held");
}
#[test]
fn drop_releases_gate_for_next_acquire() {
let gate = Arc::new(AtomicBool::new(false));
let guard = OnDemandGateGuard::try_acquire(&gate).expect("acquire 1");
drop(guard);
assert!(!gate.load(Ordering::Acquire), "gate released by Drop");
let _guard2 = OnDemandGateGuard::try_acquire(&gate).expect("acquire 2 after release");
assert!(gate.load(Ordering::Acquire), "gate re-held by acquire 2");
}
#[test]
fn drop_on_panic_releases_gate() {
let gate = Arc::new(AtomicBool::new(false));
let gate_for_catch = gate.clone();
let result = std::panic::catch_unwind(move || {
let _guard =
OnDemandGateGuard::try_acquire(&gate_for_catch).expect("acquire before panic");
panic!("simulated panic mid-critical-section");
});
assert!(result.is_err(), "catch_unwind must see the panic");
assert!(
!gate.load(Ordering::Acquire),
"gate released even under panic unwind"
);
}
#[test]
fn sequential_acquire_drop_cycles() {
let gate = Arc::new(AtomicBool::new(false));
for _ in 0..10 {
let guard = OnDemandGateGuard::try_acquire(&gate).expect("acquire in loop");
assert!(gate.load(Ordering::Acquire));
drop(guard);
assert!(!gate.load(Ordering::Acquire));
}
}
#[test]
fn no_bare_gate_atomic_mutations_in_mod_rs() {
let src = include_str!("mod.rs");
let needle = "freeze_coord_on_demand_in_flight";
let mut cursor = 0;
let mut bare_swap = 0_usize;
let mut bare_store = 0_usize;
while let Some(pos) = src[cursor..].find(needle) {
cursor += pos + needle.len();
let tail = src[cursor..].trim_start();
if tail.starts_with(".swap(") {
bare_swap += 1;
} else if tail.starts_with(".store(") {
bare_store += 1;
}
}
assert_eq!(
bare_swap, 0,
"mod.rs must not call freeze_coord_on_demand_in_flight.swap() — \
use OnDemandGateGuard::try_acquire() instead"
);
assert_eq!(
bare_store, 0,
"mod.rs must not call freeze_coord_on_demand_in_flight.store() — \
the guard's Drop is the only release path"
);
}
#[test]
fn concurrent_try_acquire_is_mutually_exclusive() {
use std::sync::atomic::AtomicUsize;
use std::thread;
let gate = Arc::new(AtomicBool::new(false));
let in_critical = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..2 {
let gate = gate.clone();
let in_critical = in_critical.clone();
let max_concurrent = max_concurrent.clone();
handles.push(thread::spawn(move || {
for _ in 0..1000 {
if let Some(_guard) = OnDemandGateGuard::try_acquire(&gate) {
let cur = in_critical.fetch_add(1, Ordering::AcqRel) + 1;
let mut prev_max = max_concurrent.load(Ordering::Acquire);
while cur > prev_max {
match max_concurrent.compare_exchange(
prev_max,
cur,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => break,
Err(actual) => prev_max = actual,
}
}
in_critical.fetch_sub(1, Ordering::AcqRel);
}
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(
max_concurrent.load(Ordering::Acquire),
1,
"at most one guard alive at any time"
);
}
}