use std::{fmt::Debug, sync::Arc, time::Duration};
use anyhow::Context;
use async_trait::async_trait;
use cron::Schedule;
use futures::StreamExt;
use tokio_util::time::DelayQueue;
use tracing::{error, info};
#[async_trait]
pub trait Task: Send + Sync + 'static {
async fn execute(&self) -> anyhow::Result<()>;
}
#[derive(Clone)]
pub struct ScheduledTask {
task: Arc<dyn Task>,
schedule: Schedule,
name: String,
}
impl Debug for ScheduledTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScheduledTask")
.field("scheduled", &self.schedule)
.field("name", &self.name)
.finish()
}
}
impl ScheduledTask {
pub fn new<S: Into<String>>(name: S, task: impl Task, schedule: S) -> anyhow::Result<Self> {
Ok(Self {
name: name.into(),
task: Arc::new(task),
schedule: schedule.into().parse()?,
})
}
}
#[derive(Clone, Debug, Default)]
pub struct TaskScheduler {
tasks: Vec<ScheduledTask>,
}
impl TaskScheduler {
pub fn with_task(mut self, task: ScheduledTask) -> Self {
self.tasks.push(task);
self
}
pub async fn run(self) -> anyhow::Result<()> {
let mut dq =
self.tasks
.iter()
.enumerate()
.try_fold(DelayQueue::new(), |mut acc, (idx, task)| {
if let Some(delay) = Self::get_duration_until_next(&task.schedule)? {
acc.insert(idx, delay);
}
Ok::<_, anyhow::Error>(acc)
})?;
while let Some(expired) = dq.next().await {
if let Some(scheduled) = self.tasks.get(*expired.get_ref()) {
let name = scheduled.name.clone();
info!("Executing task: {name}");
let task = scheduled.task.clone();
tokio::task::spawn(async move {
if let Err(e) = task.execute().await {
error!("Error while executing task `{name}`: {e}");
}
});
if let Some(delay) = Self::get_duration_until_next(&scheduled.schedule)? {
dq.insert(*expired.get_ref(), delay);
}
}
}
Ok(())
}
fn get_duration_until_next(schedule: &Schedule) -> anyhow::Result<Option<Duration>> {
if let Some(next) = schedule.upcoming(chrono::Local).next() {
let delay = next
.signed_duration_since(chrono::Local::now())
.to_std()
.context("Failed to convert chrono::Duration to std::time::Duration")?;
Ok(Some(delay))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
struct TestTask(Arc<Mutex<Vec<u32>>>);
#[async_trait]
impl Task for TestTask {
async fn execute(&self) -> anyhow::Result<()> {
self.0.lock().unwrap().push(1);
Ok(())
}
}
#[tokio::test]
async fn test_task_scheduler() {
let shared_vec = Arc::new(Mutex::new(Vec::new()));
let scheduled_task =
ScheduledTask::new("Test Task", TestTask(shared_vec.clone()), "*/1 * * * * * *")
.unwrap();
let scheduler = TaskScheduler::default().with_task(scheduled_task);
tokio::spawn(async move {
scheduler.run().await.unwrap();
});
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
assert!(shared_vec.lock().unwrap().iter().sum::<u32>() >= 4);
}
}