ora_scheduler/
scheduler.rs

1//! Scheduler implementation.
2
3use core::pin::pin;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use ahash::{AHashMap, HashSet};
7use futures::TryStreamExt;
8use ora_common::{schedule::NewTask, timeout::TimeoutPolicy, UnixNanos};
9use ora_timer::{resolution::Milliseconds, Timer, TimerHandle};
10use ora_util::schedule::next_schedule_task;
11use thiserror::Error;
12use tokio::select;
13use uuid::Uuid;
14
15use crate::store::{
16    schedule::{SchedulerScheduleStore, SchedulerScheduleStoreEvent},
17    task::{ActiveTask, SchedulerTaskStore, SchedulerTaskStoreEvent},
18};
19
20/// A scheduler that has a purpose of marking
21/// tasks as ready when their target timestamp is reached.
22///
23/// It also manages spawning new tasks of schedules if needed.
24pub struct Scheduler<S> {
25    store: S,
26    default_timeout: Option<TimeoutPolicy>,
27}
28
29impl<S> Scheduler<S>
30where
31    S: SchedulerTaskStore + SchedulerScheduleStore,
32{
33    /// Create a new scheduler with the given backing store.
34    #[must_use]
35    pub fn new(store: S) -> Self {
36        Self {
37            store,
38            default_timeout: None,
39        }
40    }
41
42    /// Set the default timeout policy for all tasks
43    /// that do not have a timeout policy set.
44    ///
45    /// This is useful as a safety measure to prevent
46    /// tasks from running indefinitely or getting
47    /// stuck in a running state by unclean shutdowns of workers.
48    ///
49    /// By default, no timeout is set.
50    #[must_use]
51    pub fn with_default_timeout(mut self, timeout: Option<TimeoutPolicy>) -> Self {
52        self.default_timeout = timeout;
53        self
54    }
55
56    /// Run the scheduler indefinitely or until a store error occurs.
57    #[tracing::instrument(level = "debug", skip_all)]
58    pub async fn run(self) -> Result<(), Error> {
59        let schedule_manager = ScheduleManager::new(&self.store);
60        let mut schedule_manager_task = pin!(schedule_manager.run());
61
62        let mut events = pin!(SchedulerTaskStore::events(&self.store)
63            .await
64            .map_err(store_error)?);
65        let pending_tasks = self.store.pending_tasks().await.map_err(store_error)?;
66
67        // Used for cancellations and to prevent accidental duplications.
68        let mut scheduled_tasks: AHashMap<Uuid, ScheduledTask> = AHashMap::new();
69
70        let (timer, mut ready_entries) = Timer::<TimerEntry, Milliseconds>::new();
71        let timer_handle = timer.handle();
72
73        let mut timer_fut = pin!(timer.run());
74
75        for task in pending_tasks {
76            handle_event(
77                SchedulerTaskStoreEvent::TaskAdded(task),
78                &timer_handle,
79                &mut scheduled_tasks,
80                self.default_timeout,
81            );
82        }
83
84        // Schedule timeouts for existing tasks,
85        // this is done once at the beginning so no deduplication
86        // is done.
87        let active_tasks = self.store.active_tasks().await.map_err(store_error)?;
88
89        for task in active_tasks {
90            schedule_timeout(task, &timer_handle, self.default_timeout);
91        }
92
93        loop {
94            select! {
95                _ = &mut timer_fut => {
96                    panic!("unexpected end of the timer loop");
97                }
98                event = events.try_next() => {
99                    match event {
100                        Ok(event) => {
101                            match event {
102                                Some(event) => {
103                                    handle_event(
104                                        event,
105                                        &timer_handle,
106                                        &mut scheduled_tasks,
107                                        self.default_timeout,
108                                    );
109                                }
110                                None => {
111                                    return Err(Error::UnexpectedEventStreamEnd);
112                                }
113                            }
114                        }
115                        Err(error) => {
116                            return Err(store_error(error));
117                        }
118                    }
119                }
120                timer_entry = ready_entries.recv() => {
121                    let timer_entry = timer_entry.unwrap();
122                    match timer_entry {
123                        TimerEntry::TaskReady(task_id) => {
124                            let state = scheduled_tasks.remove(&task_id).unwrap();
125                            tracing::trace!(%task_id, "task ready");
126                            if !state.cancelled {
127                                self.store.task_ready(task_id).await.map_err(store_error)?;
128                            }
129                        }
130                        TimerEntry::TaskTimeout(task_id) => {
131                            self.store.task_timed_out(task_id).await.map_err(store_error)?;
132                        }
133                    }
134                }
135                manager_result = &mut schedule_manager_task => {
136                    manager_result?;
137                    unreachable!()
138                }
139            }
140        }
141    }
142}
143
144#[tracing::instrument(level = "trace", skip_all)]
145fn handle_event(
146    event: SchedulerTaskStoreEvent,
147    timer: &TimerHandle<TimerEntry>,
148    scheduled_tasks: &mut AHashMap<Uuid, ScheduledTask>,
149    default_timeout: Option<TimeoutPolicy>,
150) {
151    match event {
152        SchedulerTaskStoreEvent::TaskAdded(task) => {
153            if scheduled_tasks.contains_key(&task.id) {
154                tracing::debug!(task_id = %task.id, "task already scheduled");
155                return;
156            }
157            let task_unix = Duration::from_nanos(task.target.0);
158
159            let now = SystemTime::now()
160                .duration_since(UNIX_EPOCH)
161                .expect("system time cannot be before unix epoch");
162
163            let task_delay = task_unix.saturating_sub(now);
164
165            tracing::trace!(task_id = %task.id, "task scheduled");
166            scheduled_tasks.insert(task.id, ScheduledTask::default());
167            timer.schedule(TimerEntry::TaskReady(task.id), task_delay);
168            schedule_timeout(task.into(), timer, default_timeout);
169        }
170        SchedulerTaskStoreEvent::TaskCancelled(task_id) => {
171            if let Some(task) = scheduled_tasks.get_mut(&task_id) {
172                if task.cancelled {
173                    tracing::debug!(%task_id, "task already cancelled");
174                }
175                tracing::trace!(%task_id, "task cancelled");
176                task.cancelled = true;
177            } else {
178                tracing::debug!(%task_id, "task was cancelled but it was not scheduled");
179            }
180        }
181    }
182}
183
184#[tracing::instrument(level = "trace", skip_all)]
185fn schedule_timeout(
186    task: ActiveTask,
187    timer: &TimerHandle<TimerEntry>,
188    default_timeout: Option<TimeoutPolicy>,
189) {
190    let mut timeout = task.timeout;
191
192    if let Some(default_timeout) = default_timeout {
193        if matches!(timeout, TimeoutPolicy::Never) {
194            timeout = default_timeout;
195        }
196    }
197
198    match timeout {
199        TimeoutPolicy::Never => {}
200        TimeoutPolicy::FromTarget { timeout } => {
201            let task_unix = Duration::from_nanos(task.target.0);
202
203            let timeout_unix: Duration = match Duration::try_from(timeout) {
204                Ok(t) => t + task_unix,
205                Err(error) => {
206                    tracing::warn!(%error, "timeout out of range");
207                    return;
208                }
209            };
210
211            let now = SystemTime::now()
212                .duration_since(UNIX_EPOCH)
213                .expect("system time cannot be before unix epoch");
214
215            let timeout_delay = timeout_unix.saturating_sub(now);
216
217            timer.schedule(TimerEntry::TaskTimeout(task.id), timeout_delay);
218        }
219    }
220}
221
222/// A scheduler error.
223#[derive(Debug, Error)]
224pub enum Error {
225    /// The store event stream ended unexpectedly.
226    #[error("unexpected end of event stream")]
227    UnexpectedEventStreamEnd,
228    /// A store error ocurred.
229    #[error("store error: {0:?}")]
230    Store(Box<dyn std::error::Error + Send + Sync>),
231}
232
233#[derive(Default)]
234struct ScheduledTask {
235    cancelled: bool,
236}
237
238#[derive(Debug)]
239enum TimerEntry {
240    TaskReady(Uuid),
241    TaskTimeout(Uuid),
242}
243
244struct ScheduleManager<'s, S>
245where
246    S: SchedulerScheduleStore,
247{
248    store: &'s S,
249    active_schedules: HashSet<Uuid>,
250}
251
252impl<'s, S> ScheduleManager<'s, S>
253where
254    S: SchedulerScheduleStore,
255{
256    fn new(store: &'s S) -> Self {
257        Self {
258            store,
259            active_schedules: HashSet::default(),
260        }
261    }
262
263    async fn run(mut self) -> Result<(), Error> {
264        let mut events = pin!(self.store.events().await.map_err(store_error)?);
265
266        let pending_schedules = self.store.pending_schedules().await.map_err(store_error)?;
267
268        for schedule in pending_schedules {
269            self.handle_event(SchedulerScheduleStoreEvent::ScheduleAdded(Box::new(
270                schedule,
271            )))
272            .await?;
273        }
274
275        while let Some(event) = events.try_next().await.map_err(store_error)? {
276            self.handle_event(event).await?;
277        }
278
279        Err(Error::UnexpectedEventStreamEnd)
280    }
281
282    async fn handle_event(&mut self, event: SchedulerScheduleStoreEvent) -> Result<(), Error> {
283        match event {
284            SchedulerScheduleStoreEvent::ScheduleAdded(schedule) => {
285                if self.active_schedules.contains(&schedule.id) {
286                    tracing::debug!("active schedule already exists");
287                    return Ok(());
288                }
289                self.active_schedules.insert(schedule.id);
290
291                let next_target = next_schedule_task(&schedule.definition, None, UnixNanos::now());
292
293                if let Some(next_target) = next_target {
294                    match &schedule.definition.new_task {
295                        NewTask::Repeat { task } => {
296                            self.store
297                                .add_task(schedule.id, task.clone().at_unix(next_target))
298                                .await
299                                .map_err(store_error)?;
300                        }
301                    }
302                }
303            }
304            SchedulerScheduleStoreEvent::TaskFinished(task_id) => {
305                if let Some(schedule) = self
306                    .store
307                    .pending_schedule_of_task(task_id)
308                    .await
309                    .map_err(store_error)?
310                {
311                    self.active_schedules.insert(schedule.id);
312                    let prev_target = self.store.task_target(task_id).await.map_err(store_error)?;
313                    let next_target = next_schedule_task(
314                        &schedule.definition,
315                        Some(prev_target),
316                        UnixNanos::now(),
317                    );
318
319                    if let Some(next_target) = next_target {
320                        match &schedule.definition.new_task {
321                            NewTask::Repeat { task } => {
322                                self.store
323                                    .add_task(schedule.id, task.clone().at_unix(next_target))
324                                    .await
325                                    .map_err(store_error)?;
326                            }
327                        }
328                    }
329                }
330            }
331            SchedulerScheduleStoreEvent::ScheduleCancelled(schedule_id) => {
332                self.active_schedules.remove(&schedule_id);
333            }
334        }
335
336        Ok(())
337    }
338}
339
340fn store_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
341    Error::Store(Box::new(error))
342}