use std::{
collections::{hash_map::Entry, HashMap},
future::Future,
pin::Pin,
sync::{Arc, Condvar, Mutex},
task::{Context, Poll, Waker},
};
pub struct TurnBasedSync {
pub current_turn: Mutex<usize>,
pub cv: Condvar,
}
impl TurnBasedSync {
pub fn new() -> Self {
TurnBasedSync { current_turn: Mutex::new(0), cv: Condvar::new() }
}
pub fn wait_for_turn(&self, my_turn: usize) {
let mut turn = self.current_turn.lock().unwrap();
while *turn != my_turn {
turn = self.cv.wait(turn).unwrap();
}
}
pub fn current_turn(&self) -> usize {
*self.current_turn.lock().unwrap()
}
pub fn advance_turn(&self) {
let mut turn: std::sync::MutexGuard<'_, usize> = self.current_turn.lock().unwrap();
*turn += 1;
self.cv.notify_all();
}
}
pub struct AsyncTurn {
inner: Arc<Mutex<AsyncTurnInner>>,
}
impl AsyncTurn {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(AsyncTurnInner { current_turn: 0, wakers: HashMap::new() })),
}
}
pub fn wait_for_turn(&self, my_turn: usize) -> AsyncTurnFuture {
AsyncTurnFuture { inner: self.inner.clone(), my_turn }
}
}
pub struct AsyncTurnInner {
current_turn: usize,
wakers: HashMap<usize, Waker>,
}
impl Clone for AsyncTurn {
fn clone(&self) -> Self {
Self { inner: Arc::clone(&self.inner) }
}
}
#[must_use = "Futures do nothing unless `await`ed"]
pub struct AsyncTurnFuture {
inner: Arc<Mutex<AsyncTurnInner>>,
my_turn: usize,
}
impl Future for AsyncTurnFuture {
type Output = AsyncTurnGuard;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let mut inner = this.inner.lock().expect("AsyncTurnFuture poisoned");
if inner.current_turn == this.my_turn {
return Poll::Ready(AsyncTurnGuard { inner: this.inner.clone() });
}
match inner.wakers.entry(this.my_turn) {
Entry::Vacant(v) => {
v.insert(cx.waker().clone());
}
Entry::Occupied(mut o) => {
let _ = o.insert(cx.waker().clone());
}
}
if inner.current_turn > this.my_turn {
#[cold]
#[inline(never)]
fn panic_turn_passed(turn: usize) -> ! {
panic!("AsyncTurnFuture: turn {turn} has already passed");
}
panic_turn_passed(this.my_turn);
} else {
Poll::Pending
}
}
}
pub struct AsyncTurnGuard {
inner: Arc<Mutex<AsyncTurnInner>>,
}
impl Drop for AsyncTurnGuard {
fn drop(&mut self) {
let mut lock = self.inner.lock().expect("AsyncTurnGuard poisoned");
lock.current_turn += 1;
if let Some(waker) = lock.wakers.get(&lock.current_turn) {
waker.wake_by_ref();
}
}
}