use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio::time::{interval, MissedTickBehavior};
type JobFactory = Arc<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
struct Task {
name: String,
period: Duration,
factory: JobFactory,
}
pub struct Scheduler {
tasks: Mutex<Vec<Task>>,
}
impl Default for Scheduler {
fn default() -> Self {
Self::new()
}
}
impl Scheduler {
#[must_use]
pub fn new() -> Self {
Self {
tasks: Mutex::new(Vec::new()),
}
}
pub fn every<F, Fut>(&self, name: &str, period: Duration, job: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let factory: JobFactory = Arc::new(move || Box::pin(job()));
self.tasks
.lock()
.expect("scheduler tasks poisoned")
.push(Task {
name: name.to_owned(),
period,
factory,
});
}
#[must_use]
pub fn task_count(&self) -> usize {
self.tasks.lock().expect("scheduler tasks poisoned").len()
}
pub fn start(self) -> Handle {
let tasks = self.tasks.into_inner().expect("scheduler tasks poisoned");
let mut handles = Vec::with_capacity(tasks.len());
for t in tasks {
handles.push(spawn_task_loop(t));
}
Handle { handles }
}
}
fn spawn_task_loop(task: Task) -> JoinHandle<()> {
tokio::spawn(async move {
let mut tick = interval(task.period);
tick.set_missed_tick_behavior(MissedTickBehavior::Skip);
tick.tick().await;
loop {
tick.tick().await;
let factory = task.factory.clone();
let name = task.name.clone();
let job_handle = tokio::spawn(async move {
let fut = (factory)();
fut.await;
});
if let Err(e) = job_handle.await {
if e.is_panic() {
tracing::error!(task = %name, "scheduled job panicked");
}
}
}
})
}
pub struct Handle {
handles: Vec<JoinHandle<()>>,
}
impl Handle {
#[must_use]
pub fn running_count(&self) -> usize {
self.handles.len()
}
pub async fn shutdown(mut self) {
let handles = std::mem::take(&mut self.handles);
for h in handles {
h.abort();
let _ = h.await; }
}
}
impl Drop for Handle {
fn drop(&mut self) {
for h in &self.handles {
h.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn task_count_tracks_registrations() {
let s = Scheduler::new();
s.every("a", Duration::from_secs(1), || async {});
s.every("b", Duration::from_secs(1), || async {});
s.every("c", Duration::from_secs(1), || async {});
assert_eq!(s.task_count(), 3);
}
#[tokio::test]
async fn job_fires_after_one_period() {
let counter = Arc::new(AtomicUsize::new(0));
let s = Scheduler::new();
let c = counter.clone();
s.every("count", Duration::from_millis(20), move || {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
}
});
let handle = s.start();
tokio::time::sleep(Duration::from_millis(70)).await;
let count = counter.load(Ordering::SeqCst);
assert!(count >= 2, "expected at least 2 fires, got {count}");
handle.shutdown().await;
}
#[tokio::test]
async fn shutdown_stops_further_fires() {
let counter = Arc::new(AtomicUsize::new(0));
let s = Scheduler::new();
let c = counter.clone();
s.every("stop", Duration::from_millis(15), move || {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
}
});
let handle = s.start();
tokio::time::sleep(Duration::from_millis(50)).await;
handle.shutdown().await;
let after_shutdown = counter.load(Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(counter.load(Ordering::SeqCst), after_shutdown);
}
#[tokio::test]
async fn panicking_job_does_not_kill_loop() {
let counter = Arc::new(AtomicUsize::new(0));
let s = Scheduler::new();
let c = counter.clone();
let panic_on_first = Arc::new(AtomicUsize::new(0));
let p = panic_on_first.clone();
s.every("flaky", Duration::from_millis(20), move || {
let c = c.clone();
let p = p.clone();
async move {
let n = p.fetch_add(1, Ordering::SeqCst);
if n == 0 {
panic!("simulated job failure");
}
c.fetch_add(1, Ordering::SeqCst);
}
});
let handle = s.start();
tokio::time::sleep(Duration::from_millis(80)).await;
let count = counter.load(Ordering::SeqCst);
assert!(
count >= 1,
"loop must keep running after a panic, got count={count}"
);
handle.shutdown().await;
}
#[tokio::test]
async fn multiple_tasks_run_independently() {
let a_count = Arc::new(AtomicUsize::new(0));
let b_count = Arc::new(AtomicUsize::new(0));
let s = Scheduler::new();
let a = a_count.clone();
let b = b_count.clone();
s.every("a", Duration::from_millis(15), move || {
let a = a.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
}
});
s.every("b", Duration::from_millis(15), move || {
let b = b.clone();
async move {
b.fetch_add(1, Ordering::SeqCst);
}
});
let handle = s.start();
assert_eq!(handle.running_count(), 2);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(a_count.load(Ordering::SeqCst) >= 1);
assert!(b_count.load(Ordering::SeqCst) >= 1);
handle.shutdown().await;
}
}