rexecutor_sqlx/
lib.rs

1//! A postgres backend for Rexecutor built on [`sqlx`].
2#![deny(missing_docs)]
3
4use std::{collections::HashMap, sync::Arc};
5
6use chrono::{DateTime, Utc};
7use rexecutor::{
8    backend::{BackendError, EnqueuableJob, ExecutionError, Query},
9    job::{JobId, uniqueness_criteria::Resolution},
10    pruner::PruneSpec,
11};
12use serde::Deserialize;
13use sqlx::{
14    PgPool, QueryBuilder, Row,
15    postgres::{PgListener, PgPoolOptions},
16    types::Text,
17};
18use tokio::sync::{RwLock, mpsc};
19
20use crate::{query::ToQuery, unique::Unique};
21
22mod backend;
23mod query;
24mod stream;
25mod types;
26mod unique;
27
28type Subscriber = mpsc::UnboundedSender<DateTime<Utc>>;
29
30/// A postgres implementation of a [`rexecutor::backend::Backend`].
31#[derive(Clone, Debug)]
32pub struct RexecutorPgBackend {
33    pool: PgPool,
34    subscribers: Arc<RwLock<HashMap<&'static str, Vec<Subscriber>>>>,
35}
36
37#[derive(Deserialize, Debug)]
38struct Notification {
39    executor: String,
40    scheduled_at: DateTime<Utc>,
41}
42
43use types::*;
44
45fn map_err(error: sqlx::Error) -> BackendError {
46    match error {
47        sqlx::Error::Io(err) => BackendError::Io(err),
48        sqlx::Error::Tls(err) => BackendError::Io(std::io::Error::other(err)),
49        sqlx::Error::Protocol(err) => BackendError::Io(std::io::Error::other(err)),
50        sqlx::Error::AnyDriverError(err) => BackendError::Io(std::io::Error::other(err)),
51        sqlx::Error::PoolTimedOut => BackendError::Io(std::io::Error::other(error)),
52        sqlx::Error::PoolClosed => BackendError::Io(std::io::Error::other(error)),
53        _ => BackendError::BadState,
54    }
55}
56
57impl RexecutorPgBackend {
58    /// Creates a new [`RexecutorPgBackend`] from a db connection string.
59    pub async fn from_db_url(db_url: &str) -> Result<Self, BackendError> {
60        let pool = PgPoolOptions::new()
61            .connect(db_url)
62            .await
63            .map_err(map_err)?;
64        Self::from_pool(pool).await
65    }
66    /// Create a new [`RexecutorPgBackend`] from an existing [`PgPool`].
67    pub async fn from_pool(pool: PgPool) -> Result<Self, BackendError> {
68        let this = Self {
69            pool,
70            subscribers: Default::default(),
71        };
72        let mut listener = PgListener::connect_with(&this.pool)
73            .await
74            .map_err(map_err)?;
75        listener
76            .listen("public.rexecutor_scheduled")
77            .await
78            .map_err(map_err)?;
79
80        tokio::spawn({
81            let subscribers = this.subscribers.clone();
82            async move {
83                while let Ok(notification) = listener.recv().await {
84                    let notification =
85                        serde_json::from_str::<Notification>(notification.payload()).unwrap();
86
87                    match subscribers
88                        .read()
89                        .await
90                        .get(&notification.executor.as_str())
91                    {
92                        Some(subscribers) => subscribers.iter().for_each(|sender| {
93                            let _ = sender.send(notification.scheduled_at);
94                        }),
95                        None => {
96                            tracing::warn!("No executors running for {}", notification.executor)
97                        }
98                    }
99                }
100            }
101        });
102
103        Ok(this)
104    }
105
106    /// This can be used to run the [`RexecutorPgBackend`]'s migrations.
107    pub async fn run_migrations(&self) -> Result<(), BackendError> {
108        tracing::info!("Running RexecutorPgBackend migrations");
109        sqlx::migrate!()
110            .run(&self.pool)
111            .await
112            .map_err(|err| BackendError::Io(std::io::Error::other(err)))
113    }
114
115    async fn load_job_mark_as_executing_for_executor(
116        &self,
117        executor: &str,
118    ) -> sqlx::Result<Option<Job>> {
119        sqlx::query_as!(
120            Job,
121            r#"UPDATE rexecutor_jobs
122            SET
123                status = 'executing',
124                attempted_at = timezone('UTC'::text, now()),
125                attempt = attempt + 1
126            WHERE id IN (
127                SELECT id from rexecutor_jobs
128                WHERE scheduled_at - timezone('UTC'::text, now()) < '00:00:00.1'
129                AND status in ('scheduled', 'retryable')
130                AND executor = $1
131                ORDER BY priority, scheduled_at
132                LIMIT 1
133                FOR UPDATE SKIP LOCKED
134            )
135            RETURNING
136                id,
137                status AS "status: JobStatus",
138                executor,
139                data,
140                metadata,
141                attempt,
142                max_attempts,
143                priority,
144                tags,
145                errors,
146                inserted_at,
147                scheduled_at,
148                attempted_at,
149                completed_at,
150                cancelled_at,
151                discarded_at
152            "#,
153            executor
154        )
155        .fetch_optional(&self.pool)
156        .await
157    }
158
159    async fn insert_job<'a>(&self, job: EnqueuableJob<'a>) -> sqlx::Result<JobId> {
160        let data = sqlx::query!(
161            r#"INSERT INTO rexecutor_jobs (
162                executor,
163                data,
164                metadata,
165                max_attempts,
166                scheduled_at,
167                priority,
168                tags
169            ) VALUES ($1, $2, $3, $4, $5, $6, $7)
170            RETURNING id
171            "#,
172            job.executor,
173            job.data,
174            job.metadata,
175            job.max_attempts as i32,
176            job.scheduled_at,
177            job.priority as i32,
178            &job.tags,
179        )
180        .fetch_one(&self.pool)
181        .await?;
182        Ok(data.id.into())
183    }
184
185    async fn insert_unique_job<'a>(&self, job: EnqueuableJob<'a>) -> sqlx::Result<JobId> {
186        let Some(uniqueness_criteria) = job.uniqueness_criteria else {
187            panic!();
188        };
189        let mut tx = self.pool.begin().await?;
190        let unique_identifier = uniqueness_criteria.unique_identifier(job.executor.as_str());
191        sqlx::query!("SELECT pg_advisory_xact_lock($1)", unique_identifier)
192            .execute(&mut *tx)
193            .await?;
194        match uniqueness_criteria
195            .query(job.executor.as_str(), job.scheduled_at)
196            .build()
197            .fetch_optional(&mut *tx)
198            .await?
199        {
200            None => {
201                let data = sqlx::query!(
202                    r#"INSERT INTO rexecutor_jobs (
203                        executor,
204                        data,
205                        metadata,
206                        max_attempts,
207                        scheduled_at,
208                        priority,
209                        tags,
210                        uniqueness_key
211                    ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
212                    RETURNING id
213                    "#,
214                    job.executor,
215                    job.data,
216                    job.metadata,
217                    job.max_attempts as i32,
218                    job.scheduled_at,
219                    job.priority as i32,
220                    &job.tags,
221                    unique_identifier,
222                )
223                .fetch_one(&mut *tx)
224                .await?;
225                tx.commit().await?;
226                Ok(data.id.into())
227            }
228            Some(val) => {
229                let job_id = val.get::<i32, _>(0);
230                let status = val.get::<JobStatus, _>(1);
231                match uniqueness_criteria.on_conflict {
232                    Resolution::Replace(replace)
233                        if replace
234                            .for_statuses
235                            .iter()
236                            .map(|js| JobStatus::from(*js))
237                            .any(|js| js == status) =>
238                    {
239                        let mut builder = QueryBuilder::new("UPDATE rexecutor_jobs SET ");
240                        let mut seperated = builder.separated(", ");
241                        if replace.scheduled_at {
242                            seperated.push("scheduled_at = ");
243                            seperated.push_bind_unseparated(job.scheduled_at);
244                        }
245                        if replace.data {
246                            seperated.push("data = ");
247                            seperated.push_bind_unseparated(job.data);
248                        }
249                        if replace.metadata {
250                            seperated.push("metadata = ");
251                            seperated.push_bind_unseparated(job.metadata);
252                        }
253                        if replace.priority {
254                            seperated.push("priority = ");
255                            seperated.push_bind_unseparated(job.priority as i32);
256                        }
257                        if replace.max_attempts {
258                            seperated.push("max_attempts = ");
259                            seperated.push_bind_unseparated(job.max_attempts as i32);
260                        }
261                        builder.push(" WHERE id  = ");
262                        builder.push_bind(job_id);
263                        builder.build().execute(&mut *tx).await?;
264                        tx.commit().await?;
265                    }
266                    _ => {
267                        tx.rollback().await?;
268                    }
269                }
270                Ok(job_id.into())
271            }
272        }
273    }
274
275    async fn _mark_job_complete(&self, id: JobId) -> sqlx::Result<u64> {
276        Ok(sqlx::query!(
277            r#"UPDATE rexecutor_jobs
278            SET
279                status = 'complete',
280                completed_at = timezone('UTC'::text, now())
281            WHERE id = $1"#,
282            i32::from(id),
283        )
284        .execute(&self.pool)
285        .await?
286        .rows_affected())
287    }
288
289    async fn _mark_job_retryable(
290        &self,
291        id: JobId,
292        next_scheduled_at: DateTime<Utc>,
293        error: ExecutionError,
294    ) -> sqlx::Result<u64> {
295        Ok(sqlx::query!(
296            r#"UPDATE rexecutor_jobs
297            SET
298                status = 'retryable',
299                scheduled_at = $4,
300                errors = ARRAY_APPEND(
301                    errors,
302                    jsonb_build_object(
303                        'attempt', attempt,
304                        'error_type', $2::text,
305                        'details', $3::text,
306                        'recorded_at', timezone('UTC'::text, now())::timestamptz
307                    )
308                )
309            WHERE id = $1"#,
310            i32::from(id),
311            Text(ErrorType::from(error.error_type)) as _,
312            error.message,
313            next_scheduled_at,
314        )
315        .execute(&self.pool)
316        .await?
317        .rows_affected())
318    }
319
320    async fn _mark_job_snoozed(
321        &self,
322        id: JobId,
323        next_scheduled_at: DateTime<Utc>,
324    ) -> sqlx::Result<u64> {
325        Ok(sqlx::query!(
326            r#"UPDATE rexecutor_jobs
327            SET
328                status = (CASE WHEN attempt = 1 THEN 'scheduled' ELSE 'retryable' END)::rexecutor_job_state,
329                scheduled_at = $2,
330                attempt = attempt - 1
331            WHERE id = $1"#,
332            i32::from(id),
333            next_scheduled_at,
334        )
335        .execute(&self.pool)
336        .await?
337        .rows_affected())
338    }
339
340    async fn _mark_job_discarded(&self, id: JobId, error: ExecutionError) -> sqlx::Result<u64> {
341        Ok(sqlx::query!(
342            r#"UPDATE rexecutor_jobs
343            SET
344                status = 'discarded',
345                discarded_at = timezone('UTC'::text, now()),
346                errors = ARRAY_APPEND(
347                    errors,
348                    jsonb_build_object(
349                        'attempt', attempt,
350                        'error_type', $2::text,
351                        'details', $3::text,
352                        'recorded_at', timezone('UTC'::text, now())::timestamptz
353                    )
354                )
355            WHERE id = $1"#,
356            i32::from(id),
357            Text(ErrorType::from(error.error_type)) as _,
358            error.message,
359        )
360        .execute(&self.pool)
361        .await?
362        .rows_affected())
363    }
364
365    async fn _mark_job_cancelled(&self, id: JobId, error: ExecutionError) -> sqlx::Result<u64> {
366        Ok(sqlx::query!(
367            r#"UPDATE rexecutor_jobs
368            SET
369                status = 'cancelled',
370                cancelled_at = timezone('UTC'::text, now()),
371                errors = ARRAY_APPEND(
372                    errors,
373                    jsonb_build_object(
374                        'attempt', attempt,
375                        'error_type', $2::text,
376                        'details', $3::text,
377                        'recorded_at', timezone('UTC'::text, now())::timestamptz
378                    )
379                )
380            WHERE id = $1"#,
381            i32::from(id),
382            Text(ErrorType::from(error.error_type)) as _,
383            error.message,
384        )
385        .execute(&self.pool)
386        .await?
387        .rows_affected())
388    }
389
390    async fn next_available_job_scheduled_at_for_executor(
391        &self,
392        executor: &'static str,
393    ) -> sqlx::Result<Option<DateTime<Utc>>> {
394        Ok(sqlx::query!(
395            r#"SELECT scheduled_at
396            FROM rexecutor_jobs
397            WHERE status in ('scheduled', 'retryable')
398            AND executor = $1
399            ORDER BY scheduled_at
400            LIMIT 1
401            "#,
402            executor
403        )
404        .fetch_optional(&self.pool)
405        .await?
406        .map(|data| data.scheduled_at))
407    }
408
409    async fn delete_from_spec(&self, spec: &PruneSpec) -> sqlx::Result<()> {
410        let result = spec.query().build().execute(&self.pool).await?;
411        tracing::debug!(
412            ?spec,
413            "Clean up query completed {} rows removed",
414            result.rows_affected()
415        );
416        Ok(())
417    }
418
419    async fn rerun(&self, id: JobId) -> sqlx::Result<u64> {
420        // Currently this increments the max attempts and reschedules the job.
421        Ok(sqlx::query!(
422            r#"UPDATE rexecutor_jobs
423            SET
424                status = (CASE WHEN attempt = 1 THEN 'scheduled' ELSE 'retryable' END)::rexecutor_job_state,
425                scheduled_at = $2,
426                completed_at = null,
427                cancelled_at = null,
428                discarded_at = null,
429                max_attempts = max_attempts + 1
430            WHERE id = $1"#,
431            i32::from(id),
432            Utc::now(),
433        )
434        .execute(&self.pool)
435        .await?
436        .rows_affected())
437    }
438
439    async fn update(&self, job: rexecutor::backend::Job) -> sqlx::Result<u64> {
440        Ok(sqlx::query!(
441            r#"UPDATE rexecutor_jobs
442            SET
443                data = $2,
444                metadata = $3,
445                max_attempts = $4,
446                scheduled_at = $5,
447                priority = $6,
448                tags = $7
449            WHERE id = $1"#,
450            job.id,
451            job.data,
452            job.metadata,
453            job.max_attempts as i32,
454            job.scheduled_at,
455            job.priority as i32,
456            &job.tags,
457        )
458        .execute(&self.pool)
459        .await?
460        .rows_affected())
461    }
462
463    async fn run_query<'a>(&self, query: Query<'a>) -> sqlx::Result<Vec<Job>> {
464        query
465            .query()
466            .build_query_as::<Job>()
467            .fetch_all(&self.pool)
468            .await
469    }
470}