1use apalis_core::codec::json::JsonCodec;
2use apalis_core::error::{BoxDynError, Error};
3use apalis_core::layers::{Ack, AckLayer};
4use apalis_core::poller::controller::Controller;
5use apalis_core::poller::stream::BackendStream;
6use apalis_core::poller::Poller;
7use apalis_core::request::{Parts, Request, RequestStream};
8use apalis_core::response::Response;
9use apalis_core::service_fn::FromRequest;
10use apalis_core::storage::Storage;
11use apalis_core::task::namespace::Namespace;
12use apalis_core::task::task_id::TaskId;
13use apalis_core::worker::{Event, Worker, WorkerId};
14use apalis_core::{backend::Backend, codec::Codec};
15use chrono::{DateTime, Utc};
16use futures::channel::mpsc::{self, SendError, Sender};
17use futures::{select, FutureExt, SinkExt, StreamExt, TryFutureExt};
18use log::*;
19use redis::aio::ConnectionLike;
20use redis::ErrorKind;
21use redis::{aio::ConnectionManager, Client, IntoConnectionInfo, RedisError, Script, Value};
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use std::any::type_name;
24use std::fmt::{self, Debug};
25use std::io;
26use std::num::TryFromIntError;
27use std::time::SystemTime;
28use std::{marker::PhantomData, time::Duration};
29
30pub async fn connect<S: IntoConnectionInfo>(redis: S) -> Result<ConnectionManager, RedisError> {
32 let client = Client::open(redis.into_connection_info()?)?;
33 let conn = client.get_connection_manager().await?;
34 Ok(conn)
35}
36
37const ACTIVE_JOBS_LIST: &str = "{queue}:active";
38const CONSUMERS_SET: &str = "{queue}:consumers";
39const DEAD_JOBS_SET: &str = "{queue}:dead";
40const DONE_JOBS_SET: &str = "{queue}:done";
41const FAILED_JOBS_SET: &str = "{queue}:failed";
42const INFLIGHT_JOB_SET: &str = "{queue}:inflight";
43const JOB_DATA_HASH: &str = "{queue}:data";
44const SCHEDULED_JOBS_SET: &str = "{queue}:scheduled";
45const SIGNAL_LIST: &str = "{queue}:signal";
46
47#[derive(Clone, Debug)]
51pub struct RedisQueueInfo {
52 pub active_jobs_list: String,
54
55 pub consumers_set: String,
57
58 pub dead_jobs_set: String,
60
61 pub done_jobs_set: String,
63
64 pub failed_jobs_set: String,
66
67 pub inflight_jobs_set: String,
69
70 pub job_data_hash: String,
72
73 pub scheduled_jobs_set: String,
75
76 pub signal_list: String,
78}
79
80#[derive(Clone, Debug)]
81pub(crate) struct RedisScript {
82 done_job: Script,
83 enqueue_scheduled: Script,
84 get_jobs: Script,
85 kill_job: Script,
86 push_job: Script,
87 reenqueue_active: Script,
88 reenqueue_orphaned: Script,
89 register_consumer: Script,
90 retry_job: Script,
91 schedule_job: Script,
92 vacuum: Script,
93 pub(crate) stats: Script,
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct RedisContext {
99 max_attempts: usize,
100 lock_by: Option<WorkerId>,
101 run_at: Option<SystemTime>,
102}
103
104impl Default for RedisContext {
105 fn default() -> Self {
106 Self {
107 max_attempts: 5,
108 lock_by: None,
109 run_at: None,
110 }
111 }
112}
113
114impl<Req> FromRequest<Request<Req, RedisContext>> for RedisContext {
115 fn from_request(req: &Request<Req, RedisContext>) -> Result<Self, Error> {
116 Ok(req.parts.context.clone())
117 }
118}
119
120#[derive(thiserror::Error, Debug)]
122pub enum RedisPollError {
123 #[error("KeepAlive heartbeat encountered an error: `{0}`")]
125 KeepAliveError(RedisError),
126
127 #[error("EnqueueScheduled heartbeat encountered an error: `{0}`")]
129 EnqueueScheduledError(RedisError),
130
131 #[error("PollNext heartbeat encountered an error: `{0}`")]
133 PollNextError(RedisError),
134
135 #[error("Enqueue for worker consumption encountered an error: `{0}`")]
137 EnqueueError(SendError),
138
139 #[error("Ack heartbeat encountered an error: `{0}`")]
141 AckError(RedisError),
142
143 #[error("ReenqueueOrphaned heartbeat encountered an error: `{0}`")]
145 ReenqueueOrphanedError(RedisError),
146}
147
148#[derive(Clone, Debug)]
150pub struct Config {
151 poll_interval: Duration,
152 buffer_size: usize,
153 keep_alive: Duration,
154 enqueue_scheduled: Duration,
155 reenqueue_orphaned_after: Duration,
156 namespace: String,
157}
158
159impl Default for Config {
160 fn default() -> Self {
161 Self {
162 poll_interval: Duration::from_millis(100),
163 buffer_size: 10,
164 keep_alive: Duration::from_secs(30),
165 enqueue_scheduled: Duration::from_secs(30),
166 reenqueue_orphaned_after: Duration::from_secs(300),
167 namespace: String::from("apalis_redis"),
168 }
169 }
170}
171
172impl Config {
173 pub fn get_poll_interval(&self) -> &Duration {
175 &self.poll_interval
176 }
177
178 pub fn get_buffer_size(&self) -> usize {
180 self.buffer_size
181 }
182
183 pub fn get_keep_alive(&self) -> &Duration {
185 &self.keep_alive
186 }
187
188 pub fn get_enqueue_scheduled(&self) -> &Duration {
190 &self.enqueue_scheduled
191 }
192
193 pub fn get_namespace(&self) -> &String {
195 &self.namespace
196 }
197
198 pub fn set_poll_interval(mut self, poll_interval: Duration) -> Self {
200 self.poll_interval = poll_interval;
201 self
202 }
203
204 pub fn set_buffer_size(mut self, buffer_size: usize) -> Self {
206 self.buffer_size = buffer_size;
207 self
208 }
209
210 pub fn set_keep_alive(mut self, keep_alive: Duration) -> Self {
212 self.keep_alive = keep_alive;
213 self
214 }
215
216 pub fn set_enqueue_scheduled(mut self, enqueue_scheduled: Duration) -> Self {
218 self.enqueue_scheduled = enqueue_scheduled;
219 self
220 }
221
222 pub fn set_namespace(mut self, namespace: &str) -> Self {
224 self.namespace = namespace.to_string();
225 self
226 }
227
228 pub fn active_jobs_list(&self) -> String {
234 ACTIVE_JOBS_LIST.replace("{queue}", &self.namespace)
235 }
236
237 pub fn consumers_set(&self) -> String {
243 CONSUMERS_SET.replace("{queue}", &self.namespace)
244 }
245
246 pub fn dead_jobs_set(&self) -> String {
252 DEAD_JOBS_SET.replace("{queue}", &self.namespace)
253 }
254
255 pub fn done_jobs_set(&self) -> String {
261 DONE_JOBS_SET.replace("{queue}", &self.namespace)
262 }
263
264 pub fn failed_jobs_set(&self) -> String {
270 FAILED_JOBS_SET.replace("{queue}", &self.namespace)
271 }
272
273 pub fn inflight_jobs_set(&self) -> String {
279 INFLIGHT_JOB_SET.replace("{queue}", &self.namespace)
280 }
281
282 pub fn job_data_hash(&self) -> String {
288 JOB_DATA_HASH.replace("{queue}", &self.namespace)
289 }
290
291 pub fn scheduled_jobs_set(&self) -> String {
297 SCHEDULED_JOBS_SET.replace("{queue}", &self.namespace)
298 }
299
300 pub fn signal_list(&self) -> String {
306 SIGNAL_LIST.replace("{queue}", &self.namespace)
307 }
308
309 pub fn reenqueue_orphaned_after(&self) -> Duration {
311 self.reenqueue_orphaned_after
312 }
313
314 pub fn reenqueue_orphaned_after_mut(&mut self) -> &mut Duration {
316 &mut self.reenqueue_orphaned_after
317 }
318
319 pub fn set_reenqueue_orphaned_after(mut self, after: Duration) -> Self {
324 self.reenqueue_orphaned_after = after;
325 self
326 }
327}
328
329pub struct RedisStorage<T, Conn = ConnectionManager, C = JsonCodec<Vec<u8>>> {
331 conn: Conn,
332 job_type: PhantomData<T>,
333 pub(super) scripts: RedisScript,
334 controller: Controller,
335 config: Config,
336 codec: PhantomData<C>,
337}
338
339impl<T, Conn, C> fmt::Debug for RedisStorage<T, Conn, C> {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 f.debug_struct("RedisStorage")
342 .field("conn", &"ConnectionManager")
343 .field("job_type", &std::any::type_name::<T>())
344 .field("scripts", &self.scripts)
345 .field("config", &self.config)
346 .finish()
347 }
348}
349
350impl<T, Conn: Clone, C> Clone for RedisStorage<T, Conn, C> {
351 fn clone(&self) -> Self {
352 Self {
353 conn: self.conn.clone(),
354 job_type: PhantomData,
355 scripts: self.scripts.clone(),
356 controller: self.controller.clone(),
357 config: self.config.clone(),
358 codec: self.codec,
359 }
360 }
361}
362
363impl<T: Serialize + DeserializeOwned, Conn> RedisStorage<T, Conn, JsonCodec<Vec<u8>>> {
364 pub fn new(conn: Conn) -> RedisStorage<T, Conn, JsonCodec<Vec<u8>>> {
366 Self::new_with_codec::<JsonCodec<Vec<u8>>>(
367 conn,
368 Config::default().set_namespace(type_name::<T>()),
369 )
370 }
371
372 pub fn new_with_config(
374 conn: Conn,
375 config: Config,
376 ) -> RedisStorage<T, Conn, JsonCodec<Vec<u8>>> {
377 Self::new_with_codec::<JsonCodec<Vec<u8>>>(conn, config)
378 }
379
380 pub fn new_with_codec<K>(conn: Conn, config: Config) -> RedisStorage<T, Conn, K>
382 where
383 K: Codec + Sync + Send + 'static,
384 {
385 RedisStorage {
386 conn,
387 job_type: PhantomData,
388 controller: Controller::new(),
389 config,
390 codec: PhantomData::<K>,
391 scripts: RedisScript {
392 done_job: redis::Script::new(include_str!("../lua/done_job.lua")),
393 push_job: redis::Script::new(include_str!("../lua/push_job.lua")),
394 retry_job: redis::Script::new(include_str!("../lua/retry_job.lua")),
395 enqueue_scheduled: redis::Script::new(include_str!(
396 "../lua/enqueue_scheduled_jobs.lua"
397 )),
398 get_jobs: redis::Script::new(include_str!("../lua/get_jobs.lua")),
399 register_consumer: redis::Script::new(include_str!("../lua/register_consumer.lua")),
400 kill_job: redis::Script::new(include_str!("../lua/kill_job.lua")),
401 reenqueue_active: redis::Script::new(include_str!(
402 "../lua/reenqueue_active_jobs.lua"
403 )),
404 reenqueue_orphaned: redis::Script::new(include_str!(
405 "../lua/reenqueue_orphaned_jobs.lua"
406 )),
407 schedule_job: redis::Script::new(include_str!("../lua/schedule_job.lua")),
408 vacuum: redis::Script::new(include_str!("../lua/vacuum.lua")),
409 stats: redis::Script::new(include_str!("../lua/stats.lua")),
410 },
411 }
412 }
413
414 pub fn get_connection(&self) -> &Conn {
416 &self.conn
417 }
418
419 pub fn get_config(&self) -> &Config {
421 &self.config
422 }
423}
424
425impl<T, Conn, C> RedisStorage<T, Conn, C> {
426 pub fn get_codec(&self) -> &PhantomData<C> {
428 &self.codec
429 }
430}
431
432impl<T, Conn, C> Backend<Request<T, RedisContext>> for RedisStorage<T, Conn, C>
433where
434 T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
435 Conn: ConnectionLike + Send + Sync + 'static,
436 C: Codec<Compact = Vec<u8>> + Send + 'static,
437{
438 type Stream = BackendStream<RequestStream<Request<T, RedisContext>>>;
439
440 type Layer = AckLayer<Sender<(RedisContext, Response<Vec<u8>>)>, T, RedisContext, C>;
441
442 type Codec = C;
443
444 fn poll(
445 mut self,
446 worker: &Worker<apalis_core::worker::Context>,
447 ) -> Poller<Self::Stream, Self::Layer> {
448 let (mut tx, rx) = mpsc::channel(self.config.buffer_size);
449 let (ack, ack_rx) =
450 mpsc::channel::<(RedisContext, Response<Vec<u8>>)>(self.config.buffer_size);
451 let layer = AckLayer::new(ack);
452 let controller = self.controller.clone();
453 let config = self.config.clone();
454 let stream: RequestStream<Request<T, RedisContext>> = Box::pin(rx);
455 let worker = worker.clone();
456 let heartbeat = async move {
457 if let Err(e) = self
459 .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
460 .await
461 {
462 worker.emit(Event::Error(Box::new(
463 RedisPollError::ReenqueueOrphanedError(e),
464 )));
465 }
466
467 let mut reenqueue_orphaned_stm =
468 apalis_core::interval::interval(config.poll_interval).fuse();
469
470 let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse();
471
472 let mut enqueue_scheduled_stm =
473 apalis_core::interval::interval(config.enqueue_scheduled).fuse();
474
475 let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse();
476
477 let mut ack_stream = ack_rx.fuse();
478
479 if let Err(e) = self.keep_alive(worker.id()).await {
480 worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e))));
481 }
482
483 loop {
484 select! {
485 _ = keep_alive_stm.next() => {
486 if let Err(e) = self.keep_alive(worker.id()).await {
487 worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e))));
488 }
489 }
490 _ = enqueue_scheduled_stm.next() => {
491 if let Err(e) = self.enqueue_scheduled(config.buffer_size).await {
492 worker.emit(Event::Error(Box::new(RedisPollError::EnqueueScheduledError(e))));
493 }
494 }
495 _ = poll_next_stm.next() => {
496 if worker.is_ready() {
497 let res = self.fetch_next(worker.id()).await;
498 match res {
499 Err(e) => {
500 worker.emit(Event::Error(Box::new(RedisPollError::PollNextError(e))));
501 }
502 Ok(res) => {
503 for job in res {
504 if let Err(e) = tx.send(Ok(Some(job))).await {
505 worker.emit(Event::Error(Box::new(RedisPollError::EnqueueError(e))));
506 }
507 }
508 }
509 }
510 } else {
511 continue;
512 }
513
514 }
515 id_to_ack = ack_stream.next() => {
516 if let Some((ctx, res)) = id_to_ack {
517 if let Err(e) = self.ack(&ctx, &res).await {
518 worker.emit(Event::Error(Box::new(RedisPollError::AckError(e))));
519 }
520 }
521 }
522 _ = reenqueue_orphaned_stm.next() => {
523 let dead_since = Utc::now()
524 - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
525 if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await {
526 worker.emit(Event::Error(Box::new(RedisPollError::ReenqueueOrphanedError(e))));
527 }
528 }
529 };
530 }
531 };
532 Poller::new_with_layer(
533 BackendStream::new(stream, controller),
534 heartbeat.boxed(),
535 layer,
536 )
537 }
538}
539
540impl<T, Conn, C, Res> Ack<T, Res, C> for RedisStorage<T, Conn, C>
541where
542 T: Sync + Send + Serialize + DeserializeOwned + Unpin + 'static,
543 Conn: ConnectionLike + Send + Sync + 'static,
544 C: Codec<Compact = Vec<u8>> + Send + 'static,
545 Res: Serialize + Sync + Send + 'static,
546{
547 type Context = RedisContext;
548 type AckError = RedisError;
549 async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), RedisError> {
550 let mut task = self
553 .fetch_by_id(&res.task_id)
554 .await?
555 .expect("must be a valid task");
556 task.parts.attempt = res.attempt.clone();
557 self.update(task).await?;
558 let inflight_set = format!(
561 "{}:{}",
562 self.config.inflight_jobs_set(),
563 ctx.lock_by.clone().unwrap()
564 );
565
566 let now: i64 = Utc::now().timestamp();
567 let task_id = res.task_id.to_string();
568 match &res.inner {
569 Ok(success_res) => {
570 let done_job = self.scripts.done_job.clone();
571 let done_jobs_set = &self.config.done_jobs_set();
572 done_job
573 .key(inflight_set)
574 .key(done_jobs_set)
575 .key(self.config.job_data_hash())
576 .arg(task_id)
577 .arg(now)
578 .arg(C::encode(success_res).map_err(Into::into).unwrap())
579 .invoke_async(&mut self.conn)
580 .await
581 }
582 Err(e) => match e {
583 Error::Abort(e) => {
584 let worker_id = ctx.lock_by.as_ref().unwrap();
585 self.kill(worker_id, &res.task_id, &e).await
586 }
587 _ => {
588 if ctx.max_attempts > res.attempt.current() {
589 let worker_id = ctx.lock_by.as_ref().unwrap();
590 self.retry(worker_id, &res.task_id).await.map(|_| ())
591 } else {
592 let worker_id = ctx.lock_by.as_ref().unwrap();
593
594 self.kill(
595 worker_id,
596 &res.task_id,
597 &(Box::new(io::Error::new(
598 io::ErrorKind::Interrupted,
599 format!("Max retries of {} exceeded", ctx.max_attempts),
600 )) as BoxDynError),
601 )
602 .await
603 }
604 }
605 },
606 }
607 }
608}
609
610impl<T, Conn, C> RedisStorage<T, Conn, C>
611where
612 T: DeserializeOwned + Send + Unpin + Send + Sync + 'static,
613 Conn: ConnectionLike + Send + Sync + 'static,
614 C: Codec<Compact = Vec<u8>>,
615{
616 async fn fetch_next(
617 &mut self,
618 worker_id: &WorkerId,
619 ) -> Result<Vec<Request<T, RedisContext>>, RedisError> {
620 let fetch_jobs = self.scripts.get_jobs.clone();
621 let consumers_set = self.config.consumers_set();
622 let active_jobs_list = self.config.active_jobs_list();
623 let job_data_hash = self.config.job_data_hash();
624 let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
625 let signal_list = self.config.signal_list();
626 let namespace = &self.config.namespace;
627
628 let result = fetch_jobs
629 .key(&consumers_set)
630 .key(&active_jobs_list)
631 .key(&inflight_set)
632 .key(&job_data_hash)
633 .key(&signal_list)
634 .arg(self.config.buffer_size) .arg(&inflight_set)
636 .invoke_async::<Vec<Value>>(&mut self.conn)
637 .await;
638
639 match result {
640 Ok(jobs) => {
641 let mut processed = vec![];
642 for job in jobs {
643 let bytes = deserialize_job(&job)?;
644 let mut request: Request<T, RedisContext> =
645 C::decode(bytes.clone()).map_err(|e| build_error(&e.into().to_string()))?;
646 request.parts.context.lock_by = Some(worker_id.clone());
647 request.parts.namespace = Some(Namespace(namespace.clone()));
648 processed.push(request)
649 }
650 Ok(processed)
651 }
652 Err(e) => {
653 warn!("An error occurred during streaming jobs: {e}");
654 if matches!(e.kind(), ErrorKind::ResponseError)
655 && e.to_string().contains("consumer not registered script")
656 {
657 self.keep_alive(worker_id).await?;
658 }
659 Err(e)
660 }
661 }
662 }
663}
664
665fn build_error(message: &str) -> RedisError {
666 RedisError::from(io::Error::new(io::ErrorKind::InvalidData, message))
667}
668
669fn deserialize_job(job: &Value) -> Result<&Vec<u8>, RedisError> {
670 match job {
671 Value::BulkString(bytes) => Ok(bytes),
672 Value::Array(val) | Value::Set(val) => val
673 .first()
674 .and_then(|val| {
675 if let Value::BulkString(bytes) = val {
676 Some(bytes)
677 } else {
678 None
679 }
680 })
681 .ok_or(build_error("Value::Bulk: Invalid data returned by storage")),
682 _ => Err(build_error("unknown result type for next message")),
683 }
684}
685
686impl<T, Conn: ConnectionLike, C> RedisStorage<T, Conn, C> {
687 async fn keep_alive(&mut self, worker_id: &WorkerId) -> Result<(), RedisError> {
688 let register_consumer = self.scripts.register_consumer.clone();
689 let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
690 let consumers_set = self.config.consumers_set();
691
692 let now: i64 = Utc::now().timestamp();
693
694 register_consumer
695 .key(consumers_set)
696 .arg(now)
697 .arg(inflight_set)
698 .invoke_async(&mut self.conn)
699 .await
700 }
701}
702
703impl<T, Conn, C> Storage for RedisStorage<T, Conn, C>
704where
705 T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
706 Conn: ConnectionLike + Send + Sync + 'static,
707 C: Codec<Compact = Vec<u8>> + Send + 'static,
708{
709 type Job = T;
710 type Error = RedisError;
711 type Context = RedisContext;
712
713 type Compact = Vec<u8>;
714
715 async fn push_request(
716 &mut self,
717 req: Request<T, RedisContext>,
718 ) -> Result<Parts<Self::Context>, RedisError> {
719 let conn = &mut self.conn;
720 let push_job = self.scripts.push_job.clone();
721 let job_data_hash = self.config.job_data_hash();
722 let active_jobs_list = self.config.active_jobs_list();
723 let signal_list = self.config.signal_list();
724
725 let job = C::encode(&req)
726 .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
727 push_job
728 .key(job_data_hash)
729 .key(active_jobs_list)
730 .key(signal_list)
731 .arg(req.parts.task_id.to_string())
732 .arg(job)
733 .invoke_async(conn)
734 .await?;
735 Ok(req.parts)
736 }
737
738 async fn push_raw_request(
739 &mut self,
740 req: Request<Self::Compact, Self::Context>,
741 ) -> Result<Parts<Self::Context>, Self::Error> {
742 let conn = &mut self.conn;
743 let push_job = self.scripts.push_job.clone();
744 let job_data_hash = self.config.job_data_hash();
745 let active_jobs_list = self.config.active_jobs_list();
746 let signal_list = self.config.signal_list();
747
748 let job = C::encode(&req)
749 .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
750 push_job
751 .key(job_data_hash)
752 .key(active_jobs_list)
753 .key(signal_list)
754 .arg(req.parts.task_id.to_string())
755 .arg(job)
756 .invoke_async(conn)
757 .await?;
758 Ok(req.parts)
759 }
760
761 async fn schedule_request(
762 &mut self,
763 req: Request<Self::Job, RedisContext>,
764 on: i64,
765 ) -> Result<Parts<Self::Context>, RedisError> {
766 let schedule_job = self.scripts.schedule_job.clone();
767 let job_data_hash = self.config.job_data_hash();
768 let scheduled_jobs_set = self.config.scheduled_jobs_set();
769 let job = C::encode(&req)
770 .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
771 schedule_job
772 .key(job_data_hash)
773 .key(scheduled_jobs_set)
774 .arg(req.parts.task_id.to_string())
775 .arg(job)
776 .arg(on)
777 .invoke_async(&mut self.conn)
778 .await?;
779 Ok(req.parts)
780 }
781
782 async fn len(&mut self) -> Result<i64, RedisError> {
783 let pending_jobs: i64 = redis::cmd("LLEN")
784 .arg(self.config.active_jobs_list())
785 .query_async(&mut self.conn)
786 .await?;
787
788 Ok(pending_jobs)
789 }
790
791 async fn fetch_by_id(
792 &mut self,
793 job_id: &TaskId,
794 ) -> Result<Option<Request<Self::Job, RedisContext>>, RedisError> {
795 let data: Value = redis::cmd("HMGET")
796 .arg(self.config.job_data_hash())
797 .arg(job_id.to_string())
798 .query_async(&mut self.conn)
799 .await?;
800 let bytes = deserialize_job(&data)?;
801
802 let inner: Request<T, RedisContext> = C::decode(bytes.to_vec())
803 .map_err(|e| (ErrorKind::IoError, "Decode error", e.into().to_string()))?;
804 Ok(Some(inner))
805 }
806 async fn update(&mut self, job: Request<T, RedisContext>) -> Result<(), RedisError> {
807 let task_id = job.parts.task_id.to_string();
808 let bytes = C::encode(&job)
809 .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
810 let _: i64 = redis::cmd("HSET")
811 .arg(self.config.job_data_hash())
812 .arg(task_id)
813 .arg(bytes)
814 .query_async(&mut self.conn)
815 .await?;
816 Ok(())
817 }
818
819 async fn reschedule(
820 &mut self,
821 job: Request<T, RedisContext>,
822 wait: Duration,
823 ) -> Result<(), RedisError> {
824 let schedule_job = self.scripts.schedule_job.clone();
825 let job_id = &job.parts.task_id;
826 let worker_id = &job.parts.context.lock_by.clone().unwrap();
827 let job = C::encode(&job)
828 .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
829 let job_data_hash = self.config.job_data_hash();
830 let scheduled_jobs_set = self.config.scheduled_jobs_set();
831 let on: i64 = Utc::now().timestamp();
832 let wait: i64 = wait
833 .as_secs()
834 .try_into()
835 .map_err(|e: TryFromIntError| (ErrorKind::IoError, "Duration error", e.to_string()))?;
836 let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
837 let failed_jobs_set = self.config.failed_jobs_set();
838 redis::cmd("SREM")
839 .arg(inflight_set)
840 .arg(job_id.to_string())
841 .query_async(&mut self.conn)
842 .await?;
843 redis::cmd("ZADD")
844 .arg(failed_jobs_set)
845 .arg(on)
846 .arg(job_id.to_string())
847 .query_async(&mut self.conn)
848 .await?;
849 schedule_job
850 .key(job_data_hash)
851 .key(scheduled_jobs_set)
852 .arg(job_id.to_string())
853 .arg(job)
854 .arg(on + wait)
855 .invoke_async(&mut self.conn)
856 .await
857 }
858 async fn is_empty(&mut self) -> Result<bool, RedisError> {
859 self.len().map_ok(|res| res == 0).await
860 }
861
862 async fn vacuum(&mut self) -> Result<usize, RedisError> {
863 let vacuum_script = self.scripts.vacuum.clone();
864 vacuum_script
865 .key(self.config.done_jobs_set())
866 .key(self.config.job_data_hash())
867 .invoke_async(&mut self.conn)
868 .await
869 }
870}
871
872impl<T, Conn, C> RedisStorage<T, Conn, C>
873where
874 Conn: ConnectionLike + Send + Sync + 'static,
875 C: Codec<Compact = Vec<u8>> + Send + 'static,
876{
877 pub async fn retry(&mut self, worker_id: &WorkerId, task_id: &TaskId) -> Result<i32, RedisError>
879 where
880 T: Send + DeserializeOwned + Serialize + Unpin + Sync + 'static,
881 {
882 let retry_job = self.scripts.retry_job.clone();
883 let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
884 let scheduled_jobs_set = self.config.scheduled_jobs_set();
885 let job_data_hash = self.config.job_data_hash();
886 let job_fut = self.fetch_by_id(task_id);
887 let now: i64 = Utc::now().timestamp();
888 let res = job_fut.await?;
889 let conn = &mut self.conn;
890 match res {
891 Some(job) => {
892 let attempt = &job.parts.attempt;
893 let max_attempts = &job.parts.context.max_attempts;
894 if &attempt.current() >= &max_attempts {
895 self.kill(
896 worker_id,
897 task_id,
898 &(Box::new(io::Error::new(
899 io::ErrorKind::Interrupted,
900 format!("Max retries of {} exceeded", max_attempts),
901 )) as BoxDynError),
902 )
903 .await?;
904 return Ok(1);
905 }
906 let job = C::encode(job)
907 .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
908
909 let res: Result<i32, RedisError> = retry_job
910 .key(inflight_set)
911 .key(scheduled_jobs_set)
912 .key(job_data_hash)
913 .arg(task_id.to_string())
914 .arg(now)
915 .arg(job)
916 .invoke_async(conn)
917 .await;
918 match res {
919 Ok(count) => Ok(count),
920 Err(e) => Err(e),
921 }
922 }
923 None => Err(RedisError::from((ErrorKind::ResponseError, "Id not found"))),
924 }
925 }
926
927 pub async fn kill(
929 &mut self,
930 worker_id: &WorkerId,
931 task_id: &TaskId,
932 error: &BoxDynError,
933 ) -> Result<(), RedisError> {
934 let kill_job = self.scripts.kill_job.clone();
935 let current_worker_id = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
936 let job_data_hash = self.config.job_data_hash();
937 let dead_jobs_set = self.config.dead_jobs_set();
938 let now: i64 = Utc::now().timestamp();
939 kill_job
940 .key(current_worker_id)
941 .key(dead_jobs_set)
942 .key(job_data_hash)
943 .arg(task_id.to_string())
944 .arg(now)
945 .arg(error.to_string())
946 .invoke_async(&mut self.conn)
947 .await
948 }
949
950 pub async fn enqueue_scheduled(&mut self, count: usize) -> Result<usize, RedisError> {
952 let enqueue_jobs = self.scripts.enqueue_scheduled.clone();
953 let scheduled_jobs_set = self.config.scheduled_jobs_set();
954 let active_jobs_list = self.config.active_jobs_list();
955 let signal_list = self.config.signal_list();
956 let now: i64 = Utc::now().timestamp();
957 let res: Result<usize, _> = enqueue_jobs
958 .key(scheduled_jobs_set)
959 .key(active_jobs_list)
960 .key(signal_list)
961 .arg(now)
962 .arg(count)
963 .invoke_async(&mut self.conn)
964 .await;
965 match res {
966 Ok(count) => Ok(count),
967 Err(e) => Err(e),
968 }
969 }
970
971 pub async fn reenqueue_active(&mut self, job_ids: Vec<&TaskId>) -> Result<(), RedisError> {
973 let reenqueue_active = self.scripts.reenqueue_active.clone();
974 let inflight_set: String = self.config.inflight_jobs_set().to_string();
975 let active_jobs_list = self.config.active_jobs_list();
976 let signal_list = self.config.signal_list();
977
978 reenqueue_active
979 .key(inflight_set)
980 .key(active_jobs_list)
981 .key(signal_list)
982 .arg(
983 job_ids
984 .into_iter()
985 .map(|j| j.to_string())
986 .collect::<Vec<String>>(),
987 )
988 .invoke_async(&mut self.conn)
989 .await
990 }
991 pub async fn reenqueue_orphaned(
993 &mut self,
994 count: i32,
995 dead_since: DateTime<Utc>,
996 ) -> Result<usize, RedisError> {
997 let reenqueue_orphaned = self.scripts.reenqueue_orphaned.clone();
998 let consumers_set = self.config.consumers_set();
999 let active_jobs_list = self.config.active_jobs_list();
1000 let signal_list = self.config.signal_list();
1001
1002 let dead_since = dead_since.timestamp();
1003
1004 let res: Result<usize, RedisError> = reenqueue_orphaned
1005 .key(consumers_set)
1006 .key(active_jobs_list)
1007 .key(signal_list)
1008 .arg(dead_since)
1009 .arg(count)
1010 .invoke_async(&mut self.conn)
1011 .await;
1012 match res {
1013 Ok(count) => Ok(count),
1014 Err(e) => Err(e),
1015 }
1016 }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use apalis_core::worker::Context;
1022 use apalis_core::{generic_storage_test, sleep};
1023 use email_service::Email;
1024
1025 use apalis_core::test_utils::apalis_test_service_fn;
1026 use apalis_core::test_utils::TestWrapper;
1027
1028 generic_storage_test!(setup);
1029
1030 use super::*;
1031
1032 async fn setup<T: Serialize + DeserializeOwned>() -> RedisStorage<T> {
1034 let redis_url = std::env::var("REDIS_URL").expect("No REDIS_URL is specified");
1035 let conn = connect(redis_url).await.unwrap();
1039 let config = Config::default()
1040 .set_namespace("apalis::test")
1041 .set_enqueue_scheduled(Duration::from_millis(500)); let mut storage = RedisStorage::new_with_config(conn, config);
1043 cleanup(&mut storage, &WorkerId::new("test-worker")).await;
1044 storage
1045 }
1046
1047 async fn cleanup<T>(storage: &mut RedisStorage<T>, _worker_id: &WorkerId) {
1051 let _resp: String = redis::cmd("FLUSHDB")
1052 .query_async(&mut storage.conn)
1053 .await
1054 .expect("failed to Flushdb");
1055 }
1056
1057 fn example_email() -> Email {
1058 Email {
1059 subject: "Test Subject".to_string(),
1060 to: "example@postgres".to_string(),
1061 text: "Some Text".to_string(),
1062 }
1063 }
1064
1065 async fn consume_one(
1066 storage: &mut RedisStorage<Email>,
1067 worker_id: &WorkerId,
1068 ) -> Request<Email, RedisContext> {
1069 let stream = storage.fetch_next(worker_id);
1070 stream
1071 .await
1072 .expect("failed to poll job")
1073 .first()
1074 .expect("stream is empty")
1075 .clone()
1076 }
1077
1078 async fn register_worker_at(storage: &mut RedisStorage<Email>) -> Worker<Context> {
1079 let worker = Worker::new(WorkerId::new("test-worker"), Context::default());
1080 worker.start();
1081 storage
1082 .keep_alive(&worker.id())
1083 .await
1084 .expect("failed to register worker");
1085 worker
1086 }
1087
1088 async fn register_worker(storage: &mut RedisStorage<Email>) -> Worker<Context> {
1089 register_worker_at(storage).await
1090 }
1091
1092 async fn push_email(storage: &mut RedisStorage<Email>, email: Email) {
1093 storage.push(email).await.expect("failed to push a job");
1094 }
1095
1096 async fn get_job(
1097 storage: &mut RedisStorage<Email>,
1098 job_id: &TaskId,
1099 ) -> Request<Email, RedisContext> {
1100 storage
1101 .fetch_by_id(job_id)
1102 .await
1103 .expect("failed to fetch job by id")
1104 .expect("no job found by id")
1105 }
1106
1107 #[tokio::test]
1108 async fn test_consume_last_pushed_job() {
1109 let mut storage = setup().await;
1110 push_email(&mut storage, example_email()).await;
1111
1112 let worker = register_worker(&mut storage).await;
1113
1114 let _job = consume_one(&mut storage, &worker.id()).await;
1115 }
1116
1117 #[tokio::test]
1118 async fn test_acknowledge_job() {
1119 let mut storage = setup().await;
1120 push_email(&mut storage, example_email()).await;
1121
1122 let worker = register_worker(&mut storage).await;
1123
1124 let job = consume_one(&mut storage, &worker.id()).await;
1125 let ctx = &job.parts.context;
1126 let res = 42usize;
1127 storage
1128 .ack(
1129 ctx,
1130 &Response::success(res, job.parts.task_id.clone(), job.parts.attempt.clone()),
1131 )
1132 .await
1133 .expect("failed to acknowledge the job");
1134
1135 let _job = get_job(&mut storage, &job.parts.task_id).await;
1136 }
1137
1138 #[tokio::test]
1139 async fn test_kill_job() {
1140 let mut storage = setup().await;
1141
1142 push_email(&mut storage, example_email()).await;
1143
1144 let worker = register_worker(&mut storage).await;
1145
1146 let job = consume_one(&mut storage, &worker.id()).await;
1147 let job_id = &job.parts.task_id;
1148
1149 storage
1150 .kill(
1151 &worker.id(),
1152 &job_id,
1153 &(Box::new(io::Error::new(
1154 io::ErrorKind::Interrupted,
1155 "Some unforeseen error occurred",
1156 )) as BoxDynError),
1157 )
1158 .await
1159 .expect("failed to kill job");
1160
1161 let _job = get_job(&mut storage, &job_id).await;
1162 }
1163
1164 #[tokio::test]
1165 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_1sec() {
1166 let mut storage = setup().await;
1167
1168 push_email(&mut storage, example_email()).await;
1169
1170 let worker = register_worker_at(&mut storage).await;
1171
1172 let job = consume_one(&mut storage, &worker.id()).await;
1173 sleep(Duration::from_millis(1000)).await;
1174 let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(1)).unwrap();
1175 let res = storage
1176 .reenqueue_orphaned(1, dead_since)
1177 .await
1178 .expect("failed to reenqueue_orphaned");
1179 assert_eq!(res, 1);
1181 let job = get_job(&mut storage, &job.parts.task_id).await;
1182 let ctx = &job.parts.context;
1183 assert!(ctx.lock_by.is_none());
1186 }
1191
1192 #[tokio::test]
1193 async fn test_heartbeat_renqueueorphaned_pulse_last_seen_5sec() {
1194 let mut storage = setup().await;
1195
1196 push_email(&mut storage, example_email()).await;
1197
1198 let worker = register_worker_at(&mut storage).await;
1199 sleep(Duration::from_millis(1100)).await;
1200 let job = consume_one(&mut storage, &worker.id()).await;
1201 let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(5)).unwrap();
1202 let res = storage
1203 .reenqueue_orphaned(1, dead_since)
1204 .await
1205 .expect("failed to reenqueue_orphaned");
1206 assert_eq!(res, 0);
1208 let job = get_job(&mut storage, &job.parts.task_id).await;
1209 let _ctx = &job.parts.context;
1210 assert_eq!(job.parts.attempt.current(), 0);
1216 }
1217
1218 #[tokio::test]
1219 async fn test_stats() {
1220 use apalis_core::backend::BackendExpose;
1221
1222 let mut storage = setup().await;
1223 let stats = storage.stats().await.expect("failed to get stats");
1224 assert_eq!(stats.pending, 0);
1225 assert_eq!(stats.running, 0);
1226 push_email(&mut storage, example_email()).await;
1227 let stats = storage.stats().await.expect("failed to get stats");
1228 assert_eq!(stats.pending, 1);
1229
1230 let worker = register_worker(&mut storage).await;
1231
1232 let _job = consume_one(&mut storage, &worker.id()).await;
1233
1234 let stats = storage.stats().await.expect("failed to get stats");
1235 assert_eq!(stats.pending, 0);
1236 assert_eq!(stats.running, 1);
1237 }
1238}