super-visor 0.2.1

Simple ordered startup and shutdown for long-running tokio processes
Documentation
mod ordered_select_all;

use crate::ordered_select_all::ordered_select_all;
use anyhow::Result;
use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
use std::{
    pin::{pin, Pin},
    task::{Context, Poll},
};
use tokio::signal;
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};

pub struct ShutdownSignal {
    token: CancellationToken,
    future: Option<Pin<Box<WaitForCancellationFutureOwned>>>,
}

impl ShutdownSignal {
    pub fn new(token: CancellationToken) -> Self {
        Self {
            token,
            future: None,
        }
    }

    pub fn is_cancelled(&self) -> bool {
        self.token.is_cancelled()
    }

    pub fn token(&self) -> &CancellationToken {
        &self.token
    }
}

impl Future for ShutdownSignal {
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // Lazy init the future on first poll
        if self.future.is_none() {
            self.future = Some(Box::pin(self.token.clone().cancelled_owned()));
        }

        // Poll the cached future
        self.future.as_mut().unwrap().as_mut().poll(cx)
    }
}

impl Clone for ShutdownSignal {
    fn clone(&self) -> Self {
        Self {
            token: self.token.clone(),
            future: None,
        }
    }
}

impl Unpin for ShutdownSignal {}

pub trait ManagedProc {
    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>>;
}

pub struct Supervisor {
    procs: Vec<Box<dyn ManagedProc>>,
}

impl ManagedProc for Supervisor {
    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
        Box::pin(self.do_run(Box::pin(shutdown)))
    }
}

pub struct SupervisorBuilder {
    procs: Vec<Box<dyn ManagedProc>>,
}

struct CancelableLocalFuture {
    cancel_token: CancellationToken,
    future: LocalBoxFuture<'static, Result<()>>,
}

impl Future for CancelableLocalFuture {
    type Output = Result<()>;

    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        pin!(&mut self.future).poll(ctx)
    }
}

impl<O, P> ManagedProc for P
where
    O: Future<Output = Result<()>> + 'static,
    P: FnOnce(ShutdownSignal) -> O,
{
    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
        Box::pin(self(shutdown))
    }
}

impl Default for Supervisor {
    fn default() -> Self {
        Self::new()
    }
}

impl Supervisor {
    pub fn new() -> Self {
        Self { procs: Vec::new() }
    }

    pub fn builder() -> SupervisorBuilder {
        SupervisorBuilder { procs: Vec::new() }
    }

    pub fn add(&mut self, proc: impl ManagedProc + 'static) {
        self.procs.push(Box::new(proc));
    }

    pub async fn start(self) -> Result<()> {
        let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
        let shutdown = Box::pin(
            futures::future::select(
                Box::pin(async move { sigterm.recv().await }),
                Box::pin(signal::ctrl_c()),
            )
            .map(|_| ()),
        );
        self.do_run(shutdown).await
    }

    async fn do_run(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
        let mut futures = start_futures(self.procs);

        loop {
            if futures.is_empty() {
                break;
            }

            let mut select = ordered_select_all(futures);

            tokio::select! {
                biased;
                _ = &mut shutdown => return stop_all(select.into_inner()).await,
                (result, _index, remaining) = &mut select => match result {
                    Ok(_) => futures = remaining,
                    Err(err) => {
                        let _ = stop_all(remaining).await;
                        return Err(err);
                    }
                }
            }
        }

        Ok(())
    }
}

impl SupervisorBuilder {
    pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
        self.procs.push(Box::new(proc));
        self
    }

    pub fn build(self) -> Supervisor {
        Supervisor { procs: self.procs }
    }
}

fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
    procs
        .into_iter()
        .map(|proc| {
            let cancel_token = CancellationToken::new();
            let child_token = cancel_token.child_token();
            CancelableLocalFuture {
                cancel_token,
                future: proc.run_proc(ShutdownSignal::new(child_token)),
            }
        })
        .collect()
}

