1use crate::config::SchedulerConfig;
4use crate::error::{Result, SchedulerError};
5use crate::events::SchedulerEvent;
6use crate::execution::execute_task;
7use crate::projection::{ScheduleProjection, ScheduledTask, TaskFilter};
8use crate::schedule::Schedule;
9use crate::task_handler::{TaskHandler, TaskHandlerRegistry};
10use azoth::AzothDb;
11use azoth_core::traits::{CanonicalStore, CanonicalTxn};
12use chrono::Utc;
13use rusqlite::Connection;
14use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17use tracing::{debug, error, info};
18
19pub struct Scheduler {
25 db: Arc<AzothDb>,
26 projection: Arc<ScheduleProjection>,
27 handler_registry: Arc<TaskHandlerRegistry>,
28 config: SchedulerConfig,
29 shutdown: Arc<AtomicBool>,
30 concurrent_tasks: Arc<AtomicUsize>,
31}
32
33impl Scheduler {
34 pub fn builder(db: Arc<AzothDb>) -> SchedulerBuilder {
36 SchedulerBuilder::new(db)
37 }
38
39 pub async fn run(&mut self) -> Result<()> {
47 info!("Scheduler starting");
48
49 while !self.shutdown.load(Ordering::SeqCst) {
50 let now = Utc::now().timestamp();
51
52 let due_tasks = match self.projection.get_due_tasks(now) {
54 Ok(tasks) => tasks,
55 Err(e) => {
56 error!("Failed to get due tasks: {}", e);
57 tokio::time::sleep(Duration::from_secs(1)).await;
58 continue;
59 }
60 };
61
62 if due_tasks.is_empty() {
63 let next_wake = self
65 .projection
66 .get_next_wake_time()
67 .ok()
68 .flatten()
69 .unwrap_or_else(|| {
70 Utc::now() + chrono::Duration::from_std(self.config.poll_interval).unwrap()
71 });
72
73 let sleep_duration = (next_wake.timestamp() - Utc::now().timestamp()).max(0) as u64;
74 let sleep_duration =
75 Duration::from_secs(sleep_duration).min(self.config.poll_interval);
76
77 debug!("No due tasks, sleeping for {:?}", sleep_duration);
78 tokio::time::sleep(sleep_duration).await;
79 continue;
80 }
81
82 debug!("Found {} due tasks", due_tasks.len());
83
84 for task in due_tasks {
86 while self.concurrent_tasks.load(Ordering::SeqCst)
88 >= self.config.max_concurrent_tasks
89 {
90 debug!("At concurrency limit, waiting");
91 tokio::time::sleep(Duration::from_millis(100)).await;
92 }
93
94 let db = self.db.clone();
96 let handler_registry = self.handler_registry.clone();
97 let concurrent_tasks = self.concurrent_tasks.clone();
98
99 concurrent_tasks.fetch_add(1, Ordering::SeqCst);
100
101 tokio::spawn(async move {
102 let result = execute_task(db, handler_registry, task).await;
103 concurrent_tasks.fetch_sub(1, Ordering::SeqCst);
104
105 if let Err(e) = result {
106 error!("Task execution failed: {}", e);
107 }
108 });
109 }
110
111 tokio::time::sleep(Duration::from_millis(100)).await;
113 }
114
115 info!("Scheduler shutting down, waiting for tasks to complete");
117 while self.concurrent_tasks.load(Ordering::SeqCst) > 0 {
118 tokio::time::sleep(Duration::from_millis(100)).await;
119 }
120
121 info!("Scheduler stopped");
122 Ok(())
123 }
124
125 pub fn shutdown(&self) {
127 info!("Shutdown signal received");
128 self.shutdown.store(true, Ordering::SeqCst);
129 }
130
131 pub fn schedule_task(&self, request: ScheduleTaskRequest) -> Result<()> {
136 request.schedule.validate()?;
138
139 if !self.handler_registry.has(&request.task_type) {
141 return Err(SchedulerError::HandlerNotFound(request.task_type));
142 }
143
144 let handler = self.handler_registry.get(&request.task_type)?;
146 handler.validate(&request.payload)?;
147
148 let event = SchedulerEvent::TaskScheduled {
150 task_id: request.task_id,
151 task_type: request.task_type,
152 schedule: request.schedule,
153 payload: request.payload,
154 max_retries: request.max_retries,
155 timeout_secs: request.timeout_secs,
156 };
157
158 let mut txn = self.db.canonical().write_txn()?;
159 txn.append_event(&serde_json::to_vec(&event)?)?;
160 txn.commit()?;
161
162 Ok(())
163 }
164
165 pub fn cancel_task(&self, task_id: &str, reason: &str) -> Result<()> {
169 let event = SchedulerEvent::TaskCancelled {
170 task_id: task_id.to_string(),
171 reason: reason.to_string(),
172 };
173
174 let mut txn = self.db.canonical().write_txn()?;
175 txn.append_event(&serde_json::to_vec(&event)?)?;
176 txn.commit()?;
177
178 Ok(())
179 }
180
181 pub fn get_task(&self, task_id: &str) -> Result<Option<ScheduledTask>> {
183 self.projection.get_task(task_id)
184 }
185
186 pub fn list_tasks(&self, filter: &TaskFilter) -> Result<Vec<ScheduledTask>> {
188 self.projection.list_tasks(filter)
189 }
190
191 pub fn concurrent_tasks(&self) -> usize {
193 self.concurrent_tasks.load(Ordering::SeqCst)
194 }
195}
196
197impl Clone for Scheduler {
198 fn clone(&self) -> Self {
199 Self {
200 db: self.db.clone(),
201 projection: self.projection.clone(),
202 handler_registry: self.handler_registry.clone(),
203 config: self.config.clone(),
204 shutdown: self.shutdown.clone(),
205 concurrent_tasks: self.concurrent_tasks.clone(),
206 }
207 }
208}
209
210pub struct SchedulerBuilder {
212 db: Arc<AzothDb>,
213 handlers: Vec<Arc<dyn TaskHandler>>,
214 config: SchedulerConfig,
215}
216
217impl SchedulerBuilder {
218 pub fn new(db: Arc<AzothDb>) -> Self {
220 Self {
221 db,
222 handlers: Vec::new(),
223 config: SchedulerConfig::default(),
224 }
225 }
226
227 pub fn with_task_handler(mut self, handler: impl TaskHandler + 'static) -> Self {
229 self.handlers.push(Arc::new(handler));
230 self
231 }
232
233 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
235 self.config = self.config.with_poll_interval(interval);
236 self
237 }
238
239 pub fn with_max_concurrent_tasks(mut self, max: usize) -> Self {
241 self.config = self.config.with_max_concurrent_tasks(max);
242 self
243 }
244
245 pub fn with_default_max_retries(mut self, retries: u32) -> Self {
247 self.config = self.config.with_default_max_retries(retries);
248 self
249 }
250
251 pub fn with_default_timeout_secs(mut self, timeout: u64) -> Self {
253 self.config = self.config.with_default_timeout_secs(timeout);
254 self
255 }
256
257 #[allow(clippy::arc_with_non_send_sync)]
261 pub fn build(self, projection_conn: Arc<Connection>) -> Result<Scheduler> {
262 let projection = Arc::new(ScheduleProjection::new(projection_conn));
264 projection.init_schema()?;
265
266 let mut handler_registry = TaskHandlerRegistry::new();
268 for handler in self.handlers {
269 handler_registry.register(handler);
270 }
271
272 Ok(Scheduler {
273 db: self.db,
274 projection,
275 handler_registry: Arc::new(handler_registry),
276 config: self.config,
277 shutdown: Arc::new(AtomicBool::new(false)),
278 concurrent_tasks: Arc::new(AtomicUsize::new(0)),
279 })
280 }
281}
282
283#[derive(Debug, Clone)]
285pub struct ScheduleTaskRequest {
286 pub task_id: String,
288 pub task_type: String,
290 pub schedule: Schedule,
292 pub payload: Vec<u8>,
294 pub max_retries: u32,
296 pub timeout_secs: u64,
298}
299
300impl ScheduleTaskRequest {
301 pub fn builder(task_id: impl Into<String>) -> ScheduleTaskRequestBuilder {
303 ScheduleTaskRequestBuilder {
304 task_id: task_id.into(),
305 task_type: None,
306 schedule: None,
307 payload: Vec::new(),
308 max_retries: 3,
309 timeout_secs: 300,
310 }
311 }
312}
313
314pub struct ScheduleTaskRequestBuilder {
316 task_id: String,
317 task_type: Option<String>,
318 schedule: Option<Schedule>,
319 payload: Vec<u8>,
320 max_retries: u32,
321 timeout_secs: u64,
322}
323
324impl ScheduleTaskRequestBuilder {
325 pub fn task_type(mut self, task_type: impl Into<String>) -> Self {
327 self.task_type = Some(task_type.into());
328 self
329 }
330
331 pub fn cron(mut self, expression: impl Into<String>) -> Self {
333 self.schedule = Some(Schedule::Cron {
334 expression: expression.into(),
335 });
336 self
337 }
338
339 pub fn interval(mut self, seconds: u64) -> Self {
341 self.schedule = Some(Schedule::Interval { seconds });
342 self
343 }
344
345 pub fn one_time(mut self, run_at: i64) -> Self {
347 self.schedule = Some(Schedule::OneTime { run_at });
348 self
349 }
350
351 pub fn immediate(mut self) -> Self {
353 self.schedule = Some(Schedule::Immediate);
354 self
355 }
356
357 pub fn schedule(mut self, schedule: Schedule) -> Self {
359 self.schedule = Some(schedule);
360 self
361 }
362
363 pub fn payload(mut self, payload: Vec<u8>) -> Self {
365 self.payload = payload;
366 self
367 }
368
369 pub fn max_retries(mut self, retries: u32) -> Self {
371 self.max_retries = retries;
372 self
373 }
374
375 pub fn timeout_secs(mut self, timeout: u64) -> Self {
377 self.timeout_secs = timeout;
378 self
379 }
380
381 pub fn build(self) -> Result<ScheduleTaskRequest> {
383 let task_type = self
384 .task_type
385 .ok_or_else(|| SchedulerError::InvalidTask("task_type is required".into()))?;
386 let schedule = self
387 .schedule
388 .ok_or_else(|| SchedulerError::InvalidTask("schedule is required".into()))?;
389
390 Ok(ScheduleTaskRequest {
391 task_id: self.task_id,
392 task_type,
393 schedule,
394 payload: self.payload,
395 max_retries: self.max_retries,
396 timeout_secs: self.timeout_secs,
397 })
398 }
399}