Skip to main content

pollen_scheduler/
lib.rs

1//! Task scheduler for Pollen.
2//!
3//! Manages task definitions, computes next execution times,
4//! and generates task instances.
5
6mod cron_parser;
7mod instance_generator;
8
9pub use cron_parser::parse_cron;
10pub use instance_generator::InstanceGenerator;
11
12use async_trait::async_trait;
13use chrono::{DateTime, Utc};
14use dashmap::DashMap;
15use pollen_clock::SharedClock;
16use pollen_crdt::{CrdtKv, LwwRegister};
17use pollen_executor::TaskHandler;
18use pollen_store::StoreBackend;
19use pollen_types::*;
20use std::sync::Arc;
21use std::time::Duration;
22use tracing::{info, warn};
23
24/// Scheduler service trait.
25#[async_trait]
26pub trait Scheduler: Send + Sync + 'static {
27    /// Register a new task.
28    async fn register(&self, def: TaskDef, handler: Arc<dyn TaskHandler>) -> Result<()>;
29
30    /// Unregister a task.
31    async fn unregister(&self, task_id: &TaskId) -> Result<()>;
32
33    /// Enable or disable a task.
34    async fn set_enabled(&self, task_id: &TaskId, enabled: bool) -> Result<()>;
35
36    /// Trigger immediate execution of a task.
37    async fn trigger(&self, task_id: &TaskId, payload: Option<bytes::Bytes>) -> Result<InstanceId>;
38
39    /// Get task info.
40    fn get_task(&self, task_id: &TaskId) -> Option<TaskDef>;
41
42    /// Get task by name.
43    fn get_task_by_name(&self, name: &str) -> Option<TaskDef>;
44
45    /// List all tasks.
46    fn list_tasks(&self) -> Vec<TaskDef>;
47
48    /// Get handler for a task.
49    fn get_handler(&self, task_id: &TaskId) -> Option<Arc<dyn TaskHandler>>;
50
51    /// Compute next execution time for a task.
52    fn next_execution(&self, task: &TaskDef) -> Option<DateTime<Utc>>;
53}
54
55/// Default scheduler implementation.
56pub struct DefaultScheduler {
57    /// Clock for timestamps.
58    clock: SharedClock,
59    /// Persistent storage.
60    store: Arc<StoreBackend>,
61    /// CRDT store for distributed state (optional).
62    crdt: Option<Arc<pollen_crdt::CrdtStore>>,
63    /// In-memory task cache.
64    tasks: DashMap<TaskId, TaskDef>,
65    /// Task name to ID mapping.
66    names: DashMap<String, TaskId>,
67    /// Registered handlers.
68    handlers: DashMap<TaskId, Arc<dyn TaskHandler>>,
69    /// Instance generator.
70    generator: Arc<InstanceGenerator>,
71}
72
73impl DefaultScheduler {
74    /// Create a new scheduler.
75    pub fn new(
76        clock: SharedClock,
77        store: Arc<StoreBackend>,
78        crdt: Option<Arc<pollen_crdt::CrdtStore>>,
79    ) -> Self {
80        
81
82        Self {
83            clock: clock.clone(),
84            store: Arc::clone(&store),
85            crdt,
86            tasks: DashMap::new(),
87            names: DashMap::new(),
88            handlers: DashMap::new(),
89            generator: Arc::new(InstanceGenerator::new(store)),
90        }
91    }
92
93    /// Start the scheduler background tasks with default poll interval (100ms).
94    pub fn start(self: Arc<Self>) {
95        self.start_with_interval(Duration::from_millis(100));
96    }
97
98    /// Start the scheduler background tasks with custom poll interval.
99    pub fn start_with_interval(self: Arc<Self>, poll_interval: Duration) {
100        // Clone Arc to move into the task
101        let scheduler = Arc::clone(&self);
102
103        tokio::spawn(async move {
104            let mut interval = tokio::time::interval(poll_interval);
105
106            loop {
107                interval.tick().await;
108
109                let now = Utc::now();
110                // Access tasks from the shared scheduler reference
111                for entry in scheduler.tasks.iter() {
112                    let task = entry.value();
113                    if !task.enabled {
114                        continue;
115                    }
116
117                    // Check if we need to generate a new instance
118                    if let Some(next) = compute_next_execution(&task.schedule, now) {
119                        if next <= now + chrono::Duration::seconds(5) {
120                            if let Err(e) = scheduler.generator.ensure_instance(task, next).await {
121                                warn!("Failed to generate instance for {}: {}", task.name, e);
122                            }
123                        }
124                    }
125                }
126            }
127        });
128
129        info!("Scheduler started");
130    }
131
132    /// Load tasks from storage.
133    pub async fn load(&self) -> Result<()> {
134        let tasks = self.store.read(|r| r.list_tasks()).await?;
135
136        for task in tasks {
137            self.tasks.insert(task.id.clone(), task.clone());
138            self.names.insert(task.name.clone(), task.id.clone());
139        }
140
141        info!("Loaded {} tasks from storage", self.tasks.len());
142        Ok(())
143    }
144
145    /// Sync tasks to/from CRDT store.
146    async fn sync_to_crdt(&self, task: &TaskDef) -> Result<()> {
147        if let Some(crdt) = &self.crdt {
148            let key = format!("task:{}", task.id);
149            let register = LwwRegister::new(task.clone(), task.hlc_timestamp);
150            crdt.set(&key, register).await?;
151        }
152        Ok(())
153    }
154}
155
156#[async_trait]
157impl Scheduler for DefaultScheduler {
158    async fn register(&self, mut def: TaskDef, handler: Arc<dyn TaskHandler>) -> Result<()> {
159        // Check for duplicate name
160        if self.names.contains_key(&def.name) {
161            return Err(PollenError::TaskAlreadyExists(def.name.clone()));
162        }
163
164        // Validate schedule
165        if let Schedule::Cron(ref expr) = def.schedule {
166            parse_cron(expr)?;
167        }
168
169        // Set timestamps
170        let ts = self.clock.now();
171        def.hlc_timestamp = ts.as_u128() as u64;
172        def.updated_at = Utc::now();
173
174        // Store in database
175        let def_clone = def.clone();
176        self.store.write(move |w| w.insert_task(&def_clone)).await?;
177
178        // Store in memory
179        self.tasks.insert(def.id.clone(), def.clone());
180        self.names.insert(def.name.clone(), def.id.clone());
181        self.handlers.insert(def.id.clone(), handler);
182
183        // Sync to CRDT
184        self.sync_to_crdt(&def).await?;
185
186        info!("Registered task: {} ({})", def.name, def.id);
187
188        Ok(())
189    }
190
191    async fn unregister(&self, task_id: &TaskId) -> Result<()> {
192        let task = self.tasks.remove(task_id);
193        if let Some((_, task)) = task {
194            self.names.remove(&task.name);
195            self.handlers.remove(task_id);
196
197            let id = task_id.clone();
198            self.store.write(move |w| w.delete_task(&id)).await?;
199
200            if let Some(crdt) = &self.crdt {
201                let key = format!("task:{}", task_id);
202                crdt.delete(&key).await?;
203            }
204
205            info!("Unregistered task: {}", task.name);
206        }
207
208        Ok(())
209    }
210
211    async fn set_enabled(&self, task_id: &TaskId, enabled: bool) -> Result<()> {
212        if let Some(mut task) = self.tasks.get_mut(task_id) {
213            task.enabled = enabled;
214            task.updated_at = Utc::now();
215            task.hlc_timestamp = self.clock.now().as_u128() as u64;
216
217            let task_clone = task.clone();
218            self.store.write(move |w| w.update_task(&task_clone)).await?;
219            self.sync_to_crdt(&task).await?;
220
221            info!("Task {} enabled={}", task.name, enabled);
222        }
223
224        Ok(())
225    }
226
227    async fn trigger(&self, task_id: &TaskId, payload: Option<bytes::Bytes>) -> Result<InstanceId> {
228        let task = self.tasks.get(task_id).ok_or(PollenError::TaskNotFound(task_id.clone()))?;
229
230        let instance = TaskInstance::new(task_id.clone(), Utc::now());
231        if let Some(_p) = payload {
232            // TODO: Update the instance with custom payload
233        }
234
235        let id = instance.id.clone();
236        self.store.write(move |w| w.insert_instance(&instance)).await?;
237
238        info!("Triggered task {} (instance {})", task.name, id);
239
240        Ok(id)
241    }
242
243    fn get_task(&self, task_id: &TaskId) -> Option<TaskDef> {
244        self.tasks.get(task_id).map(|t| t.clone())
245    }
246
247    fn get_task_by_name(&self, name: &str) -> Option<TaskDef> {
248        self.names.get(name).and_then(|id| self.tasks.get(&*id).map(|t| t.clone()))
249    }
250
251    fn list_tasks(&self) -> Vec<TaskDef> {
252        self.tasks.iter().map(|e| e.value().clone()).collect()
253    }
254
255    fn get_handler(&self, task_id: &TaskId) -> Option<Arc<dyn TaskHandler>> {
256        self.handlers.get(task_id).map(|h| h.clone())
257    }
258
259    fn next_execution(&self, task: &TaskDef) -> Option<DateTime<Utc>> {
260        compute_next_execution(&task.schedule, Utc::now())
261    }
262}
263
264/// Compute the next execution time for a schedule.
265pub fn compute_next_execution(schedule: &Schedule, after: DateTime<Utc>) -> Option<DateTime<Utc>> {
266    match schedule {
267        Schedule::Cron(expr) => {
268            parse_cron(expr)
269                .ok()
270                .and_then(|cron| cron.find_next_occurrence(&after, false).ok())
271        }
272        Schedule::Interval(duration) => {
273            Some(after + chrono::Duration::from_std(*duration).ok()?)
274        }
275        Schedule::Once(at) => {
276            if *at > after {
277                Some(*at)
278            } else {
279                None
280            }
281        }
282    }
283}
284
285/// Shared scheduler instance.
286pub type SharedScheduler = Arc<dyn Scheduler>;
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use pollen_executor::simple_handler;
292    use pollen_store::{MemoryStore, StoreBackend};
293
294    #[tokio::test]
295    async fn test_register_task() {
296        let clock = pollen_clock::new_clock();
297        let store = Arc::new(StoreBackend::Memory(MemoryStore::new()));
298        let scheduler = DefaultScheduler::new(clock, store, None);
299
300        let task = TaskDef::new("test", Schedule::interval(Duration::from_secs(60)));
301        let handler = simple_handler(|| async { Ok(()) });
302
303        scheduler.register(task.clone(), handler).await.unwrap();
304
305        let fetched = scheduler.get_task_by_name("test");
306        assert!(fetched.is_some());
307        assert_eq!(fetched.unwrap().name, "test");
308    }
309
310    #[tokio::test]
311    async fn test_duplicate_name() {
312        let clock = pollen_clock::new_clock();
313        let store = Arc::new(StoreBackend::Memory(MemoryStore::new()));
314        let scheduler = DefaultScheduler::new(clock, store, None);
315
316        let task1 = TaskDef::new("test", Schedule::interval(Duration::from_secs(60)));
317        let task2 = TaskDef::new("test", Schedule::interval(Duration::from_secs(30)));
318        let handler = simple_handler(|| async { Ok(()) });
319
320        scheduler.register(task1, handler.clone()).await.unwrap();
321        let result = scheduler.register(task2, handler).await;
322
323        assert!(result.is_err());
324    }
325
326    #[test]
327    fn test_next_execution_interval() {
328        let now = Utc::now();
329        let schedule = Schedule::interval(Duration::from_secs(60));
330        let next = compute_next_execution(&schedule, now);
331
332        assert!(next.is_some());
333        assert!(next.unwrap() > now);
334    }
335
336    #[test]
337    fn test_next_execution_once_past() {
338        let past = Utc::now() - chrono::Duration::hours(1);
339        let schedule = Schedule::Once(past);
340        let next = compute_next_execution(&schedule, Utc::now());
341
342        assert!(next.is_none());
343    }
344}