use std::{
future::Future,
num::NonZeroUsize,
pin::Pin,
sync::atomic::Ordering,
sync::Arc,
task::{Context, Poll}
};
use super::Shared;
pub struct KillWait {
pub(super) ctx: Arc<Shared>
}
impl KillWait {
pub fn wait(&self) -> KillWaitFuture {
KillWaitFuture {
ctx: Arc::clone(&self.ctx),
id: None
}
}
}
impl Clone for KillWait {
fn clone(&self) -> KillWait {
KillWait {
ctx: Arc::clone(&self.ctx)
}
}
}
impl Drop for KillWait {
fn drop(&mut self) {
let mut state = self.ctx.state.lock();
if Arc::<super::Shared>::strong_count(&self.ctx) == 2
&& state.waker.is_some()
{
if let Some(waker) = state.waker.take() {
waker.wake();
}
}
}
}
pub struct KillWaitFuture {
ctx: Arc<Shared>,
id: Option<NonZeroUsize>
}
impl Future for KillWaitFuture {
type Output = ();
fn poll(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>
) -> Poll<Self::Output> {
match self.ctx.triggered.load(Ordering::SeqCst) {
true => Poll::Ready(()),
false => {
if self.id.is_none() {
let mut state = self.ctx.state.lock();
let id = loop {
let id = self.ctx.id();
if !state.waiting.contains_key(&id) {
break id;
}
};
state.waiting.insert(id, ctx.waker().clone());
drop(state);
self.id = Some(unsafe { NonZeroUsize::new_unchecked(id) });
}
Poll::Pending
}
}
}
}
impl Drop for KillWaitFuture {
fn drop(&mut self) {
if let Some(id) = self.id {
let mut state = self.ctx.state.lock();
state.waiting.remove(&id.get());
}
}
}