1use std::pin::Pin;
2
3use async_stream::stream;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use futures::Stream;
7use rexecutor::{
8 backend::{Backend, BackendError, EnqueuableJob, ExecutionError, Job, Query},
9 executor::ExecutorIdentifier,
10 job::JobId,
11 pruner::PruneSpec,
12};
13use tokio::sync::mpsc;
14use tracing::instrument;
15
16use crate::{stream::ReadyJobStream, RexecutorPgBackend};
17
18impl RexecutorPgBackend {
19 fn handle_update(result: sqlx::Result<u64>, job_id: JobId) -> Result<(), BackendError> {
20 match result {
21 Ok(0) => Err(BackendError::JobNotFound(job_id)),
22 Ok(1) => Ok(()),
23 Ok(_) => Err(BackendError::BadState),
24 Err(_error) => Err(BackendError::BadState),
26 }
27 }
28}
29
30#[async_trait]
31impl Backend for RexecutorPgBackend {
32 #[instrument(skip(self))]
33 async fn subscribe_ready_jobs(
34 &self,
35 executor_identifier: ExecutorIdentifier,
36 ) -> Pin<Box<dyn Stream<Item = Result<Job, BackendError>> + Send>> {
37 let (sender, receiver) = mpsc::unbounded_channel();
38 self.subscribers
39 .write()
40 .await
41 .entry(executor_identifier.as_str())
42 .or_default()
43 .push(sender);
44
45 let mut stream: ReadyJobStream = ReadyJobStream {
46 receiver,
47 backend: self.clone(),
48 executor_identifier,
49 };
50 Box::pin(stream! {
51 loop {
52 yield stream.next().await;
53 }
54 })
55 }
56 async fn enqueue<'a>(&self, job: EnqueuableJob<'a>) -> Result<JobId, BackendError> {
57 if job.uniqueness_criteria.is_some() {
58 self.insert_unique_job(job).await
59 } else {
60 self.insert_job(job).await
61 }
62 .map_err(|_| BackendError::BadState)
64 }
65 async fn mark_job_complete(&self, id: JobId) -> Result<(), BackendError> {
66 let result = self._mark_job_complete(id).await;
67 Self::handle_update(result, id)
68 }
69 async fn mark_job_retryable(
70 &self,
71 id: JobId,
72 next_scheduled_at: DateTime<Utc>,
73 error: ExecutionError,
74 ) -> Result<(), BackendError> {
75 let result = self._mark_job_retryable(id, next_scheduled_at, error).await;
76 Self::handle_update(result, id)
77 }
78 async fn mark_job_discarded(
79 &self,
80 id: JobId,
81 error: ExecutionError,
82 ) -> Result<(), BackendError> {
83 let result = self._mark_job_discarded(id, error).await;
84 Self::handle_update(result, id)
85 }
86 async fn mark_job_cancelled(
87 &self,
88 id: JobId,
89 error: ExecutionError,
90 ) -> Result<(), BackendError> {
91 let result = self._mark_job_cancelled(id, error).await;
92 Self::handle_update(result, id)
93 }
94 async fn mark_job_snoozed(
95 &self,
96 id: JobId,
97 next_scheduled_at: DateTime<Utc>,
98 ) -> Result<(), BackendError> {
99 let result = self._mark_job_snoozed(id, next_scheduled_at).await;
100 Self::handle_update(result, id)
101 }
102 async fn prune_jobs(&self, spec: &PruneSpec) -> Result<(), BackendError> {
103 self.delete_from_spec(spec)
104 .await
105 .map_err(|_| BackendError::BadState)
106 }
107 async fn rerun_job(&self, id: JobId) -> Result<(), BackendError> {
108 let result = self.rerun(id).await;
109 Self::handle_update(result, id)
110 }
111 async fn update_job(&self, job: Job) -> Result<(), BackendError> {
112 let id = job.id.into();
113 let result = self.update(job).await;
114 Self::handle_update(result, id)
115 }
116 async fn query<'a>(&self, query: Query<'a>) -> Result<Vec<Job>, BackendError> {
117 self.run_query(query)
118 .await
119 .map_err(|_| BackendError::BadState)?
120 .into_iter()
121 .map(TryFrom::try_from)
122 .collect()
123 }
124}
125
126#[cfg(test)]
127mod test {
128 use crate::types::Job;
129 use crate::JobStatus;
130
131 use super::*;
132 use chrono::TimeDelta;
133 use rexecutor::job::ErrorType;
134 use serde_json::Value;
135 use sqlx::PgPool;
136
137 impl From<PgPool> for RexecutorPgBackend {
138 fn from(pool: PgPool) -> Self {
139 Self {
140 pool,
141 subscribers: Default::default(),
142 }
143 }
144 }
145
146 struct MockJob<'a>(EnqueuableJob<'a>);
147
148 impl<'a> From<MockJob<'a>> for EnqueuableJob<'a> {
149 fn from(value: MockJob<'a>) -> Self {
150 value.0
151 }
152 }
153
154 impl<'a> Default for MockJob<'a> {
155 fn default() -> Self {
156 Self(EnqueuableJob {
157 executor: "executor".to_owned(),
158 data: Value::String("data".to_owned()),
159 metadata: Value::String("metadata".to_owned()),
160 max_attempts: 5,
161 scheduled_at: Utc::now(),
162 tags: Default::default(),
163 priority: 0,
164 uniqueness_criteria: None,
165 })
166 }
167 }
168
169 impl<'a> MockJob<'a> {
170 const EXECUTOR: &'static str = "executor";
171
172 async fn enqueue(self, backend: impl Backend) -> JobId {
173 backend.enqueue(self.0).await.unwrap()
174 }
175
176 fn with_scheduled_at(self, scheduled_at: DateTime<Utc>) -> Self {
177 Self(EnqueuableJob {
178 scheduled_at,
179 ..self.0
180 })
181 }
182 }
183
184 impl RexecutorPgBackend {
185 async fn all_jobs(&self) -> sqlx::Result<Vec<Job>> {
186 sqlx::query_as!(
187 Job,
188 r#"SELECT
189 id,
190 status AS "status: JobStatus",
191 executor,
192 data,
193 metadata,
194 attempt,
195 max_attempts,
196 priority,
197 tags,
198 errors,
199 inserted_at,
200 scheduled_at,
201 attempted_at,
202 completed_at,
203 cancelled_at,
204 discarded_at
205 FROM rexecutor_jobs
206 "#
207 )
208 .fetch_all(&self.pool)
209 .await
210 }
211 }
212
213 rexecutor::backend::testing::test_suite!(
214 attr: sqlx::test,
215 args: (pool: PgPool),
216 backend: RexecutorPgBackend::from_pool(pool).await.unwrap()
217 );
218
219 #[sqlx::test]
222 async fn load_job_mark_as_executing_for_executor_returns_none_when_db_empty(pool: PgPool) {
223 let backend: RexecutorPgBackend = pool.into();
224
225 let job = backend
226 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
227 .await
228 .unwrap();
229
230 assert!(job.is_none());
231 }
232
233 #[sqlx::test]
234 async fn load_job_mark_as_executing_for_executor_returns_job_when_ready_for_execution(
235 pool: PgPool,
236 ) {
237 let backend: RexecutorPgBackend = pool.into();
238
239 let job_id = MockJob::default().enqueue(&backend).await;
240
241 let job = backend
242 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
243 .await
244 .unwrap()
245 .expect("Should return a job");
246
247 assert_eq!(job.id, job_id);
248 assert_eq!(job.status, JobStatus::Executing);
249 }
250
251 #[sqlx::test]
252 async fn load_job_mark_as_executing_for_executor_does_not_return_executing_jobs(pool: PgPool) {
253 let backend: RexecutorPgBackend = pool.into();
254
255 MockJob::default().enqueue(&backend).await;
256
257 let _ = backend
258 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
259 .await
260 .unwrap()
261 .expect("Should return a job");
262
263 let job = backend
264 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
265 .await
266 .unwrap();
267
268 assert!(job.is_none());
269 }
270
271 #[sqlx::test]
272 async fn load_job_mark_as_executing_for_executor_returns_retryable_jobs(pool: PgPool) {
273 let backend: RexecutorPgBackend = pool.into();
274
275 let job_id = MockJob::default().enqueue(&backend).await;
276 backend
277 .mark_job_retryable(
278 job_id,
279 Utc::now(),
280 ExecutionError {
281 error_type: ErrorType::Panic,
282 message: "Oh dear".to_owned(),
283 },
284 )
285 .await
286 .unwrap();
287
288 let job = backend
289 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
290 .await
291 .unwrap()
292 .expect("Should return a job");
293
294 assert_eq!(job.id, job_id);
295 assert_eq!(job.status, JobStatus::Executing);
296 }
297
298 #[sqlx::test]
299 async fn load_job_mark_as_executing_for_executor_returns_job_when_job_scheduled_in_past(
300 pool: PgPool,
301 ) {
302 let backend: RexecutorPgBackend = pool.into();
303 let job_id = MockJob::default()
304 .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
305 .enqueue(&backend)
306 .await;
307
308 let job = backend
309 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
310 .await
311 .unwrap()
312 .expect("Should return a job");
313
314 assert_eq!(job.id, job_id);
315 assert_eq!(job.status, JobStatus::Executing);
316 }
317
318 #[sqlx::test]
319 async fn load_job_mark_as_executing_for_executor_returns_oldest_scheduled_at_executable_job(
320 pool: PgPool,
321 ) {
322 let backend: RexecutorPgBackend = pool.into();
323 let expected_job_id = MockJob::default()
324 .with_scheduled_at(Utc::now() - TimeDelta::hours(3))
325 .enqueue(&backend)
326 .await;
327
328 let _ = MockJob::default().enqueue(&backend).await;
329
330 let job_id = MockJob::default().enqueue(&backend).await;
331 backend.mark_job_complete(job_id).await.unwrap();
332
333 let job_id = MockJob::default().enqueue(&backend).await;
334 backend
335 .mark_job_discarded(
336 job_id,
337 ExecutionError {
338 error_type: ErrorType::Panic,
339 message: "Oh dear".to_owned(),
340 },
341 )
342 .await
343 .unwrap();
344
345 let job_id = MockJob::default().enqueue(&backend).await;
346 backend
347 .mark_job_cancelled(
348 job_id,
349 ExecutionError {
350 error_type: ErrorType::Cancelled,
351 message: "Not needed".to_owned(),
352 },
353 )
354 .await
355 .unwrap();
356
357 let job = backend
358 .load_job_mark_as_executing_for_executor(MockJob::EXECUTOR)
359 .await
360 .unwrap()
361 .expect("Should return a job");
362
363 assert_eq!(job.id, expected_job_id);
364 assert_eq!(job.status, JobStatus::Executing);
365 }
366
367 #[sqlx::test]
368 async fn enqueue_test(pool: PgPool) {
369 let backend: RexecutorPgBackend = pool.into();
370 let job = MockJob::default();
371
372 let result = backend.enqueue(job.into()).await;
373
374 assert!(result.is_ok());
375
376 let all_jobs = backend.all_jobs().await.unwrap();
377
378 assert_eq!(all_jobs.len(), 1);
379 }
380}