1#![deny(missing_docs)]
3
4use std::{collections::HashMap, sync::Arc};
5
6use chrono::{DateTime, Utc};
7use rexecutor::{
8 backend::{BackendError, EnqueuableJob, ExecutionError, Query},
9 job::{uniqueness_criteria::Resolution, JobId},
10 pruner::PruneSpec,
11};
12use serde::Deserialize;
13use sqlx::{
14 postgres::{PgListener, PgPoolOptions},
15 types::Text,
16 PgPool, QueryBuilder, Row,
17};
18use tokio::sync::{mpsc, RwLock};
19
20use crate::{query::ToQuery, unique::Unique};
21
22mod backend;
23mod query;
24mod stream;
25mod types;
26mod unique;
27
28type Subscriber = mpsc::UnboundedSender<DateTime<Utc>>;
29
30#[derive(Clone, Debug)]
32pub struct RexecutorPgBackend {
33 pool: PgPool,
34 subscribers: Arc<RwLock<HashMap<&'static str, Vec<Subscriber>>>>,
35}
36
37#[derive(Deserialize, Debug)]
38struct Notification {
39 executor: String,
40 scheduled_at: DateTime<Utc>,
41}
42
43use types::*;
44
45impl RexecutorPgBackend {
46 pub async fn from_db_url(db_url: &str) -> Result<Self, BackendError> {
48 let pool = PgPoolOptions::new()
49 .connect(db_url)
50 .await
51 .map_err(|_| BackendError::BadState)?;
52 Self::from_pool(pool).await
53 }
54 pub async fn from_pool(pool: PgPool) -> Result<Self, BackendError> {
56 let this = Self {
57 pool,
58 subscribers: Default::default(),
59 };
60 let mut listener = PgListener::connect_with(&this.pool)
61 .await
62 .map_err(|_| BackendError::BadState)?;
63 listener
64 .listen("public.rexecutor_scheduled")
65 .await
66 .map_err(|_| BackendError::BadState)?;
67
68 tokio::spawn({
69 let subscribers = this.subscribers.clone();
70 async move {
71 while let Ok(notification) = listener.recv().await {
72 let notification =
73 serde_json::from_str::<Notification>(notification.payload()).unwrap();
74
75 match subscribers
76 .read()
77 .await
78 .get(¬ification.executor.as_str())
79 {
80 Some(subscribers) => subscribers.iter().for_each(|sender| {
81 let _ = sender.send(notification.scheduled_at);
82 }),
83 None => {
84 tracing::warn!("No executors running for {}", notification.executor)
85 }
86 }
87 }
88 }
89 });
90
91 Ok(this)
92 }
93
94 pub async fn run_migrations(&self) -> Result<(), BackendError> {
96 tracing::info!("Running RexecutorPgBackend migrations");
97 sqlx::migrate!()
98 .run(&self.pool)
99 .await
100 .map_err(|_| BackendError::BadState)
101 }
102
103 async fn load_job_mark_as_executing_for_executor(
104 &self,
105 executor: &str,
106 ) -> sqlx::Result<Option<Job>> {
107 sqlx::query_as!(
108 Job,
109 r#"UPDATE rexecutor_jobs
110 SET
111 status = 'executing',
112 attempted_at = timezone('UTC'::text, now()),
113 attempt = attempt + 1
114 WHERE id IN (
115 SELECT id from rexecutor_jobs
116 WHERE scheduled_at - timezone('UTC'::text, now()) < '00:00:00.1'
117 AND status in ('scheduled', 'retryable')
118 AND executor = $1
119 ORDER BY priority, scheduled_at
120 LIMIT 1
121 FOR UPDATE SKIP LOCKED
122 )
123 RETURNING
124 id,
125 status AS "status: JobStatus",
126 executor,
127 data,
128 metadata,
129 attempt,
130 max_attempts,
131 priority,
132 tags,
133 errors,
134 inserted_at,
135 scheduled_at,
136 attempted_at,
137 completed_at,
138 cancelled_at,
139 discarded_at
140 "#,
141 executor
142 )
143 .fetch_optional(&self.pool)
144 .await
145 }
146
147 async fn insert_job<'a>(&self, job: EnqueuableJob<'a>) -> sqlx::Result<JobId> {
148 let data = sqlx::query!(
149 r#"INSERT INTO rexecutor_jobs (
150 executor,
151 data,
152 metadata,
153 max_attempts,
154 scheduled_at,
155 priority,
156 tags
157 ) VALUES ($1, $2, $3, $4, $5, $6, $7)
158 RETURNING id
159 "#,
160 job.executor,
161 job.data,
162 job.metadata,
163 job.max_attempts as i32,
164 job.scheduled_at,
165 job.priority as i32,
166 &job.tags,
167 )
168 .fetch_one(&self.pool)
169 .await?;
170 Ok(data.id.into())
171 }
172
173 async fn insert_unique_job<'a>(&self, job: EnqueuableJob<'a>) -> sqlx::Result<JobId> {
174 let Some(uniqueness_criteria) = job.uniqueness_criteria else {
175 panic!();
176 };
177 let mut tx = self.pool.begin().await?;
178 let unique_identifier = uniqueness_criteria.unique_identifier(job.executor.as_str());
179 sqlx::query!("SELECT pg_advisory_xact_lock($1)", unique_identifier)
180 .execute(&mut *tx)
181 .await?;
182 match uniqueness_criteria
183 .query(job.executor.as_str(), job.scheduled_at)
184 .build()
185 .fetch_optional(&mut *tx)
186 .await?
187 {
188 None => {
189 let data = sqlx::query!(
190 r#"INSERT INTO rexecutor_jobs (
191 executor,
192 data,
193 metadata,
194 max_attempts,
195 scheduled_at,
196 priority,
197 tags,
198 uniqueness_key
199 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
200 RETURNING id
201 "#,
202 job.executor,
203 job.data,
204 job.metadata,
205 job.max_attempts as i32,
206 job.scheduled_at,
207 job.priority as i32,
208 &job.tags,
209 unique_identifier,
210 )
211 .fetch_one(&mut *tx)
212 .await?;
213 tx.commit().await?;
214 Ok(data.id.into())
215 }
216 Some(val) => {
217 let job_id = val.get::<i32, _>(0);
218 let status = val.get::<JobStatus, _>(1);
219 match uniqueness_criteria.on_conflict {
220 Resolution::Replace(replace)
221 if replace
222 .for_statuses
223 .iter()
224 .map(|js| JobStatus::from(*js))
225 .any(|js| js == status) =>
226 {
227 let mut builder = QueryBuilder::new("UPDATE rexecutor_jobs SET ");
228 let mut seperated = builder.separated(", ");
229 if replace.scheduled_at {
230 seperated.push("scheduled_at = ");
231 seperated.push_bind_unseparated(job.scheduled_at);
232 }
233 if replace.data {
234 seperated.push("data = ");
235 seperated.push_bind_unseparated(job.data);
236 }
237 if replace.metadata {
238 seperated.push("metadata = ");
239 seperated.push_bind_unseparated(job.metadata);
240 }
241 if replace.priority {
242 seperated.push("priority = ");
243 seperated.push_bind_unseparated(job.priority as i32);
244 }
245 if replace.max_attempts {
246 seperated.push("max_attempts = ");
247 seperated.push_bind_unseparated(job.max_attempts as i32);
248 }
249 builder.push(" WHERE id = ");
250 builder.push_bind(job_id);
251 builder.build().execute(&mut *tx).await?;
252 tx.commit().await?;
253 }
254 _ => {
255 tx.rollback().await?;
256 }
257 }
258 Ok(job_id.into())
259 }
260 }
261 }
262
263 async fn _mark_job_complete(&self, id: JobId) -> sqlx::Result<u64> {
264 Ok(sqlx::query!(
265 r#"UPDATE rexecutor_jobs
266 SET
267 status = 'complete',
268 completed_at = timezone('UTC'::text, now())
269 WHERE id = $1"#,
270 i32::from(id),
271 )
272 .execute(&self.pool)
273 .await?
274 .rows_affected())
275 }
276
277 async fn _mark_job_retryable(
278 &self,
279 id: JobId,
280 next_scheduled_at: DateTime<Utc>,
281 error: ExecutionError,
282 ) -> sqlx::Result<u64> {
283 Ok(sqlx::query!(
284 r#"UPDATE rexecutor_jobs
285 SET
286 status = 'retryable',
287 scheduled_at = $4,
288 errors = ARRAY_APPEND(
289 errors,
290 jsonb_build_object(
291 'attempt', attempt,
292 'error_type', $2::text,
293 'details', $3::text,
294 'recorded_at', timezone('UTC'::text, now())::timestamptz
295 )
296 )
297 WHERE id = $1"#,
298 i32::from(id),
299 Text(ErrorType::from(error.error_type)) as _,
300 error.message,
301 next_scheduled_at,
302 )
303 .execute(&self.pool)
304 .await?
305 .rows_affected())
306 }
307
308 async fn _mark_job_snoozed(
309 &self,
310 id: JobId,
311 next_scheduled_at: DateTime<Utc>,
312 ) -> sqlx::Result<u64> {
313 Ok(sqlx::query!(
314 r#"UPDATE rexecutor_jobs
315 SET
316 status = (CASE WHEN attempt = 1 THEN 'scheduled' ELSE 'retryable' END)::rexecutor_job_state,
317 scheduled_at = $2,
318 attempt = attempt - 1
319 WHERE id = $1"#,
320 i32::from(id),
321 next_scheduled_at,
322 )
323 .execute(&self.pool)
324 .await?
325 .rows_affected())
326 }
327
328 async fn _mark_job_discarded(&self, id: JobId, error: ExecutionError) -> sqlx::Result<u64> {
329 Ok(sqlx::query!(
330 r#"UPDATE rexecutor_jobs
331 SET
332 status = 'discarded',
333 discarded_at = timezone('UTC'::text, now()),
334 errors = ARRAY_APPEND(
335 errors,
336 jsonb_build_object(
337 'attempt', attempt,
338 'error_type', $2::text,
339 'details', $3::text,
340 'recorded_at', timezone('UTC'::text, now())::timestamptz
341 )
342 )
343 WHERE id = $1"#,
344 i32::from(id),
345 Text(ErrorType::from(error.error_type)) as _,
346 error.message,
347 )
348 .execute(&self.pool)
349 .await?
350 .rows_affected())
351 }
352
353 async fn _mark_job_cancelled(&self, id: JobId, error: ExecutionError) -> sqlx::Result<u64> {
354 Ok(sqlx::query!(
355 r#"UPDATE rexecutor_jobs
356 SET
357 status = 'cancelled',
358 cancelled_at = timezone('UTC'::text, now()),
359 errors = ARRAY_APPEND(
360 errors,
361 jsonb_build_object(
362 'attempt', attempt,
363 'error_type', $2::text,
364 'details', $3::text,
365 'recorded_at', timezone('UTC'::text, now())::timestamptz
366 )
367 )
368 WHERE id = $1"#,
369 i32::from(id),
370 Text(ErrorType::from(error.error_type)) as _,
371 error.message,
372 )
373 .execute(&self.pool)
374 .await?
375 .rows_affected())
376 }
377
378 async fn next_available_job_scheduled_at_for_executor(
379 &self,
380 executor: &'static str,
381 ) -> sqlx::Result<Option<DateTime<Utc>>> {
382 Ok(sqlx::query!(
383 r#"SELECT scheduled_at
384 FROM rexecutor_jobs
385 WHERE status in ('scheduled', 'retryable')
386 AND executor = $1
387 ORDER BY scheduled_at
388 LIMIT 1
389 "#,
390 executor
391 )
392 .fetch_optional(&self.pool)
393 .await?
394 .map(|data| data.scheduled_at))
395 }
396
397 async fn delete_from_spec(&self, spec: &PruneSpec) -> sqlx::Result<()> {
398 let result = spec.query().build().execute(&self.pool).await?;
399 tracing::debug!(
400 ?spec,
401 "Clean up query completed {} rows removed",
402 result.rows_affected()
403 );
404 Ok(())
405 }
406
407 async fn rerun(&self, id: JobId) -> sqlx::Result<u64> {
408 Ok(sqlx::query!(
410 r#"UPDATE rexecutor_jobs
411 SET
412 status = (CASE WHEN attempt = 1 THEN 'scheduled' ELSE 'retryable' END)::rexecutor_job_state,
413 scheduled_at = $2,
414 completed_at = null,
415 cancelled_at = null,
416 discarded_at = null,
417 max_attempts = max_attempts + 1
418 WHERE id = $1"#,
419 i32::from(id),
420 Utc::now(),
421 )
422 .execute(&self.pool)
423 .await?
424 .rows_affected())
425 }
426
427 async fn update(&self, job: rexecutor::backend::Job) -> sqlx::Result<u64> {
428 Ok(sqlx::query!(
429 r#"UPDATE rexecutor_jobs
430 SET
431 data = $2,
432 metadata = $3,
433 max_attempts = $4,
434 scheduled_at = $5,
435 priority = $6,
436 tags = $7
437 WHERE id = $1"#,
438 job.id,
439 job.data,
440 job.metadata,
441 job.max_attempts as i32,
442 job.scheduled_at,
443 job.priority as i32,
444 &job.tags,
445 )
446 .execute(&self.pool)
447 .await?
448 .rows_affected())
449 }
450
451 async fn run_query<'a>(&self, query: Query<'a>) -> sqlx::Result<Vec<Job>> {
452 query
453 .query()
454 .build_query_as::<Job>()
455 .fetch_all(&self.pool)
456 .await
457 }
458}