use anyhow::{Context, Result};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{sleep, Instant};
use tracing::{debug, error, info, warn};
use crate::utils::parse_duration;
pub type TaskId = u64;
pub type TaskCallback = Arc<dyn Fn() -> TaskFuture + Send + Sync>;
pub type TaskFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>;
#[derive(Clone)]
pub struct ScheduledTask {
pub id: TaskId,
pub name: String,
pub interval: Duration,
pub callback: TaskCallback,
pub recurring: bool,
pub enabled: bool,
}
#[derive(Debug, Clone)]
pub struct TaskExecution {
pub task_id: TaskId,
pub task_name: String,
pub started_at: chrono::DateTime<chrono::Utc>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
pub success: bool,
pub error: Option<String>,
}
pub struct Scheduler {
tasks: Arc<RwLock<Vec<ScheduledTask>>>,
next_task_id: Arc<Mutex<TaskId>>,
is_running: Arc<RwLock<bool>>,
is_paused: Arc<RwLock<bool>>,
history: Arc<RwLock<Vec<TaskExecution>>>,
history_limit: usize,
}
impl Scheduler {
pub fn new() -> Self {
Self::with_history_limit(100)
}
pub fn with_history_limit(limit: usize) -> Self {
Self {
tasks: Arc::new(RwLock::new(Vec::new())),
next_task_id: Arc::new(Mutex::new(1)),
is_running: Arc::new(RwLock::new(false)),
is_paused: Arc::new(RwLock::new(false)),
history: Arc::new(RwLock::new(Vec::new())),
history_limit: limit,
}
}
pub async fn add_recurring_task<F, Fut>(
&self,
name: impl Into<String>,
interval_str: &str,
callback: F,
) -> Result<TaskId>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<()>> + Send + 'static,
{
let interval_secs = parse_duration(interval_str)
.context("Failed to parse interval duration")?;
let interval = Duration::from_secs(interval_secs);
let name = name.into();
let task_id = {
let mut next_id = self.next_task_id.lock().await;
let id = *next_id;
*next_id += 1;
id
};
let callback_arc: TaskCallback = Arc::new(move || Box::pin(callback()));
let task = ScheduledTask {
id: task_id,
name: name.clone(),
interval,
callback: callback_arc,
recurring: true,
enabled: true,
};
self.tasks.write().await.push(task);
info!(
"Added recurring task '{}' (ID: {}) with interval: {}",
name, task_id, interval_str
);
Ok(task_id)
}
pub async fn add_oneshot_task<F, Fut>(
&self,
name: impl Into<String>,
delay_str: &str,
callback: F,
) -> Result<TaskId>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<()>> + Send + 'static,
{
let delay_secs = parse_duration(delay_str)
.context("Failed to parse delay duration")?;
let delay = Duration::from_secs(delay_secs);
let name = name.into();
let task_id = {
let mut next_id = self.next_task_id.lock().await;
let id = *next_id;
*next_id += 1;
id
};
let callback_arc: TaskCallback = Arc::new(move || Box::pin(callback()));
let task = ScheduledTask {
id: task_id,
name: name.clone(),
interval: delay,
callback: callback_arc,
recurring: false,
enabled: true,
};
self.tasks.write().await.push(task);
info!(
"Added one-shot task '{}' (ID: {}) with delay: {}",
name, task_id, delay_str
);
Ok(task_id)
}
pub async fn start(&self) -> Result<()> {
let is_running = *self.is_running.read().await;
if is_running {
warn!("Scheduler is already running");
return Ok(());
}
*self.is_running.write().await = true;
info!("Starting scheduler");
let tasks = Arc::clone(&self.tasks);
let is_running = Arc::clone(&self.is_running);
let is_paused = Arc::clone(&self.is_paused);
let history = Arc::clone(&self.history);
let history_limit = self.history_limit;
tokio::spawn(async move {
Self::run_scheduler_loop(tasks, is_running, is_paused, history, history_limit).await;
});
Ok(())
}
async fn run_scheduler_loop(
tasks: Arc<RwLock<Vec<ScheduledTask>>>,
is_running: Arc<RwLock<bool>>,
is_paused: Arc<RwLock<bool>>,
history: Arc<RwLock<Vec<TaskExecution>>>,
history_limit: usize,
) {
let mut task_timers: std::collections::HashMap<TaskId, Instant> =
std::collections::HashMap::new();
loop {
if !*is_running.read().await {
info!("Scheduler stopped");
break;
}
if *is_paused.read().await {
sleep(Duration::from_millis(500)).await;
continue;
}
let now = Instant::now();
let tasks_snapshot = tasks.read().await.clone();
for task in tasks_snapshot.iter() {
if !task.enabled {
continue;
}
let next_run = task_timers
.entry(task.id)
.or_insert_with(|| now + task.interval);
if now >= *next_run {
debug!("Executing task: {} (ID: {})", task.name, task.id);
let execution = TaskExecution {
task_id: task.id,
task_name: task.name.clone(),
started_at: chrono::Utc::now(),
completed_at: None,
success: false,
error: None,
};
let callback = Arc::clone(&task.callback);
let task_name = task.name.clone();
let task_id = task.id;
let is_recurring = task.recurring;
let interval = task.interval;
let history_clone = Arc::clone(&history);
let tasks_clone = Arc::clone(&tasks);
tokio::spawn(async move {
let result = (callback)().await;
let mut exec = execution;
exec.completed_at = Some(chrono::Utc::now());
match result {
Ok(_) => {
info!("Task '{}' (ID: {}) completed successfully", task_name, task_id);
exec.success = true;
}
Err(e) => {
error!("Task '{}' (ID: {}) failed: {}", task_name, task_id, e);
exec.error = Some(e.to_string());
}
}
let mut hist = history_clone.write().await;
hist.push(exec);
let hist_len = hist.len();
if hist_len > history_limit {
hist.drain(0..hist_len - history_limit);
}
if !is_recurring {
let mut tasks_write = tasks_clone.write().await;
tasks_write.retain(|t| t.id != task_id);
debug!("Removed one-shot task '{}' (ID: {}) after execution", task_name, task_id);
}
});
if is_recurring {
*next_run = now + interval;
} else {
task_timers.remove(&task.id);
}
}
}
sleep(Duration::from_millis(100)).await;
}
}
pub async fn stop(&self) {
info!("Stopping scheduler");
*self.is_running.write().await = false;
}
pub async fn pause(&self) {
info!("Pausing scheduler");
*self.is_paused.write().await = true;
}
pub async fn resume(&self) {
info!("Resuming scheduler");
*self.is_paused.write().await = false;
}
pub async fn is_running(&self) -> bool {
*self.is_running.read().await
}
pub async fn is_paused(&self) -> bool {
*self.is_paused.read().await
}
pub async fn enable_task(&self, task_id: TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
if let Some(task) = tasks.iter_mut().find(|t| t.id == task_id) {
task.enabled = true;
info!("Enabled task '{}' (ID: {})", task.name, task_id);
Ok(())
} else {
Err(anyhow::anyhow!("Task {} not found", task_id))
}
}
pub async fn disable_task(&self, task_id: TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
if let Some(task) = tasks.iter_mut().find(|t| t.id == task_id) {
task.enabled = false;
info!("Disabled task '{}' (ID: {})", task.name, task_id);
Ok(())
} else {
Err(anyhow::anyhow!("Task {} not found", task_id))
}
}
pub async fn remove_task(&self, task_id: TaskId) -> Result<()> {
let mut tasks = self.tasks.write().await;
let initial_len = tasks.len();
tasks.retain(|t| t.id != task_id);
if tasks.len() < initial_len {
info!("Removed task ID: {}", task_id);
Ok(())
} else {
Err(anyhow::anyhow!("Task {} not found", task_id))
}
}
pub async fn list_tasks(&self) -> Vec<ScheduledTask> {
self.tasks.read().await.clone()
}
pub async fn get_task(&self, task_id: TaskId) -> Option<ScheduledTask> {
self.tasks
.read()
.await
.iter()
.find(|t| t.id == task_id)
.cloned()
}
pub async fn get_history(&self) -> Vec<TaskExecution> {
self.history.read().await.clone()
}
pub async fn get_task_history(&self, task_id: TaskId) -> Vec<TaskExecution> {
self.history
.read()
.await
.iter()
.filter(|e| e.task_id == task_id)
.cloned()
.collect()
}
pub async fn clear_history(&self) {
self.history.write().await.clear();
info!("Cleared execution history");
}
pub async fn active_task_count(&self) -> usize {
self.tasks.read().await.iter().filter(|t| t.enabled).count()
}
pub async fn total_task_count(&self) -> usize {
self.tasks.read().await.len()
}
}
impl Default for Scheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn test_scheduler_creation() {
let scheduler = Scheduler::new();
assert!(!scheduler.is_running().await);
assert!(!scheduler.is_paused().await);
}
#[tokio::test]
async fn test_add_recurring_task() {
let scheduler = Scheduler::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let task_id = scheduler
.add_recurring_task("test_task", "1s", move || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
})
.await
.unwrap();
assert_eq!(task_id, 1);
assert_eq!(scheduler.total_task_count().await, 1);
}
#[tokio::test]
async fn test_add_oneshot_task() {
let scheduler = Scheduler::new();
let executed = Arc::new(AtomicUsize::new(0));
let executed_clone = Arc::clone(&executed);
let task_id = scheduler
.add_oneshot_task("oneshot_task", "1s", move || {
let executed = Arc::clone(&executed_clone);
async move {
executed.store(1, Ordering::SeqCst);
Ok(())
}
})
.await
.unwrap();
assert_eq!(task_id, 1);
assert_eq!(scheduler.total_task_count().await, 1);
}
#[tokio::test]
async fn test_pause_resume() {
let scheduler = Scheduler::new();
scheduler.start().await.unwrap();
assert!(scheduler.is_running().await);
assert!(!scheduler.is_paused().await);
scheduler.pause().await;
assert!(scheduler.is_paused().await);
scheduler.resume().await;
assert!(!scheduler.is_paused().await);
scheduler.stop().await;
}
#[tokio::test]
async fn test_enable_disable_task() {
let scheduler = Scheduler::new();
let task_id = scheduler
.add_recurring_task("test", "1s", || async { Ok(()) })
.await
.unwrap();
scheduler.disable_task(task_id).await.unwrap();
let task = scheduler.get_task(task_id).await.unwrap();
assert!(!task.enabled);
scheduler.enable_task(task_id).await.unwrap();
let task = scheduler.get_task(task_id).await.unwrap();
assert!(task.enabled);
}
#[tokio::test]
async fn test_remove_task() {
let scheduler = Scheduler::new();
let task_id = scheduler
.add_recurring_task("test", "1s", || async { Ok(()) })
.await
.unwrap();
assert_eq!(scheduler.total_task_count().await, 1);
scheduler.remove_task(task_id).await.unwrap();
assert_eq!(scheduler.total_task_count().await, 0);
}
#[tokio::test]
async fn test_task_execution() {
let scheduler = Scheduler::new();
let executed = Arc::new(AtomicUsize::new(0));
let executed_clone = Arc::clone(&executed);
scheduler
.add_oneshot_task("test", "1s", move || {
let executed = Arc::clone(&executed_clone);
async move {
executed.fetch_add(1, Ordering::SeqCst);
Ok(())
}
})
.await
.unwrap();
scheduler.start().await.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
assert_eq!(executed.load(Ordering::SeqCst), 1);
scheduler.stop().await;
}
#[tokio::test]
async fn test_history_tracking() {
let scheduler = Scheduler::new();
scheduler
.add_oneshot_task("test", "1s", || async { Ok(()) })
.await
.unwrap();
scheduler.start().await.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
let history = scheduler.get_history().await;
assert!(!history.is_empty());
scheduler.stop().await;
}
}