1#![deny(missing_docs)]
3
4use std::{collections::HashMap, sync::Arc};
5
6use chrono::{DateTime, Utc};
7use rexecutor::{
8 backend::{BackendError, EnqueuableJob, ExecutionError, Query},
9 job::{JobId, uniqueness_criteria::Resolution},
10 pruner::PruneSpec,
11};
12use serde::Deserialize;
13use sqlx::{
14 PgPool, QueryBuilder, Row,
15 postgres::{PgListener, PgPoolOptions},
16 types::Text,
17};
18use tokio::sync::{RwLock, mpsc};
19
20use crate::{query::ToQuery, unique::Unique};
21
22mod backend;
23mod query;
24mod stream;
25mod types;
26mod unique;
27
28type Subscriber = mpsc::UnboundedSender<DateTime<Utc>>;
29
30#[derive(Clone, Debug)]
32pub struct RexecutorPgBackend {
33 pool: PgPool,
34 subscribers: Arc<RwLock<HashMap<&'static str, Vec<Subscriber>>>>,
35}
36
37#[derive(Deserialize, Debug)]
38struct Notification {
39 executor: String,
40 scheduled_at: DateTime<Utc>,
41}
42
43use types::*;
44
45fn map_err(error: sqlx::Error) -> BackendError {
46 match error {
47 sqlx::Error::Io(err) => BackendError::Io(err),
48 sqlx::Error::Tls(err) => BackendError::Io(std::io::Error::other(err)),
49 sqlx::Error::Protocol(err) => BackendError::Io(std::io::Error::other(err)),
50 sqlx::Error::AnyDriverError(err) => BackendError::Io(std::io::Error::other(err)),
51 sqlx::Error::PoolTimedOut => BackendError::Io(std::io::Error::other(error)),
52 sqlx::Error::PoolClosed => BackendError::Io(std::io::Error::other(error)),
53 _ => BackendError::BadState,
54 }
55}
56
57impl RexecutorPgBackend {
58 pub async fn from_db_url(db_url: &str) -> Result<Self, BackendError> {
60 let pool = PgPoolOptions::new()
61 .connect(db_url)
62 .await
63 .map_err(map_err)?;
64 Self::from_pool(pool).await
65 }
66 pub async fn from_pool(pool: PgPool) -> Result<Self, BackendError> {
68 let this = Self {
69 pool,
70 subscribers: Default::default(),
71 };
72 let mut listener = PgListener::connect_with(&this.pool)
73 .await
74 .map_err(map_err)?;
75 listener
76 .listen("public.rexecutor_scheduled")
77 .await
78 .map_err(map_err)?;
79
80 tokio::spawn({
81 let subscribers = this.subscribers.clone();
82 async move {
83 while let Ok(notification) = listener.recv().await {
84 let notification =
85 serde_json::from_str::<Notification>(notification.payload()).unwrap();
86
87 match subscribers
88 .read()
89 .await
90 .get(¬ification.executor.as_str())
91 {
92 Some(subscribers) => subscribers.iter().for_each(|sender| {
93 let _ = sender.send(notification.scheduled_at);
94 }),
95 None => {
96 tracing::warn!("No executors running for {}", notification.executor)
97 }
98 }
99 }
100 }
101 });
102
103 Ok(this)
104 }
105
106 pub async fn run_migrations(&self) -> Result<(), BackendError> {
108 tracing::info!("Running RexecutorPgBackend migrations");
109 sqlx::migrate!()
110 .run(&self.pool)
111 .await
112 .map_err(|err| BackendError::Io(std::io::Error::other(err)))
113 }
114
115 async fn load_job_mark_as_executing_for_executor(
116 &self,
117 executor: &str,
118 ) -> sqlx::Result<Option<Job>> {
119 sqlx::query_as!(
120 Job,
121 r#"UPDATE rexecutor_jobs
122 SET
123 status = 'executing',
124 attempted_at = timezone('UTC'::text, now()),
125 attempt = attempt + 1
126 WHERE id IN (
127 SELECT id from rexecutor_jobs
128 WHERE scheduled_at - timezone('UTC'::text, now()) < '00:00:00.1'
129 AND status in ('scheduled', 'retryable')
130 AND executor = $1
131 ORDER BY priority, scheduled_at
132 LIMIT 1
133 FOR UPDATE SKIP LOCKED
134 )
135 RETURNING
136 id,
137 status AS "status: JobStatus",
138 executor,
139 data,
140 metadata,
141 attempt,
142 max_attempts,
143 priority,
144 tags,
145 errors,
146 inserted_at,
147 scheduled_at,
148 attempted_at,
149 completed_at,
150 cancelled_at,
151 discarded_at
152 "#,
153 executor
154 )
155 .fetch_optional(&self.pool)
156 .await
157 }
158
159 async fn insert_job<'a>(&self, job: EnqueuableJob<'a>) -> sqlx::Result<JobId> {
160 let data = sqlx::query!(
161 r#"INSERT INTO rexecutor_jobs (
162 executor,
163 data,
164 metadata,
165 max_attempts,
166 scheduled_at,
167 priority,
168 tags
169 ) VALUES ($1, $2, $3, $4, $5, $6, $7)
170 RETURNING id
171 "#,
172 job.executor,
173 job.data,
174 job.metadata,
175 job.max_attempts as i32,
176 job.scheduled_at,
177 job.priority as i32,
178 &job.tags,
179 )
180 .fetch_one(&self.pool)
181 .await?;
182 Ok(data.id.into())
183 }
184
185 async fn insert_unique_job<'a>(&self, job: EnqueuableJob<'a>) -> sqlx::Result<JobId> {
186 let Some(uniqueness_criteria) = job.uniqueness_criteria else {
187 panic!();
188 };
189 let mut tx = self.pool.begin().await?;
190 let unique_identifier = uniqueness_criteria.unique_identifier(job.executor.as_str());
191 sqlx::query!("SELECT pg_advisory_xact_lock($1)", unique_identifier)
192 .execute(&mut *tx)
193 .await?;
194 match uniqueness_criteria
195 .query(job.executor.as_str(), job.scheduled_at)
196 .build()
197 .fetch_optional(&mut *tx)
198 .await?
199 {
200 None => {
201 let data = sqlx::query!(
202 r#"INSERT INTO rexecutor_jobs (
203 executor,
204 data,
205 metadata,
206 max_attempts,
207 scheduled_at,
208 priority,
209 tags,
210 uniqueness_key
211 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
212 RETURNING id
213 "#,
214 job.executor,
215 job.data,
216 job.metadata,
217 job.max_attempts as i32,
218 job.scheduled_at,
219 job.priority as i32,
220 &job.tags,
221 unique_identifier,
222 )
223 .fetch_one(&mut *tx)
224 .await?;
225 tx.commit().await?;
226 Ok(data.id.into())
227 }
228 Some(val) => {
229 let job_id = val.get::<i32, _>(0);
230 let status = val.get::<JobStatus, _>(1);
231 match uniqueness_criteria.on_conflict {
232 Resolution::Replace(replace)
233 if replace
234 .for_statuses
235 .iter()
236 .map(|js| JobStatus::from(*js))
237 .any(|js| js == status) =>
238 {
239 let mut builder = QueryBuilder::new("UPDATE rexecutor_jobs SET ");
240 let mut seperated = builder.separated(", ");
241 if replace.scheduled_at {
242 seperated.push("scheduled_at = ");
243 seperated.push_bind_unseparated(job.scheduled_at);
244 }
245 if replace.data {
246 seperated.push("data = ");
247 seperated.push_bind_unseparated(job.data);
248 }
249 if replace.metadata {
250 seperated.push("metadata = ");
251 seperated.push_bind_unseparated(job.metadata);
252 }
253 if replace.priority {
254 seperated.push("priority = ");
255 seperated.push_bind_unseparated(job.priority as i32);
256 }
257 if replace.max_attempts {
258 seperated.push("max_attempts = ");
259 seperated.push_bind_unseparated(job.max_attempts as i32);
260 }
261 builder.push(" WHERE id = ");
262 builder.push_bind(job_id);
263 builder.build().execute(&mut *tx).await?;
264 tx.commit().await?;
265 }
266 _ => {
267 tx.rollback().await?;
268 }
269 }
270 Ok(job_id.into())
271 }
272 }
273 }
274
275 async fn _mark_job_complete(&self, id: JobId) -> sqlx::Result<u64> {
276 Ok(sqlx::query!(
277 r#"UPDATE rexecutor_jobs
278 SET
279 status = 'complete',
280 completed_at = timezone('UTC'::text, now())
281 WHERE id = $1"#,
282 i32::from(id),
283 )
284 .execute(&self.pool)
285 .await?
286 .rows_affected())
287 }
288
289 async fn _mark_job_retryable(
290 &self,
291 id: JobId,
292 next_scheduled_at: DateTime<Utc>,
293 error: ExecutionError,
294 ) -> sqlx::Result<u64> {
295 Ok(sqlx::query!(
296 r#"UPDATE rexecutor_jobs
297 SET
298 status = 'retryable',
299 scheduled_at = $4,
300 errors = ARRAY_APPEND(
301 errors,
302 jsonb_build_object(
303 'attempt', attempt,
304 'error_type', $2::text,
305 'details', $3::text,
306 'recorded_at', timezone('UTC'::text, now())::timestamptz
307 )
308 )
309 WHERE id = $1"#,
310 i32::from(id),
311 Text(ErrorType::from(error.error_type)) as _,
312 error.message,
313 next_scheduled_at,
314 )
315 .execute(&self.pool)
316 .await?
317 .rows_affected())
318 }
319
320 async fn _mark_job_snoozed(
321 &self,
322 id: JobId,
323 next_scheduled_at: DateTime<Utc>,
324 ) -> sqlx::Result<u64> {
325 Ok(sqlx::query!(
326 r#"UPDATE rexecutor_jobs
327 SET
328 status = (CASE WHEN attempt = 1 THEN 'scheduled' ELSE 'retryable' END)::rexecutor_job_state,
329 scheduled_at = $2,
330 attempt = attempt - 1
331 WHERE id = $1"#,
332 i32::from(id),
333 next_scheduled_at,
334 )
335 .execute(&self.pool)
336 .await?
337 .rows_affected())
338 }
339
340 async fn _mark_job_discarded(&self, id: JobId, error: ExecutionError) -> sqlx::Result<u64> {
341 Ok(sqlx::query!(
342 r#"UPDATE rexecutor_jobs
343 SET
344 status = 'discarded',
345 discarded_at = timezone('UTC'::text, now()),
346 errors = ARRAY_APPEND(
347 errors,
348 jsonb_build_object(
349 'attempt', attempt,
350 'error_type', $2::text,
351 'details', $3::text,
352 'recorded_at', timezone('UTC'::text, now())::timestamptz
353 )
354 )
355 WHERE id = $1"#,
356 i32::from(id),
357 Text(ErrorType::from(error.error_type)) as _,
358 error.message,
359 )
360 .execute(&self.pool)
361 .await?
362 .rows_affected())
363 }
364
365 async fn _mark_job_cancelled(&self, id: JobId, error: ExecutionError) -> sqlx::Result<u64> {
366 Ok(sqlx::query!(
367 r#"UPDATE rexecutor_jobs
368 SET
369 status = 'cancelled',
370 cancelled_at = timezone('UTC'::text, now()),
371 errors = ARRAY_APPEND(
372 errors,
373 jsonb_build_object(
374 'attempt', attempt,
375 'error_type', $2::text,
376 'details', $3::text,
377 'recorded_at', timezone('UTC'::text, now())::timestamptz
378 )
379 )
380 WHERE id = $1"#,
381 i32::from(id),
382 Text(ErrorType::from(error.error_type)) as _,
383 error.message,
384 )
385 .execute(&self.pool)
386 .await?
387 .rows_affected())
388 }
389
390 async fn next_available_job_scheduled_at_for_executor(
391 &self,
392 executor: &'static str,
393 ) -> sqlx::Result<Option<DateTime<Utc>>> {
394 Ok(sqlx::query!(
395 r#"SELECT scheduled_at
396 FROM rexecutor_jobs
397 WHERE status in ('scheduled', 'retryable')
398 AND executor = $1
399 ORDER BY scheduled_at
400 LIMIT 1
401 "#,
402 executor
403 )
404 .fetch_optional(&self.pool)
405 .await?
406 .map(|data| data.scheduled_at))
407 }
408
409 async fn delete_from_spec(&self, spec: &PruneSpec) -> sqlx::Result<()> {
410 let result = spec.query().build().execute(&self.pool).await?;
411 tracing::debug!(
412 ?spec,
413 "Clean up query completed {} rows removed",
414 result.rows_affected()
415 );
416 Ok(())
417 }
418
419 async fn rerun(&self, id: JobId) -> sqlx::Result<u64> {
420 Ok(sqlx::query!(
422 r#"UPDATE rexecutor_jobs
423 SET
424 status = (CASE WHEN attempt = 1 THEN 'scheduled' ELSE 'retryable' END)::rexecutor_job_state,
425 scheduled_at = $2,
426 completed_at = null,
427 cancelled_at = null,
428 discarded_at = null,
429 max_attempts = max_attempts + 1
430 WHERE id = $1"#,
431 i32::from(id),
432 Utc::now(),
433 )
434 .execute(&self.pool)
435 .await?
436 .rows_affected())
437 }
438
439 async fn update(&self, job: rexecutor::backend::Job) -> sqlx::Result<u64> {
440 Ok(sqlx::query!(
441 r#"UPDATE rexecutor_jobs
442 SET
443 data = $2,
444 metadata = $3,
445 max_attempts = $4,
446 scheduled_at = $5,
447 priority = $6,
448 tags = $7
449 WHERE id = $1"#,
450 job.id,
451 job.data,
452 job.metadata,
453 job.max_attempts as i32,
454 job.scheduled_at,
455 job.priority as i32,
456 &job.tags,
457 )
458 .execute(&self.pool)
459 .await?
460 .rows_affected())
461 }
462
463 async fn run_query<'a>(&self, query: Query<'a>) -> sqlx::Result<Vec<Job>> {
464 query
465 .query()
466 .build_query_as::<Job>()
467 .fetch_all(&self.pool)
468 .await
469 }
470}