restrepo 0.5.12

A collection of components for building restful webservices with actix-web
Documentation
use std::{fmt::Debug, sync::Arc, time::Duration};

use anyhow::Context;
use async_trait::async_trait;
use cron::Schedule;
use futures::StreamExt;
use tokio_util::time::DelayQueue;
use tracing::{error, info};

/// A trait representing a task that can be scheduled and executed asynchronously.
#[async_trait]
pub trait Task: Send + Sync + 'static {
    async fn execute(&self) -> anyhow::Result<()>;
}

/// A task that can be scheduled to run at specific intervals using cron syntax.
///
/// # Example
/// ```
/// use std::sync::Arc;
/// use anyhow::Result;
/// use async_trait::async_trait;
/// use restrepo::tasks::{Task, ScheduledTask};
///
/// struct PrintTask;
///
/// #[async_trait]
/// impl Task for PrintTask {
///     async fn execute(&self) -> Result<()> {
///         println!("Task executed!");
///         Ok(())
///     }
/// }
///
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// // Schedule task to run every minute
/// let scheduled = ScheduledTask::new(
///     "print_task",
///     PrintTask,
///     "0 * * * * *"
/// )?;
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct ScheduledTask {
    task: Arc<dyn Task>,
    schedule: Schedule,
    name: String,
}

impl Debug for ScheduledTask {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ScheduledTask")
            .field("scheduled", &self.schedule)
            .field("name", &self.name)
            .finish()
    }
}

impl ScheduledTask {
    pub fn new<S: Into<String>>(name: S, task: impl Task, schedule: S) -> anyhow::Result<Self> {
        Ok(Self {
            name: name.into(),
            task: Arc::new(task),
            schedule: schedule.into().parse()?,
        })
    }
}

/// A scheduler that manages and executes multiple scheduled tasks.
///
/// The scheduler runs tasks according to their cron schedules and handles task execution
/// in separate tokio tasks to prevent blocking.
///
/// # Example
///
/// ```
/// use std::sync::Arc;
/// use anyhow::Result;
/// use async_trait::async_trait;
/// use restrepo::tasks::{Task, ScheduledTask, TaskScheduler};
///
/// struct CounterTask(Arc<std::sync::atomic::AtomicU32>);
///
/// #[async_trait]
/// impl Task for CounterTask {
///     async fn execute(&self) -> Result<()> {
///         self.0.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
///         Ok(())
///     }
/// }
///
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
///
/// let scheduled = ScheduledTask::new(
///     "counter",
///     CounterTask(counter.clone()),
///     "*/5 * * * * * *" // Every 5 seconds
/// )?;
///
/// let scheduler = TaskScheduler::default()
///     .with_task(scheduled);
///
/// // This would then be spawned to run in the background:
/// // tokio::spawn(async move { scheduler.run().await });
/// # Ok(())
/// # }
/// ```
#[derive(Clone, Debug, Default)]
pub struct TaskScheduler {
    tasks: Vec<ScheduledTask>,
}

impl TaskScheduler {
    pub fn with_task(mut self, task: ScheduledTask) -> Self {
        self.tasks.push(task);
        self
    }

    pub async fn run(self) -> anyhow::Result<()> {
        let mut dq =
            self.tasks
                .iter()
                .enumerate()
                .try_fold(DelayQueue::new(), |mut acc, (idx, task)| {
                    if let Some(delay) = Self::get_duration_until_next(&task.schedule)? {
                        acc.insert(idx, delay);
                    }
                    Ok::<_, anyhow::Error>(acc)
                })?;

        while let Some(expired) = dq.next().await {
            if let Some(scheduled) = self.tasks.get(*expired.get_ref()) {
                let name = scheduled.name.clone();
                info!("Executing task: {name}");
                let task = scheduled.task.clone();
                tokio::task::spawn(async move {
                    if let Err(e) = task.execute().await {
                        error!("Error while executing task `{name}`: {e}");
                    }
                });

                if let Some(delay) = Self::get_duration_until_next(&scheduled.schedule)? {
                    dq.insert(*expired.get_ref(), delay);
                }
            }
        }

        Ok(())
    }

    fn get_duration_until_next(schedule: &Schedule) -> anyhow::Result<Option<Duration>> {
        if let Some(next) = schedule.upcoming(chrono::Local).next() {
            let delay = next
                .signed_duration_since(chrono::Local::now())
                .to_std()
                .context("Failed to convert chrono::Duration to std::time::Duration")?;
            Ok(Some(delay))
        } else {
            Ok(None)
        }
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Mutex;

    use super::*;

    struct TestTask(Arc<Mutex<Vec<u32>>>);

    #[async_trait]
    impl Task for TestTask {
        async fn execute(&self) -> anyhow::Result<()> {
            self.0.lock().unwrap().push(1);
            Ok(())
        }
    }

    #[tokio::test]
    async fn test_task_scheduler() {
        let shared_vec = Arc::new(Mutex::new(Vec::new()));
        let scheduled_task =
            ScheduledTask::new("Test Task", TestTask(shared_vec.clone()), "*/1 * * * * * *")
                .unwrap();
        let scheduler = TaskScheduler::default().with_task(scheduled_task);
        tokio::spawn(async move {
            scheduler.run().await.unwrap();
        });
        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
        assert!(shared_vec.lock().unwrap().iter().sum::<u32>() >= 4);
    }
}