durable_execution_sdk_testing/checkpoint_server/
scheduler.rs1use std::collections::VecDeque;
23use std::future::Future;
24use std::pin::Pin;
25use std::sync::atomic::{AtomicBool, Ordering};
26use std::sync::Arc;
27
28use chrono::{DateTime, Utc};
29use tokio::sync::Mutex;
30use tokio::task::JoinHandle;
31
32use crate::error::TestError;
33
34pub type BoxedAsyncFn = Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
36
37pub type ErrorHandler = Box<dyn FnOnce(TestError) + Send>;
39
40pub type CheckpointUpdateFn =
42 Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<(), TestError>> + Send>> + Send>;
43
44pub struct ScheduledFunction {
46 pub start_invocation: BoxedAsyncFn,
48 pub on_error: ErrorHandler,
50 pub timestamp: Option<DateTime<Utc>>,
52 pub update_checkpoint: Option<CheckpointUpdateFn>,
54}
55
56impl std::fmt::Debug for ScheduledFunction {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("ScheduledFunction")
59 .field("timestamp", &self.timestamp)
60 .field("has_update_checkpoint", &self.update_checkpoint.is_some())
61 .finish()
62 }
63}
64
65pub trait Scheduler: Send {
70 fn schedule_function(
79 &mut self,
80 start_invocation: BoxedAsyncFn,
81 on_error: ErrorHandler,
82 timestamp: Option<DateTime<Utc>>,
83 update_checkpoint: Option<CheckpointUpdateFn>,
84 );
85
86 fn has_scheduled_function(&self) -> bool;
88
89 fn flush_timers(&mut self);
91
92 fn process_next(&mut self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
96}
97
98pub struct QueueScheduler {
104 function_queue: VecDeque<ScheduledFunction>,
106 is_processing: Arc<AtomicBool>,
108}
109
110impl QueueScheduler {
111 pub fn new() -> Self {
113 Self {
114 function_queue: VecDeque::new(),
115 is_processing: Arc::new(AtomicBool::new(false)),
116 }
117 }
118
119 pub fn is_processing(&self) -> bool {
121 self.is_processing.load(Ordering::SeqCst)
122 }
123
124 pub fn queue_len(&self) -> usize {
126 self.function_queue.len()
127 }
128}
129
130impl Default for QueueScheduler {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl Scheduler for QueueScheduler {
137 fn schedule_function(
138 &mut self,
139 start_invocation: BoxedAsyncFn,
140 on_error: ErrorHandler,
141 timestamp: Option<DateTime<Utc>>,
142 update_checkpoint: Option<CheckpointUpdateFn>,
143 ) {
144 let scheduled = ScheduledFunction {
145 start_invocation,
146 on_error,
147 timestamp,
148 update_checkpoint,
149 };
150 self.function_queue.push_back(scheduled);
151 }
152
153 fn has_scheduled_function(&self) -> bool {
154 !self.function_queue.is_empty()
155 }
156
157 fn flush_timers(&mut self) {
158 self.function_queue.clear();
159 }
160
161 fn process_next(&mut self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
162 Box::pin(async move {
163 if let Some(scheduled) = self.function_queue.pop_front() {
164 self.is_processing.store(true, Ordering::SeqCst);
165
166 if let Some(update_fn) = scheduled.update_checkpoint {
168 let update_future = update_fn();
169 if let Err(e) = update_future.await {
170 (scheduled.on_error)(e);
171 self.is_processing.store(false, Ordering::SeqCst);
172 return true;
173 }
174 }
175
176 let invocation_future = (scheduled.start_invocation)();
178 invocation_future.await;
179
180 self.is_processing.store(false, Ordering::SeqCst);
181 true
182 } else {
183 false
184 }
185 })
186 }
187}
188
189pub struct TimerScheduler {
195 scheduled_tasks: Vec<JoinHandle<()>>,
197 pending_count: Arc<Mutex<usize>>,
199}
200
201impl TimerScheduler {
202 pub fn new() -> Self {
204 Self {
205 scheduled_tasks: Vec::new(),
206 pending_count: Arc::new(Mutex::new(0)),
207 }
208 }
209
210 pub async fn pending_count(&self) -> usize {
212 *self.pending_count.lock().await
213 }
214}
215
216impl Default for TimerScheduler {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222impl Scheduler for TimerScheduler {
223 fn schedule_function(
224 &mut self,
225 start_invocation: BoxedAsyncFn,
226 on_error: ErrorHandler,
227 timestamp: Option<DateTime<Utc>>,
228 update_checkpoint: Option<CheckpointUpdateFn>,
229 ) {
230 let pending_count = Arc::clone(&self.pending_count);
231
232 let pending_count_clone = Arc::clone(&pending_count);
234 tokio::spawn(async move {
235 let mut count = pending_count_clone.lock().await;
236 *count += 1;
237 });
238
239 let handle = tokio::spawn(async move {
240 if let Some(ts) = timestamp {
242 let now = Utc::now();
243 if ts > now {
244 let duration = (ts - now).to_std().unwrap_or_default();
245 tokio::time::sleep(duration).await;
246 }
247 }
248
249 if let Some(update_fn) = update_checkpoint {
251 let update_future = update_fn();
252 if let Err(e) = update_future.await {
253 (on_error)(e);
254 let mut count = pending_count.lock().await;
256 *count = count.saturating_sub(1);
257 return;
258 }
259 }
260
261 let invocation_future = start_invocation();
263 invocation_future.await;
264
265 let mut count = pending_count.lock().await;
267 *count = count.saturating_sub(1);
268 });
269
270 self.scheduled_tasks.push(handle);
271 }
272
273 fn has_scheduled_function(&self) -> bool {
274 self.scheduled_tasks.iter().any(|h| !h.is_finished())
276 }
277
278 fn flush_timers(&mut self) {
279 for handle in self.scheduled_tasks.drain(..) {
281 handle.abort();
282 }
283 }
284
285 fn process_next(&mut self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
286 Box::pin(async move {
290 self.scheduled_tasks.retain(|h| !h.is_finished());
292 !self.scheduled_tasks.is_empty()
293 })
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use std::sync::atomic::AtomicUsize;
301
302 #[tokio::test]
303 async fn test_queue_scheduler_new() {
304 let scheduler = QueueScheduler::new();
305 assert!(!scheduler.has_scheduled_function());
306 assert!(!scheduler.is_processing());
307 assert_eq!(scheduler.queue_len(), 0);
308 }
309
310 #[tokio::test]
311 async fn test_queue_scheduler_schedule_and_process() {
312 let mut scheduler = QueueScheduler::new();
313 let counter = Arc::new(AtomicUsize::new(0));
314 let counter_clone = Arc::clone(&counter);
315
316 scheduler.schedule_function(
317 Box::new(move || {
318 let counter = Arc::clone(&counter_clone);
319 Box::pin(async move {
320 counter.fetch_add(1, Ordering::SeqCst);
321 })
322 }),
323 Box::new(|_| {}),
324 None,
325 None,
326 );
327
328 assert!(scheduler.has_scheduled_function());
329 assert_eq!(scheduler.queue_len(), 1);
330
331 let processed = scheduler.process_next().await;
332 assert!(processed);
333 assert_eq!(counter.load(Ordering::SeqCst), 1);
334 assert!(!scheduler.has_scheduled_function());
335 }
336
337 #[tokio::test]
338 async fn test_queue_scheduler_fifo_order() {
339 let mut scheduler = QueueScheduler::new();
340 let order = Arc::new(Mutex::new(Vec::new()));
341
342 for i in 0..3 {
344 let order_clone = Arc::clone(&order);
345 scheduler.schedule_function(
346 Box::new(move || {
347 Box::pin(async move {
348 order_clone.lock().await.push(i);
349 })
350 }),
351 Box::new(|_| {}),
352 None,
353 None,
354 );
355 }
356
357 while scheduler.process_next().await {}
359
360 let result = order.lock().await;
362 assert_eq!(*result, vec![0, 1, 2]);
363 }
364
365 #[tokio::test]
366 async fn test_queue_scheduler_with_checkpoint_update() {
367 let mut scheduler = QueueScheduler::new();
368 let checkpoint_called = Arc::new(AtomicBool::new(false));
369 let invocation_called = Arc::new(AtomicBool::new(false));
370
371 let checkpoint_clone = Arc::clone(&checkpoint_called);
372 let invocation_clone = Arc::clone(&invocation_called);
373
374 scheduler.schedule_function(
375 Box::new(move || {
376 let invocation = Arc::clone(&invocation_clone);
377 Box::pin(async move {
378 invocation.store(true, Ordering::SeqCst);
379 })
380 }),
381 Box::new(|_| {}),
382 None,
383 Some(Box::new(move || {
384 let checkpoint = Arc::clone(&checkpoint_clone);
385 Box::pin(async move {
386 checkpoint.store(true, Ordering::SeqCst);
387 Ok(())
388 })
389 })),
390 );
391
392 scheduler.process_next().await;
393
394 assert!(checkpoint_called.load(Ordering::SeqCst));
395 assert!(invocation_called.load(Ordering::SeqCst));
396 }
397
398 #[tokio::test]
399 async fn test_queue_scheduler_flush() {
400 let mut scheduler = QueueScheduler::new();
401
402 for _ in 0..5 {
403 scheduler.schedule_function(
404 Box::new(|| Box::pin(async {})),
405 Box::new(|_| {}),
406 None,
407 None,
408 );
409 }
410
411 assert_eq!(scheduler.queue_len(), 5);
412 scheduler.flush_timers();
413 assert_eq!(scheduler.queue_len(), 0);
414 assert!(!scheduler.has_scheduled_function());
415 }
416
417 #[tokio::test]
418 async fn test_timer_scheduler_new() {
419 let scheduler = TimerScheduler::new();
420 assert!(!scheduler.has_scheduled_function());
421 }
422
423 #[tokio::test]
424 async fn test_timer_scheduler_schedule_immediate() {
425 let mut scheduler = TimerScheduler::new();
426 let counter = Arc::new(AtomicUsize::new(0));
427 let counter_clone = Arc::clone(&counter);
428
429 scheduler.schedule_function(
430 Box::new(move || {
431 let counter = Arc::clone(&counter_clone);
432 Box::pin(async move {
433 counter.fetch_add(1, Ordering::SeqCst);
434 })
435 }),
436 Box::new(|_| {}),
437 None, None,
439 );
440
441 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
443
444 assert_eq!(counter.load(Ordering::SeqCst), 1);
445 }
446
447 #[tokio::test]
448 async fn test_timer_scheduler_flush() {
449 let mut scheduler = TimerScheduler::new();
450 let counter = Arc::new(AtomicUsize::new(0));
451
452 let counter_clone = Arc::clone(&counter);
454 let future_time = Utc::now() + chrono::Duration::seconds(10);
455
456 scheduler.schedule_function(
457 Box::new(move || {
458 let counter = Arc::clone(&counter_clone);
459 Box::pin(async move {
460 counter.fetch_add(1, Ordering::SeqCst);
461 })
462 }),
463 Box::new(|_| {}),
464 Some(future_time),
465 None,
466 );
467
468 scheduler.flush_timers();
470
471 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
473
474 assert_eq!(counter.load(Ordering::SeqCst), 0);
476 }
477}