apalis_sql/
postgres.rs

1//! # apalis-postgres
2//!
3//! Allows using postgres as a Backend
4//!
5//! ## Postgres Example
6//!  ```rust,no_run
7//! use apalis::prelude::*;
8//! # use apalis_sql::postgres::PostgresStorage;
9//! # use apalis_sql::postgres::PgPool;
10
11//!  use email_service::Email;
12//!
13//!  #[tokio::main]
14//!  async fn main() -> std::io::Result<()> {
15//!      std::env::set_var("RUST_LOG", "debug,sqlx::query=error");
16//!      let database_url = std::env::var("DATABASE_URL").expect("Must specify url to db");
17//!      let pool = PgPool::connect(&database_url).await.unwrap();
18//!      
19//!      PostgresStorage::setup(&pool).await.unwrap();
20//!      let pg: PostgresStorage<Email> = PostgresStorage::new(pool);
21//!
22//!      async fn send_email(job: Email, data: Data<usize>) -> Result<(), Error> {
23//!          /// execute job
24//!          Ok(())
25//!      }
26//!     // This can be even in another program/language
27//!     // let query = "Select apalis.push_job('apalis::Email', json_build_object('subject', 'Test apalis', 'to', 'test1@example.com', 'text', 'Lorem Ipsum'));";
28//!     // pg.execute(query).await.unwrap();
29//!
30//!      Monitor::new()
31//!          .register({
32//!              WorkerBuilder::new(&format!("tasty-avocado"))
33//!                  .data(0usize)
34//!                  .backend(pg)
35//!                  .build_fn(send_email)
36//!          })
37//!          .run()
38//!          .await
39//!  }
40//! ```
41use crate::context::SqlContext;
42use crate::{calculate_status, Config, SqlError};
43use apalis_core::backend::{BackendExpose, Stat, WorkerState};
44use apalis_core::codec::json::JsonCodec;
45use apalis_core::error::{BoxDynError, Error};
46use apalis_core::layers::{Ack, AckLayer};
47use apalis_core::notify::Notify;
48use apalis_core::poller::controller::Controller;
49use apalis_core::poller::stream::BackendStream;
50use apalis_core::poller::Poller;
51use apalis_core::request::{Parts, Request, RequestStream, State};
52use apalis_core::response::Response;
53use apalis_core::storage::Storage;
54use apalis_core::task::namespace::Namespace;
55use apalis_core::task::task_id::TaskId;
56use apalis_core::worker::{Context, Event, Worker, WorkerId};
57use apalis_core::{backend::Backend, codec::Codec};
58use chrono::{DateTime, Utc};
59use futures::channel::mpsc;
60use futures::StreamExt;
61use futures::{select, stream, SinkExt};
62use log::error;
63use serde::{de::DeserializeOwned, Serialize};
64use serde_json::Value;
65use sqlx::postgres::PgListener;
66use sqlx::{Pool, Postgres, Row};
67use std::any::type_name;
68use std::convert::TryInto;
69use std::fmt::Debug;
70use std::sync::Arc;
71use std::{fmt, io};
72use std::{marker::PhantomData, time::Duration};
73
74type Timestamp = i64;
75
76pub use sqlx::postgres::PgPool;
77
78use crate::from_row::SqlRequest;
79
80/// Represents a [Storage] that persists to Postgres
81// #[derive(Debug)]
82pub struct PostgresStorage<T, C = JsonCodec<serde_json::Value>>
83where
84    C: Codec,
85{
86    pool: PgPool,
87    job_type: PhantomData<T>,
88    codec: PhantomData<C>,
89    config: Config,
90    controller: Controller,
91    ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
92    subscription: Option<PgSubscription>,
93}
94
95impl<T, C: Codec> Clone for PostgresStorage<T, C> {
96    fn clone(&self) -> Self {
97        PostgresStorage {
98            pool: self.pool.clone(),
99            job_type: PhantomData,
100            codec: PhantomData,
101            config: self.config.clone(),
102            controller: self.controller.clone(),
103            ack_notify: self.ack_notify.clone(),
104            subscription: self.subscription.clone(),
105        }
106    }
107}
108
109impl<T, C: Codec> fmt::Debug for PostgresStorage<T, C> {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        f.debug_struct("PostgresStorage")
112            .field("pool", &self.pool)
113            .field("job_type", &"PhantomData<T>")
114            .field("controller", &self.controller)
115            .field("config", &self.config)
116            .field("codec", &std::any::type_name::<C>())
117            // .field("ack_notify", &std::any::type_name_of_val(&self.ack_notify))
118            .finish()
119    }
120}
121
122/// Errors that can occur while polling a PostgreSQL database.
123#[derive(thiserror::Error, Debug)]
124pub enum PgPollError {
125    /// Error during task acknowledgment.
126    #[error("Encountered an error during ACK: `{0}`")]
127    AckError(sqlx::Error),
128
129    /// Error while fetching the next item.
130    #[error("Encountered an error during FetchNext: `{0}`")]
131    FetchNextError(apalis_core::error::Error),
132
133    /// Error while listening to PostgreSQL notifications.
134    #[error("Encountered an error during listening to PgNotification: {0}")]
135    PgNotificationError(apalis_core::error::Error),
136
137    /// Error during a keep-alive heartbeat.
138    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
139    KeepAliveError(sqlx::Error),
140
141    /// Error during re-enqueuing orphaned tasks.
142    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
143    ReenqueueOrphanedError(sqlx::Error),
144
145    /// Error during result encoding.
146    #[error("Encountered an error during encoding the result: {0}")]
147    CodecError(BoxDynError),
148}
149
150impl<T, C, Res> Backend<Request<T, SqlContext>, Res> for PostgresStorage<T, C>
151where
152    T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
153    C: Codec<Compact = serde_json::Value> + Send + 'static,
154    C::Error: std::error::Error + 'static + Send + Sync,
155{
156    type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
157
158    type Layer = AckLayer<PostgresStorage<T, C>, T, SqlContext, Res>;
159
160    fn poll<Svc>(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
161        let layer = AckLayer::new(self.clone());
162        let subscription = self.subscription.clone();
163        let config = self.config.clone();
164        let controller = self.controller.clone();
165        let (mut tx, rx) = mpsc::channel(self.config.buffer_size);
166        let ack_notify = self.ack_notify.clone();
167        let pool = self.pool.clone();
168        let worker = worker.clone();
169        let heartbeat = async move {
170            let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse();
171            let mut reenqueue_orphaned_stm =
172                apalis_core::interval::interval(config.poll_interval).fuse();
173            let mut ack_stream = ack_notify.clone().ready_chunks(config.buffer_size).fuse();
174
175            let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse();
176
177            let mut pg_notification = subscription
178                .map(|stm| stm.notify.boxed().fuse())
179                .unwrap_or(stream::iter(vec![]).boxed().fuse());
180
181            async fn fetch_next_batch<
182                T: Unpin + DeserializeOwned + Send + 'static,
183                C: Codec<Compact = Value>,
184            >(
185                storage: &mut PostgresStorage<T, C>,
186                worker: &WorkerId,
187                tx: &mut mpsc::Sender<Result<Option<Request<T, SqlContext>>, Error>>,
188            ) -> Result<(), Error> {
189                let res = storage
190                    .fetch_next(worker)
191                    .await
192                    .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
193                for job in res {
194                    tx.send(Ok(Some(job)))
195                        .await
196                        .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
197                }
198                Ok(())
199            }
200
201            if let Err(e) = self
202                .keep_alive_at::<Self::Layer>(worker.id(), Utc::now().timestamp())
203                .await
204            {
205                worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e))));
206            }
207
208            loop {
209                select! {
210                    _ = keep_alive_stm.next() => {
211                        if let Err(e) = self.keep_alive_at::<Self::Layer>(worker.id(), Utc::now().timestamp()).await {
212                            worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e))));
213                        }
214                    }
215                    ids = ack_stream.next() => {
216                        if let Some(ids) = ids {
217                            let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(_ctx, res)| {
218                                (res.task_id.to_string(), worker.id().to_string(), serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())).expect("Could not convert response to json"), calculate_status(&res.inner).to_string(), (res.attempt.current() + 1) as u64 )
219                            }).collect();
220                            let query =
221                                "UPDATE apalis.jobs
222                                    SET status = Q.status, 
223                                        done_at = now(), 
224                                        lock_by = Q.worker_id, 
225                                        last_error = Q.result, 
226                                        attempts = Q.attempts 
227                                    FROM (
228                                        SELECT (value->>0)::text as id, 
229                                            (value->>1)::text as worker_id, 
230                                            (value->>2)::text as result, 
231                                            (value->>3)::text as status, 
232                                            (value->>4)::int as attempts 
233                                        FROM json_array_elements($1::json)
234                                    ) Q
235                                    WHERE apalis.jobs.id = Q.id;
236                                    ";
237                            let codec_res = C::encode(&ack_ids);
238                            match codec_res {
239                                Ok(val) => {
240                                    if let Err(e) = sqlx::query(query)
241                                        .bind(val)
242                                        .execute(&pool)
243                                        .await
244                                    {
245                                        worker.emit(Event::Error(Box::new(PgPollError::AckError(e))));
246                                    }
247                                }
248                                Err(e) => {
249                                    worker.emit(Event::Error(Box::new(PgPollError::CodecError(e.into()))));
250                                }
251                            }
252
253                        }
254                    }
255                    _ = poll_next_stm.next() => {
256                        if worker.is_ready() {
257                            if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await {
258                                worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e))));
259                            }
260                        }
261                    }
262                    _ = pg_notification.next() => {
263                        if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await {
264                            worker.emit(Event::Error(Box::new(PgPollError::PgNotificationError(e))));
265
266                        }
267                    }
268                    _ = reenqueue_orphaned_stm.next() => {
269                        let dead_since = Utc::now()
270                            - chrono::Duration::from_std(config.reenqueue_orphaned_after).expect("could not build dead_since");
271                        if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await {
272                            worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(e))));
273                        }
274                    }
275
276
277                };
278            }
279        };
280        Poller::new_with_layer(BackendStream::new(rx.boxed(), controller), heartbeat, layer)
281    }
282}
283
284impl PostgresStorage<()> {
285    /// Get postgres migrations without running them
286    #[cfg(feature = "migrate")]
287    pub fn migrations() -> sqlx::migrate::Migrator {
288        sqlx::migrate!("migrations/postgres")
289    }
290
291    /// Do migrations for Postgres
292    #[cfg(feature = "migrate")]
293    pub async fn setup(pool: &Pool<Postgres>) -> Result<(), sqlx::Error> {
294        Self::migrations().run(pool).await?;
295        Ok(())
296    }
297}
298
299impl<T> PostgresStorage<T> {
300    /// New Storage from [PgPool]
301    pub fn new(pool: PgPool) -> Self {
302        Self::new_with_config(pool, Config::new(type_name::<T>()))
303    }
304    /// New Storage from [PgPool] and custom config
305    pub fn new_with_config(pool: PgPool, config: Config) -> Self {
306        Self {
307            pool,
308            job_type: PhantomData,
309            codec: PhantomData,
310            config,
311            controller: Controller::new(),
312            ack_notify: Notify::new(),
313            subscription: None,
314        }
315    }
316
317    /// Expose the pool for other functionality, eg custom migrations
318    pub fn pool(&self) -> &Pool<Postgres> {
319        &self.pool
320    }
321
322    /// Expose the config
323    pub fn config(&self) -> &Config {
324        &self.config
325    }
326}
327
328impl<T, C: Codec> PostgresStorage<T, C> {
329    /// Expose the codec
330    pub fn codec(&self) -> &PhantomData<C> {
331        &self.codec
332    }
333
334    async fn keep_alive_at<Service>(
335        &mut self,
336        worker_id: &WorkerId,
337        last_seen: Timestamp,
338    ) -> Result<(), sqlx::Error> {
339        let last_seen = DateTime::from_timestamp(last_seen, 0).ok_or(sqlx::Error::Io(
340            io::Error::new(io::ErrorKind::InvalidInput, "Invalid Timestamp"),
341        ))?;
342        let worker_type = self.config.namespace.clone();
343        let storage_name = std::any::type_name::<Self>();
344        let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
345                VALUES ($1, $2, $3, $4, $5)
346                ON CONFLICT (id) DO
347                   UPDATE SET last_seen = EXCLUDED.last_seen";
348        sqlx::query(query)
349            .bind(worker_id.to_string())
350            .bind(worker_type)
351            .bind(storage_name)
352            .bind(std::any::type_name::<Service>())
353            .bind(last_seen)
354            .execute(&self.pool)
355            .await?;
356        Ok(())
357    }
358}
359
360/// A listener that listens to Postgres notifications
361#[derive(Debug)]
362pub struct PgListen {
363    listener: PgListener,
364    subscriptions: Vec<(String, PgSubscription)>,
365}
366
367/// A postgres subscription
368#[derive(Debug, Clone)]
369pub struct PgSubscription {
370    notify: Notify<()>,
371}
372
373impl PgListen {
374    /// Build a new listener.
375    ///
376    /// Maintaining a connection can be expensive, its encouraged you only create one [PgListen] and share it with multiple [PostgresStorage]
377    pub async fn new(pool: PgPool) -> Result<Self, sqlx::Error> {
378        let listener = PgListener::connect_with(&pool).await?;
379        Ok(Self {
380            listener,
381            subscriptions: Vec::new(),
382        })
383    }
384
385    /// Add a new subscription with a storage
386    pub fn subscribe_with<T>(&mut self, storage: &mut PostgresStorage<T>) {
387        let sub = PgSubscription {
388            notify: Notify::new(),
389        };
390        self.subscriptions
391            .push((storage.config.namespace.to_owned(), sub.clone()));
392        storage.subscription = Some(sub)
393    }
394
395    /// Add a new subscription
396    pub fn subscribe(&mut self, namespace: &str) -> PgSubscription {
397        let sub = PgSubscription {
398            notify: Notify::new(),
399        };
400        self.subscriptions.push((namespace.to_owned(), sub.clone()));
401        sub
402    }
403    /// Start listening to jobs
404    pub async fn listen(mut self) -> Result<(), sqlx::Error> {
405        self.listener.listen("apalis::job").await?;
406        let mut notification = self.listener.into_stream();
407        while let Some(Ok(res)) = notification.next().await {
408            let _: Vec<_> = self
409                .subscriptions
410                .iter()
411                .filter(|s| s.0 == res.payload())
412                .map(|s| s.1.notify.notify(()))
413                .collect();
414        }
415        Ok(())
416    }
417}
418
419impl<T, C> PostgresStorage<T, C>
420where
421    T: DeserializeOwned + Send + Unpin + 'static,
422    C: Codec<Compact = serde_json::Value>,
423{
424    async fn fetch_next(
425        &mut self,
426        worker_id: &WorkerId,
427    ) -> Result<Vec<Request<T, SqlContext>>, sqlx::Error> {
428        let config = &self.config;
429        let job_type = &config.namespace;
430        let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);";
431        let jobs: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
432            .bind(worker_id.to_string())
433            .bind(job_type)
434            // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
435            .bind(
436                i32::try_from(config.buffer_size)
437                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
438            )
439            .fetch_all(&self.pool)
440            .await?;
441        let jobs: Vec<_> = jobs
442            .into_iter()
443            .map(|job| {
444                let (req, parts) = job.req.take_parts();
445                let req = C::decode(req)
446                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
447                    .expect("Unable to decode");
448                let mut req = Request::new_with_parts(req, parts);
449                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
450                req
451            })
452            .collect();
453        Ok(jobs)
454    }
455}
456
457impl<Req, C> Storage for PostgresStorage<Req, C>
458where
459    Req: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
460    C: Codec<Compact = Value> + Send + 'static,
461{
462    type Job = Req;
463
464    type Error = sqlx::Error;
465
466    type Context = SqlContext;
467
468    /// Push a job to Postgres [Storage]
469    ///
470    /// # SQL Example
471    ///
472    /// ```sql
473    /// Select apalis.push_job(job_type::text, job::json);
474    /// ```
475    async fn push_request(
476        &mut self,
477        req: Request<Self::Job, SqlContext>,
478    ) -> Result<Parts<SqlContext>, sqlx::Error> {
479        let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, NOW() , NULL, NULL, NULL, NULL)";
480
481        let args = C::encode(&req.args)
482            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
483        let job_type = self.config.namespace.clone();
484        sqlx::query(query)
485            .bind(args)
486            .bind(req.parts.task_id.to_string())
487            .bind(&job_type)
488            .bind(req.parts.context.max_attempts())
489            .execute(&self.pool)
490            .await?;
491        Ok(req.parts)
492    }
493
494    async fn schedule_request(
495        &mut self,
496        req: Request<Self::Job, SqlContext>,
497        on: Timestamp,
498    ) -> Result<Parts<Self::Context>, sqlx::Error> {
499        let query =
500            "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, $5, NULL, NULL, NULL, NULL)";
501        let task_id = req.parts.task_id.to_string();
502        let parts = req.parts;
503        let on = DateTime::from_timestamp(on, 0);
504        let job = C::encode(&req.args)
505            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
506        let job_type = self.config.namespace.clone();
507        sqlx::query(query)
508            .bind(job)
509            .bind(task_id)
510            .bind(job_type)
511            .bind(parts.context.max_attempts())
512            .bind(on)
513            .execute(&self.pool)
514            .await?;
515        Ok(parts)
516    }
517
518    async fn fetch_by_id(
519        &mut self,
520        job_id: &TaskId,
521    ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
522        let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1 LIMIT 1";
523        let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
524            .bind(job_id.to_string())
525            .fetch_optional(&self.pool)
526            .await?;
527
528        match res {
529            None => Ok(None),
530            Some(job) => Ok(Some({
531                let (req, parts) = job.req.take_parts();
532                let args = C::decode(req)
533                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
534
535                let mut req: Request<Req, SqlContext> = Request::new_with_parts(args, parts);
536                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
537                req
538            })),
539        }
540    }
541
542    async fn len(&mut self) -> Result<i64, sqlx::Error> {
543        let query = "Select Count(*) as count from apalis.jobs where status='Pending'";
544        let record = sqlx::query(query).fetch_one(&self.pool).await?;
545        record.try_get("count")
546    }
547
548    async fn reschedule(
549        &mut self,
550        job: Request<Req, SqlContext>,
551        wait: Duration,
552    ) -> Result<(), sqlx::Error> {
553        let job_id = job.parts.task_id;
554        let on = Utc::now() + wait;
555        let mut tx = self.pool.acquire().await?;
556        let query =
557                "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
558
559        sqlx::query(query)
560            .bind(job_id.to_string())
561            .bind(on)
562            .execute(&mut *tx)
563            .await?;
564        Ok(())
565    }
566
567    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
568        let ctx = job.parts.context;
569        let job_id = job.parts.task_id;
570        let status = ctx.status().to_string();
571        let attempts: i32 = job
572            .parts
573            .attempt
574            .current()
575            .try_into()
576            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
577        let done_at = *ctx.done_at();
578        let lock_by = ctx.lock_by().clone();
579        let lock_at = *ctx.lock_at();
580        let last_error = ctx.last_error().clone();
581
582        let mut tx = self.pool.acquire().await?;
583        let query =
584                "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7";
585        sqlx::query(query)
586            .bind(status.to_owned())
587            .bind(attempts)
588            .bind(done_at)
589            .bind(lock_by.map(|w| w.name().to_string()))
590            .bind(lock_at)
591            .bind(last_error)
592            .bind(job_id.to_string())
593            .execute(&mut *tx)
594            .await?;
595        Ok(())
596    }
597
598    async fn is_empty(&mut self) -> Result<bool, sqlx::Error> {
599        Ok(self.len().await? == 0)
600    }
601
602    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
603        let query = "Delete from apalis.jobs where status='Done'";
604        let record = sqlx::query(query).execute(&self.pool).await?;
605        Ok(record.rows_affected().try_into().unwrap_or_default())
606    }
607}
608
609impl<T, Res, C> Ack<T, Res> for PostgresStorage<T, C>
610where
611    T: Sync + Send,
612    Res: Serialize + Sync + Clone,
613    C: Codec<Compact = Value> + Send,
614{
615    type Context = SqlContext;
616    type AckError = sqlx::Error;
617    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
618        let res = res.clone().map(|r| {
619            C::encode(r)
620                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))
621                .expect("Could not encode result")
622        });
623
624        self.ack_notify
625            .notify((ctx.clone(), res))
626            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))?;
627
628        Ok(())
629    }
630}
631
632impl<T, C: Codec> PostgresStorage<T, C> {
633    /// Kill a job
634    pub async fn kill(
635        &mut self,
636        worker_id: &WorkerId,
637        task_id: &TaskId,
638    ) -> Result<(), sqlx::Error> {
639        let mut tx = self.pool.acquire().await?;
640        let query =
641                "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
642        sqlx::query(query)
643            .bind(task_id.to_string())
644            .bind(worker_id.to_string())
645            .execute(&mut *tx)
646            .await?;
647        Ok(())
648    }
649
650    /// Puts the job instantly back into the queue
651    /// Another Worker may consume
652    pub async fn retry(
653        &mut self,
654        worker_id: &WorkerId,
655        task_id: &TaskId,
656    ) -> Result<(), sqlx::Error> {
657        let mut tx = self.pool.acquire().await?;
658        let query =
659                "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
660        sqlx::query(query)
661            .bind(task_id.to_string())
662            .bind(worker_id.to_string())
663            .execute(&mut *tx)
664            .await?;
665        Ok(())
666    }
667
668    /// Reenqueue jobs that have been abandoned by their workers
669    pub async fn reenqueue_orphaned(
670        &mut self,
671        count: i32,
672        dead_since: DateTime<Utc>,
673    ) -> Result<(), sqlx::Error> {
674        let job_type = self.config.namespace.clone();
675        let mut tx = self.pool.acquire().await?;
676        let query = "UPDATE apalis.jobs
677                            SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error = 'Job was abandoned'
678                            WHERE id IN
679                                (SELECT jobs.id FROM apalis.jobs INNER JOIN apalis.workers ON lock_by = workers.id
680                                    WHERE status = 'Running' 
681                                    AND workers.last_seen < ($3::timestamp)
682                                    AND workers.worker_type = $1 
683                                    ORDER BY lock_at ASC 
684                                    LIMIT $2);";
685
686        sqlx::query(query)
687            .bind(job_type)
688            .bind(count)
689            .bind(dead_since)
690            .execute(&mut *tx)
691            .await?;
692        Ok(())
693    }
694}
695
696impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
697    for PostgresStorage<J>
698{
699    type Request = Request<J, Parts<SqlContext>>;
700    type Error = SqlError;
701    async fn stats(&self) -> Result<Stat, Self::Error> {
702        let fetch_query = "SELECT
703                            COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
704                            COUNT(1) FILTER (WHERE status = 'Running') AS running,
705                            COUNT(1) FILTER (WHERE status = 'Done') AS done,
706                            COUNT(1) FILTER (WHERE status = 'Retry') AS retry,
707                            COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
708                            COUNT(1) FILTER (WHERE status = 'Killed') AS killed
709                        FROM apalis.jobs WHERE job_type = $1";
710
711        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
712            .bind(self.config().namespace())
713            .fetch_one(self.pool())
714            .await?;
715
716        Ok(Stat {
717            pending: res.0.try_into()?,
718            running: res.1.try_into()?,
719            dead: res.4.try_into()?,
720            failed: res.3.try_into()?,
721            success: res.2.try_into()?,
722        })
723    }
724
725    async fn list_jobs(
726        &self,
727        status: &State,
728        page: i32,
729    ) -> Result<Vec<Self::Request>, Self::Error> {
730        let status = status.to_string();
731        let fetch_query = "SELECT * FROM apalis.jobs WHERE status = $1 AND job_type = $2 ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET $3";
732        let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
733            .bind(status)
734            .bind(self.config().namespace())
735            .bind(((page - 1) * 10).to_string())
736            .fetch_all(self.pool())
737            .await?;
738        Ok(res
739            .into_iter()
740            .map(|j| {
741                let (req, ctx) = j.req.take_parts();
742                let req = JsonCodec::<Value>::decode(req).unwrap();
743                Request::new_with_ctx(req, ctx)
744            })
745            .collect())
746    }
747
748    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
749        let fetch_query =
750            "SELECT id, layers, last_seen FROM apalis.workers WHERE worker_type = $1 ORDER BY last_seen DESC LIMIT 20 OFFSET $2";
751        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
752            .bind(self.config().namespace())
753            .bind(0)
754            .fetch_all(self.pool())
755            .await?;
756        Ok(res
757            .into_iter()
758            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
759            .collect())
760    }
761}
762
763#[cfg(test)]
764mod tests {
765
766    use crate::sql_storage_tests;
767
768    use super::*;
769    use apalis_core::test_utils::DummyService;
770    use chrono::Utc;
771    use email_service::Email;
772
773    use apalis_core::generic_storage_test;
774    use apalis_core::test_utils::apalis_test_service_fn;
775    use apalis_core::test_utils::TestWrapper;
776
777    generic_storage_test!(setup);
778
779    sql_storage_tests!(setup::<Email>, PostgresStorage<Email>, Email);
780
781    /// migrate DB and return a storage instance.
782    async fn setup<T: Serialize + DeserializeOwned>() -> PostgresStorage<T> {
783        let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
784        let pool = PgPool::connect(&db_url).await.unwrap();
785        // Because connections cannot be shared across async runtime
786        // (different runtimes are created for each test),
787        // we don't share the storage and tests must be run sequentially.
788        PostgresStorage::setup(&pool).await.unwrap();
789        let config = Config::new("apalis-ci-tests").set_buffer_size(1);
790        let mut storage = PostgresStorage::new_with_config(pool, config);
791        cleanup(&mut storage, &WorkerId::new("test-worker")).await;
792        storage
793    }
794
795    /// rollback DB changes made by tests.
796    /// Delete the following rows:
797    ///  - jobs of the current type
798    ///  - worker identified by `worker_id`
799    ///
800    /// You should execute this function in the end of a test
801    async fn cleanup<T>(storage: &mut PostgresStorage<T>, worker_id: &WorkerId) {
802        let mut tx = storage
803            .pool
804            .acquire()
805            .await
806            .expect("failed to get connection");
807        sqlx::query("Delete from apalis.jobs where job_type = $1 OR lock_by = $2")
808            .bind(storage.config.namespace())
809            .bind(worker_id.to_string())
810            .execute(&mut *tx)
811            .await
812            .expect("failed to delete jobs");
813        sqlx::query("Delete from apalis.workers where id = $1")
814            .bind(worker_id.to_string())
815            .execute(&mut *tx)
816            .await
817            .expect("failed to delete worker");
818    }
819
820    fn example_email() -> Email {
821        Email {
822            subject: "Test Subject".to_string(),
823            to: "example@postgres".to_string(),
824            text: "Some Text".to_string(),
825        }
826    }
827
828    async fn consume_one(
829        storage: &mut PostgresStorage<Email>,
830        worker_id: &WorkerId,
831    ) -> Request<Email, SqlContext> {
832        let req = storage.fetch_next(worker_id).await;
833        req.unwrap()[0].clone()
834    }
835
836    async fn register_worker_at(
837        storage: &mut PostgresStorage<Email>,
838        last_seen: Timestamp,
839    ) -> Worker<Context> {
840        let worker_id = WorkerId::new("test-worker");
841
842        storage
843            .keep_alive_at::<DummyService>(&worker_id, last_seen)
844            .await
845            .expect("failed to register worker");
846        let wrk = Worker::new(worker_id, Context::default());
847        wrk.start();
848        wrk
849    }
850
851    async fn register_worker(storage: &mut PostgresStorage<Email>) -> Worker<Context> {
852        register_worker_at(storage, Utc::now().timestamp()).await
853    }
854
855    async fn push_email(storage: &mut PostgresStorage<Email>, email: Email) {
856        storage.push(email).await.expect("failed to push a job");
857    }
858
859    async fn get_job(
860        storage: &mut PostgresStorage<Email>,
861        job_id: &TaskId,
862    ) -> Request<Email, SqlContext> {
863        // add a slight delay to allow background actions like ack to complete
864        apalis_core::sleep(Duration::from_secs(2)).await;
865        storage
866            .fetch_by_id(job_id)
867            .await
868            .expect("failed to fetch job by id")
869            .expect("no job found by id")
870    }
871
872    #[tokio::test]
873    async fn test_consume_last_pushed_job() {
874        let mut storage = setup().await;
875        push_email(&mut storage, example_email()).await;
876
877        let worker = register_worker(&mut storage).await;
878
879        let job = consume_one(&mut storage, &worker.id()).await;
880        let job_id = &job.parts.task_id;
881
882        // Refresh our job
883        let job = get_job(&mut storage, job_id).await;
884        let ctx = job.parts.context;
885        assert_eq!(*ctx.status(), State::Running);
886        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
887        assert!(ctx.lock_at().is_some());
888    }
889
890    #[tokio::test]
891    async fn test_kill_job() {
892        let mut storage = setup().await;
893
894        push_email(&mut storage, example_email()).await;
895
896        let worker = register_worker(&mut storage).await;
897
898        let job = consume_one(&mut storage, &worker.id()).await;
899        let job_id = &job.parts.task_id;
900
901        storage
902            .kill(&worker.id(), job_id)
903            .await
904            .expect("failed to kill job");
905
906        let job = get_job(&mut storage, job_id).await;
907        let ctx = job.parts.context;
908        assert_eq!(*ctx.status(), State::Killed);
909        assert!(ctx.done_at().is_some());
910    }
911
912    #[tokio::test]
913    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
914        let mut storage = setup().await;
915
916        push_email(&mut storage, example_email()).await;
917        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
918        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
919
920        let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
921
922        let job = consume_one(&mut storage, &worker.id()).await;
923        storage
924            .reenqueue_orphaned(1, five_minutes_ago)
925            .await
926            .expect("failed to heartbeat");
927        let job_id = &job.parts.task_id;
928        let job = get_job(&mut storage, job_id).await;
929        let ctx = job.parts.context;
930
931        assert_eq!(*ctx.status(), State::Pending);
932        assert!(ctx.done_at().is_none());
933        assert!(ctx.lock_by().is_none());
934        assert!(ctx.lock_at().is_none());
935        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
936        assert_eq!(job.parts.attempt.current(), 0); // TODO: update get_jobs to increase attempts
937    }
938
939    #[tokio::test]
940    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
941        let mut storage = setup().await;
942
943        push_email(&mut storage, example_email()).await;
944
945        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
946        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
947
948        let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
949
950        let job = consume_one(&mut storage, &worker.id()).await;
951        let ctx = &job.parts.context;
952
953        assert_eq!(*ctx.status(), State::Running);
954        storage
955            .reenqueue_orphaned(1, six_minutes_ago)
956            .await
957            .expect("failed to heartbeat");
958
959        let job_id = &job.parts.task_id;
960        let job = get_job(&mut storage, job_id).await;
961        let ctx = job.parts.context;
962        assert_eq!(*ctx.status(), State::Running);
963        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
964        assert!(ctx.lock_at().is_some());
965        assert_eq!(*ctx.last_error(), None);
966        assert_eq!(job.parts.attempt.current(), 0);
967    }
968}