use std::collections::HashMap;
use std::fmt::Display;
use std::future::Future;
use std::hash::Hash;
use std::marker::PhantomData;
use std::panic::AssertUnwindSafe;
use futures::FutureExt;
use tokio::task::JoinHandle;
use crate::breaker::{Breaker, RateLimited};
use crate::builder::{ChildBuilder, SystemBuilder};
use crate::error::Escalation;
use crate::internal;
use crate::nursery::Nursery;
use crate::policy::Permanent;
use crate::signals::{SettlingToken, ShutdownToken, TerminationToken};
use crate::system::SystemSupervisor;
pub trait System<K: Key>: Send + Sync + 'static {
fn start(&self, scope: &mut Scope<K>) -> impl Future<Output = ()> + Send;
}
pub trait Child<'a, R>: Send + 'a {
fn start(self, ctx: Context) -> impl Future<Output = MainLoop<R>> + Send + 'a;
}
impl<'a, F, Fut, R: 'a> Child<'a, R> for F
where
F: FnOnce(Context) -> Fut + Send + 'a,
Fut: Future<Output = MainLoop<R>> + Send + 'a,
{
async fn start(self, ctx: Context) -> MainLoop<R> {
self(ctx).await
}
}
pub trait Key: Clone + Hash + Eq + Send + 'static {
fn name(&self) -> String;
}
impl Key for () {
fn name(&self) -> String {
"unit".to_string()
}
}
macro_rules! key_from_to_string {
($($t:ty),*) => { $(impl Key for $t { fn name(&self) -> String { self.to_string() } })* };
}
key_from_to_string!(&'static str, String, usize, u64, u32, u16, u8, isize, i64, i32, i16, i8);
#[derive(Clone, Debug)]
pub struct ChildName(pub(crate) String);
impl Display for ChildName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct Scope<K> {
pub(crate) context: Context,
pub(crate) history: HashMap<K, internal::ProcessOutcome>,
pub(crate) nursery: Nursery<K>,
}
impl<K: Key> Scope<K> {
pub fn child<'a, R, C: Child<'a, R>>(&'a mut self, key: K, child: C) -> ChildBuilder<'a, K, Permanent, R> {
ChildBuilder::new(self, key, Box::new(|cx| Box::pin(child.start(cx))))
}
pub fn system<Kp: Key, S: System<Kp>>(&mut self, key: K, system: S) -> SystemBuilder<K, Permanent, S, RateLimited> {
SystemBuilder::new(key, system, self)
}
}
pub struct Context {
name: ChildName,
settling_token: SettlingToken,
shutdown_token: ShutdownToken,
termination_token: TerminationToken,
}
impl Context {
pub fn name(&self) -> &ChildName {
&self.name
}
pub async fn settled(&self) {
self.settling_token.settled().await
}
pub fn is_settled(&self) -> bool {
self.settling_token.is_settled()
}
pub fn settling_token(&self) -> &SettlingToken {
&self.settling_token
}
pub fn termination_token(&self) -> &TerminationToken {
&self.termination_token
}
pub fn shutdown_token(&self) -> &ShutdownToken {
&self.shutdown_token
}
pub(crate) fn new_child_context(&self, name_extension: String) -> Self {
let name = ChildName(format!("{parent}.{name_extension}", parent = self.name.0));
Self {
name,
settling_token: SettlingToken::new(),
shutdown_token: ShutdownToken::new(),
termination_token: TerminationToken::new(),
}
}
}
pub struct MainLoop<R> {
pub(crate) handle: JoinHandle<()>,
pub(crate) return_value: R,
}
impl MainLoop<()> {
pub fn new(handle: JoinHandle<()>) -> Self {
Self { handle, return_value: () }
}
}
impl<R> MainLoop<R> {
pub fn new_returning(value: R, handle: JoinHandle<()>) -> Self {
Self { handle, return_value: value }
}
pub fn with<Rp>(self, return_value: Rp) -> MainLoop<Rp> {
MainLoop { handle: self.handle, return_value }
}
}
impl From<JoinHandle<()>> for MainLoop<()> {
fn from(handle: JoinHandle<()>) -> Self {
Self { handle, return_value: () }
}
}
pub struct Toplevel<K, S, B> {
name: ChildName,
settling_token: SettlingToken,
shutdown_token: ShutdownToken,
termination_token: TerminationToken,
breaker: B,
system: S,
phantom: PhantomData<K>,
}
impl<K, S> Toplevel<K, S, RateLimited> {
pub fn new<N: ToString>(name: N, system: S) -> Self {
let name = ChildName(name.to_string());
let settling_token = SettlingToken::new();
let shutdown_token = settling_token.clone().into_shutdown_token();
let termination_token = TerminationToken::new();
let breaker = RateLimited::default();
Self { name, settling_token, shutdown_token, termination_token, breaker, system, phantom: PhantomData }
}
}
impl<K, S, P> Toplevel<K, S, P> {
pub fn with_breaker<Pp: Breaker<K>>(self, breaker: Pp) -> Toplevel<K, S, Pp> {
Toplevel {
breaker,
name: self.name,
settling_token: self.settling_token,
shutdown_token: self.shutdown_token,
termination_token: self.termination_token,
system: self.system,
phantom: PhantomData,
}
}
pub fn shutdown_token(&self) -> &ShutdownToken {
&self.shutdown_token
}
pub fn termination_token(&self) -> &TerminationToken {
&self.termination_token
}
}
impl<K: Key, S: System<K>, P: Breaker<K>> Toplevel<K, S, P> {
#[tracing::instrument(target = "spry", level = "info", name = "toplevel", skip_all)]
pub async fn start(self) -> Result<(), Escalation> {
let termination_token = self.termination_token.clone();
let context = Context {
name: self.name,
settling_token: self.settling_token.clone(),
shutdown_token: self.settling_token.into_shutdown_token(), termination_token: termination_token.clone(),
};
let sup = SystemSupervisor::new(self.system, self.breaker);
let main = sup.start_inner(context).await;
let result = AssertUnwindSafe(main).catch_unwind().await;
termination_token.signal_termination();
match result {
Ok(()) => Ok(()),
Err(panic) => match panic.downcast() {
Ok(known) => Err(*known),
Err(unknown) => Err(Escalation::Unknown(unknown)),
},
}
}
}