use crate::{ChildName, Key};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use indexmap::IndexMap;
use pin_project::{pin_project, pinned_drop};
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::time::Duration;
use tokio::task::{AbortHandle, JoinError, JoinHandle};
use tokio::time::Instant;
use tokio_util::sync::WaitForCancellationFutureOwned;
use crate::error::{RestartReason, TerminationReason};
use crate::internal::{ChildConfig, ConcretePolicy};
use crate::signals::{SettlingToken, ShutdownToken, TerminationToken};
pub struct State(Liveness, ChildConfig, TerminationToken);
pub enum Liveness {
Free(SettlingToken, AbortHandle),
Settling(AbortHandle, Instant),
Aborted,
}
impl Liveness {
pub fn burden(&self) -> Burden {
match self {
Liveness::Free(_, _) => Burden::Free,
Liveness::Settling(_, _) => Burden::Settling,
Liveness::Aborted => Burden::Aborted,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Burden {
Free,
Settling,
Aborted,
}
impl Burden {
pub fn tracing_name(&self) -> &'static str {
match self {
Burden::Free => "none",
Burden::Settling => "settling",
Burden::Aborted => "aborted",
}
}
}
#[pin_project(PinnedDrop)]
struct ChildHandle(#[pin] JoinHandle<()>);
#[pinned_drop]
impl PinnedDrop for ChildHandle {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
let _ = this.0.abort();
}
}
impl Future for ChildHandle {
type Output = Result<(), JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.project().0.poll(cx)
}
}
#[pin_project]
struct Annotated<K, F> {
key: Option<K>,
#[pin]
future: F,
}
impl<K, F> Annotated<K, F> {
pub fn new(key: K, future: F) -> Self {
Self { key: Some(key), future }
}
}
impl<K, F: Future> Future for Annotated<K, F> {
type Output = (K, F::Output);
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
let proj = self.project();
if proj.key.is_none() {
return std::task::Poll::Pending;
}
proj.future.poll(cx).map(|res| (proj.key.take().unwrap(), res))
}
}
pub struct Nursery<K> {
state: IndexMap<K, State>,
shutdown_requests: FuturesUnordered<Annotated<K, WaitForCancellationFutureOwned>>,
settling_timeouts: FuturesUnordered<Annotated<K, tokio::time::Sleep>>,
handles: FuturesUnordered<Annotated<K, ChildHandle>>,
}
impl<K> Default for Nursery<K> {
fn default() -> Self {
Self {
state: IndexMap::default(),
shutdown_requests: FuturesUnordered::new(),
settling_timeouts: FuturesUnordered::new(),
handles: FuturesUnordered::new(),
}
}
}
#[derive(Debug)]
pub struct AlreadyExists<K> {
pub key: K,
pub config: ChildConfig,
pub settling_token: SettlingToken,
pub handle: JoinHandle<()>,
}
impl<K: Hash + Eq + Clone> Nursery<K> {
pub fn try_insert(
&mut self,
key: K,
config: ChildConfig,
shutdown_token: ShutdownToken,
settling_token: SettlingToken,
termination_token: TerminationToken,
handle: JoinHandle<()>,
) -> Result<(), AlreadyExists<K>> {
if self.state.contains_key(&key) {
Err(AlreadyExists { key, config, settling_token, handle })
} else {
let state = State(Liveness::Free(settling_token, handle.abort_handle()), config, termination_token);
self.state.insert(key.clone(), state);
self.shutdown_requests.push(Annotated::new(key.clone(), shutdown_token.0.cancelled_owned()));
self.handles.push(Annotated::new(key, ChildHandle(handle)));
Ok(())
}
}
pub fn settle(&mut self, key: &K) {
if let Some(State(liveness, config, _)) = self.state.get_mut(key) {
match std::mem::replace(liveness, Liveness::Aborted) {
Liveness::Free(settling_token, handle) => {
settling_token.signal_settlement();
let deadline = Instant::now() + config.settling_timeout.unwrap_or(Duration::MAX);
self.settling_timeouts.push(Annotated::new(key.clone(), tokio::time::sleep_until(deadline)));
*liveness = Liveness::Settling(handle, deadline);
}
lv => *liveness = lv,
}
}
}
pub fn abort(&mut self, key: &K) {
if let Some(State(liveness, _, _)) = self.state.get_mut(key) {
let lv = std::mem::replace(liveness, Liveness::Aborted);
if let Liveness::Free(_, handle) | Liveness::Settling(handle, _) = lv {
handle.abort();
}
}
}
pub fn contains_key(&self, key: &K) -> bool {
self.state.contains_key(key)
}
pub async fn next_termination(&mut self) -> Option<TerminationReport<K>> {
loop {
tokio::select! {
biased;
Some((key, ())) = self.shutdown_requests.next() => self.settle(&key),
Some((key, ())) = self.settling_timeouts.next() => self.abort(&key),
Some((key, result)) = self.handles.next() => {
let State(liveness, config, term_token) = self.state.shift_remove(&key)?;
term_token.signal_termination();
let reason = match result {
Ok(()) => TerminationReason::Normal,
Err(e) if e.is_cancelled() => TerminationReason::Aborted,
Err(e) => TerminationReason::Panicked(e.into_panic()),
};
return Some(TerminationReport { key, config, burden: liveness.burden(), reason })
},
else => return None
}
}
}
pub async fn wait_for_key(&mut self, key: K) -> Option<TerminationReport<K>> {
loop {
match self.next_termination().await {
None => return None, Some(result) if result.key == key => return Some(result),
_ => continue,
}
}
}
pub fn last(&self) -> Option<&K> {
self.state.last().map(|(k, _)| k)
}
}
pub struct TerminationReport<K> {
pub key: K,
pub config: ChildConfig,
pub burden: Burden,
pub reason: TerminationReason,
}
impl<K: Key> TerminationReport<K> {
pub fn into_restart_reason(self) -> Option<RestartReason> {
let name = ChildName(self.key.name());
match (self.config.policy, self.reason) {
(ConcretePolicy::Permanent, result) => Some(RestartReason::PermanentChildTerminated(name, result)),
(ConcretePolicy::Transient, TerminationReason::Panicked(reason)) => {
Some(RestartReason::TransientChildFailed(name, reason))
}
(ConcretePolicy::Transient, _) => None,
(ConcretePolicy::Temporary, _) => None,
}
}
}