use std::sync::{Condvar, Mutex};
use std::task::Waker;
use std::time::{Duration, Instant};
#[cfg(feature = "async")]
use std::future::Future;
#[cfg(feature = "async")]
use std::pin::Pin;
#[cfg(feature = "async")]
use std::sync::Arc;
#[cfg(feature = "async")]
use std::task::{Context, Poll};
use crate::{Error, ExitStatus, Result};
pub(crate) struct ExitNotifier {
state: Mutex<ExitNotifierState>,
condvar: Condvar,
}
enum ExitNotifierState {
Pending {
#[cfg_attr(
not(feature = "async"),
expect(
dead_code,
reason = "key allocator only used by the async ExitWaitFuture path"
)
)]
next_key: u64,
wakers: Vec<(u64, Waker)>,
},
Done(ExitStatus),
Failed,
}
fn failure_error() -> Error {
Error::ExitStatus(tastty::Error::ExitStatusUnavailable)
}
impl ExitNotifier {
pub(crate) fn new() -> Self {
Self {
state: Mutex::new(ExitNotifierState::Pending {
next_key: 0,
wakers: Vec::new(),
}),
condvar: Condvar::new(),
}
}
pub(crate) fn notify_exit(&self, status: ExitStatus) {
let wakers = {
let mut state = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match std::mem::replace(&mut *state, ExitNotifierState::Done(status)) {
ExitNotifierState::Pending { wakers, .. } => wakers,
other => {
*state = other;
return;
}
}
};
self.condvar.notify_all();
for (_, waker) in wakers {
waker.wake();
}
}
pub(crate) fn notify_failure(&self) {
let wakers = {
let mut state = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match std::mem::replace(&mut *state, ExitNotifierState::Failed) {
ExitNotifierState::Pending { wakers, .. } => wakers,
other => {
*state = other;
return;
}
}
};
self.condvar.notify_all();
for (_, waker) in wakers {
waker.wake();
}
}
pub(crate) fn wait_blocking(&self) -> Result<ExitStatus> {
let mut guard = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
loop {
match &*guard {
ExitNotifierState::Done(status) => return Ok(*status),
ExitNotifierState::Failed => return Err(failure_error()),
ExitNotifierState::Pending { .. } => {
guard = self.condvar.wait(guard).unwrap_or_else(|p| p.into_inner());
}
}
}
}
pub(crate) fn wait_blocking_timeout(&self, timeout: Duration) -> Option<Result<ExitStatus>> {
let deadline = Instant::now().checked_add(timeout);
let mut guard = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
loop {
match &*guard {
ExitNotifierState::Done(status) => return Some(Ok(*status)),
ExitNotifierState::Failed => return Some(Err(failure_error())),
ExitNotifierState::Pending { .. } => {
let remaining = match deadline {
Some(deadline) => deadline.saturating_duration_since(Instant::now()),
None => Duration::MAX,
};
if remaining.is_zero() {
return None;
}
let (next, result) = self
.condvar
.wait_timeout(guard, remaining)
.unwrap_or_else(|p| p.into_inner());
guard = next;
if result.timed_out() && matches!(*guard, ExitNotifierState::Pending { .. }) {
return None;
}
}
}
}
}
}
#[cfg(feature = "async")]
impl ExitNotifier {
fn poll_status(&self, key_slot: &mut Option<u64>, waker: &Waker) -> Poll<Result<ExitStatus>> {
let mut state = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
match &mut *state {
ExitNotifierState::Done(status) => Poll::Ready(Ok(*status)),
ExitNotifierState::Failed => Poll::Ready(Err(failure_error())),
ExitNotifierState::Pending { next_key, wakers } => {
if let Some(key) = *key_slot {
for (k, w) in wakers.iter_mut() {
if *k == key {
if !w.will_wake(waker) {
*w = waker.clone();
}
return Poll::Pending;
}
}
}
let key = *next_key;
*next_key += 1;
wakers.push((key, waker.clone()));
*key_slot = Some(key);
Poll::Pending
}
}
}
fn unregister(&self, key: u64) {
let mut state = self
.state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if let ExitNotifierState::Pending { wakers, .. } = &mut *state {
wakers.retain(|(k, _)| *k != key);
}
}
}
#[cfg(feature = "async")]
pub(crate) struct ExitWaitFuture {
notifier: Arc<ExitNotifier>,
key: Option<u64>,
done: bool,
}
#[cfg(feature = "async")]
impl ExitWaitFuture {
pub(crate) fn new(notifier: Arc<ExitNotifier>) -> Self {
Self {
notifier,
key: None,
done: false,
}
}
}
#[cfg(feature = "async")]
impl Future for ExitWaitFuture {
type Output = Result<ExitStatus>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
assert!(!this.done, "ExitWaitFuture polled after completion");
match this.notifier.poll_status(&mut this.key, cx.waker()) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
this.done = true;
this.key = None;
Poll::Ready(result)
}
}
}
}
#[cfg(feature = "async")]
impl Drop for ExitWaitFuture {
fn drop(&mut self) {
if let Some(key) = self.key {
self.notifier.unregister(key);
}
}
}