1use crate::context::SqlContext;
42use crate::{calculate_status, Config, SqlError};
43use apalis_core::backend::{BackendExpose, Stat, WorkerState};
44use apalis_core::codec::json::JsonCodec;
45use apalis_core::error::{BoxDynError, Error};
46use apalis_core::layers::{Ack, AckLayer};
47use apalis_core::notify::Notify;
48use apalis_core::poller::controller::Controller;
49use apalis_core::poller::stream::BackendStream;
50use apalis_core::poller::Poller;
51use apalis_core::request::{Parts, Request, RequestStream, State};
52use apalis_core::response::Response;
53use apalis_core::storage::Storage;
54use apalis_core::task::namespace::Namespace;
55use apalis_core::task::task_id::TaskId;
56use apalis_core::worker::{Context, Event, Worker, WorkerId};
57use apalis_core::{backend::Backend, codec::Codec};
58use chrono::{DateTime, Utc};
59use futures::channel::mpsc;
60use futures::StreamExt;
61use futures::{select, stream, SinkExt};
62use log::error;
63use serde::{de::DeserializeOwned, Serialize};
64use serde_json::Value;
65use sqlx::postgres::PgListener;
66use sqlx::{Pool, Postgres, Row};
67use std::any::type_name;
68use std::convert::TryInto;
69use std::fmt::Debug;
70use std::sync::Arc;
71use std::{fmt, io};
72use std::{marker::PhantomData, time::Duration};
73
74type Timestamp = i64;
75
76pub use sqlx::postgres::PgPool;
77
78use crate::from_row::SqlRequest;
79
80pub struct PostgresStorage<T, C = JsonCodec<serde_json::Value>>
83where
84 C: Codec,
85{
86 pool: PgPool,
87 job_type: PhantomData<T>,
88 codec: PhantomData<C>,
89 config: Config,
90 controller: Controller,
91 ack_notify: Notify<(SqlContext, Response<C::Compact>)>,
92 subscription: Option<PgSubscription>,
93}
94
95impl<T, C: Codec> Clone for PostgresStorage<T, C> {
96 fn clone(&self) -> Self {
97 PostgresStorage {
98 pool: self.pool.clone(),
99 job_type: PhantomData,
100 codec: PhantomData,
101 config: self.config.clone(),
102 controller: self.controller.clone(),
103 ack_notify: self.ack_notify.clone(),
104 subscription: self.subscription.clone(),
105 }
106 }
107}
108
109impl<T, C: Codec> fmt::Debug for PostgresStorage<T, C> {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 f.debug_struct("PostgresStorage")
112 .field("pool", &self.pool)
113 .field("job_type", &"PhantomData<T>")
114 .field("controller", &self.controller)
115 .field("config", &self.config)
116 .field("codec", &std::any::type_name::<C>())
117 .finish()
119 }
120}
121
122#[derive(thiserror::Error, Debug)]
124pub enum PgPollError {
125 #[error("Encountered an error during ACK: `{0}`")]
127 AckError(sqlx::Error),
128
129 #[error("Encountered an error during FetchNext: `{0}`")]
131 FetchNextError(apalis_core::error::Error),
132
133 #[error("Encountered an error during listening to PgNotification: {0}")]
135 PgNotificationError(apalis_core::error::Error),
136
137 #[error("Encountered an error during KeepAlive heartbeat: `{0}`")]
139 KeepAliveError(sqlx::Error),
140
141 #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")]
143 ReenqueueOrphanedError(sqlx::Error),
144
145 #[error("Encountered an error during encoding the result: {0}")]
147 CodecError(BoxDynError),
148}
149
150impl<T, C> Backend<Request<T, SqlContext>> for PostgresStorage<T, C>
151where
152 T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
153 C: Codec<Compact = Value> + Send + 'static,
154 C::Error: std::error::Error + 'static + Send + Sync,
155{
156 type Stream = BackendStream<RequestStream<Request<T, SqlContext>>>;
157
158 type Layer = AckLayer<PostgresStorage<T, C>, T, SqlContext, C>;
159
160 type Codec = C;
161
162 fn poll(mut self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
163 let layer = AckLayer::new(self.clone());
164 let subscription = self.subscription.clone();
165 let config = self.config.clone();
166 let controller = self.controller.clone();
167 let (mut tx, rx) = mpsc::channel(self.config.buffer_size);
168 let ack_notify = self.ack_notify.clone();
169 let pool = self.pool.clone();
170 let worker = worker.clone();
171 let heartbeat = async move {
172 if let Err(e) = self
174 .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
175 .await
176 {
177 worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(
178 e,
179 ))));
180 }
181
182 let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse();
183 let mut reenqueue_orphaned_stm =
184 apalis_core::interval::interval(config.poll_interval).fuse();
185
186 let mut ack_stream = ack_notify.clone().ready_chunks(config.buffer_size).fuse();
187
188 let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse();
189
190 let mut pg_notification = subscription
191 .map(|stm| stm.notify.boxed().fuse())
192 .unwrap_or(stream::iter(vec![]).boxed().fuse());
193
194 async fn fetch_next_batch<
195 T: Unpin + DeserializeOwned + Send + 'static,
196 C: Codec<Compact = Value>,
197 >(
198 storage: &mut PostgresStorage<T, C>,
199 worker: &WorkerId,
200 tx: &mut mpsc::Sender<Result<Option<Request<T, SqlContext>>, Error>>,
201 ) -> Result<(), Error> {
202 let res = storage
203 .fetch_next(worker)
204 .await
205 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
206 for job in res {
207 tx.send(Ok(Some(job)))
208 .await
209 .map_err(|e| Error::SourceError(Arc::new(Box::new(e))))?;
210 }
211 Ok(())
212 }
213
214 if let Err(e) = self
215 .keep_alive_at::<Self::Layer>(worker.id(), Utc::now().timestamp())
216 .await
217 {
218 worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e))));
219 }
220
221 loop {
222 select! {
223 _ = keep_alive_stm.next() => {
224 if let Err(e) = self.keep_alive_at::<Self::Layer>(worker.id(), Utc::now().timestamp()).await {
225 worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e))));
226 }
227 }
228 ids = ack_stream.next() => {
229
230 if let Some(ids) = ids {
231 let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(ctx, res)| {
232 (res.task_id.to_string(), worker.id().to_string(), serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())).expect("Could not convert response to json"), calculate_status(ctx,res).to_string(), res.attempt.current() as u64)
233 }).collect();
234 let query =
235 "UPDATE apalis.jobs
236 SET status = Q.status,
237 done_at = now(),
238 lock_by = Q.worker_id,
239 last_error = Q.result,
240 attempts = Q.attempts
241 FROM (
242 SELECT (value->>0)::text as id,
243 (value->>1)::text as worker_id,
244 (value->>2)::text as result,
245 (value->>3)::text as status,
246 (value->>4)::int as attempts
247 FROM json_array_elements($1::json)
248 ) Q
249 WHERE apalis.jobs.id = Q.id;
250 ";
251 let codec_res = C::encode(&ack_ids);
252 match codec_res {
253 Ok(val) => {
254 if let Err(e) = sqlx::query(query)
255 .bind(val)
256 .execute(&pool)
257 .await
258 {
259 worker.emit(Event::Error(Box::new(PgPollError::AckError(e))));
260 }
261 }
262 Err(e) => {
263 worker.emit(Event::Error(Box::new(PgPollError::CodecError(e.into()))));
264 }
265 }
266
267 }
268 }
269 _ = poll_next_stm.next() => {
270 if worker.is_ready() {
271 if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await {
272 worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e))));
273 }
274 }
275 }
276 _ = pg_notification.next() => {
277 if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await {
278 worker.emit(Event::Error(Box::new(PgPollError::PgNotificationError(e))));
279
280 }
281 }
282 _ = reenqueue_orphaned_stm.next() => {
283 let dead_since = Utc::now()
284 - chrono::Duration::from_std(config.reenqueue_orphaned_after).expect("could not build dead_since");
285 if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await {
286 worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(e))));
287 }
288 }
289
290
291 };
292 }
293 };
294 Poller::new_with_layer(BackendStream::new(rx.boxed(), controller), heartbeat, layer)
295 }
296}
297
298impl PostgresStorage<()> {
299 #[cfg(feature = "migrate")]
301 pub fn migrations() -> sqlx::migrate::Migrator {
302 sqlx::migrate!("migrations/postgres")
303 }
304
305 #[cfg(feature = "migrate")]
307 pub async fn setup(pool: &Pool<Postgres>) -> Result<(), sqlx::Error> {
308 Self::migrations().run(pool).await?;
309 Ok(())
310 }
311}
312
313impl<T> PostgresStorage<T> {
314 pub fn new(pool: PgPool) -> Self {
316 Self::new_with_config(pool, Config::new(type_name::<T>()))
317 }
318 pub fn new_with_config(pool: PgPool, config: Config) -> Self {
320 Self {
321 pool,
322 job_type: PhantomData,
323 codec: PhantomData,
324 config,
325 controller: Controller::new(),
326 ack_notify: Notify::new(),
327 subscription: None,
328 }
329 }
330
331 pub fn pool(&self) -> &Pool<Postgres> {
333 &self.pool
334 }
335
336 pub fn config(&self) -> &Config {
338 &self.config
339 }
340}
341
342impl<T, C: Codec> PostgresStorage<T, C> {
343 pub fn codec(&self) -> &PhantomData<C> {
345 &self.codec
346 }
347
348 async fn keep_alive_at<Service>(
349 &mut self,
350 worker_id: &WorkerId,
351 last_seen: Timestamp,
352 ) -> Result<(), sqlx::Error> {
353 let last_seen = DateTime::from_timestamp(last_seen, 0).ok_or(sqlx::Error::Io(
354 io::Error::new(io::ErrorKind::InvalidInput, "Invalid Timestamp"),
355 ))?;
356 let worker_type = self.config.namespace.clone();
357 let storage_name = std::any::type_name::<Self>();
358 let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
359 VALUES ($1, $2, $3, $4, $5)
360 ON CONFLICT (id) DO
361 UPDATE SET last_seen = EXCLUDED.last_seen";
362 sqlx::query(query)
363 .bind(worker_id.to_string())
364 .bind(worker_type)
365 .bind(storage_name)
366 .bind(std::any::type_name::<Service>())
367 .bind(last_seen)
368 .execute(&self.pool)
369 .await?;
370 Ok(())
371 }
372}
373
374#[derive(Debug)]
376pub struct PgListen {
377 listener: PgListener,
378 subscriptions: Vec<(String, PgSubscription)>,
379}
380
381#[derive(Debug, Clone)]
383pub struct PgSubscription {
384 notify: Notify<()>,
385}
386
387impl PgListen {
388 pub async fn new(pool: PgPool) -> Result<Self, sqlx::Error> {
392 let listener = PgListener::connect_with(&pool).await?;
393 Ok(Self {
394 listener,
395 subscriptions: Vec::new(),
396 })
397 }
398
399 pub fn subscribe_with<T>(&mut self, storage: &mut PostgresStorage<T>) {
401 let sub = PgSubscription {
402 notify: Notify::new(),
403 };
404 self.subscriptions
405 .push((storage.config.namespace.to_owned(), sub.clone()));
406 storage.subscription = Some(sub)
407 }
408
409 pub fn subscribe(&mut self, namespace: &str) -> PgSubscription {
411 let sub = PgSubscription {
412 notify: Notify::new(),
413 };
414 self.subscriptions.push((namespace.to_owned(), sub.clone()));
415 sub
416 }
417 pub async fn listen(mut self) -> Result<(), sqlx::Error> {
419 self.listener.listen("apalis::job").await?;
420 let mut notification = self.listener.into_stream();
421 while let Some(Ok(res)) = notification.next().await {
422 let _: Vec<_> = self
423 .subscriptions
424 .iter()
425 .filter(|s| s.0 == res.payload())
426 .map(|s| s.1.notify.notify(()))
427 .collect();
428 }
429 Ok(())
430 }
431}
432
433impl<T, C> PostgresStorage<T, C>
434where
435 T: DeserializeOwned + Send + Unpin + 'static,
436 C: Codec<Compact = Value>,
437{
438 async fn fetch_next(
439 &mut self,
440 worker_id: &WorkerId,
441 ) -> Result<Vec<Request<T, SqlContext>>, sqlx::Error> {
442 let config = &self.config;
443 let job_type = &config.namespace;
444 let fetch_query = "Select * from apalis.get_jobs($1, $2, $3);";
445 let jobs: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
446 .bind(worker_id.to_string())
447 .bind(job_type)
448 .bind(
450 i32::try_from(config.buffer_size)
451 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
452 )
453 .fetch_all(&self.pool)
454 .await?;
455 let jobs: Vec<_> = jobs
456 .into_iter()
457 .map(|job| {
458 let (req, parts) = job.req.take_parts();
459 let req = C::decode(req)
460 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))
461 .expect("Unable to decode");
462 let mut req = Request::new_with_parts(req, parts);
463 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
464 req
465 })
466 .collect();
467 Ok(jobs)
468 }
469}
470
471impl<Req, C> Storage for PostgresStorage<Req, C>
472where
473 Req: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
474 C: Codec<Compact = Value> + Send + 'static,
475 C::Error: Send + std::error::Error + Sync + 'static,
476{
477 type Job = Req;
478
479 type Error = sqlx::Error;
480
481 type Context = SqlContext;
482
483 type Compact = Value;
484
485 async fn push_request(
493 &mut self,
494 req: Request<Self::Job, SqlContext>,
495 ) -> Result<Parts<SqlContext>, sqlx::Error> {
496 let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, NOW() , NULL, NULL, NULL, NULL, $5)";
497
498 let args = C::encode(&req.args)
499 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
500 let job_type = self.config.namespace.clone();
501 sqlx::query(query)
502 .bind(args)
503 .bind(req.parts.task_id.to_string())
504 .bind(&job_type)
505 .bind(req.parts.context.max_attempts())
506 .bind(req.parts.context.priority())
507 .execute(&self.pool)
508 .await?;
509 Ok(req.parts)
510 }
511
512 async fn push_raw_request(
513 &mut self,
514 req: Request<Self::Compact, SqlContext>,
515 ) -> Result<Parts<SqlContext>, sqlx::Error> {
516 let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, NOW() , NULL, NULL, NULL, NULL, $5)";
517
518 let args = C::encode(&req.args)
519 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
520 let job_type = self.config.namespace.clone();
521 sqlx::query(query)
522 .bind(args)
523 .bind(req.parts.task_id.to_string())
524 .bind(&job_type)
525 .bind(req.parts.context.max_attempts())
526 .bind(req.parts.context.priority())
527 .execute(&self.pool)
528 .await?;
529 Ok(req.parts)
530 }
531
532 async fn schedule_request(
533 &mut self,
534 req: Request<Self::Job, SqlContext>,
535 on: Timestamp,
536 ) -> Result<Parts<Self::Context>, sqlx::Error> {
537 let query =
538 "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, $4, $5, NULL, NULL, NULL, NULL, $6)";
539 let task_id = req.parts.task_id.to_string();
540 let parts = req.parts;
541 let on = DateTime::from_timestamp(on, 0);
542 let job = C::encode(&req.args)
543 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidInput, e)))?;
544 let job_type = self.config.namespace.clone();
545 sqlx::query(query)
546 .bind(job)
547 .bind(task_id)
548 .bind(job_type)
549 .bind(parts.context.max_attempts())
550 .bind(on)
551 .bind(parts.context.priority())
552 .execute(&self.pool)
553 .await?;
554 Ok(parts)
555 }
556
557 async fn fetch_by_id(
558 &mut self,
559 job_id: &TaskId,
560 ) -> Result<Option<Request<Self::Job, SqlContext>>, sqlx::Error> {
561 let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1 LIMIT 1";
562 let res: Option<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
563 .bind(job_id.to_string())
564 .fetch_optional(&self.pool)
565 .await?;
566
567 match res {
568 None => Ok(None),
569 Some(job) => Ok(Some({
570 let (req, parts) = job.req.take_parts();
571 let args = C::decode(req)
572 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
573
574 let mut req: Request<Req, SqlContext> = Request::new_with_parts(args, parts);
575 req.parts.namespace = Some(Namespace(self.config.namespace.clone()));
576 req
577 })),
578 }
579 }
580
581 async fn len(&mut self) -> Result<i64, sqlx::Error> {
582 let query = "Select Count(*) as count from apalis.jobs where status='Pending' OR (status = 'Failed' AND attempts < max_attempts)";
583 let record = sqlx::query(query).fetch_one(&self.pool).await?;
584 record.try_get("count")
585 }
586
587 async fn reschedule(
588 &mut self,
589 job: Request<Req, SqlContext>,
590 wait: Duration,
591 ) -> Result<(), sqlx::Error> {
592 let job_id = job.parts.task_id;
593 let on = Utc::now() + wait;
594 let mut tx = self.pool.acquire().await?;
595 let query =
596 "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
597
598 sqlx::query(query)
599 .bind(job_id.to_string())
600 .bind(on)
601 .execute(&mut *tx)
602 .await?;
603 Ok(())
604 }
605
606 async fn update(&mut self, job: Request<Self::Job, SqlContext>) -> Result<(), sqlx::Error> {
607 let ctx = job.parts.context;
608 let job_id = job.parts.task_id;
609 let status = ctx.status().to_string();
610 let attempts: i32 = job
611 .parts
612 .attempt
613 .current()
614 .try_into()
615 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?;
616 let done_at = *ctx.done_at();
617 let lock_by = ctx.lock_by().clone();
618 let lock_at = *ctx.lock_at();
619 let last_error = ctx.last_error().clone();
620 let priority = *ctx.priority();
621
622 let mut tx = self.pool.acquire().await?;
623 let query =
624 "UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = to_timestamp($3), lock_by = $4, lock_at = to_timestamp($5), last_error = $6, priority = $7 WHERE id = $8";
625 sqlx::query(query)
626 .bind(status.to_owned())
627 .bind(attempts)
628 .bind(done_at)
629 .bind(lock_by.map(|w| w.name().to_string()))
630 .bind(lock_at)
631 .bind(last_error)
632 .bind(priority)
633 .bind(job_id.to_string())
634 .execute(&mut *tx)
635 .await?;
636 Ok(())
637 }
638
639 async fn is_empty(&mut self) -> Result<bool, sqlx::Error> {
640 Ok(self.len().await? == 0)
641 }
642
643 async fn vacuum(&mut self) -> Result<usize, sqlx::Error> {
644 let query = "Delete from apalis.jobs where status='Done'";
645 let record = sqlx::query(query).execute(&self.pool).await?;
646 Ok(record.rows_affected().try_into().unwrap_or_default())
647 }
648}
649
650impl<T, Res, C> Ack<T, Res, C> for PostgresStorage<T, C>
651where
652 T: Sync + Send,
653 Res: Serialize + Sync + Clone,
654 C: Codec<Compact = Value> + Send,
655{
656 type Context = SqlContext;
657 type AckError = sqlx::Error;
658 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), sqlx::Error> {
659 let res = res.clone().map(|r| {
660 C::encode(r)
661 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))
662 .expect("Could not encode result")
663 });
664
665 self.ack_notify
666 .notify((ctx.clone(), res))
667 .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e)))?;
668
669 Ok(())
670 }
671}
672
673impl<T, C: Codec> PostgresStorage<T, C> {
674 pub async fn kill(
676 &mut self,
677 worker_id: &WorkerId,
678 task_id: &TaskId,
679 ) -> Result<(), sqlx::Error> {
680 let mut tx = self.pool.acquire().await?;
681 let query =
682 "UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
683 sqlx::query(query)
684 .bind(task_id.to_string())
685 .bind(worker_id.to_string())
686 .execute(&mut *tx)
687 .await?;
688 Ok(())
689 }
690
691 pub async fn retry(
694 &mut self,
695 worker_id: &WorkerId,
696 task_id: &TaskId,
697 ) -> Result<(), sqlx::Error> {
698 let mut tx = self.pool.acquire().await?;
699 let query =
700 "UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
701 sqlx::query(query)
702 .bind(task_id.to_string())
703 .bind(worker_id.to_string())
704 .execute(&mut *tx)
705 .await?;
706 Ok(())
707 }
708
709 pub async fn reenqueue_orphaned(
711 &mut self,
712 count: i32,
713 dead_since: DateTime<Utc>,
714 ) -> Result<(), sqlx::Error> {
715 let job_type = self.config.namespace.clone();
716 let mut tx = self.pool.acquire().await?;
717 let query = "UPDATE apalis.jobs
718 SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error = 'Job was abandoned'
719 WHERE id IN
720 (SELECT jobs.id FROM apalis.jobs INNER JOIN apalis.workers ON lock_by = workers.id
721 WHERE status = 'Running'
722 AND workers.last_seen < ($3::timestamp)
723 AND workers.worker_type = $1
724 ORDER BY lock_at ASC
725 LIMIT $2);";
726
727 sqlx::query(query)
728 .bind(job_type)
729 .bind(count)
730 .bind(dead_since)
731 .execute(&mut *tx)
732 .await?;
733 Ok(())
734 }
735}
736
737impl<J: 'static + Serialize + DeserializeOwned + Unpin + Send + Sync> BackendExpose<J>
738 for PostgresStorage<J>
739{
740 type Request = Request<J, Parts<SqlContext>>;
741 type Error = SqlError;
742 async fn stats(&self) -> Result<Stat, Self::Error> {
743 let fetch_query = "SELECT
744 COUNT(1) FILTER (WHERE status = 'Pending') AS pending,
745 COUNT(1) FILTER (WHERE status = 'Running') AS running,
746 COUNT(1) FILTER (WHERE status = 'Done') AS done,
747 COUNT(1) FILTER (WHERE status = 'Retry') AS retry,
748 COUNT(1) FILTER (WHERE status = 'Failed') AS failed,
749 COUNT(1) FILTER (WHERE status = 'Killed') AS killed
750 FROM apalis.jobs WHERE job_type = $1";
751
752 let res: (i64, i64, i64, i64, i64, i64) = sqlx::query_as(fetch_query)
753 .bind(self.config().namespace())
754 .fetch_one(self.pool())
755 .await?;
756
757 Ok(Stat {
758 pending: res.0.try_into()?,
759 running: res.1.try_into()?,
760 dead: res.4.try_into()?,
761 failed: res.3.try_into()?,
762 success: res.2.try_into()?,
763 })
764 }
765
766 async fn list_jobs(
767 &self,
768 status: &State,
769 page: i32,
770 ) -> Result<Vec<Self::Request>, Self::Error> {
771 let status = status.to_string();
772 let fetch_query = "SELECT * FROM apalis.jobs WHERE status = $1 AND job_type = $2 ORDER BY done_at DESC, run_at DESC LIMIT 10 OFFSET $3";
773 let res: Vec<SqlRequest<serde_json::Value>> = sqlx::query_as(fetch_query)
774 .bind(status)
775 .bind(self.config().namespace())
776 .bind(((page - 1) * 10) as i64)
777 .fetch_all(self.pool())
778 .await?;
779 Ok(res
780 .into_iter()
781 .map(|j| {
782 let (req, ctx) = j.req.take_parts();
783 let req = JsonCodec::<Value>::decode(req).unwrap();
784 Request::new_with_ctx(req, ctx)
785 })
786 .collect())
787 }
788
789 async fn list_workers(&self) -> Result<Vec<Worker<WorkerState>>, Self::Error> {
790 let fetch_query =
791 "SELECT id, layers, cast(extract(epoch from last_seen) as bigint) FROM apalis.workers WHERE worker_type = $1 ORDER BY last_seen DESC LIMIT 20 OFFSET $2";
792 let res: Vec<(String, String, i64)> = sqlx::query_as(fetch_query)
793 .bind(self.config().namespace())
794 .bind(0)
795 .fetch_all(self.pool())
796 .await?;
797 Ok(res
798 .into_iter()
799 .map(|w| Worker::new(WorkerId::new(w.0), WorkerState::new::<Self>(w.1)))
800 .collect())
801 }
802}
803
804#[cfg(test)]
805mod tests {
806
807 use crate::sql_storage_tests;
808
809 use super::*;
810 use apalis_core::test_utils::DummyService;
811 use chrono::Utc;
812 use email_service::Email;
813
814 use apalis_core::generic_storage_test;
815 use apalis_core::test_utils::apalis_test_service_fn;
816 use apalis_core::test_utils::TestWrapper;
817
818 generic_storage_test!(setup);
819
820 sql_storage_tests!(setup::<Email>, PostgresStorage<Email>, Email);
821
822 async fn setup<T: Serialize + DeserializeOwned>() -> PostgresStorage<T> {
824 let db_url = &std::env::var("DATABASE_URL").expect("No DATABASE_URL is specified");
825 let pool = PgPool::connect(&db_url).await.unwrap();
826 PostgresStorage::setup(&pool).await.unwrap();
830 let config = Config::new("apalis-tests").set_buffer_size(1);
831 let mut storage = PostgresStorage::new_with_config(pool, config);
832 cleanup(&mut storage, &WorkerId::new("test-worker")).await;
833 storage
834 }
835
836 async fn cleanup<T>(storage: &mut PostgresStorage<T>, worker_id: &WorkerId) {
843 let mut tx = storage
844 .pool
845 .acquire()
846 .await
847 .expect("failed to get connection");
848 sqlx::query("Delete from apalis.jobs where job_type = $1 OR lock_by = $2")
849 .bind(storage.config.namespace())
850 .bind(worker_id.to_string())
851 .execute(&mut *tx)
852 .await
853 .expect("failed to delete jobs");
854 sqlx::query("Delete from apalis.workers where id = $1")
855 .bind(worker_id.to_string())
856 .execute(&mut *tx)
857 .await
858 .expect("failed to delete worker");
859 }
860
861 fn example_email() -> Email {
862 Email {
863 subject: "Test Subject".to_string(),
864 to: "example@postgres".to_string(),
865 text: "Some Text".to_string(),
866 }
867 }
868
869 async fn consume_one(
870 storage: &mut PostgresStorage<Email>,
871 worker_id: &WorkerId,
872 ) -> Request<Email, SqlContext> {
873 let req = storage.fetch_next(worker_id).await;
874 req.unwrap()[0].clone()
875 }
876
877 async fn register_worker_at(
878 storage: &mut PostgresStorage<Email>,
879 last_seen: Timestamp,
880 ) -> Worker<Context> {
881 let worker_id = WorkerId::new("test-worker");
882
883 storage
884 .keep_alive_at::<DummyService>(&worker_id, last_seen)
885 .await
886 .expect("failed to register worker");
887 let wrk = Worker::new(worker_id, Context::default());
888 wrk.start();
889 wrk
890 }
891
892 async fn register_worker(storage: &mut PostgresStorage<Email>) -> Worker<Context> {
893 register_worker_at(storage, Utc::now().timestamp()).await
894 }
895
896 async fn push_email(storage: &mut PostgresStorage<Email>, email: Email) -> TaskId {
897 storage
898 .push(email)
899 .await
900 .expect("failed to push a job")
901 .task_id
902 }
903
904 async fn get_job(
905 storage: &mut PostgresStorage<Email>,
906 job_id: &TaskId,
907 ) -> Request<Email, SqlContext> {
908 apalis_core::sleep(Duration::from_secs(2)).await;
910 storage
911 .fetch_by_id(job_id)
912 .await
913 .expect("failed to fetch job by id")
914 .expect("no job found by id")
915 }
916
917 #[tokio::test]
918 async fn test_consume_last_pushed_job() {
919 let mut storage = setup().await;
920 push_email(&mut storage, example_email()).await;
921
922 let worker = register_worker(&mut storage).await;
923
924 let job = consume_one(&mut storage, &worker.id()).await;
925 let job_id = &job.parts.task_id;
926
927 let job = get_job(&mut storage, job_id).await;
929 let ctx = job.parts.context;
930 assert_eq!(*ctx.status(), State::Running);
931 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
932 assert!(ctx.lock_at().is_some());
933 }
934
935 #[tokio::test]
936 async fn test_kill_job() {
937 let mut storage = setup().await;
938
939 push_email(&mut storage, example_email()).await;
940
941 let worker = register_worker(&mut storage).await;
942
943 let job = consume_one(&mut storage, &worker.id()).await;
944 let job_id = &job.parts.task_id;
945
946 storage
947 .kill(&worker.id(), job_id)
948 .await
949 .expect("failed to kill job");
950
951 let job = get_job(&mut storage, job_id).await;
952 let ctx = job.parts.context;
953 assert_eq!(*ctx.status(), State::Killed);
954 assert!(ctx.done_at().is_some());
955 }
956
957 #[tokio::test]
958 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_6min() {
959 let mut storage = setup().await;
960
961 push_email(&mut storage, example_email()).await;
962 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
963 let five_minutes_ago = Utc::now() - Duration::from_secs(5 * 60);
964
965 let worker = register_worker_at(&mut storage, six_minutes_ago.timestamp()).await;
966
967 let job = consume_one(&mut storage, &worker.id()).await;
968 storage
969 .reenqueue_orphaned(1, five_minutes_ago)
970 .await
971 .expect("failed to heartbeat");
972 let job_id = &job.parts.task_id;
973 let job = get_job(&mut storage, job_id).await;
974 let ctx = job.parts.context;
975
976 assert_eq!(*ctx.status(), State::Pending);
977 assert!(ctx.done_at().is_none());
978 assert!(ctx.lock_by().is_none());
979 assert!(ctx.lock_at().is_none());
980 assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
981 assert_eq!(job.parts.attempt.current(), 0); }
983
984 #[tokio::test]
985 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_4min() {
986 let mut storage = setup().await;
987
988 push_email(&mut storage, example_email()).await;
989
990 let four_minutes_ago = Utc::now() - Duration::from_secs(4 * 60);
991 let six_minutes_ago = Utc::now() - Duration::from_secs(6 * 60);
992
993 let worker = register_worker_at(&mut storage, four_minutes_ago.timestamp()).await;
994
995 let job = consume_one(&mut storage, &worker.id()).await;
996 let ctx = &job.parts.context;
997
998 assert_eq!(*ctx.status(), State::Running);
999 storage
1000 .reenqueue_orphaned(1, six_minutes_ago)
1001 .await
1002 .expect("failed to heartbeat");
1003
1004 let job_id = &job.parts.task_id;
1005 let job = get_job(&mut storage, job_id).await;
1006 let ctx = job.parts.context;
1007 assert_eq!(*ctx.status(), State::Running);
1008 assert_eq!(*ctx.lock_by(), Some(worker.id().clone()));
1009 assert!(ctx.lock_at().is_some());
1010 assert_eq!(*ctx.last_error(), None);
1011 assert_eq!(job.parts.attempt.current(), 0);
1012 }
1013
1014 #[tokio::test]
1017 async fn test_scheduled_request_not_fetched() {
1018 let mut storage = setup().await;
1020
1021 let run_at = Utc::now().timestamp() + 300; let scheduled_req = Request::new(example_email());
1024
1025 storage
1026 .schedule_request(scheduled_req, run_at)
1027 .await
1028 .expect("failed to schedule request");
1029
1030 let worker = register_worker(&mut storage).await;
1032 let jobs = storage
1033 .fetch_next(worker.id())
1034 .await
1035 .expect("failed to fetch next jobs");
1036 assert!(
1037 jobs.is_empty(),
1038 "Scheduled job should not be fetched before its scheduled time"
1039 );
1040
1041 let jobs = storage
1043 .list_jobs(&State::Pending, 1)
1044 .await
1045 .expect("failed to list jobs");
1046 assert_eq!(jobs.len(), 1, "Expected one job to be listed");
1047 }
1048
1049 #[tokio::test]
1052 async fn test_fetch_with_different_job_type_returns_empty() {
1053 let mut storage_email = setup().await;
1055
1056 let pool = storage_email.pool().clone();
1058 let sms_config = Config::new("sms-test").set_buffer_size(1);
1059 let mut storage_sms: PostgresStorage<Email> =
1060 PostgresStorage::new_with_config(pool, sms_config);
1061
1062 push_email(&mut storage_email, example_email()).await;
1064
1065 let worker_id = WorkerId::new("sms-worker");
1067 let worker = Worker::new(worker_id, Context::default());
1068 worker.start();
1069
1070 let jobs = storage_sms
1071 .fetch_next(worker.id())
1072 .await
1073 .expect("failed to fetch next jobs");
1074 assert!(
1075 jobs.is_empty(),
1076 "A worker with a different job_type should not fetch jobs"
1077 );
1078
1079 let worker = register_worker(&mut storage_email).await;
1081 let jobs = storage_email
1082 .fetch_next(worker.id())
1083 .await
1084 .expect("failed to fetch next jobs");
1085 assert!(!jobs.is_empty(), "Worker should fetch the job");
1086 }
1087}