task-tracker 0.2.0

Start and stop tasks decleratively based on which tasks should run at any given moment.
Documentation
use tokio::select;
use tokio::sync::oneshot;
use tokio::task::{JoinError, JoinHandle};
use tracing::debug;
use uuid::Uuid;

pub struct InnerTask<K, R>
where
    K: Clone + Send + std::fmt::Debug + 'static,
    R: Send + 'static,
{
    pub id: Uuid,
    pub key: K,
    stop_tx: oneshot::Sender<()>,
    join_handle: JoinHandle<TaskResult<R>>,
}

impl<K, R> InnerTask<K, R>
where
    K: Clone + Send + std::fmt::Debug + 'static,
    R: Send + 'static,
{
    pub fn new<Fut>(key: K, fut: Fut) -> Self
    where
        Fut: std::future::Future<Output = R> + Send + 'static,
    {
        let id = Uuid::new_v4();
        let (stop_tx, stop_rx) = oneshot::channel::<()>();
        debug!(?key, %id, "creating new task");
        let key_ = key.clone();
        let fut = fut;
        let task = async move {
            debug!(key = ?key_, task_id = %id, "task started");
            let result = select! {
                task_result = fut => {
                    debug!(key = ?key_, task_id = %id, "task finished");
                    TaskResult::Done(task_result)
                },
                _ = stop_rx => {
                    debug!(key = ?key_, task_id = %id, "task cancelled");
                    TaskResult::Cancelled
                },
            };
            result
        };
        Self {
            id,
            key,
            stop_tx,
            join_handle: tokio::task::spawn(task),
        }
    }

    pub async fn cancel_and_wait(self) -> TaskResult<R> {
        debug!(key = ?self.key, task_id = %self.id, "waiting for task to finish");
        self.stop_tx.send(()).unwrap();
        match self.join_handle.await {
            Ok(task_result) => task_result,
            Err(join_error) => TaskResult::JoinError(join_error),
        }
    }

    pub fn is_finished(&self) -> bool {
        self.join_handle.is_finished()
    }

    pub async fn wait(self) -> TaskResult<R> {
        match self.join_handle.await {
            Ok(task_result) => task_result,
            Err(join_error) => TaskResult::JoinError(join_error),
        }
    }
}

pub enum TaskResult<R>
where
    R: Send,
{
    Done(R),
    Cancelled,
    JoinError(JoinError),
}

impl<R> std::fmt::Debug for TaskResult<R>
where
    R: Send,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{}",
            match self {
                Self::Done(_) => "Done(..)",
                Self::Cancelled => "Cancelled",
                Self::JoinError(_) => "JoinError",
            }
        )
    }
}