use crate::TaskExecutor;
use chrono::{DateTime, Utc};
use cron::Schedule as CronParser;
use std::str::FromStr;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct CronSchedule {
pub expression: String,
}
impl CronSchedule {
pub fn new(expression: String) -> Self {
Self { expression }
}
pub fn next_run(&self) -> Option<DateTime<Utc>> {
let schedule = CronParser::from_str(&self.expression).ok()?;
schedule.upcoming(Utc).next()
}
}
pub trait Schedule: Send + Sync {
fn next_run(&self) -> Option<DateTime<Utc>>;
}
impl Schedule for CronSchedule {
fn next_run(&self) -> Option<DateTime<Utc>> {
CronSchedule::next_run(self)
}
}
pub struct Scheduler {
tasks: Vec<(Arc<dyn TaskExecutor>, Box<dyn Schedule>)>,
shutdown_tx: tokio::sync::broadcast::Sender<()>,
}
impl Scheduler {
pub fn new() -> Self {
let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
Self {
tasks: Vec::new(),
shutdown_tx,
}
}
pub fn add_task(&mut self, task: Arc<dyn TaskExecutor>, schedule: Box<dyn Schedule>) {
self.tasks.push((task, schedule));
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
pub async fn run(&self) {
use tokio::time::{Duration, sleep};
let mut shutdown_rx = self.shutdown_tx.subscribe();
loop {
let now = Utc::now();
let mut next_check = None;
for (task, schedule) in &self.tasks {
if let Some(next_run) = schedule.next_run() {
if next_run <= now {
let task = Arc::clone(task);
tokio::spawn(async move {
if let Err(e) = task.execute().await {
tracing::error!(error = %e, "Task execution failed");
}
});
} else {
match next_check {
None => next_check = Some(next_run),
Some(current) if next_run < current => next_check = Some(next_run),
_ => {}
}
}
}
}
const MIN_SLEEP: Duration = Duration::from_millis(100);
let sleep_duration = if let Some(next) = next_check {
(next - now).to_std().unwrap_or(MIN_SLEEP).max(MIN_SLEEP)
} else {
Duration::from_secs(60)
};
tokio::select! {
_ = sleep(sleep_duration) => {}
_ = shutdown_rx.recv() => {
break;
}
}
}
}
}
impl Default for Scheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TaskId, TaskResult};
use async_trait::async_trait;
use rstest::rstest;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug)]
#[allow(dead_code)]
struct DummyTask {
id: TaskId,
}
impl crate::Task for DummyTask {
fn id(&self) -> TaskId {
self.id
}
fn name(&self) -> &str {
"dummy"
}
}
#[async_trait]
impl TaskExecutor for DummyTask {
async fn execute(&self) -> TaskResult<()> {
Ok(())
}
}
struct PastSchedule;
impl Schedule for PastSchedule {
fn next_run(&self) -> Option<DateTime<Utc>> {
Some(Utc::now() - chrono::Duration::hours(1))
}
}
struct CountingTask {
id: TaskId,
count: Arc<AtomicU64>,
}
impl std::fmt::Debug for CountingTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CountingTask")
.field("id", &self.id)
.finish()
}
}
impl crate::Task for CountingTask {
fn id(&self) -> TaskId {
self.id
}
fn name(&self) -> &str {
"counting"
}
}
#[async_trait]
impl TaskExecutor for CountingTask {
async fn execute(&self) -> TaskResult<()> {
self.count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[rstest]
#[tokio::test]
async fn test_scheduler_shutdown() {
let scheduler = Arc::new(Scheduler::new());
let scheduler_clone = Arc::clone(&scheduler);
let handle = tokio::spawn(async move {
scheduler_clone.run().await;
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
scheduler.shutdown();
tokio::time::timeout(std::time::Duration::from_secs(2), handle)
.await
.expect("scheduler should shut down within timeout")
.expect("scheduler task should not panic");
}
#[rstest]
#[tokio::test]
async fn test_scheduler_does_not_busy_loop_when_next_run_is_in_the_past() {
let count = Arc::new(AtomicU64::new(0));
let task = Arc::new(CountingTask {
id: TaskId::new(),
count: Arc::clone(&count),
});
let mut scheduler = Scheduler::new();
scheduler.add_task(task, Box::new(PastSchedule));
let scheduler = Arc::new(scheduler);
let scheduler_clone = Arc::clone(&scheduler);
let handle = tokio::spawn(async move {
scheduler_clone.run().await;
});
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
scheduler.shutdown();
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), handle).await;
let execution_count = count.load(Ordering::SeqCst);
assert!(
execution_count <= 10,
"Expected at most ~10 executions in 500ms with min sleep guard, got {}",
execution_count
);
}
#[rstest]
#[tokio::test]
async fn test_min_sleep_enforced_prevents_busy_loop_regression() {
let count = Arc::new(AtomicU64::new(0));
let task = Arc::new(CountingTask {
id: TaskId::new(),
count: Arc::clone(&count),
});
let mut scheduler = Scheduler::new();
scheduler.add_task(task, Box::new(PastSchedule));
let scheduler = Arc::new(scheduler);
let scheduler_clone = Arc::clone(&scheduler);
let handle = tokio::spawn(async move {
scheduler_clone.run().await;
});
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
scheduler.shutdown();
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), handle).await;
let execution_count = count.load(Ordering::SeqCst);
assert!(
execution_count <= 6,
"Regression #754: busy-loop guard must cap executions at ~3 in 300ms, got {}",
execution_count
);
}
}