ora_scheduler/
scheduler.rs1use core::pin::pin;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use ahash::{AHashMap, HashSet};
7use futures::TryStreamExt;
8use ora_common::{schedule::NewTask, timeout::TimeoutPolicy, UnixNanos};
9use ora_timer::{resolution::Milliseconds, Timer, TimerHandle};
10use ora_util::schedule::next_schedule_task;
11use thiserror::Error;
12use tokio::select;
13use uuid::Uuid;
14
15use crate::store::{
16 schedule::{SchedulerScheduleStore, SchedulerScheduleStoreEvent},
17 task::{ActiveTask, SchedulerTaskStore, SchedulerTaskStoreEvent},
18};
19
20pub struct Scheduler<S> {
25 store: S,
26 default_timeout: Option<TimeoutPolicy>,
27}
28
29impl<S> Scheduler<S>
30where
31 S: SchedulerTaskStore + SchedulerScheduleStore,
32{
33 #[must_use]
35 pub fn new(store: S) -> Self {
36 Self {
37 store,
38 default_timeout: None,
39 }
40 }
41
42 #[must_use]
51 pub fn with_default_timeout(mut self, timeout: Option<TimeoutPolicy>) -> Self {
52 self.default_timeout = timeout;
53 self
54 }
55
56 #[tracing::instrument(level = "debug", skip_all)]
58 pub async fn run(self) -> Result<(), Error> {
59 let schedule_manager = ScheduleManager::new(&self.store);
60 let mut schedule_manager_task = pin!(schedule_manager.run());
61
62 let mut events = pin!(SchedulerTaskStore::events(&self.store)
63 .await
64 .map_err(store_error)?);
65 let pending_tasks = self.store.pending_tasks().await.map_err(store_error)?;
66
67 let mut scheduled_tasks: AHashMap<Uuid, ScheduledTask> = AHashMap::new();
69
70 let (timer, mut ready_entries) = Timer::<TimerEntry, Milliseconds>::new();
71 let timer_handle = timer.handle();
72
73 let mut timer_fut = pin!(timer.run());
74
75 for task in pending_tasks {
76 handle_event(
77 SchedulerTaskStoreEvent::TaskAdded(task),
78 &timer_handle,
79 &mut scheduled_tasks,
80 self.default_timeout,
81 );
82 }
83
84 let active_tasks = self.store.active_tasks().await.map_err(store_error)?;
88
89 for task in active_tasks {
90 schedule_timeout(task, &timer_handle, self.default_timeout);
91 }
92
93 loop {
94 select! {
95 _ = &mut timer_fut => {
96 panic!("unexpected end of the timer loop");
97 }
98 event = events.try_next() => {
99 match event {
100 Ok(event) => {
101 match event {
102 Some(event) => {
103 handle_event(
104 event,
105 &timer_handle,
106 &mut scheduled_tasks,
107 self.default_timeout,
108 );
109 }
110 None => {
111 return Err(Error::UnexpectedEventStreamEnd);
112 }
113 }
114 }
115 Err(error) => {
116 return Err(store_error(error));
117 }
118 }
119 }
120 timer_entry = ready_entries.recv() => {
121 let timer_entry = timer_entry.unwrap();
122 match timer_entry {
123 TimerEntry::TaskReady(task_id) => {
124 let state = scheduled_tasks.remove(&task_id).unwrap();
125 tracing::trace!(%task_id, "task ready");
126 if !state.cancelled {
127 self.store.task_ready(task_id).await.map_err(store_error)?;
128 }
129 }
130 TimerEntry::TaskTimeout(task_id) => {
131 self.store.task_timed_out(task_id).await.map_err(store_error)?;
132 }
133 }
134 }
135 manager_result = &mut schedule_manager_task => {
136 manager_result?;
137 unreachable!()
138 }
139 }
140 }
141 }
142}
143
144#[tracing::instrument(level = "trace", skip_all)]
145fn handle_event(
146 event: SchedulerTaskStoreEvent,
147 timer: &TimerHandle<TimerEntry>,
148 scheduled_tasks: &mut AHashMap<Uuid, ScheduledTask>,
149 default_timeout: Option<TimeoutPolicy>,
150) {
151 match event {
152 SchedulerTaskStoreEvent::TaskAdded(task) => {
153 if scheduled_tasks.contains_key(&task.id) {
154 tracing::debug!(task_id = %task.id, "task already scheduled");
155 return;
156 }
157 let task_unix = Duration::from_nanos(task.target.0);
158
159 let now = SystemTime::now()
160 .duration_since(UNIX_EPOCH)
161 .expect("system time cannot be before unix epoch");
162
163 let task_delay = task_unix.saturating_sub(now);
164
165 tracing::trace!(task_id = %task.id, "task scheduled");
166 scheduled_tasks.insert(task.id, ScheduledTask::default());
167 timer.schedule(TimerEntry::TaskReady(task.id), task_delay);
168 schedule_timeout(task.into(), timer, default_timeout);
169 }
170 SchedulerTaskStoreEvent::TaskCancelled(task_id) => {
171 if let Some(task) = scheduled_tasks.get_mut(&task_id) {
172 if task.cancelled {
173 tracing::debug!(%task_id, "task already cancelled");
174 }
175 tracing::trace!(%task_id, "task cancelled");
176 task.cancelled = true;
177 } else {
178 tracing::debug!(%task_id, "task was cancelled but it was not scheduled");
179 }
180 }
181 }
182}
183
184#[tracing::instrument(level = "trace", skip_all)]
185fn schedule_timeout(
186 task: ActiveTask,
187 timer: &TimerHandle<TimerEntry>,
188 default_timeout: Option<TimeoutPolicy>,
189) {
190 let mut timeout = task.timeout;
191
192 if let Some(default_timeout) = default_timeout {
193 if matches!(timeout, TimeoutPolicy::Never) {
194 timeout = default_timeout;
195 }
196 }
197
198 match timeout {
199 TimeoutPolicy::Never => {}
200 TimeoutPolicy::FromTarget { timeout } => {
201 let task_unix = Duration::from_nanos(task.target.0);
202
203 let timeout_unix: Duration = match Duration::try_from(timeout) {
204 Ok(t) => t + task_unix,
205 Err(error) => {
206 tracing::warn!(%error, "timeout out of range");
207 return;
208 }
209 };
210
211 let now = SystemTime::now()
212 .duration_since(UNIX_EPOCH)
213 .expect("system time cannot be before unix epoch");
214
215 let timeout_delay = timeout_unix.saturating_sub(now);
216
217 timer.schedule(TimerEntry::TaskTimeout(task.id), timeout_delay);
218 }
219 }
220}
221
222#[derive(Debug, Error)]
224pub enum Error {
225 #[error("unexpected end of event stream")]
227 UnexpectedEventStreamEnd,
228 #[error("store error: {0:?}")]
230 Store(Box<dyn std::error::Error + Send + Sync>),
231}
232
233#[derive(Default)]
234struct ScheduledTask {
235 cancelled: bool,
236}
237
238#[derive(Debug)]
239enum TimerEntry {
240 TaskReady(Uuid),
241 TaskTimeout(Uuid),
242}
243
244struct ScheduleManager<'s, S>
245where
246 S: SchedulerScheduleStore,
247{
248 store: &'s S,
249 active_schedules: HashSet<Uuid>,
250}
251
252impl<'s, S> ScheduleManager<'s, S>
253where
254 S: SchedulerScheduleStore,
255{
256 fn new(store: &'s S) -> Self {
257 Self {
258 store,
259 active_schedules: HashSet::default(),
260 }
261 }
262
263 async fn run(mut self) -> Result<(), Error> {
264 let mut events = pin!(self.store.events().await.map_err(store_error)?);
265
266 let pending_schedules = self.store.pending_schedules().await.map_err(store_error)?;
267
268 for schedule in pending_schedules {
269 self.handle_event(SchedulerScheduleStoreEvent::ScheduleAdded(Box::new(
270 schedule,
271 )))
272 .await?;
273 }
274
275 while let Some(event) = events.try_next().await.map_err(store_error)? {
276 self.handle_event(event).await?;
277 }
278
279 Err(Error::UnexpectedEventStreamEnd)
280 }
281
282 async fn handle_event(&mut self, event: SchedulerScheduleStoreEvent) -> Result<(), Error> {
283 match event {
284 SchedulerScheduleStoreEvent::ScheduleAdded(schedule) => {
285 if self.active_schedules.contains(&schedule.id) {
286 tracing::debug!("active schedule already exists");
287 return Ok(());
288 }
289 self.active_schedules.insert(schedule.id);
290
291 let next_target = next_schedule_task(&schedule.definition, None, UnixNanos::now());
292
293 if let Some(next_target) = next_target {
294 match &schedule.definition.new_task {
295 NewTask::Repeat { task } => {
296 self.store
297 .add_task(schedule.id, task.clone().at_unix(next_target))
298 .await
299 .map_err(store_error)?;
300 }
301 }
302 }
303 }
304 SchedulerScheduleStoreEvent::TaskFinished(task_id) => {
305 if let Some(schedule) = self
306 .store
307 .pending_schedule_of_task(task_id)
308 .await
309 .map_err(store_error)?
310 {
311 self.active_schedules.insert(schedule.id);
312 let prev_target = self.store.task_target(task_id).await.map_err(store_error)?;
313 let next_target = next_schedule_task(
314 &schedule.definition,
315 Some(prev_target),
316 UnixNanos::now(),
317 );
318
319 if let Some(next_target) = next_target {
320 match &schedule.definition.new_task {
321 NewTask::Repeat { task } => {
322 self.store
323 .add_task(schedule.id, task.clone().at_unix(next_target))
324 .await
325 .map_err(store_error)?;
326 }
327 }
328 }
329 }
330 }
331 SchedulerScheduleStoreEvent::ScheduleCancelled(schedule_id) => {
332 self.active_schedules.remove(&schedule_id);
333 }
334 }
335
336 Ok(())
337 }
338}
339
340fn store_error<E: std::error::Error + Send + Sync + 'static>(error: E) -> Error {
341 Error::Store(Box::new(error))
342}