apalis_sql/
postgres.rs

1//! # apalis-postgres
2//!
3//! Allows using postgres as a Backend
4//!
5//! ## Postgres Example
6//!  ```rust,no_run
7//! use apalis::prelude::*;
8//! # use apalis_sql::postgres::PostgresStorage;
9//! # use apalis_sql::postgres::PgPool;
10
11//!  use email_service::Email;
12//!
13//!  #[tokio::main]
14//!  async fn main() -> std::io::Result<()> {
15//!      std::env::set_var("RUST_LOG", "debug,sqlx::query=error");
16//!      let database_url = std::env::var("DATABASE_URL").expect("Must specify url to db");
17//!      let pool = PgPool::connect(&database_url).await.unwrap();
18//!      
19//!      PostgresStorage::setup(&pool).await.unwrap();
20//!      let pg: PostgresStorage<Email> = PostgresStorage::new(pool);
21//!
22//!      async fn send_email(job: Email, data: Data<usize>) -> Result<(), Error> {
23//!          /// execute job
24//!          Ok(())
25//!      }
26//!     // This can be even in another program/language
27//!     // let query = "Select apalis.push_job('apalis::Email', json_build_object('subject', 'Test apalis', 'to', 'test1@example.com', 'text', 'Lorem Ipsum'));";
28//!     // pg.execute(query).await.unwrap();
29//!
30//!      Monitor::new()
31//!          .register({
32//!              WorkerBuilder::new(&format!("tasty-avocado"))
33//!                  .data(0usize)
34//!                  .backend(pg)
35//!                  .build_fn(send_email)
36//!          })
37//!          .run()
38//!          .await
39//!  }
40//! ```
41use crate::context::SqlContext;
42use crate::{calculate_status, Config, SqlError};
43use apalis_core::backend::{BackendExpose, Stat, WorkerState};
44use apalis_core::codec::json::JsonCodec;
45use apalis_core::error::{BoxDynError, Error};
46use apalis_core::layers::{Ack, AckLayer};
47use apalis_core::notify::Notify;
48use apalis_core::poller::controller::Controller;
49use apalis_core::poller::stream::BackendStream;
50use apalis_core::poller::Poller;
51use apalis_core::request::{Parts, Request, RequestStream, State};
52use apalis_core::response::Response;
53use apalis_core::storage::Storage;
54use apalis_core::task::namespace::Namespace;
55use apalis_core::task::task_id::TaskId;
56use apalis_core::worker::{Context, Event, Worker, WorkerId};
57use apalis_core::{backend::Backend, codec::Codec};
58use chrono::{DateTime, Utc};
59use futures::channel::mpsc;
60use futures::StreamExt;
61use futures::{select, stream, SinkExt};
62use log::error;
63use serde::{de::DeserializeOwned, Serialize};
64use serde_json::Value;
65use sqlx::postgres::PgListener;
66use sqlx::{Pool, Postgres, Row};
67use std::any::type_name;
68use std::convert::TryInto;
69use std::fmt::Debug;
70use std::sync::Arc;
71use std::{fmt, io};
72use std::{marker::PhantomData, time::Duration};
73
74type Timestamp = i64;
75
76pub use sqlx::postgres::PgPool;
77
78use crate::from_row::SqlRequest;
79
80/// Represents a [Storage] that persists to Postgres
81// #[derive(Debug)]
82pub struct PostgresStorage<T, C = JsonCodec<serde_json::Value>>
83where
84    C: Codec,
85{
86    pool: PgPool,
87    job_type: PhantomData<T>,
88    codec: PhantomData<C>,
89    config: Config,
90    controller: Controller,
91    ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
92    subscription: Option<PgSubscription>,
93}
94
95impl<T, C: Codec> Clone for PostgresStorage<T, C> {
96    fn clone(&self) -> Self {
97        PostgresStorage {
98            pool: self.pool.clone(),
99            job_type: PhantomData,
100            codec: PhantomData,
101            config: self.config.clone(),
102            controller: self.controller.clone(),
103            ack_notify: self.ack_notify.clone(),
104            subscription: self.subscription.clone(),
105        }
106    }
107}
108
109impl<T, C: Codec> fmt::Debug for PostgresStorage<T, C> {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        f.debug_struct("PostgresStorage")
112            .field("pool", &self.pool)
113            .field("job_type", &"PhantomData<T>")
114            .field("controller", &self.controller)
115            .field("config", &self.config)
116            .field("codec", &std::any::type_name::<C>())
117            // .field("ack_notify", &std::any::type_name_of_val(&self.ack_notify))
118            .finish()
119    }
120}
121
122/// Errors that can occur while polling a PostgreSQL database.
123#[derive(thiserror::Error, Debug)]
124pub enum PgPollError {
125    /// Error during task acknowledgment.
126    #[error("Encountered an error during ACK: `{0}`")]
127    AckError(sqlx::Error),
128
129    /// Error while fetching the next item.
130    #[error("Encountered an error during FetchNext: `{0}`")]
131    FetchNextError(apalis_core::error::Error),
132
133    /// Error while listening to PostgreSQL notifications.
134    #[error("Encountered an error during listening to PgNotification: {0}")]
135    PgNotificationError(apalis_core::error::Error),
136
137    /// Error during a keep-alive heartbeat.
138    #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
139    KeepAliveError(sqlx::Error),
140
141    /// Error during re-enqueuing orphaned tasks.
142    #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
143    ReenqueueOrphanedError(sqlx::Error),
144
145    /// Error during result encoding.
146    #[error("Encountered an error during encoding the result: {0}")]
147    CodecError(BoxDynError),
148}
149
150impl<T, C> Backend<Request<T, SqlContext>> for PostgresStorage<T, C>
151where
152    T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
153    C: Codec<Compact = Value> + Send + 'static,
154    C::Error: std::error::Error + 'static + Send + Sync,
155{
156    type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
157
158    type Layer = AckLayer<PostgresStorage<T, C>, T, SqlContext, C>;
159
160    type Codec = C;
161
162    fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
163        let layer = AckLayer::new(self.clone());
164        let subscription = self.subscription.clone();
165        let config = self.config.clone();
166        let controller = self.controller.clone();
167        let (mut tx, rx) = mpsc::channel(self.config.buffer_size);
168        let ack_notify = self.ack_notify.clone();
169        let pool = self.pool.clone();
170        let worker = worker.clone();
171        let heartbeat = async move {
172            // Lets reenqueue any jobs that belonged to this worker in case of a death
173            if let Err(e) = self
174                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
175                .await
176            {
177                worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(
178                    e,
179                ))));
180            }
181
182            let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse();
183            let mut reenqueue_orphaned_stm =
184                apalis_core::interval::interval(config.poll_interval).fuse();
185
186            let mut ack_stream = ack_notify.clone().ready_chunks(config.buffer_size).fuse();
187
188            let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse();
189
190            let mut pg_notification = subscription
191                .map(|stm| stm.notify.boxed().fuse())
192                .unwrap_or(stream::iter(vec![]).boxed().fuse());
193
194            async fn fetch_next_batch<
195                T: Unpin + DeserializeOwned + Send + 'static,
196                C: Codec<Compact = Value>,
197            >(
198                storage: &mut PostgresStorage<T, C>,
199                worker: &WorkerId,
200                tx: &mut mpsc::Sender<Result<Option<Request<T, SqlContext>>, Error>>,
201            ) -> Result<(), Error> {
202                let res = storage
203                    .fetch_next(worker)
204                    .await
205                    .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
206                for job in res {
207                    tx.send(Ok(Some(job)))
208                        .await
209                        .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
210                }
211                Ok(())
212            }
213
214            if let Err(e) = self
215                .keep_alive_at::<Self::Layer>(worker.id(), Utc::now().timestamp())
216                .await
217            {
218                worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e))));
219            }
220
221            loop {
222                select! {
223                    _ = keep_alive_stm.next() => {
224                        if let Err(e) = self.keep_alive_at::<Self::Layer>(worker.id(), Utc::now().timestamp()).await {
225                            worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e))));
226                        }
227                    }
228                    ids = ack_stream.next() => {
229
230                        if let Some(ids) = ids {
231                            let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(ctx, res)| {
232                                (res.task_id.to_string(), worker.id().to_string(), serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())).expect("Could not convert response to json"), calculate_status(ctx,res).to_string(), res.attempt.current() as u64)
233                            }).collect();
234                            let query =
235                                "UPDATE apalis.jobs
236                                    SET status = Q.status, 
237                                        done_at = now(), 
238                                        lock_by = Q.worker_id, 
239                                        last_error = Q.result, 
240                                        attempts = Q.attempts 
241                                    FROM (
242                                        SELECT (value->>0)::text as id, 
243                                            (value->>1)::text as worker_id, 
244                                            (value->>2)::text as result, 
245                                            (value->>3)::text as status, 
246                                            (value->>4)::int as attempts 
247                                        FROM json_array_elements($1::json)
248                                    ) Q
249                                    WHERE apalis.jobs.id = Q.id;
250                                    ";
251                            let codec_res = C::encode(&ack_ids);
252                            match codec_res {
253                                Ok(val) => {
254                                    if let Err(e) = sqlx::query(query)
255                                        .bind(val)
256                                        .execute(&pool)
257                                        .await
258                                    {
259                                        worker.emit(Event::Error(Box::new(PgPollError::AckError(e))));
260                                    }
261                                }
262                                Err(e) => {
263                                    worker.emit(Event::Error(Box::new(PgPollError::CodecError(e.into()))));
264                                }
265                            }
266
267                        }
268                    }
269                    _ = poll_next_stm.next() => {
270                        if worker.is_ready() {
271                            if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await {
272                                worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e))));
273                            }
274                        }
275                    }
276                    _ = pg_notification.next() => {
277                        if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await {
278                            worker.emit(Event::Error(Box::new(PgPollError::PgNotificationError(e))));
279
280                        }
281                    }
282                    _ = reenqueue_orphaned_stm.next() => {
283                        let dead_since = Utc::now()
284                            - chrono::Duration::from_std(config.reenqueue_orphaned_after).expect("could not build dead_since");
285                        if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await {
286                            worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(e))));
287                        }
288                    }
289
290
291                };
292            }
293        };
294        Poller::new_with_layer(BackendStream::new(rx.boxed(), controller), heartbeat, layer)
295    }
296}
297
298impl PostgresStorage<()> {
299    /// Get postgres migrations without running them
300    #[cfg(feature = "migrate")]
301    pub fn migrations() -> sqlx::migrate::Migrator {
302        sqlx::migrate!("migrations/postgres")
303    }
304
305    /// Do migrations for Postgres
306    #[cfg(feature = "migrate")]
307    pub async fn setup(pool: &Pool<Postgres>) -> Result<(), sqlx::Error> {
308        Self::migrations().run(pool).await?;
309        Ok(())
310    }
311}
312
313impl<T> PostgresStorage<T> {
314    /// New Storage from [PgPool]
315    pub fn new(pool: PgPool) -> Self {
316        Self::new_with_config(pool, Config::new(type_name::<T>()))
317    }
318    /// New Storage from [PgPool] and custom config
319    pub fn new_with_config(pool: PgPool, config: Config) -> Self {
320        Self {
321            pool,
322            job_type: PhantomData,
323            codec: PhantomData,
324            config,
325            controller: Controller::new(),
326            ack_notify: Notify::new(),
327            subscription: None,
328        }
329    }
330
331    /// Expose the pool for other functionality, eg custom migrations
332    pub fn pool(&self) -> &Pool<Postgres> {
333        &self.pool
334    }
335
336    /// Expose the config
337    pub fn config(&self) -> &Config {
338        &self.config
339    }
340}
341
342impl<T, C: Codec> PostgresStorage<T, C> {
343    /// Expose the codec
344    pub fn codec(&self) -> &PhantomData<C> {
345        &self.codec
346    }
347
348    async fn keep_alive_at<Service>(
349        &mut self,
350        worker_id: &WorkerId,
351        last_seen: Timestamp,
352    ) -> Result<(), sqlx::Error> {
353        let last_seen = DateTime::from_timestamp(last_seen, 0).ok_or(sqlx::Error::Io(
354            io::Error::new(io::ErrorKind::InvalidInput, "Invalid Timestamp"),
355        ))?;
356        let worker_type = self.config.namespace.clone();
357        let storage_name = std::any::type_name::<Self>();
358        let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
359                VALUES ($1, $2, $3, $4, $5)
360                ON CONFLICT (id) DO
361                   UPDATE SET last_seen = EXCLUDED.last_seen";
362        sqlx::query(query)
363            .bind(worker_id.to_string())
364            .bind(worker_type)
365            .bind(storage_name)
366            .bind(std::any::type_name::<Service>())
367            .bind(last_seen)
368            .execute(&self.pool)
369            .await?;
370        Ok(())
371    }
372}
373
374/// A listener that listens to Postgres notifications
375#[derive(Debug)]
376pub struct PgListen {
377    listener: PgListener,
378    subscriptions: Vec<(String, PgSubscription)>,
379}
380
381/// A postgres subscription
382#[derive(Debug, Clone)]
383pub struct PgSubscription {
384    notify: Notify<()>,
385}
386
387impl PgListen {
388    /// Build a new listener.
389    ///
390    /// Maintaining a connection can be expensive, its encouraged you only create one [PgListen] and share it with multiple [PostgresStorage]
391    pub async fn new(pool: PgPool) -> Result<Self, sqlx::Error> {
392        let listener = PgListener::connect_with(&pool).await?;
393        Ok(Self {
394            listener,
395            subscriptions: Vec::new(),
396        })
397    }
398
399    /// Add a new subscription with a storage
400    pub fn subscribe_with<T>(&mut self, storage: &mut PostgresStorage<T>) {
401        let sub = PgSubscription {
402            notify: Notify::new(),
403        };
404        self.subscriptions
405            .push((storage.config.namespace.to_owned(), sub.clone()));
406        storage.subscription = Some(sub)
407    }
408
409    /// Add a new subscription
410    pub fn subscribe(&mut self, namespace: &str) -> PgSubscription {
411        let sub = PgSubscription {
412            notify: Notify::new(),
413        };
414        self.subscriptions.push((namespace.to_owned(), sub.clone()));
415        sub
416    }
417    /// Start listening to jobs
418    pub async fn listen(mut self) -> Result<(), sqlx::Error> {
419        self.listener.listen("apalis::job").await?;
420        let mut notification = self.listener.into_stream();
421        while let Some(Ok(res)) = notification.next().await {
422            let _: Vec<_> = self
423                .subscriptions
424                .iter()
425                .filter(|s| s.0 == res.payload())
426                .map(|s| s.1.notify.notify(()))
427                .collect();
428        }
429        Ok(())
430    }
431}
432
433impl<T, C> PostgresStorage<T, C>
434where
435    T: DeserializeOwned + Send + Unpin + 'static,
436    C: Codec<Compact = Value>,
437{
438    async fn fetch_next(
439        &mut self,
440        worker_id: &WorkerId,
441    ) -> Result<Vec<Request<T, SqlContext>>, sqlx::Error> {
442        let config = &self.config;
443        let job_type = &config.namespace;
444        let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);";
445        let jobs: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
446            .bind(worker_id.to_string())
447            .bind(job_type)
448            // https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
449            .bind(
450                i32::try_from(config.buffer_size)
451                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
452            )
453            .fetch_all(&self.pool)
454            .await?;
455        let jobs: Vec<_> = jobs
456            .into_iter()
457            .map(|job| {
458                let (req, parts) = job.req.take_parts();
459                let req = C::decode(req)
460                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
461                    .expect("Unable to decode");
462                let mut req = Request::new_with_parts(req, parts);
463                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
464                req
465            })
466            .collect();
467        Ok(jobs)
468    }
469}
470
471impl<Req, C> Storage for PostgresStorage<Req, C>
472where
473    Req: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
474    C: Codec<Compact = Value> + Send + 'static,
475    C::Error: Send + std::error::Error + Sync + 'static,
476{
477    type Job = Req;
478
479    type Error = sqlx::Error;
480
481    type Context = SqlContext;
482
483    type Compact = Value;
484
485    /// Push a job to Postgres [Storage]
486    ///
487    /// # SQL Example
488    ///
489    /// ```sql
490    /// Select apalis.push_job(job_type::text, job::json);
491    /// ```
492    async fn push_request(
493        &mut self,
494        req: Request<Self::Job, SqlContext>,
495    ) -> Result<Parts<SqlContext>, sqlx::Error> {
496        let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, NOW() , NULL, NULL, NULL, NULL, $5)";
497
498        let args = C::encode(&req.args)
499            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
500        let job_type = self.config.namespace.clone();
501        sqlx::query(query)
502            .bind(args)
503            .bind(req.parts.task_id.to_string())
504            .bind(&job_type)
505            .bind(req.parts.context.max_attempts())
506            .bind(req.parts.context.priority())
507            .execute(&self.pool)
508            .await?;
509        Ok(req.parts)
510    }
511
512    async fn push_raw_request(
513        &mut self,
514        req: Request<Self::Compact, SqlContext>,
515    ) -> Result<Parts<SqlContext>, sqlx::Error> {
516        let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, NOW() , NULL, NULL, NULL, NULL, $5)";
517
518        let args = C::encode(&req.args)
519            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
520        let job_type = self.config.namespace.clone();
521        sqlx::query(query)
522            .bind(args)
523            .bind(req.parts.task_id.to_string())
524            .bind(&job_type)
525            .bind(req.parts.context.max_attempts())
526            .bind(req.parts.context.priority())
527            .execute(&self.pool)
528            .await?;
529        Ok(req.parts)
530    }
531
532    async fn schedule_request(
533        &mut self,
534        req: Request<Self::Job, SqlContext>,
535        on: Timestamp,
536    ) -> Result<Parts<Self::Context>, sqlx::Error> {
537        let query =
538            "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, $5, NULL, NULL, NULL, NULL, $6)";
539        let task_id = req.parts.task_id.to_string();
540        let parts = req.parts;
541        let on = DateTime::from_timestamp(on, 0);
542        let job = C::encode(&req.args)
543            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
544        let job_type = self.config.namespace.clone();
545        sqlx::query(query)
546            .bind(job)
547            .bind(task_id)
548            .bind(job_type)
549            .bind(parts.context.max_attempts())
550            .bind(on)
551            .bind(parts.context.priority())
552            .execute(&self.pool)
553            .await?;
554        Ok(parts)
555    }
556
557    async fn fetch_by_id(
558        &mut self,
559        job_id: &TaskId,
560    ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
561        let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1 LIMIT 1";
562        let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
563            .bind(job_id.to_string())
564            .fetch_optional(&self.pool)
565            .await?;
566
567        match res {
568            None => Ok(None),
569            Some(job) => Ok(Some({
570                let (req, parts) = job.req.take_parts();
571                let args = C::decode(req)
572                    .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
573
574                let mut req: Request<Req, SqlContext> = Request::new_with_parts(args, parts);
575                req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
576                req
577            })),
578        }
579    }
580
581    async fn len(&mut self) -> Result<i64, sqlx::Error> {
582        let query = "Select Count(*) as count from apalis.jobs where status='Pending' OR (status = 'Failed' AND attempts < max_attempts)";
583        let record = sqlx::query(query).fetch_one(&self.pool).await?;
584        record.try_get("count")
585    }
586
587    async fn reschedule(
588        &mut self,
589        job: Request<Req, SqlContext>,
590        wait: Duration,
591    ) -> Result<(), sqlx::Error> {
592        let job_id = job.parts.task_id;
593        let on = Utc::now() + wait;
594        let mut tx = self.pool.acquire().await?;
595        let query =
596                "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
597
598        sqlx::query(query)
599            .bind(job_id.to_string())
600            .bind(on)
601            .execute(&mut *tx)
602            .await?;
603        Ok(())
604    }
605
606    async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
607        let ctx = job.parts.context;
608        let job_id = job.parts.task_id;
609        let status = ctx.status().to_string();
610        let attempts: i32 = job
611            .parts
612            .attempt
613            .current()
614            .try_into()
615            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
616        let done_at = *ctx.done_at();
617        let lock_by = ctx.lock_by().clone();
618        let lock_at = *ctx.lock_at();
619        let last_error = ctx.last_error().clone();
620        let priority = *ctx.priority();
621
622        let mut tx = self.pool.acquire().await?;
623        let query =
624                "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = to_timestamp($3), lock_by = $4, lock_at = to_timestamp($5), last_error = $6, priority = $7 WHERE id = $8";
625        sqlx::query(query)
626            .bind(status.to_owned())
627            .bind(attempts)
628            .bind(done_at)
629            .bind(lock_by.map(|w| w.name().to_string()))
630            .bind(lock_at)
631            .bind(last_error)
632            .bind(priority)
633            .bind(job_id.to_string())
634            .execute(&mut *tx)
635            .await?;
636        Ok(())
637    }
638
639    async fn is_empty(&mut self) -> Result<bool, sqlx::Error> {
640        Ok(self.len().await? == 0)
641    }
642
643    async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
644        let query = "Delete from apalis.jobs where status='Done'";
645        let record = sqlx::query(query).execute(&self.pool).await?;
646        Ok(record.rows_affected().try_into().unwrap_or_default())
647    }
648}
649
650impl<T, Res, C> Ack<T, Res, C> for PostgresStorage<T, C>
651where
652    T: Sync + Send,
653    Res: Serialize + Sync + Clone,
654    C: Codec<Compact = Value> + Send,
655{
656    type Context = SqlContext;
657    type AckError = sqlx::Error;
658    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
659        let res = res.clone().map(|r| {
660            C::encode(r)
661                .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))
662                .expect("Could not encode result")
663        });
664
665        self.ack_notify
666            .notify((ctx.clone(), res))
667            .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))?;
668
669        Ok(())
670    }
671}
672
673impl<T, C: Codec> PostgresStorage<T, C> {
674    /// Kill a job
675    pub async fn kill(
676        &mut self,
677        worker_id: &WorkerId,
678        task_id: &TaskId,
679    ) -> Result<(), sqlx::Error> {
680        let mut tx = self.pool.acquire().await?;
681        let query =
682                "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
683        sqlx::query(query)
684            .bind(task_id.to_string())
685            .bind(worker_id.to_string())
686            .execute(&mut *tx)
687            .await?;
688        Ok(())
689    }
690
691    /// Puts the job instantly back into the queue
692    /// Another Worker may consume
693    pub async fn retry(
694        &mut self,
695        worker_id: &WorkerId,
696        task_id: &TaskId,
697    ) -> Result<(), sqlx::Error> {
698        let mut tx = self.pool.acquire().await?;
699        let query =
700                "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
701        sqlx::query(query)
702            .bind(task_id.to_string())
703            .bind(worker_id.to_string())
704            .execute(&mut *tx)
705            .await?;
706        Ok(())
707    }
708
709    /// Reenqueue jobs that have been abandoned by their workers
710    pub async fn reenqueue_orphaned(
711        &mut self,
712        count: i32,
713        dead_since: DateTime<Utc>,
714    ) -> Result<(), sqlx::Error> {
715        let job_type = self.config.namespace.clone();
716        let mut tx = self.pool.acquire().await?;
717        let query = "UPDATE apalis.jobs
718                            SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error = 'Job was abandoned'
719                            WHERE id IN
720                                (SELECT jobs.id FROM apalis.jobs INNER JOIN apalis.workers ON lock_by = workers.id
721                                    WHERE status = 'Running' 
722                                    AND workers.last_seen < ($3::timestamp)
723                                    AND workers.worker_type = $1 
724                                    ORDER BY lock_at ASC 
725                                    LIMIT $2);";
726
727        sqlx::query(query)
728            .bind(job_type)
729            .bind(count)
730            .bind(dead_since)
731            .execute(&mut *tx)
732            .await?;
733        Ok(())
734    }
735}
736
737impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
738    for PostgresStorage<J>
739{
740    type Request = Request<J, Parts<SqlContext>>;
741    type Error = SqlError;
742    async fn stats(&self) -> Result<Stat, Self::Error> {
743        let fetch_query = "SELECT
744                            COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
745                            COUNT(1) FILTER (WHERE status = 'Running') AS running,
746                            COUNT(1) FILTER (WHERE status = 'Done') AS done,
747                            COUNT(1) FILTER (WHERE status = 'Retry') AS retry,
748                            COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
749                            COUNT(1) FILTER (WHERE status = 'Killed') AS killed
750                        FROM apalis.jobs WHERE job_type = $1";
751
752        let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
753            .bind(self.config().namespace())
754            .fetch_one(self.pool())
755            .await?;
756
757        Ok(Stat {
758            pending: res.0.try_into()?,
759            running: res.1.try_into()?,
760            dead: res.4.try_into()?,
761            failed: res.3.try_into()?,
762            success: res.2.try_into()?,
763        })
764    }
765
766    async fn list_jobs(
767        &self,
768        status: &State,
769        page: i32,
770    ) -> Result<Vec<Self::Request>, Self::Error> {
771        let status = status.to_string();
772        let fetch_query = "SELECT * FROM apalis.jobs WHERE status = $1 AND job_type = $2 ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET $3";
773        let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
774            .bind(status)
775            .bind(self.config().namespace())
776            .bind(((page - 1) * 10) as i64)
777            .fetch_all(self.pool())
778            .await?;
779        Ok(res
780            .into_iter()
781            .map(|j| {
782                let (req, ctx) = j.req.take_parts();
783                let req = JsonCodec::<Value>::decode(req).unwrap();
784                Request::new_with_ctx(req, ctx)
785            })
786            .collect())
787    }
788
789    async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
790        let fetch_query =
791            "SELECT id, layers, cast(extract(epoch from last_seen) as bigint) FROM apalis.workers WHERE worker_type = $1 ORDER BY last_seen DESC LIMIT 20 OFFSET $2";
792        let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
793            .bind(self.config().namespace())
794            .bind(0)
795            .fetch_all(self.pool())
796            .await?;
797        Ok(res
798            .into_iter()
799            .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
800            .collect())
801    }
802}
803
804#[cfg(test)]
805mod tests {
806
807    use crate::sql_storage_tests;
808
809    use super::*;
810    use apalis_core::test_utils::DummyService;
811    use chrono::Utc;
812    use email_service::Email;
813
814    use apalis_core::generic_storage_test;
815    use apalis_core::test_utils::apalis_test_service_fn;
816    use apalis_core::test_utils::TestWrapper;
817
818    generic_storage_test!(setup);
819
820    sql_storage_tests!(setup::<Email>, PostgresStorage<Email>, Email);
821
822    /// migrate DB and return a storage instance.
823    async fn setup<T: Serialize + DeserializeOwned>() -> PostgresStorage<T> {
824        let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
825        let pool = PgPool::connect(&db_url).await.unwrap();
826        // Because connections cannot be shared across async runtime
827        // (different runtimes are created for each test),
828        // we don't share the storage and tests must be run sequentially.
829        PostgresStorage::setup(&pool).await.unwrap();
830        let config = Config::new("apalis-tests").set_buffer_size(1);
831        let mut storage = PostgresStorage::new_with_config(pool, config);
832        cleanup(&mut storage, &WorkerId::new("test-worker")).await;
833        storage
834    }
835
836    /// rollback DB changes made by tests.
837    /// Delete the following rows:
838    ///  - jobs of the current type
839    ///  - worker identified by `worker_id`
840    ///
841    /// You should execute this function in the end of a test
842    async fn cleanup<T>(storage: &mut PostgresStorage<T>, worker_id: &WorkerId) {
843        let mut tx = storage
844            .pool
845            .acquire()
846            .await
847            .expect("failed to get connection");
848        sqlx::query("Delete from apalis.jobs where job_type = $1 OR lock_by = $2")
849            .bind(storage.config.namespace())
850            .bind(worker_id.to_string())
851            .execute(&mut *tx)
852            .await
853            .expect("failed to delete jobs");
854        sqlx::query("Delete from apalis.workers where id = $1")
855            .bind(worker_id.to_string())
856            .execute(&mut *tx)
857            .await
858            .expect("failed to delete worker");
859    }
860
861    fn example_email() -> Email {
862        Email {
863            subject: "Test Subject".to_string(),
864            to: "example@postgres".to_string(),
865            text: "Some Text".to_string(),
866        }
867    }
868
869    async fn consume_one(
870        storage: &mut PostgresStorage<Email>,
871        worker_id: &WorkerId,
872    ) -> Request<Email, SqlContext> {
873        let req = storage.fetch_next(worker_id).await;
874        req.unwrap()[0].clone()
875    }
876
877    async fn register_worker_at(
878        storage: &mut PostgresStorage<Email>,
879        last_seen: Timestamp,
880    ) -> Worker<Context> {
881        let worker_id = WorkerId::new("test-worker");
882
883        storage
884            .keep_alive_at::<DummyService>(&worker_id, last_seen)
885            .await
886            .expect("failed to register worker");
887        let wrk = Worker::new(worker_id, Context::default());
888        wrk.start();
889        wrk
890    }
891
892    async fn register_worker(storage: &mut PostgresStorage<Email>) -> Worker<Context> {
893        register_worker_at(storage, Utc::now().timestamp()).await
894    }
895
896    async fn push_email(storage: &mut PostgresStorage<Email>, email: Email) -> TaskId {
897        storage
898            .push(email)
899            .await
900            .expect("failed to push a job")
901            .task_id
902    }
903
904    async fn get_job(
905        storage: &mut PostgresStorage<Email>,
906        job_id: &TaskId,
907    ) -> Request<Email, SqlContext> {
908        // add a slight delay to allow background actions like ack to complete
909        apalis_core::sleep(Duration::from_secs(2)).await;
910        storage
911            .fetch_by_id(job_id)
912            .await
913            .expect("failed to fetch job by id")
914            .expect("no job found by id")
915    }
916
917    #[tokio::test]
918    async fn test_consume_last_pushed_job() {
919        let mut storage = setup().await;
920        push_email(&mut storage, example_email()).await;
921
922        let worker = register_worker(&mut storage).await;
923
924        let job = consume_one(&mut storage, &worker.id()).await;
925        let job_id = &job.parts.task_id;
926
927        // Refresh our job
928        let job = get_job(&mut storage, job_id).await;
929        let ctx = job.parts.context;
930        assert_eq!(*ctx.status(), State::Running);
931        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
932        assert!(ctx.lock_at().is_some());
933    }
934
935    #[tokio::test]
936    async fn test_kill_job() {
937        let mut storage = setup().await;
938
939        push_email(&mut storage, example_email()).await;
940
941        let worker = register_worker(&mut storage).await;
942
943        let job = consume_one(&mut storage, &worker.id()).await;
944        let job_id = &job.parts.task_id;
945
946        storage
947            .kill(&worker.id(), job_id)
948            .await
949            .expect("failed to kill job");
950
951        let job = get_job(&mut storage, job_id).await;
952        let ctx = job.parts.context;
953        assert_eq!(*ctx.status(), State::Killed);
954        assert!(ctx.done_at().is_some());
955    }
956
957    #[tokio::test]
958    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
959        let mut storage = setup().await;
960
961        push_email(&mut storage, example_email()).await;
962        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
963        let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
964
965        let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
966
967        let job = consume_one(&mut storage, &worker.id()).await;
968        storage
969            .reenqueue_orphaned(1, five_minutes_ago)
970            .await
971            .expect("failed to heartbeat");
972        let job_id = &job.parts.task_id;
973        let job = get_job(&mut storage, job_id).await;
974        let ctx = job.parts.context;
975
976        assert_eq!(*ctx.status(), State::Pending);
977        assert!(ctx.done_at().is_none());
978        assert!(ctx.lock_by().is_none());
979        assert!(ctx.lock_at().is_none());
980        assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
981        assert_eq!(job.parts.attempt.current(), 0); // TODO: update get_jobs to increase attempts
982    }
983
984    #[tokio::test]
985    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
986        let mut storage = setup().await;
987
988        push_email(&mut storage, example_email()).await;
989
990        let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
991        let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
992
993        let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
994
995        let job = consume_one(&mut storage, &worker.id()).await;
996        let ctx = &job.parts.context;
997
998        assert_eq!(*ctx.status(), State::Running);
999        storage
1000            .reenqueue_orphaned(1, six_minutes_ago)
1001            .await
1002            .expect("failed to heartbeat");
1003
1004        let job_id = &job.parts.task_id;
1005        let job = get_job(&mut storage, job_id).await;
1006        let ctx = job.parts.context;
1007        assert_eq!(*ctx.status(), State::Running);
1008        assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
1009        assert!(ctx.lock_at().is_some());
1010        assert_eq!(*ctx.last_error(), None);
1011        assert_eq!(job.parts.attempt.current(), 0);
1012    }
1013
1014    // This test pushes a scheduled request (scheduled 5 minutes in the future)
1015    // and then asserts that fetch_next returns nothing.
1016    #[tokio::test]
1017    async fn test_scheduled_request_not_fetched() {
1018        // Setup storage using the provided helper; scheduled jobs use the same table as regular ones.
1019        let mut storage = setup().await;
1020
1021        // Schedule a request 5 minutes in the future.
1022        let run_at = Utc::now().timestamp() + 300; // 5 minutes = 300 secs
1023        let scheduled_req = Request::new(example_email());
1024
1025        storage
1026            .schedule_request(scheduled_req, run_at)
1027            .await
1028            .expect("failed to schedule request");
1029
1030        // Fetch the next jobs for a worker; expect empty since the job is scheduled for the future.
1031        let worker = register_worker(&mut storage).await;
1032        let jobs = storage
1033            .fetch_next(worker.id())
1034            .await
1035            .expect("failed to fetch next jobs");
1036        assert!(
1037            jobs.is_empty(),
1038            "Scheduled job should not be fetched before its scheduled time"
1039        );
1040
1041        // List jobs with status 'Pending' and expect the scheduled job to be there.
1042        let jobs = storage
1043            .list_jobs(&State::Pending, 1)
1044            .await
1045            .expect("failed to list jobs");
1046        assert_eq!(jobs.len(), 1, "Expected one job to be listed");
1047    }
1048
1049    // This test pushes a request using one job_type, then uses a worker with a different job_type
1050    // to fetch jobs and asserts that it returns nothing.
1051    #[tokio::test]
1052    async fn test_fetch_with_different_job_type_returns_empty() {
1053        // Setup one storage with its config namespace (job_type)
1054        let mut storage_email = setup().await;
1055
1056        // Create a second storage using the same pool but with a different namespace.
1057        let pool = storage_email.pool().clone();
1058        let sms_config = Config::new("sms-test").set_buffer_size(1);
1059        let mut storage_sms: PostgresStorage<Email> =
1060            PostgresStorage::new_with_config(pool, sms_config);
1061
1062        // Push a job using the first storage (job_type = storage_email.config.namespace)
1063        push_email(&mut storage_email, example_email()).await;
1064
1065        // Attempt to fetch the job with a worker associated with the different job_type.
1066        let worker_id = WorkerId::new("sms-worker");
1067        let worker = Worker::new(worker_id, Context::default());
1068        worker.start();
1069
1070        let jobs = storage_sms
1071            .fetch_next(worker.id())
1072            .await
1073            .expect("failed to fetch next jobs");
1074        assert!(
1075            jobs.is_empty(),
1076            "A worker with a different job_type should not fetch jobs"
1077        );
1078
1079        // Fetch the job with a worker associated with the correct job_type.
1080        let worker = register_worker(&mut storage_email).await;
1081        let jobs = storage_email
1082            .fetch_next(worker.id())
1083            .await
1084            .expect("failed to fetch next jobs");
1085        assert!(!jobs.is_empty(), "Worker should fetch the job");
1086    }
1087}