durable_execution_sdk_testing/checkpoint_server/
scheduler.rs1use std::collections::VecDeque;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::atomic::{AtomicBool, Ordering};
26use std::sync::Arc;
27
28use chrono::{DateTime, Utc};
29use tokio::sync::Mutex;
30use tokio::task::JoinHandle;
31
32use crate::error::TestError;
33
34pub type BoxedAsyncFn = Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
36
37pub type ErrorHandler = Box<dyn FnOnce(TestError) + Send>;
39
40pub type CheckpointUpdateFn =
42 Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<(), TestError>> + Send>> + Send>;
43
44pub struct ScheduledFunction {
46 pub start_invocation: BoxedAsyncFn,
48 pub on_error: ErrorHandler,
50 pub timestamp: Option<DateTime<Utc>>,
52 pub update_checkpoint: Option<CheckpointUpdateFn>,
54}
55
56impl std::fmt::Debug for ScheduledFunction {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("ScheduledFunction")
59 .field("timestamp", &self.timestamp)
60 .field("has_update_checkpoint", &self.update_checkpoint.is_some())
61 .finish()
62 }
63}
64
65pub trait Scheduler: Send {
70 fn schedule_function(
79 &mut self,
80 start_invocation: BoxedAsyncFn,
81 on_error: ErrorHandler,
82 timestamp: Option<DateTime<Utc>>,
83 update_checkpoint: Option<CheckpointUpdateFn>,
84 );
85
86 fn has_scheduled_function(&self) -> bool;
88
89 fn flush_timers(&mut self);
91
92 fn process_next(&mut self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
96}
97
98pub struct QueueScheduler {
109 function_queue: VecDeque<ScheduledFunction>,
111 is_processing: Arc<AtomicBool>,
113}
114
115impl QueueScheduler {
116 pub fn new() -> Self {
118 Self {
119 function_queue: VecDeque::new(),
120 is_processing: Arc::new(AtomicBool::new(false)),
121 }
122 }
123
124 pub fn is_processing(&self) -> bool {
126 self.is_processing.load(Ordering::SeqCst)
127 }
128
129 pub fn queue_len(&self) -> usize {
131 self.function_queue.len()
132 }
133}
134
135impl Default for QueueScheduler {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl Scheduler for QueueScheduler {
142 fn schedule_function(
143 &mut self,
144 start_invocation: BoxedAsyncFn,
145 on_error: ErrorHandler,
146 timestamp: Option<DateTime<Utc>>,
147 update_checkpoint: Option<CheckpointUpdateFn>,
148 ) {
149 let scheduled = ScheduledFunction {
150 start_invocation,
151 on_error,
152 timestamp,
153 update_checkpoint,
154 };
155 self.function_queue.push_back(scheduled);
156 }
157
158 fn has_scheduled_function(&self) -> bool {
159 !self.function_queue.is_empty()
160 }
161
162 fn flush_timers(&mut self) {
163 self.function_queue.clear();
164 }
165
166 fn process_next(&mut self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
167 Box::pin(async move {
168 if let Some(scheduled) = self.function_queue.pop_front() {
169 self.is_processing.store(true, Ordering::SeqCst);
170
171 if let Some(update_fn) = scheduled.update_checkpoint {
173 let update_future = update_fn();
174 if let Err(e) = update_future.await {
175 (scheduled.on_error)(e);
176 self.is_processing.store(false, Ordering::SeqCst);
177 return true;
178 }
179 }
180
181 let invocation_future = (scheduled.start_invocation)();
183 invocation_future.await;
184
185 self.is_processing.store(false, Ordering::SeqCst);
186 true
187 } else {
188 false
189 }
190 })
191 }
192}
193
194pub struct TimerScheduler {
205 scheduled_tasks: Vec<JoinHandle<()>>,
207 pending_count: Arc<Mutex<usize>>,
209}
210
211impl TimerScheduler {
212 pub fn new() -> Self {
214 Self {
215 scheduled_tasks: Vec::new(),
216 pending_count: Arc::new(Mutex::new(0)),
217 }
218 }
219
220 pub async fn pending_count(&self) -> usize {
222 *self.pending_count.lock().await
223 }
224}
225
226impl Default for TimerScheduler {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232impl Scheduler for TimerScheduler {
233 fn schedule_function(
234 &mut self,
235 start_invocation: BoxedAsyncFn,
236 on_error: ErrorHandler,
237 timestamp: Option<DateTime<Utc>>,
238 update_checkpoint: Option<CheckpointUpdateFn>,
239 ) {
240 let pending_count = Arc::clone(&self.pending_count);
241
242 let pending_count_clone = Arc::clone(&pending_count);
244 tokio::spawn(async move {
245 let mut count = pending_count_clone.lock().await;
246 *count += 1;
247 });
248
249 let handle = tokio::spawn(async move {
250 if let Some(ts) = timestamp {
252 let now = Utc::now();
253 if ts > now {
254 let duration = (ts - now).to_std().unwrap_or_default();
255 tokio::time::sleep(duration).await;
256 }
257 }
258
259 if let Some(update_fn) = update_checkpoint {
261 let update_future = update_fn();
262 if let Err(e) = update_future.await {
263 (on_error)(e);
264 let mut count = pending_count.lock().await;
266 *count = count.saturating_sub(1);
267 return;
268 }
269 }
270
271 let invocation_future = start_invocation();
273 invocation_future.await;
274
275 let mut count = pending_count.lock().await;
277 *count = count.saturating_sub(1);
278 });
279
280 self.scheduled_tasks.push(handle);
281 }
282
283 fn has_scheduled_function(&self) -> bool {
284 self.scheduled_tasks.iter().any(|h| !h.is_finished())
286 }
287
288 fn flush_timers(&mut self) {
289 for handle in self.scheduled_tasks.drain(..) {
291 handle.abort();
292 }
293 }
294
295 fn process_next(&mut self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
296 Box::pin(async move {
300 self.scheduled_tasks.retain(|h| !h.is_finished());
302 !self.scheduled_tasks.is_empty()
303 })
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use std::sync::atomic::AtomicUsize;
311
312 #[tokio::test]
313 async fn test_queue_scheduler_new() {
314 let scheduler = QueueScheduler::new();
315 assert!(!scheduler.has_scheduled_function());
316 assert!(!scheduler.is_processing());
317 assert_eq!(scheduler.queue_len(), 0);
318 }
319
320 #[tokio::test]
321 async fn test_queue_scheduler_schedule_and_process() {
322 let mut scheduler = QueueScheduler::new();
323 let counter = Arc::new(AtomicUsize::new(0));
324 let counter_clone = Arc::clone(&counter);
325
326 scheduler.schedule_function(
327 Box::new(move || {
328 let counter = Arc::clone(&counter_clone);
329 Box::pin(async move {
330 counter.fetch_add(1, Ordering::SeqCst);
331 })
332 }),
333 Box::new(|_| {}),
334 None,
335 None,
336 );
337
338 assert!(scheduler.has_scheduled_function());
339 assert_eq!(scheduler.queue_len(), 1);
340
341 let processed = scheduler.process_next().await;
342 assert!(processed);
343 assert_eq!(counter.load(Ordering::SeqCst), 1);
344 assert!(!scheduler.has_scheduled_function());
345 }
346
347 #[tokio::test]
348 async fn test_queue_scheduler_fifo_order() {
349 let mut scheduler = QueueScheduler::new();
350 let order = Arc::new(Mutex::new(Vec::new()));
351
352 for i in 0..3 {
354 let order_clone = Arc::clone(&order);
355 scheduler.schedule_function(
356 Box::new(move || {
357 Box::pin(async move {
358 order_clone.lock().await.push(i);
359 })
360 }),
361 Box::new(|_| {}),
362 None,
363 None,
364 );
365 }
366
367 while scheduler.process_next().await {}
369
370 let result = order.lock().await;
372 assert_eq!(*result, vec![0, 1, 2]);
373 }
374
375 #[tokio::test]
376 async fn test_queue_scheduler_with_checkpoint_update() {
377 let mut scheduler = QueueScheduler::new();
378 let checkpoint_called = Arc::new(AtomicBool::new(false));
379 let invocation_called = Arc::new(AtomicBool::new(false));
380
381 let checkpoint_clone = Arc::clone(&checkpoint_called);
382 let invocation_clone = Arc::clone(&invocation_called);
383
384 scheduler.schedule_function(
385 Box::new(move || {
386 let invocation = Arc::clone(&invocation_clone);
387 Box::pin(async move {
388 invocation.store(true, Ordering::SeqCst);
389 })
390 }),
391 Box::new(|_| {}),
392 None,
393 Some(Box::new(move || {
394 let checkpoint = Arc::clone(&checkpoint_clone);
395 Box::pin(async move {
396 checkpoint.store(true, Ordering::SeqCst);
397 Ok(())
398 })
399 })),
400 );
401
402 scheduler.process_next().await;
403
404 assert!(checkpoint_called.load(Ordering::SeqCst));
405 assert!(invocation_called.load(Ordering::SeqCst));
406 }
407
408 #[tokio::test]
409 async fn test_queue_scheduler_flush() {
410 let mut scheduler = QueueScheduler::new();
411
412 for _ in 0..5 {
413 scheduler.schedule_function(
414 Box::new(|| Box::pin(async {})),
415 Box::new(|_| {}),
416 None,
417 None,
418 );
419 }
420
421 assert_eq!(scheduler.queue_len(), 5);
422 scheduler.flush_timers();
423 assert_eq!(scheduler.queue_len(), 0);
424 assert!(!scheduler.has_scheduled_function());
425 }
426
427 #[tokio::test]
428 async fn test_timer_scheduler_new() {
429 let scheduler = TimerScheduler::new();
430 assert!(!scheduler.has_scheduled_function());
431 }
432
433 #[tokio::test]
434 async fn test_timer_scheduler_schedule_immediate() {
435 let mut scheduler = TimerScheduler::new();
436 let counter = Arc::new(AtomicUsize::new(0));
437 let counter_clone = Arc::clone(&counter);
438
439 scheduler.schedule_function(
440 Box::new(move || {
441 let counter = Arc::clone(&counter_clone);
442 Box::pin(async move {
443 counter.fetch_add(1, Ordering::SeqCst);
444 })
445 }),
446 Box::new(|_| {}),
447 None, None,
449 );
450
451 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
453
454 assert_eq!(counter.load(Ordering::SeqCst), 1);
455 }
456
457 #[tokio::test]
458 async fn test_timer_scheduler_flush() {
459 let mut scheduler = TimerScheduler::new();
460 let counter = Arc::new(AtomicUsize::new(0));
461
462 let counter_clone = Arc::clone(&counter);
464 let future_time = Utc::now() + chrono::Duration::seconds(10);
465
466 scheduler.schedule_function(
467 Box::new(move || {
468 let counter = Arc::clone(&counter_clone);
469 Box::pin(async move {
470 counter.fetch_add(1, Ordering::SeqCst);
471 })
472 }),
473 Box::new(|_| {}),
474 Some(future_time),
475 None,
476 );
477
478 scheduler.flush_timers();
480
481 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
483
484 assert_eq!(counter.load(Ordering::SeqCst), 0);
486 }
487}