rexecutor_sqlx/
backend.rs

1use std::pin::Pin;
2
3use async_stream::stream;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use futures::Stream;
7use rexecutor::{
8    backend::{Backend, BackendError, EnqueuableJob, ExecutionError, Job, Query},
9    executor::ExecutorIdentifier,
10    job::JobId,
11    pruner::PruneSpec,
12};
13use tokio::sync::mpsc;
14use tracing::instrument;
15
16use crate::{stream::ReadyJobStream, RexecutorPgBackend};
17
18impl RexecutorPgBackend {
19    fn handle_update(result: sqlx::Result<u64>, job_id: JobId) -> Result<(), BackendError> {
20        match result {
21            Ok(0) => Err(BackendError::JobNotFound(job_id)),
22            Ok(1) => Ok(()),
23            Ok(_) => Err(BackendError::BadState),
24            // TODO fix this
25            Err(_error) => Err(BackendError::BadState),
26        }
27    }
28}
29
30#[async_trait]
31impl Backend for RexecutorPgBackend {
32    #[instrument(skip(self))]
33    async fn subscribe_ready_jobs(
34        &self,
35        executor_identifier: ExecutorIdentifier,
36    ) -> Pin<Box<dyn Stream<Item = Result<Job, BackendError>> + Send>> {
37        let (sender, receiver) = mpsc::unbounded_channel();
38        self.subscribers
39            .write()
40            .await
41            .entry(executor_identifier.as_str())
42            .or_default()
43            .push(sender);
44
45        let mut stream: ReadyJobStream = ReadyJobStream {
46            receiver,
47            backend: self.clone(),
48            executor_identifier,
49        };
50        Box::pin(stream! {
51            loop {
52                yield stream.next().await;
53            }
54        })
55    }
56    async fn enqueue<'a>(&self, job: EnqueuableJob<'a>) -> Result<JobId, BackendError> {
57        if job.uniqueness_criteria.is_some() {
58            self.insert_unique_job(job).await
59        } else {
60            self.insert_job(job).await
61        }
62        // TODO error handling
63        .map_err(|_| BackendError::BadState)
64    }
65    async fn mark_job_complete(&self, id: JobId) -> Result<(), BackendError> {
66        let result = self._mark_job_complete(id).await;
67        Self::handle_update(result, id)
68    }
69    async fn mark_job_retryable(
70        &self,
71        id: JobId,
72        next_scheduled_at: DateTime<Utc>,
73        error: ExecutionError,
74    ) -> Result<(), BackendError> {
75        let result = self._mark_job_retryable(id, next_scheduled_at, error).await;
76        Self::handle_update(result, id)
77    }
78    async fn mark_job_discarded(
79        &self,
80        id: JobId,
81        error: ExecutionError,
82    ) -> Result<(), BackendError> {
83        let result = self._mark_job_discarded(id, error).await;
84        Self::handle_update(result, id)
85    }
86    async fn mark_job_cancelled(
87        &self,
88        id: JobId,
89        error: ExecutionError,
90    ) -> Result<(), BackendError> {
91        let result = self._mark_job_cancelled(id, error).await;
92        Self::handle_update(result, id)
93    }
94    async fn mark_job_snoozed(
95        &self,
96        id: JobId,
97        next_scheduled_at: DateTime<Utc>,
98    ) -> Result<(), BackendError> {
99        let result = self._mark_job_snoozed(id, next_scheduled_at).await;
100        Self::handle_update(result, id)
101    }
102    async fn prune_jobs(&self, spec: &PruneSpec) -> Result<(), BackendError> {
103        self.delete_from_spec(spec)
104            .await
105            .map_err(|_| BackendError::BadState)
106    }
107    async fn rerun_job(&self, id: JobId) -> Result<(), BackendError> {
108        let result = self.rerun(id).await;
109        Self::handle_update(result, id)
110    }
111    async fn update_job(&self, job: Job) -> Result<(), BackendError> {
112        let id = job.id.into();
113        let result = self.update(job).await;
114        Self::handle_update(result, id)
115    }
116    async fn query<'a>(&self, query: Query<'a>) -> Result<Vec<Job>, BackendError> {
117        self.run_query(query)
118            .await
119            .map_err(|_| BackendError::BadState)?
120            .into_iter()
121            .map(TryFrom::try_from)
122            .collect()
123    }
124}
125
126#[cfg(test)]
127mod test {
128    use crate::types::Job;
129    use crate::JobStatus;
130
131    use super::*;
132    use chrono::TimeDelta;
133    use rexecutor::job::ErrorType;
134    use serde_json::Value;
135    use sqlx::PgPool;
136
137    impl From<PgPool> for RexecutorPgBackend {
138        fn from(pool: PgPool) -> Self {
139            Self {
140                pool,
141                subscribers: Default::default(),
142            }
143        }
144    }
145
146    struct MockJob<'a>(EnqueuableJob<'a>);
147
148    impl<'a> From<MockJob<'a>> for EnqueuableJob<'a> {
149        fn from(value: MockJob<'a>) -> Self {
150            value.0
151        }
152    }
153
154    impl<'a> Default for MockJob<'a> {
155        fn default() -> Self {
156            Self(EnqueuableJob {
157                executor: "executor".to_owned(),
158                data: Value::String("data".to_owned()),
159                metadata: Value::String("metadata".to_owned()),
160                max_attempts: 5,
161                scheduled_at: Utc::now(),
162                tags: Default::default(),
163                priority: 0,
164                uniqueness_criteria: None,
165            })
166        }
167    }
168
169    impl<'a> MockJob<'a> {
170        const EXECUTOR: &'static str = "executor";
171
172        async fn enqueue(self, backend: impl Backend) -> JobId {
173            backend.enqueue(self.0).await.unwrap()
174        }
175
176        fn with_scheduled_at(self, scheduled_at: DateTime<Utc>) -> Self {
177            Self(EnqueuableJob {
178                scheduled_at,
179                ..self.0
180            })
181        }
182    }
183
184    impl RexecutorPgBackend {
185        async fn all_jobs(&self) -> sqlx::Result<Vec<Job>> {
186            sqlx::query_as!(
187                Job,
188                r#"SELECT
189                    id,
190                    status AS "status: JobStatus",
191                    executor,
192                    data,
193                    metadata,
194                    attempt,
195                    max_attempts,
196                    priority,
197                    tags,
198                    errors,
199                    inserted_at,
200                    scheduled_at,
201                    attempted_at,
202                    completed_at,
203                    cancelled_at,
204                    discarded_at
205                FROM rexecutor_jobs
206                "#
207            )
208            .fetch_all(&self.pool)
209            .await
210        }
211    }
212
213    rexecutor::backend::testing::test_suite!(
214        attr: sqlx::test,
215        args: (pool: PgPool),
216        backend: RexecutorPgBackend::from_pool(pool).await.unwrap()
217    );
218
219    // TODO: add tests for ignoring running, cancelled, complete, and discarded jobs
220
221    #[sqlx::test]
222    async fn load_job_mark_as_executing_for_executor_returns_none_when_db_empty(pool: PgPool) {
223        let backend: RexecutorPgBackend = pool.into();
224
225        let job = backend
226            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
227            .await
228            .unwrap();
229
230        assert!(job.is_none());
231    }
232
233    #[sqlx::test]
234    async fn load_job_mark_as_executing_for_executor_returns_job_when_ready_for_execution(
235        pool: PgPool,
236    ) {
237        let backend: RexecutorPgBackend = pool.into();
238
239        let job_id = MockJob::default().enqueue(&backend).await;
240
241        let job = backend
242            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
243            .await
244            .unwrap()
245            .expect("Should return a job");
246
247        assert_eq!(job.id, job_id);
248        assert_eq!(job.status, JobStatus::Executing);
249    }
250
251    #[sqlx::test]
252    async fn load_job_mark_as_executing_for_executor_does_not_return_executing_jobs(pool: PgPool) {
253        let backend: RexecutorPgBackend = pool.into();
254
255        MockJob::default().enqueue(&backend).await;
256
257        let _ = backend
258            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
259            .await
260            .unwrap()
261            .expect("Should return a job");
262
263        let job = backend
264            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
265            .await
266            .unwrap();
267
268        assert!(job.is_none());
269    }
270
271    #[sqlx::test]
272    async fn load_job_mark_as_executing_for_executor_returns_retryable_jobs(pool: PgPool) {
273        let backend: RexecutorPgBackend = pool.into();
274
275        let job_id = MockJob::default().enqueue(&backend).await;
276        backend
277            .mark_job_retryable(
278                job_id,
279                Utc::now(),
280                ExecutionError {
281                    error_type: ErrorType::Panic,
282                    message: "Oh dear".to_owned(),
283                },
284            )
285            .await
286            .unwrap();
287
288        let job = backend
289            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
290            .await
291            .unwrap()
292            .expect("Should return a job");
293
294        assert_eq!(job.id, job_id);
295        assert_eq!(job.status, JobStatus::Executing);
296    }
297
298    #[sqlx::test]
299    async fn load_job_mark_as_executing_for_executor_returns_job_when_job_scheduled_in_past(
300        pool: PgPool,
301    ) {
302        let backend: RexecutorPgBackend = pool.into();
303        let job_id = MockJob::default()
304            .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
305            .enqueue(&backend)
306            .await;
307
308        let job = backend
309            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
310            .await
311            .unwrap()
312            .expect("Should return a job");
313
314        assert_eq!(job.id, job_id);
315        assert_eq!(job.status, JobStatus::Executing);
316    }
317
318    #[sqlx::test]
319    async fn load_job_mark_as_executing_for_executor_returns_oldest_scheduled_at_executable_job(
320        pool: PgPool,
321    ) {
322        let backend: RexecutorPgBackend = pool.into();
323        let expected_job_id = MockJob::default()
324            .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
325            .enqueue(&backend)
326            .await;
327
328        let _ = MockJob::default().enqueue(&backend).await;
329
330        let job_id = MockJob::default().enqueue(&backend).await;
331        backend.mark_job_complete(job_id).await.unwrap();
332
333        let job_id = MockJob::default().enqueue(&backend).await;
334        backend
335            .mark_job_discarded(
336                job_id,
337                ExecutionError {
338                    error_type: ErrorType::Panic,
339                    message: "Oh dear".to_owned(),
340                },
341            )
342            .await
343            .unwrap();
344
345        let job_id = MockJob::default().enqueue(&backend).await;
346        backend
347            .mark_job_cancelled(
348                job_id,
349                ExecutionError {
350                    error_type: ErrorType::Cancelled,
351                    message: "Not needed".to_owned(),
352                },
353            )
354            .await
355            .unwrap();
356
357        let job = backend
358            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
359            .await
360            .unwrap()
361            .expect("Should return a job");
362
363        assert_eq!(job.id, expected_job_id);
364        assert_eq!(job.status, JobStatus::Executing);
365    }
366
367    #[sqlx::test]
368    async fn enqueue_test(pool: PgPool) {
369        let backend: RexecutorPgBackend = pool.into();
370        let job = MockJob::default();
371
372        let result = backend.enqueue(job.into()).await;
373
374        assert!(result.is_ok());
375
376        let all_jobs = backend.all_jobs().await.unwrap();
377
378        assert_eq!(all_jobs.len(), 1);
379    }
380}