use std::{
future::Future,
hash::Hash,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc
},
task::{Context, Poll, Waker},
thread,
time::Duration
};
use parking_lot::{Condvar, Mutex};
use hashbrown::{HashMap, HashSet};
struct Inner<T> {
checkpoints: HashSet<T>,
wakers: HashMap<usize, Waker>
}
impl<T> Inner<T> {
fn new() -> Self {
Inner {
checkpoints: HashSet::new(),
wakers: HashMap::new()
}
}
}
struct Shared<T> {
inner: Mutex<Inner<T>>,
delay: Option<Duration>,
signal: Condvar,
idgen: AtomicUsize
}
impl<T> Shared<T> {
fn new(inner: Inner<T>, delay: Option<Duration>) -> Self {
Self {
inner: Mutex::new(inner),
delay,
signal: Condvar::new(),
idgen: AtomicUsize::new(0)
}
}
fn wake_all(&self, inner: &mut Inner<T>) {
self.signal.notify_all();
for (_, waker) in inner.wakers.drain() {
waker.wake();
}
}
}
#[derive(Clone)]
pub struct Checkpoint<T>(Arc<Shared<T>>);
impl<T> Checkpoint<T>
where
T: PartialEq + Eq + Hash
{
pub fn new() -> Self {
Self::default()
}
pub fn with_delay(dur: Duration) -> Self {
let inner = Inner::new();
let shared = Shared::new(inner, Some(dur));
Self(Arc::new(shared))
}
pub fn reached(&self, checkpoint: T) -> &Self {
let mut inner = self.0.inner.lock();
inner.checkpoints.insert(checkpoint);
self.0.wake_all(&mut inner);
drop(inner);
if let Some(dur) = self.0.delay {
thread::sleep(dur);
}
self
}
pub fn waitfor<I>(&self, checkpoints: I) -> &Self
where
I: IntoIterator<Item = T>
{
let checkpoints = HashSet::from_iter(checkpoints);
let mut inner = self.0.inner.lock();
while !checkpoints.is_subset(&inner.checkpoints) {
self.0.signal.wait(&mut inner);
}
drop(inner);
if let Some(dur) = self.0.delay {
thread::sleep(dur);
}
self
}
pub fn async_waitfor<I>(&self, checkpoints: I) -> WaitForCheckpoints<T>
where
I: IntoIterator<Item = T>
{
WaitForCheckpoints {
sh: Arc::clone(&self.0),
checkpoints: HashSet::from_iter(checkpoints),
id: None
}
}
}
impl<T> Default for Checkpoint<T> {
fn default() -> Self {
let inner = Inner::new();
let shared = Shared::new(inner, None);
Self(Arc::new(shared))
}
}
pub struct WaitForCheckpoints<T>
where
T: PartialEq + Eq + Hash
{
sh: Arc<Shared<T>>,
checkpoints: HashSet<T>,
id: Option<usize>
}
impl<T> Future for WaitForCheckpoints<T>
where
T: PartialEq + Eq + Hash + Unpin
{
type Output = ();
fn poll(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>
) -> Poll<Self::Output> {
let mut inner = self.sh.inner.lock();
if self.checkpoints.is_subset(&inner.checkpoints) {
Poll::Ready(())
} else {
let id = loop {
let id = self.sh.idgen.fetch_add(1, Ordering::SeqCst);
if !inner.wakers.contains_key(&id) {
break id;
}
};
inner.wakers.insert(id, ctx.waker().clone());
drop(inner);
self.id = Some(id);
Poll::Pending
}
}
}
impl<T> Drop for WaitForCheckpoints<T>
where
T: Eq + Hash
{
fn drop(&mut self) {
if let Some(id) = self.id {
let mut inner = self.sh.inner.lock();
inner.wakers.remove(&id);
}
}
}