mod cron_parser;
mod instance_generator;
pub use cron_parser::parse_cron;
pub use instance_generator::InstanceGenerator;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use pollen_clock::SharedClock;
use pollen_crdt::{CrdtKv, LwwRegister};
use pollen_executor::TaskHandler;
use pollen_store::StoreBackend;
use pollen_types::*;
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
#[async_trait]
pub trait Scheduler: Send + Sync + 'static {
async fn register(&self, def: TaskDef, handler: Arc<dyn TaskHandler>) -> Result<()>;
async fn unregister(&self, task_id: &TaskId) -> Result<()>;
async fn set_enabled(&self, task_id: &TaskId, enabled: bool) -> Result<()>;
async fn trigger(&self, task_id: &TaskId, payload: Option<bytes::Bytes>) -> Result<InstanceId>;
fn get_task(&self, task_id: &TaskId) -> Option<TaskDef>;
fn get_task_by_name(&self, name: &str) -> Option<TaskDef>;
fn list_tasks(&self) -> Vec<TaskDef>;
fn get_handler(&self, task_id: &TaskId) -> Option<Arc<dyn TaskHandler>>;
fn next_execution(&self, task: &TaskDef) -> Option<DateTime<Utc>>;
}
pub struct DefaultScheduler {
clock: SharedClock,
store: Arc<StoreBackend>,
crdt: Option<Arc<pollen_crdt::CrdtStore>>,
tasks: DashMap<TaskId, TaskDef>,
names: DashMap<String, TaskId>,
handlers: DashMap<TaskId, Arc<dyn TaskHandler>>,
generator: Arc<InstanceGenerator>,
}
impl DefaultScheduler {
pub fn new(
clock: SharedClock,
store: Arc<StoreBackend>,
crdt: Option<Arc<pollen_crdt::CrdtStore>>,
) -> Self {
Self {
clock: clock.clone(),
store: Arc::clone(&store),
crdt,
tasks: DashMap::new(),
names: DashMap::new(),
handlers: DashMap::new(),
generator: Arc::new(InstanceGenerator::new(store)),
}
}
pub fn start(self: Arc<Self>) {
self.start_with_interval(Duration::from_millis(100));
}
pub fn start_with_interval(self: Arc<Self>, poll_interval: Duration) {
let scheduler = Arc::clone(&self);
tokio::spawn(async move {
let mut interval = tokio::time::interval(poll_interval);
loop {
interval.tick().await;
let now = Utc::now();
for entry in scheduler.tasks.iter() {
let task = entry.value();
if !task.enabled {
continue;
}
if let Some(next) = compute_next_execution(&task.schedule, now) {
if next <= now + chrono::Duration::seconds(5) {
if let Err(e) = scheduler.generator.ensure_instance(task, next).await {
warn!("Failed to generate instance for {}: {}", task.name, e);
}
}
}
}
}
});
info!("Scheduler started");
}
pub async fn load(&self) -> Result<()> {
let tasks = self.store.read(|r| r.list_tasks()).await?;
for task in tasks {
self.tasks.insert(task.id.clone(), task.clone());
self.names.insert(task.name.clone(), task.id.clone());
}
info!("Loaded {} tasks from storage", self.tasks.len());
Ok(())
}
async fn sync_to_crdt(&self, task: &TaskDef) -> Result<()> {
if let Some(crdt) = &self.crdt {
let key = format!("task:{}", task.id);
let register = LwwRegister::new(task.clone(), task.hlc_timestamp);
crdt.set(&key, register).await?;
}
Ok(())
}
}
#[async_trait]
impl Scheduler for DefaultScheduler {
async fn register(&self, mut def: TaskDef, handler: Arc<dyn TaskHandler>) -> Result<()> {
if self.names.contains_key(&def.name) {
return Err(PollenError::TaskAlreadyExists(def.name.clone()));
}
if let Schedule::Cron(ref expr) = def.schedule {
parse_cron(expr)?;
}
let ts = self.clock.now();
def.hlc_timestamp = ts.as_u128() as u64;
def.updated_at = Utc::now();
let def_clone = def.clone();
self.store.write(move |w| w.insert_task(&def_clone)).await?;
self.tasks.insert(def.id.clone(), def.clone());
self.names.insert(def.name.clone(), def.id.clone());
self.handlers.insert(def.id.clone(), handler);
self.sync_to_crdt(&def).await?;
info!("Registered task: {} ({})", def.name, def.id);
Ok(())
}
async fn unregister(&self, task_id: &TaskId) -> Result<()> {
let task = self.tasks.remove(task_id);
if let Some((_, task)) = task {
self.names.remove(&task.name);
self.handlers.remove(task_id);
let id = task_id.clone();
self.store.write(move |w| w.delete_task(&id)).await?;
if let Some(crdt) = &self.crdt {
let key = format!("task:{}", task_id);
crdt.delete(&key).await?;
}
info!("Unregistered task: {}", task.name);
}
Ok(())
}
async fn set_enabled(&self, task_id: &TaskId, enabled: bool) -> Result<()> {
if let Some(mut task) = self.tasks.get_mut(task_id) {
task.enabled = enabled;
task.updated_at = Utc::now();
task.hlc_timestamp = self.clock.now().as_u128() as u64;
let task_clone = task.clone();
self.store.write(move |w| w.update_task(&task_clone)).await?;
self.sync_to_crdt(&task).await?;
info!("Task {} enabled={}", task.name, enabled);
}
Ok(())
}
async fn trigger(&self, task_id: &TaskId, payload: Option<bytes::Bytes>) -> Result<InstanceId> {
let task = self.tasks.get(task_id).ok_or(PollenError::TaskNotFound(task_id.clone()))?;
let instance = TaskInstance::new(task_id.clone(), Utc::now());
if let Some(_p) = payload {
}
let id = instance.id.clone();
self.store.write(move |w| w.insert_instance(&instance)).await?;
info!("Triggered task {} (instance {})", task.name, id);
Ok(id)
}
fn get_task(&self, task_id: &TaskId) -> Option<TaskDef> {
self.tasks.get(task_id).map(|t| t.clone())
}
fn get_task_by_name(&self, name: &str) -> Option<TaskDef> {
self.names.get(name).and_then(|id| self.tasks.get(&*id).map(|t| t.clone()))
}
fn list_tasks(&self) -> Vec<TaskDef> {
self.tasks.iter().map(|e| e.value().clone()).collect()
}
fn get_handler(&self, task_id: &TaskId) -> Option<Arc<dyn TaskHandler>> {
self.handlers.get(task_id).map(|h| h.clone())
}
fn next_execution(&self, task: &TaskDef) -> Option<DateTime<Utc>> {
compute_next_execution(&task.schedule, Utc::now())
}
}
pub fn compute_next_execution(schedule: &Schedule, after: DateTime<Utc>) -> Option<DateTime<Utc>> {
match schedule {
Schedule::Cron(expr) => {
parse_cron(expr)
.ok()
.and_then(|cron| cron.find_next_occurrence(&after, false).ok())
}
Schedule::Interval(duration) => {
Some(after + chrono::Duration::from_std(*duration).ok()?)
}
Schedule::Once(at) => {
if *at > after {
Some(*at)
} else {
None
}
}
}
}
pub type SharedScheduler = Arc<dyn Scheduler>;
#[cfg(test)]
mod tests {
use super::*;
use pollen_executor::simple_handler;
use pollen_store::{MemoryStore, StoreBackend};
#[tokio::test]
async fn test_register_task() {
let clock = pollen_clock::new_clock();
let store = Arc::new(StoreBackend::Memory(MemoryStore::new()));
let scheduler = DefaultScheduler::new(clock, store, None);
let task = TaskDef::new("test", Schedule::interval(Duration::from_secs(60)));
let handler = simple_handler(|| async { Ok(()) });
scheduler.register(task.clone(), handler).await.unwrap();
let fetched = scheduler.get_task_by_name("test");
assert!(fetched.is_some());
assert_eq!(fetched.unwrap().name, "test");
}
#[tokio::test]
async fn test_duplicate_name() {
let clock = pollen_clock::new_clock();
let store = Arc::new(StoreBackend::Memory(MemoryStore::new()));
let scheduler = DefaultScheduler::new(clock, store, None);
let task1 = TaskDef::new("test", Schedule::interval(Duration::from_secs(60)));
let task2 = TaskDef::new("test", Schedule::interval(Duration::from_secs(30)));
let handler = simple_handler(|| async { Ok(()) });
scheduler.register(task1, handler.clone()).await.unwrap();
let result = scheduler.register(task2, handler).await;
assert!(result.is_err());
}
#[test]
fn test_next_execution_interval() {
let now = Utc::now();
let schedule = Schedule::interval(Duration::from_secs(60));
let next = compute_next_execution(&schedule, now);
assert!(next.is_some());
assert!(next.unwrap() > now);
}
#[test]
fn test_next_execution_once_past() {
let past = Utc::now() - chrono::Duration::hours(1);
let schedule = Schedule::Once(past);
let next = compute_next_execution(&schedule, Utc::now());
assert!(next.is_none());
}
}