1use std::pin::Pin;
2
3use async_stream::stream;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use futures::Stream;
7use rexecutor::{
8 backend::{Backend, BackendError, EnqueuableJob, ExecutionError, Job, Query},
9 executor::ExecutorIdentifier,
10 job::JobId,
11 pruner::PruneSpec,
12};
13use tokio::sync::mpsc;
14use tracing::instrument;
15
16use crate::{RexecutorPgBackend, map_err, stream::ReadyJobStream};
17
18impl RexecutorPgBackend {
19 fn handle_update(result: sqlx::Result<u64>, job_id: JobId) -> Result<(), BackendError> {
20 match result {
21 Ok(0) => Err(BackendError::JobNotFound(job_id)),
22 Ok(1) => Ok(()),
23 Ok(_) => Err(BackendError::BadState),
24 Err(error) => Err(map_err(error)),
25 }
26 }
27}
28
29#[async_trait]
30impl Backend for RexecutorPgBackend {
31 #[instrument(skip(self))]
32 async fn subscribe_ready_jobs(
33 &self,
34 executor_identifier: ExecutorIdentifier,
35 ) -> Pin<Box<dyn Stream<Item = Result<Job, BackendError>> + Send>> {
36 let (sender, receiver) = mpsc::unbounded_channel();
37 self.subscribers
38 .write()
39 .await
40 .entry(executor_identifier.as_str())
41 .or_default()
42 .push(sender);
43
44 let mut stream: ReadyJobStream = ReadyJobStream {
45 receiver,
46 backend: self.clone(),
47 executor_identifier,
48 };
49 Box::pin(stream! {
50 loop {
51 yield stream.next().await;
52 }
53 })
54 }
55 async fn enqueue<'a>(&self, job: EnqueuableJob<'a>) -> Result<JobId, BackendError> {
56 if job.uniqueness_criteria.is_some() {
57 self.insert_unique_job(job).await
58 } else {
59 self.insert_job(job).await
60 }
61 .map_err(map_err)
62 }
63 async fn mark_job_complete(&self, id: JobId) -> Result<(), BackendError> {
64 let result = self._mark_job_complete(id).await;
65 Self::handle_update(result, id)
66 }
67 async fn mark_job_retryable(
68 &self,
69 id: JobId,
70 next_scheduled_at: DateTime<Utc>,
71 error: ExecutionError,
72 ) -> Result<(), BackendError> {
73 let result = self._mark_job_retryable(id, next_scheduled_at, error).await;
74 Self::handle_update(result, id)
75 }
76 async fn mark_job_discarded(
77 &self,
78 id: JobId,
79 error: ExecutionError,
80 ) -> Result<(), BackendError> {
81 let result = self._mark_job_discarded(id, error).await;
82 Self::handle_update(result, id)
83 }
84 async fn mark_job_cancelled(
85 &self,
86 id: JobId,
87 error: ExecutionError,
88 ) -> Result<(), BackendError> {
89 let result = self._mark_job_cancelled(id, error).await;
90 Self::handle_update(result, id)
91 }
92 async fn mark_job_snoozed(
93 &self,
94 id: JobId,
95 next_scheduled_at: DateTime<Utc>,
96 ) -> Result<(), BackendError> {
97 let result = self._mark_job_snoozed(id, next_scheduled_at).await;
98 Self::handle_update(result, id)
99 }
100 async fn prune_jobs(&self, spec: &PruneSpec) -> Result<(), BackendError> {
101 self.delete_from_spec(spec).await.map_err(map_err)
102 }
103 async fn rerun_job(&self, id: JobId) -> Result<(), BackendError> {
104 let result = self.rerun(id).await;
105 Self::handle_update(result, id)
106 }
107 async fn update_job(&self, job: Job) -> Result<(), BackendError> {
108 let id = job.id.into();
109 let result = self.update(job).await;
110 Self::handle_update(result, id)
111 }
112 async fn query<'a>(&self, query: Query<'a>) -> Result<Vec<Job>, BackendError> {
113 self.run_query(query)
114 .await
115 .map_err(map_err)?
116 .into_iter()
117 .map(TryFrom::try_from)
118 .collect()
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use crate::JobStatus;
125 use crate::types::Job;
126
127 use super::*;
128 use chrono::TimeDelta;
129 use rexecutor::job::ErrorType;
130 use serde_json::Value;
131 use sqlx::PgPool;
132
133 impl From<PgPool> for RexecutorPgBackend {
134 fn from(pool: PgPool) -> Self {
135 Self {
136 pool,
137 subscribers: Default::default(),
138 }
139 }
140 }
141
142 struct MockJob<'a>(EnqueuableJob<'a>);
143
144 impl<'a> From<MockJob<'a>> for EnqueuableJob<'a> {
145 fn from(value: MockJob<'a>) -> Self {
146 value.0
147 }
148 }
149
150 impl<'a> Default for MockJob<'a> {
151 fn default() -> Self {
152 Self(EnqueuableJob {
153 executor: "executor".to_owned(),
154 data: Value::String("data".to_owned()),
155 metadata: Value::String("metadata".to_owned()),
156 max_attempts: 5,
157 scheduled_at: Utc::now(),
158 tags: Default::default(),
159 priority: 0,
160 uniqueness_criteria: None,
161 })
162 }
163 }
164
165 impl<'a> MockJob<'a> {
166 const EXECUTOR: &'static str = "executor";
167
168 async fn enqueue(self, backend: impl Backend) -> JobId {
169 backend.enqueue(self.0).await.unwrap()
170 }
171
172 fn with_scheduled_at(self, scheduled_at: DateTime<Utc>) -> Self {
173 Self(EnqueuableJob {
174 scheduled_at,
175 ..self.0
176 })
177 }
178 }
179
180 impl RexecutorPgBackend {
181 async fn all_jobs(&self) -> sqlx::Result<Vec<Job>> {
182 sqlx::query_as!(
183 Job,
184 r#"SELECT
185 id,
186 status AS "status: JobStatus",
187 executor,
188 data,
189 metadata,
190 attempt,
191 max_attempts,
192 priority,
193 tags,
194 errors,
195 inserted_at,
196 scheduled_at,
197 attempted_at,
198 completed_at,
199 cancelled_at,
200 discarded_at
201 FROM rexecutor_jobs
202 "#
203 )
204 .fetch_all(&self.pool)
205 .await
206 }
207 }
208
209 rexecutor::backend::testing::test_suite!(
210 attr: sqlx::test,
211 args: (pool: PgPool),
212 backend: RexecutorPgBackend::from_pool(pool).await.unwrap()
213 );
214
215 #[sqlx::test]
216 async fn load_job_mark_as_executing_for_executor_returns_none_when_db_empty(pool: PgPool) {
217 let backend: RexecutorPgBackend = pool.into();
218
219 let job = backend
220 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
221 .await
222 .unwrap();
223
224 assert!(job.is_none());
225 }
226
227 #[sqlx::test]
228 async fn load_job_mark_as_executing_for_executor_returns_job_when_ready_for_execution(
229 pool: PgPool,
230 ) {
231 let backend: RexecutorPgBackend = pool.into();
232
233 let job_id = MockJob::default().enqueue(&backend).await;
234
235 let job = backend
236 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
237 .await
238 .unwrap()
239 .expect("Should return a job");
240
241 assert_eq!(job.id, job_id);
242 assert_eq!(job.status, JobStatus::Executing);
243 }
244
245 #[sqlx::test]
246 async fn load_job_mark_as_executing_for_executor_does_not_return_executing_jobs(pool: PgPool) {
247 let backend: RexecutorPgBackend = pool.into();
248
249 MockJob::default().enqueue(&backend).await;
250
251 let _ = backend
252 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
253 .await
254 .unwrap()
255 .expect("Should return a job");
256
257 let job = backend
258 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
259 .await
260 .unwrap();
261
262 assert!(job.is_none());
263 }
264
265 #[sqlx::test]
266 async fn load_job_mark_as_executing_for_executor_returns_retryable_jobs(pool: PgPool) {
267 let backend: RexecutorPgBackend = pool.into();
268
269 let job_id = MockJob::default().enqueue(&backend).await;
270 backend
271 .mark_job_retryable(
272 job_id,
273 Utc::now(),
274 ExecutionError {
275 error_type: ErrorType::Panic,
276 message: "Oh dear".to_owned(),
277 },
278 )
279 .await
280 .unwrap();
281
282 let job = backend
283 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
284 .await
285 .unwrap()
286 .expect("Should return a job");
287
288 assert_eq!(job.id, job_id);
289 assert_eq!(job.status, JobStatus::Executing);
290 }
291
292 #[sqlx::test]
293 async fn load_job_mark_as_executing_for_executor_returns_job_when_job_scheduled_in_past(
294 pool: PgPool,
295 ) {
296 let backend: RexecutorPgBackend = pool.into();
297 let job_id = MockJob::default()
298 .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
299 .enqueue(&backend)
300 .await;
301
302 let job = backend
303 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
304 .await
305 .unwrap()
306 .expect("Should return a job");
307
308 assert_eq!(job.id, job_id);
309 assert_eq!(job.status, JobStatus::Executing);
310 }
311
312 #[sqlx::test]
313 async fn load_job_mark_as_executing_for_executor_returns_oldest_scheduled_at_executable_job(
314 pool: PgPool,
315 ) {
316 let backend: RexecutorPgBackend = pool.into();
317 let expected_job_id = MockJob::default()
318 .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
319 .enqueue(&backend)
320 .await;
321
322 let _ = MockJob::default().enqueue(&backend).await;
323
324 let job_id = MockJob::default().enqueue(&backend).await;
325 backend.mark_job_complete(job_id).await.unwrap();
326
327 let job_id = MockJob::default().enqueue(&backend).await;
328 backend
329 .mark_job_discarded(
330 job_id,
331 ExecutionError {
332 error_type: ErrorType::Panic,
333 message: "Oh dear".to_owned(),
334 },
335 )
336 .await
337 .unwrap();
338
339 let job_id = MockJob::default().enqueue(&backend).await;
340 backend
341 .mark_job_cancelled(
342 job_id,
343 ExecutionError {
344 error_type: ErrorType::Cancelled,
345 message: "Not needed".to_owned(),
346 },
347 )
348 .await
349 .unwrap();
350
351 let job = backend
352 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
353 .await
354 .unwrap()
355 .expect("Should return a job");
356
357 assert_eq!(job.id, expected_job_id);
358 assert_eq!(job.status, JobStatus::Executing);
359 }
360
361 #[sqlx::test]
362 async fn enqueue_test(pool: PgPool) {
363 let backend: RexecutorPgBackend = pool.into();
364 let job = MockJob::default();
365
366 let result = backend.enqueue(job.into()).await;
367
368 assert!(result.is_ok());
369
370 let all_jobs = backend.all_jobs().await.unwrap();
371
372 assert_eq!(all_jobs.len(), 1);
373 }
374}