use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::task::Context;
use crate::internal::Mutex;
use crate::internal::WaitSet;
#[derive(Debug)]
pub(crate) struct CountdownState {
state: AtomicU32,
waiters: Mutex<WaitSet>,
}
impl CountdownState {
pub(crate) const fn new(count: u32) -> Self {
Self {
state: AtomicU32::new(count),
waiters: Mutex::new(WaitSet::new()),
}
}
pub(crate) fn state(&self) -> u32 {
self.state.load(Ordering::Acquire)
}
pub(crate) fn cas_state(&self, current: u32, new: u32) -> Result<(), u32> {
self.state
.compare_exchange_weak(current, new, Ordering::Release, Ordering::Relaxed)
.map(|_| ())
}
pub(crate) fn wake_all(&self) {
let mut waiters = self.waiters.lock();
waiters.wake_all();
}
pub(crate) fn register_waker(&self, idx: &mut Option<usize>, cx: &mut Context<'_>) {
let mut waiters = self.waiters.lock();
waiters.register_waker(idx, cx);
}
pub(crate) fn spin_wait(&self, n: usize) -> Result<(), u32> {
for _ in 0..n {
if self.state() == 0 {
return Ok(());
}
std::hint::spin_loop();
}
match self.state() {
0 => Ok(()),
s => Err(s),
}
}
pub(crate) fn decrement(&self, n: u32) -> bool {
let mut cnt = self.state();
loop {
if cnt == 0 {
return false;
}
let new_cnt = cnt.saturating_sub(n);
match self.cas_state(cnt, new_cnt) {
Ok(_) => return new_cnt == 0,
Err(x) => cnt = x,
}
}
}
}