spry 0.0.4

Resilient, self-healing async process hierarchies in the style of Erlang/OTP
Documentation
use crate::{ChildName, Key};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use indexmap::IndexMap;
use pin_project::{pin_project, pinned_drop};
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::time::Duration;
use tokio::task::{AbortHandle, JoinError, JoinHandle};
use tokio::time::Instant;
use tokio_util::sync::WaitForCancellationFutureOwned;

use crate::error::{RestartReason, TerminationReason};
use crate::internal::{ChildConfig, ConcretePolicy};
use crate::signals::{SettlingToken, ShutdownToken, TerminationToken};

pub struct State(Liveness, ChildConfig, TerminationToken);

pub enum Liveness {
  Free(SettlingToken, AbortHandle),
  Settling(AbortHandle, Instant),
  Aborted,
}

impl Liveness {
  pub fn burden(&self) -> Burden {
    match self {
      Liveness::Free(_, _) => Burden::Free,
      Liveness::Settling(_, _) => Burden::Settling,
      Liveness::Aborted => Burden::Aborted,
    }
  }
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Burden {
  Free,
  Settling,
  Aborted,
}

impl Burden {
  pub fn tracing_name(&self) -> &'static str {
    match self {
      Burden::Free => "none",
      Burden::Settling => "settling",
      Burden::Aborted => "aborted",
    }
  }
}

#[pin_project(PinnedDrop)]
struct ChildHandle(#[pin] JoinHandle<()>);

#[pinned_drop]
impl PinnedDrop for ChildHandle {
  fn drop(self: Pin<&mut Self>) {
    let this = self.project();
    let _ = this.0.abort();
  }
}

impl Future for ChildHandle {
  type Output = Result<(), JoinError>;

  fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
    self.project().0.poll(cx)
  }
}

#[pin_project]
struct Annotated<K, F> {
  key: Option<K>,
  #[pin]
  future: F,
}

impl<K, F> Annotated<K, F> {
  pub fn new(key: K, future: F) -> Self {
    Self { key: Some(key), future }
  }
}

impl<K, F: Future> Future for Annotated<K, F> {
  type Output = (K, F::Output);

  fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
    let proj = self.project();

    // fuse this future if the key has already been taken
    if proj.key.is_none() {
      return std::task::Poll::Pending;
    }

    proj.future.poll(cx).map(|res| (proj.key.take().unwrap(), res))
  }
}

pub struct Nursery<K> {
  state: IndexMap<K, State>,
  shutdown_requests: FuturesUnordered<Annotated<K, WaitForCancellationFutureOwned>>,
  settling_timeouts: FuturesUnordered<Annotated<K, tokio::time::Sleep>>,
  handles: FuturesUnordered<Annotated<K, ChildHandle>>,
}

impl<K> Default for Nursery<K> {
  fn default() -> Self {
    Self {
      state: IndexMap::default(),
      shutdown_requests: FuturesUnordered::new(),
      settling_timeouts: FuturesUnordered::new(),
      handles: FuturesUnordered::new(),
    }
  }
}

#[derive(Debug)]
pub struct AlreadyExists<K> {
  pub key: K,
  pub config: ChildConfig,
  pub settling_token: SettlingToken,
  pub handle: JoinHandle<()>,
}

impl<K: Hash + Eq + Clone> Nursery<K> {
  pub fn try_insert(
    &mut self,
    key: K,
    config: ChildConfig,
    shutdown_token: ShutdownToken,
    settling_token: SettlingToken,
    termination_token: TerminationToken,
    handle: JoinHandle<()>,
  ) -> Result<(), AlreadyExists<K>> {
    if self.state.contains_key(&key) {
      Err(AlreadyExists { key, config, settling_token, handle })
    } else {
      let state = State(Liveness::Free(settling_token, handle.abort_handle()), config, termination_token);
      self.state.insert(key.clone(), state);
      self.shutdown_requests.push(Annotated::new(key.clone(), shutdown_token.0.cancelled_owned()));
      self.handles.push(Annotated::new(key, ChildHandle(handle)));

      Ok(())
    }
  }

  pub fn settle(&mut self, key: &K) {
    if let Some(State(liveness, config, _)) = self.state.get_mut(key) {
      match std::mem::replace(liveness, Liveness::Aborted) {
        Liveness::Free(settling_token, handle) => {
          settling_token.signal_settlement();
          let deadline = Instant::now() + config.settling_timeout.unwrap_or(Duration::MAX);
          self.settling_timeouts.push(Annotated::new(key.clone(), tokio::time::sleep_until(deadline)));
          *liveness = Liveness::Settling(handle, deadline);
        }
        lv => *liveness = lv,
      }
    }
  }

  pub fn abort(&mut self, key: &K) {
    if let Some(State(liveness, _, _)) = self.state.get_mut(key) {
      let lv = std::mem::replace(liveness, Liveness::Aborted);
      if let Liveness::Free(_, handle) | Liveness::Settling(handle, _) = lv {
        handle.abort();
      }
    }
  }

  pub fn contains_key(&self, key: &K) -> bool {
    self.state.contains_key(key)
  }

  pub async fn next_termination(&mut self) -> Option<TerminationReport<K>> {
    loop {
      tokio::select! {
        biased;
        Some((key, ())) = self.shutdown_requests.next() => self.settle(&key),
        Some((key, ())) = self.settling_timeouts.next() => self.abort(&key),
        Some((key, result)) = self.handles.next() => {
          let State(liveness, config, term_token) = self.state.shift_remove(&key)?;
          term_token.signal_termination();
          let reason = match result {
            Ok(()) => TerminationReason::Normal,
            Err(e) if e.is_cancelled() => TerminationReason::Aborted,
            Err(e) => TerminationReason::Panicked(e.into_panic()),
          };
          return Some(TerminationReport { key, config, burden: liveness.burden(), reason })
        },
        else => return None
      }
    }
  }

  pub async fn wait_for_key(&mut self, key: K) -> Option<TerminationReport<K>> {
    loop {
      match self.next_termination().await {
        None => return None, // all children are dead
        Some(result) if result.key == key => return Some(result),
        _ => continue,
      }
    }
  }

  pub fn last(&self) -> Option<&K> {
    self.state.last().map(|(k, _)| k)
  }
}

pub struct TerminationReport<K> {
  pub key: K,
  pub config: ChildConfig,
  pub burden: Burden,
  pub reason: TerminationReason,
}

impl<K: Key> TerminationReport<K> {
  pub fn into_restart_reason(self) -> Option<RestartReason> {
    // permanent children always want for a restart
    // transient children want for a restart iff they did not complete normally
    // temporary children never want for a restart
    let name = ChildName(self.key.name());
    match (self.config.policy, self.reason) {
      (ConcretePolicy::Permanent, result) => Some(RestartReason::PermanentChildTerminated(name, result)),
      (ConcretePolicy::Transient, TerminationReason::Panicked(reason)) => {
        Some(RestartReason::TransientChildFailed(name, reason))
      }
      (ConcretePolicy::Transient, _) => None,
      (ConcretePolicy::Temporary, _) => None,
    }
  }
}