async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
    futures::stream::iter(procs.into_iter().rev())
        .then(|proc| async move {
            proc.cancel_token.cancel();
            proc.future.await
        })
        .collect::<Vec<_>>()
        .await
        .into_iter()
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;
    use anyhow::anyhow;
    use futures::TryFutureExt;
    use tokio::sync::mpsc;

    struct TestProc {
        name: &'static str,
        delay: u64,
        result: Result<()>,
        sender: mpsc::Sender<&'static str>,
    }

    impl ManagedProc for TestProc {
        fn run_proc(
            self: Box<Self>,
            shutdown: ShutdownSignal,
        ) -> LocalBoxFuture<'static, Result<()>> {
            let handle = tokio::spawn(async move {
                tokio::select! {
                    _ = shutdown => (),
                    _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
                }
                self.sender.send(self.name).await.expect("unable to send");
                self.result
            });

            Box::pin(
                handle
                    .map_err(|err| err.into())
                    .and_then(|result| async move { result }),
            )
        }
    }

    #[tokio::test]
    async fn stop_when_all_tasks_have_completed() {
        let (sender, mut receiver) = mpsc::channel(5);

        let result = Supervisor::builder()
            .add_proc(TestProc {
                name: "1",
                delay: 50,
                result: Ok(()),
                sender: sender.clone(),
            })
            .add_proc(TestProc {
                name: "2",
                delay: 100,
                result: Ok(()),
                sender: sender.clone(),
            })
            .build()
            .start()
            .await;

        assert_eq!(Some("1"), receiver.recv().await);
        assert_eq!(Some("2"), receiver.recv().await);
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn will_stop_all_in_reverse_order_after_error() {
        let (sender, mut receiver) = mpsc::channel(5);

        let result = Supervisor::builder()
            .add_proc(TestProc {
                name: "1",
                delay: 1000,
                result: Ok(()),
                sender: sender.clone(),
            })
            .add_proc(TestProc {
                name: "2",
                delay: 50,
                result: Err(anyhow!("error")),
                sender: sender.clone(),
            })
            .add_proc(TestProc {
                name: "3",
                delay: 1000,
                result: Ok(()),
                sender: sender.clone(),
            })
            .build()
            .start()
            .await;

        assert_eq!(Some("2"), receiver.recv().await);
        assert_eq!(Some("3"), receiver.recv().await);
        assert_eq!(Some("1"), receiver.recv().await);
        assert_eq!("error", result.unwrap_err().to_string());
    }

    #[tokio::test]
    async fn will_return_first_error_returned() {
        let (sender, mut receiver) = mpsc::channel(5);

        let result = Supervisor::builder()
            .add_proc(TestProc {
                name: "1",
                delay: 1000,
                result: Ok(()),
                sender: sender.clone(),
            })
            .add_proc(TestProc {
                name: "2",
                delay: 50,
                result: Err(anyhow!("error")),
                sender: sender.clone(),
            })
            .add_proc(TestProc {
                name: "3",
                delay: 200,
                result: Err(anyhow!("second error")),
                sender: sender.clone(),
            })
            .build()
            .start()
            .await;

        assert_eq!(Some("2"), receiver.recv().await);
        assert_eq!(Some("3"), receiver.recv().await);
        assert_eq!(Some("1"), receiver.recv().await);
        assert_eq!("error", result.unwrap_err().to_string());
    }

    #[tokio::test]
    async fn nested_procs_will_stop_parent_then_move_up() {
        let (sender, mut receiver) = mpsc::channel(10);

        let result = Supervisor::builder()
            .add_proc(TestProc {
                name: "proc-1",
                delay: 500,
                result: Ok(()),
                sender: sender.clone(),
            })
            .add_proc(
                Supervisor::builder()
                    .add_proc(TestProc {
                        name: "proc-2-1",
                        delay: 500,
                        result: Ok(()),
                        sender: sender.clone(),
                    })
                    .add_proc(TestProc {
                        name: "proc-2-2",
                        delay: 100,
                        result: Err(anyhow!("error")),
                        sender: sender.clone(),
                    })
                    .add_proc(TestProc {
                        name: "proc-2-3",
                        delay: 500,
                        result: Ok(()),
                        sender: sender.clone(),
                    })
                    .add_proc(TestProc {
                        name: "proc-2-4",
                        delay: 500,
                        result: Ok(()),
                        sender: sender.clone(),
                    })
                    .build(),
            )
            .add_proc(
                Supervisor::builder()
                    .add_proc(TestProc {
                        name: "proc-3-1",
                        delay: 1000,
                        result: Ok(()),
                        sender: sender.clone(),
                    })
                    .add_proc(TestProc {
                        name: "proc-3-2",
                        delay: 1000,
                        result: Ok(()),
                        sender: sender.clone(),
                    })
                    .add_proc(TestProc {
                        name: "proc-3-3",
                        delay: 1000,
                        result: Ok(()),
                        sender: sender.clone(),
                    })
                    .build(),
            )
            .build()
            .start()
            .await;

        assert_eq!(Some("proc-2-2"), receiver.recv().await);
        assert_eq!(Some("proc-2-4"), receiver.recv().await);
        assert_eq!(Some("proc-2-3"), receiver.recv().await);
        assert_eq!(Some("proc-2-1"), receiver.recv().await);
        assert_eq!(Some("proc-3-3"), receiver.recv().await);
        assert_eq!(Some("proc-3-2"), receiver.recv().await);
        assert_eq!(Some("proc-3-1"), receiver.recv().await);
        assert_eq!(Some("proc-1"), receiver.recv().await);
        assert!(result.is_err());
    }
}