Skip to main content

scheduler/model/
task.rs

1use crate::model::{MissedRunPolicy, OverlapPolicy, Schedule};
2use chrono::{DateTime, Utc};
3use chrono_tz::Tz;
4use std::any::type_name;
5use std::future::{Future, ready};
6use std::panic::resume_unwind;
7use std::pin::Pin;
8use std::sync::Arc;
9
10/// The task return type used by scheduled jobs.
11pub type JobResult = Result<(), String>;
12/// The boxed future returned by a scheduled job.
13pub type JobFuture = Pin<Box<dyn Future<Output = JobResult> + Send>>;
14pub(crate) type TaskHandler<D> = Arc<dyn Fn(TaskContext<D>) -> JobFuture + Send + Sync>;
15
16#[derive(Clone)]
17pub struct Task<D> {
18    pub(crate) handler: TaskHandler<D>,
19}
20
21impl<D> Task<D>
22where
23    D: Send + Sync + 'static,
24{
25    fn from_handler(handler: TaskHandler<D>) -> Self {
26        Self { handler }
27    }
28
29    /// Create an async task from the full [`TaskContext`].
30    pub fn from_async<F, Fut>(task: F) -> Self
31    where
32        F: Fn(TaskContext<D>) -> Fut + Send + Sync + 'static,
33        Fut: Future<Output = JobResult> + Send + 'static,
34    {
35        Self::from_handler(wrap_async_handler(Arc::new(task)))
36    }
37
38    /// Create a lightweight synchronous task from the full [`TaskContext`].
39    pub fn from_sync<F>(task: F) -> Self
40    where
41        F: Fn(TaskContext<D>) -> JobResult + Send + Sync + 'static,
42    {
43        Self::from_handler(wrap_sync_handler(Arc::new(task)))
44    }
45
46    /// Create a blocking synchronous task from the full [`TaskContext`].
47    pub fn from_blocking<F>(task: F) -> Self
48    where
49        F: Fn(TaskContext<D>) -> JobResult + Send + Sync + 'static,
50    {
51        Self::from_handler(wrap_blocking_handler(Arc::new(task)))
52    }
53}
54
55#[derive(Clone)]
56pub struct Job<D = ()> {
57    pub job_id: String,
58    pub schedule: Schedule,
59    pub max_runs: Option<u32>,
60    pub missed_run_policy: MissedRunPolicy,
61    pub overlap_policy: OverlapPolicy,
62    pub(crate) task: TaskHandler<D>,
63    pub(crate) deps: Arc<D>,
64}
65
66impl Job<()> {
67    /// Create a job that uses no injected dependencies.
68    pub fn without_deps(job_id: impl Into<String>, schedule: Schedule, task: Task<()>) -> Self {
69        Self::from_parts(job_id.into(), schedule, Arc::new(()), task)
70    }
71}
72
73impl<D> Job<D>
74where
75    D: Send + Sync + 'static,
76{
77    /// Create a job from explicit dependencies and a task handler.
78    pub fn new(
79        job_id: impl Into<String>,
80        schedule: Schedule,
81        deps: impl Into<Arc<D>>,
82        task: Task<D>,
83    ) -> Self {
84        Self::from_parts(job_id.into(), schedule, deps.into(), task)
85    }
86}
87
88impl<D> Job<D> {
89    fn default_policies() -> (MissedRunPolicy, OverlapPolicy) {
90        (MissedRunPolicy::CatchUpOnce, OverlapPolicy::Forbid)
91    }
92
93    fn from_parts(job_id: String, schedule: Schedule, deps: Arc<D>, task: Task<D>) -> Self {
94        let (missed_run_policy, overlap_policy) = Self::default_policies();
95        Self {
96            job_id,
97            schedule,
98            max_runs: None,
99            missed_run_policy,
100            overlap_policy,
101            task: task.handler,
102            deps,
103        }
104    }
105
106    /// Limit how many triggers this job can consume before it exits.
107    ///
108    /// This applies to [`Schedule::Interval`], [`Schedule::AtTimes`], and
109    /// [`Schedule::Cron`].
110    /// A value of `0` makes the job exit immediately without running.
111    pub fn with_max_runs(mut self, max_runs: u32) -> Self {
112        self.max_runs = Some(max_runs);
113        self
114    }
115
116    pub fn with_missed_run_policy(mut self, policy: MissedRunPolicy) -> Self {
117        self.missed_run_policy = policy;
118        self
119    }
120
121    pub fn with_overlap_policy(mut self, policy: OverlapPolicy) -> Self {
122        self.overlap_policy = policy;
123        self
124    }
125}
126
127impl<D> std::fmt::Debug for Job<D> {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("Job")
130            .field("job_id", &self.job_id)
131            .field("schedule", &self.schedule)
132            .field("max_runs", &self.max_runs)
133            .field("missed_run_policy", &self.missed_run_policy)
134            .field("overlap_policy", &self.overlap_policy)
135            .field("deps", &type_name::<D>())
136            .finish_non_exhaustive()
137    }
138}
139
140fn wrap_async_handler<D, F, Fut>(task: Arc<F>) -> TaskHandler<D>
141where
142    D: Send + Sync + 'static,
143    F: Fn(TaskContext<D>) -> Fut + Send + Sync + 'static,
144    Fut: Future<Output = JobResult> + Send + 'static,
145{
146    Arc::new(move |context| Box::pin((*task)(context)))
147}
148
149fn wrap_sync_handler<D, F>(task: Arc<F>) -> TaskHandler<D>
150where
151    D: Send + Sync + 'static,
152    F: Fn(TaskContext<D>) -> JobResult + Send + Sync + 'static,
153{
154    Arc::new(move |context| Box::pin(ready((*task)(context))))
155}
156
157fn wrap_blocking_handler<D, F>(task: Arc<F>) -> TaskHandler<D>
158where
159    D: Send + Sync + 'static,
160    F: Fn(TaskContext<D>) -> JobResult + Send + Sync + 'static,
161{
162    Arc::new(move |context| {
163        let task = task.clone();
164        Box::pin(async move { await_blocking(move || (*task)(context)).await })
165    })
166}
167
168#[derive(Debug, Clone)]
169pub struct RunContext {
170    pub job_id: String,
171    pub scheduled_at: DateTime<Utc>,
172    pub catch_up: bool,
173    /// The scheduler-configured timezone for downstream task logic.
174    pub timezone: Tz,
175}
176
177#[derive(Clone)]
178pub struct TaskContext<D> {
179    pub run: RunContext,
180    pub deps: Arc<D>,
181}
182
183impl<D> std::fmt::Debug for TaskContext<D> {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.debug_struct("TaskContext")
186            .field("run", &self.run)
187            .field("deps", &type_name::<D>())
188            .finish()
189    }
190}
191
192async fn await_blocking<F>(task: F) -> JobResult
193where
194    F: FnOnce() -> JobResult + Send + 'static,
195{
196    match tokio::task::spawn_blocking(task).await {
197        Ok(result) => result,
198        Err(error) if error.is_panic() => resume_unwind(error.into_panic()),
199        Err(error) => panic!("blocking task failed to join: {error}"),
200    }
201}