rexecutor_sqlx/
backend.rs

1use std::pin::Pin;
2
3use async_stream::stream;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use futures::Stream;
7use rexecutor::{
8    backend::{Backend, BackendError, EnqueuableJob, ExecutionError, Job, Query},
9    executor::ExecutorIdentifier,
10    job::JobId,
11    pruner::PruneSpec,
12};
13use tokio::sync::mpsc;
14use tracing::instrument;
15
16use crate::{RexecutorPgBackend, map_err, stream::ReadyJobStream};
17
18impl RexecutorPgBackend {
19    fn handle_update(result: sqlx::Result<u64>, job_id: JobId) -> Result<(), BackendError> {
20        match result {
21            Ok(0) => Err(BackendError::JobNotFound(job_id)),
22            Ok(1) => Ok(()),
23            Ok(_) => Err(BackendError::BadState),
24            Err(error) => Err(map_err(error)),
25        }
26    }
27}
28
29#[async_trait]
30impl Backend for RexecutorPgBackend {
31    #[instrument(skip(self))]
32    async fn subscribe_ready_jobs(
33        &self,
34        executor_identifier: ExecutorIdentifier,
35    ) -> Pin<Box<dyn Stream<Item = Result<Job, BackendError>> + Send>> {
36        let (sender, receiver) = mpsc::unbounded_channel();
37        self.subscribers
38            .write()
39            .await
40            .entry(executor_identifier.as_str())
41            .or_default()
42            .push(sender);
43
44        let mut stream: ReadyJobStream = ReadyJobStream {
45            receiver,
46            backend: self.clone(),
47            executor_identifier,
48        };
49        Box::pin(stream! {
50            loop {
51                yield stream.next().await;
52            }
53        })
54    }
55    async fn enqueue<'a>(&self, job: EnqueuableJob<'a>) -> Result<JobId, BackendError> {
56        if job.uniqueness_criteria.is_some() {
57            self.insert_unique_job(job).await
58        } else {
59            self.insert_job(job).await
60        }
61        .map_err(map_err)
62    }
63    async fn mark_job_complete(&self, id: JobId) -> Result<(), BackendError> {
64        let result = self._mark_job_complete(id).await;
65        Self::handle_update(result, id)
66    }
67    async fn mark_job_retryable(
68        &self,
69        id: JobId,
70        next_scheduled_at: DateTime<Utc>,
71        error: ExecutionError,
72    ) -> Result<(), BackendError> {
73        let result = self._mark_job_retryable(id, next_scheduled_at, error).await;
74        Self::handle_update(result, id)
75    }
76    async fn mark_job_discarded(
77        &self,
78        id: JobId,
79        error: ExecutionError,
80    ) -> Result<(), BackendError> {
81        let result = self._mark_job_discarded(id, error).await;
82        Self::handle_update(result, id)
83    }
84    async fn mark_job_cancelled(
85        &self,
86        id: JobId,
87        error: ExecutionError,
88    ) -> Result<(), BackendError> {
89        let result = self._mark_job_cancelled(id, error).await;
90        Self::handle_update(result, id)
91    }
92    async fn mark_job_snoozed(
93        &self,
94        id: JobId,
95        next_scheduled_at: DateTime<Utc>,
96    ) -> Result<(), BackendError> {
97        let result = self._mark_job_snoozed(id, next_scheduled_at).await;
98        Self::handle_update(result, id)
99    }
100    async fn prune_jobs(&self, spec: &PruneSpec) -> Result<(), BackendError> {
101        self.delete_from_spec(spec).await.map_err(map_err)
102    }
103    async fn rerun_job(&self, id: JobId) -> Result<(), BackendError> {
104        let result = self.rerun(id).await;
105        Self::handle_update(result, id)
106    }
107    async fn update_job(&self, job: Job) -> Result<(), BackendError> {
108        let id = job.id.into();
109        let result = self.update(job).await;
110        Self::handle_update(result, id)
111    }
112    async fn query<'a>(&self, query: Query<'a>) -> Result<Vec<Job>, BackendError> {
113        self.run_query(query)
114            .await
115            .map_err(map_err)?
116            .into_iter()
117            .map(TryFrom::try_from)
118            .collect()
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use crate::JobStatus;
125    use crate::types::Job;
126
127    use super::*;
128    use chrono::TimeDelta;
129    use rexecutor::job::ErrorType;
130    use serde_json::Value;
131    use sqlx::PgPool;
132
133    impl From<PgPool> for RexecutorPgBackend {
134        fn from(pool: PgPool) -> Self {
135            Self {
136                pool,
137                subscribers: Default::default(),
138            }
139        }
140    }
141
142    struct MockJob<'a>(EnqueuableJob<'a>);
143
144    impl<'a> From<MockJob<'a>> for EnqueuableJob<'a> {
145        fn from(value: MockJob<'a>) -> Self {
146            value.0
147        }
148    }
149
150    impl<'a> Default for MockJob<'a> {
151        fn default() -> Self {
152            Self(EnqueuableJob {
153                executor: "executor".to_owned(),
154                data: Value::String("data".to_owned()),
155                metadata: Value::String("metadata".to_owned()),
156                max_attempts: 5,
157                scheduled_at: Utc::now(),
158                tags: Default::default(),
159                priority: 0,
160                uniqueness_criteria: None,
161            })
162        }
163    }
164
165    impl<'a> MockJob<'a> {
166        const EXECUTOR: &'static str = "executor";
167
168        async fn enqueue(self, backend: impl Backend) -> JobId {
169            backend.enqueue(self.0).await.unwrap()
170        }
171
172        fn with_scheduled_at(self, scheduled_at: DateTime<Utc>) -> Self {
173            Self(EnqueuableJob {
174                scheduled_at,
175                ..self.0
176            })
177        }
178    }
179
180    impl RexecutorPgBackend {
181        async fn all_jobs(&self) -> sqlx::Result<Vec<Job>> {
182            sqlx::query_as!(
183                Job,
184                r#"SELECT
185                    id,
186                    status AS "status: JobStatus",
187                    executor,
188                    data,
189                    metadata,
190                    attempt,
191                    max_attempts,
192                    priority,
193                    tags,
194                    errors,
195                    inserted_at,
196                    scheduled_at,
197                    attempted_at,
198                    completed_at,
199                    cancelled_at,
200                    discarded_at
201                FROM rexecutor_jobs
202                "#
203            )
204            .fetch_all(&self.pool)
205            .await
206        }
207    }
208
209    rexecutor::backend::testing::test_suite!(
210        attr: sqlx::test,
211        args: (pool: PgPool),
212        backend: RexecutorPgBackend::from_pool(pool).await.unwrap()
213    );
214
215    #[sqlx::test]
216    async fn load_job_mark_as_executing_for_executor_returns_none_when_db_empty(pool: PgPool) {
217        let backend: RexecutorPgBackend = pool.into();
218
219        let job = backend
220            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
221            .await
222            .unwrap();
223
224        assert!(job.is_none());
225    }
226
227    #[sqlx::test]
228    async fn load_job_mark_as_executing_for_executor_returns_job_when_ready_for_execution(
229        pool: PgPool,
230    ) {
231        let backend: RexecutorPgBackend = pool.into();
232
233        let job_id = MockJob::default().enqueue(&backend).await;
234
235        let job = backend
236            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
237            .await
238            .unwrap()
239            .expect("Should return a job");
240
241        assert_eq!(job.id, job_id);
242        assert_eq!(job.status, JobStatus::Executing);
243    }
244
245    #[sqlx::test]
246    async fn load_job_mark_as_executing_for_executor_does_not_return_executing_jobs(pool: PgPool) {
247        let backend: RexecutorPgBackend = pool.into();
248
249        MockJob::default().enqueue(&backend).await;
250
251        let _ = backend
252            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
253            .await
254            .unwrap()
255            .expect("Should return a job");
256
257        let job = backend
258            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
259            .await
260            .unwrap();
261
262        assert!(job.is_none());
263    }
264
265    #[sqlx::test]
266    async fn load_job_mark_as_executing_for_executor_returns_retryable_jobs(pool: PgPool) {
267        let backend: RexecutorPgBackend = pool.into();
268
269        let job_id = MockJob::default().enqueue(&backend).await;
270        backend
271            .mark_job_retryable(
272                job_id,
273                Utc::now(),
274                ExecutionError {
275                    error_type: ErrorType::Panic,
276                    message: "Oh dear".to_owned(),
277                },
278            )
279            .await
280            .unwrap();
281
282        let job = backend
283            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
284            .await
285            .unwrap()
286            .expect("Should return a job");
287
288        assert_eq!(job.id, job_id);
289        assert_eq!(job.status, JobStatus::Executing);
290    }
291
292    #[sqlx::test]
293    async fn load_job_mark_as_executing_for_executor_returns_job_when_job_scheduled_in_past(
294        pool: PgPool,
295    ) {
296        let backend: RexecutorPgBackend = pool.into();
297        let job_id = MockJob::default()
298            .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
299            .enqueue(&backend)
300            .await;
301
302        let job = backend
303            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
304            .await
305            .unwrap()
306            .expect("Should return a job");
307
308        assert_eq!(job.id, job_id);
309        assert_eq!(job.status, JobStatus::Executing);
310    }
311
312    #[sqlx::test]
313    async fn load_job_mark_as_executing_for_executor_returns_oldest_scheduled_at_executable_job(
314        pool: PgPool,
315    ) {
316        let backend: RexecutorPgBackend = pool.into();
317        let expected_job_id = MockJob::default()
318            .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
319            .enqueue(&backend)
320            .await;
321
322        let _ = MockJob::default().enqueue(&backend).await;
323
324        let job_id = MockJob::default().enqueue(&backend).await;
325        backend.mark_job_complete(job_id).await.unwrap();
326
327        let job_id = MockJob::default().enqueue(&backend).await;
328        backend
329            .mark_job_discarded(
330                job_id,
331                ExecutionError {
332                    error_type: ErrorType::Panic,
333                    message: "Oh dear".to_owned(),
334                },
335            )
336            .await
337            .unwrap();
338
339        let job_id = MockJob::default().enqueue(&backend).await;
340        backend
341            .mark_job_cancelled(
342                job_id,
343                ExecutionError {
344                    error_type: ErrorType::Cancelled,
345                    message: "Not needed".to_owned(),
346                },
347            )
348            .await
349            .unwrap();
350
351        let job = backend
352            .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
353            .await
354            .unwrap()
355            .expect("Should return a job");
356
357        assert_eq!(job.id, expected_job_id);
358        assert_eq!(job.status, JobStatus::Executing);
359    }
360
361    #[sqlx::test]
362    async fn enqueue_test(pool: PgPool) {
363        let backend: RexecutorPgBackend = pool.into();
364        let job = MockJob::default();
365
366        let result = backend.enqueue(job.into()).await;
367
368        assert!(result.is_ok());
369
370        let all_jobs = backend.all_jobs().await.unwrap();
371
372        assert_eq!(all_jobs.len(), 1);
373    }
374}