apalis_redis/
storage.rs

1use apalis_core::codec::json::JsonCodec;
2use apalis_core::error::{BoxDynError, Error};
3use apalis_core::layers::{Ack, AckLayer};
4use apalis_core::poller::controller::Controller;
5use apalis_core::poller::stream::BackendStream;
6use apalis_core::poller::Poller;
7use apalis_core::request::{Parts, Request, RequestStream};
8use apalis_core::response::Response;
9use apalis_core::service_fn::FromRequest;
10use apalis_core::storage::Storage;
11use apalis_core::task::namespace::Namespace;
12use apalis_core::task::task_id::TaskId;
13use apalis_core::worker::{Event, Worker, WorkerId};
14use apalis_core::{backend::Backend, codec::Codec};
15use chrono::{DateTime, Utc};
16use futures::channel::mpsc::{self, SendError, Sender};
17use futures::{select, FutureExt, SinkExt, StreamExt, TryFutureExt};
18use log::*;
19use redis::aio::ConnectionLike;
20use redis::ErrorKind;
21use redis::{aio::ConnectionManager, Client, IntoConnectionInfo, RedisError, Script, Value};
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use std::any::type_name;
24use std::fmt::{self, Debug};
25use std::io;
26use std::num::TryFromIntError;
27use std::time::SystemTime;
28use std::{marker::PhantomData, time::Duration};
29
30/// Shorthand to create a client and connect
31pub async fn connect<S: IntoConnectionInfo>(redis: S) -> Result<ConnectionManager, RedisError> {
32    let client = Client::open(redis.into_connection_info()?)?;
33    let conn = client.get_connection_manager().await?;
34    Ok(conn)
35}
36
37const ACTIVE_JOBS_LIST: &str = "{queue}:active";
38const CONSUMERS_SET: &str = "{queue}:consumers";
39const DEAD_JOBS_SET: &str = "{queue}:dead";
40const DONE_JOBS_SET: &str = "{queue}:done";
41const FAILED_JOBS_SET: &str = "{queue}:failed";
42const INFLIGHT_JOB_SET: &str = "{queue}:inflight";
43const JOB_DATA_HASH: &str = "{queue}:data";
44const SCHEDULED_JOBS_SET: &str = "{queue}:scheduled";
45const SIGNAL_LIST: &str = "{queue}:signal";
46
47/// Represents redis key names for various components of the RedisStorage.
48///
49/// This struct defines keys used in Redis to manage jobs and their lifecycle in the storage.
50#[derive(Clone, Debug)]
51pub struct RedisQueueInfo {
52    /// Key for the list of currently active jobs.
53    pub active_jobs_list: String,
54
55    /// Key for the set of active consumers.
56    pub consumers_set: String,
57
58    /// Key for the set of jobs that are no longer retryable.
59    pub dead_jobs_set: String,
60
61    /// Key for the set of jobs that have completed successfully.
62    pub done_jobs_set: String,
63
64    /// Key for the set of jobs that have failed.
65    pub failed_jobs_set: String,
66
67    /// Key for the set of jobs that are currently being processed.
68    pub inflight_jobs_set: String,
69
70    /// Key for the hash storing data for each job.
71    pub job_data_hash: String,
72
73    /// Key for the set of jobs scheduled for future execution.
74    pub scheduled_jobs_set: String,
75
76    /// Key for the list used for signaling and communication between consumers and producers.
77    pub signal_list: String,
78}
79
80#[derive(Clone, Debug)]
81pub(crate) struct RedisScript {
82    done_job: Script,
83    enqueue_scheduled: Script,
84    get_jobs: Script,
85    kill_job: Script,
86    push_job: Script,
87    reenqueue_active: Script,
88    reenqueue_orphaned: Script,
89    register_consumer: Script,
90    retry_job: Script,
91    schedule_job: Script,
92    vacuum: Script,
93    pub(crate) stats: Script,
94}
95
96/// The context for a redis storage job
97#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct RedisContext {
99    max_attempts: usize,
100    lock_by: Option<WorkerId>,
101    run_at: Option<SystemTime>,
102}
103
104impl Default for RedisContext {
105    fn default() -> Self {
106        Self {
107            max_attempts: 5,
108            lock_by: None,
109            run_at: None,
110        }
111    }
112}
113
114impl<Req> FromRequest<Request<Req, RedisContext>> for RedisContext {
115    fn from_request(req: &Request<Req, RedisContext>) -> Result<Self, Error> {
116        Ok(req.parts.context.clone())
117    }
118}
119
120/// Errors that can occur while polling a Redis backend.
121#[derive(thiserror::Error, Debug)]
122pub enum RedisPollError {
123    /// Error during a keep-alive heartbeat.
124    #[error("KeepAlive heartbeat encountered an error: `{0}`")]
125    KeepAliveError(RedisError),
126
127    /// Error during enqueueing scheduled tasks.
128    #[error("EnqueueScheduled heartbeat encountered an error: `{0}`")]
129    EnqueueScheduledError(RedisError),
130
131    /// Error during polling for the next task or message.
132    #[error("PollNext heartbeat encountered an error: `{0}`")]
133    PollNextError(RedisError),
134
135    /// Error during enqueueing tasks for worker consumption.
136    #[error("Enqueue for worker consumption encountered an error: `{0}`")]
137    EnqueueError(SendError),
138
139    /// Error during acknowledgment of tasks.
140    #[error("Ack heartbeat encountered an error: `{0}`")]
141    AckError(RedisError),
142
143    /// Error during re-enqueuing orphaned tasks.
144    #[error("ReenqueueOrphaned heartbeat encountered an error: `{0}`")]
145    ReenqueueOrphanedError(RedisError),
146}
147
148/// Config for a [RedisStorage]
149#[derive(Clone, Debug)]
150pub struct Config {
151    poll_interval: Duration,
152    buffer_size: usize,
153    keep_alive: Duration,
154    enqueue_scheduled: Duration,
155    reenqueue_orphaned_after: Duration,
156    namespace: String,
157}
158
159impl Default for Config {
160    fn default() -> Self {
161        Self {
162            poll_interval: Duration::from_millis(100),
163            buffer_size: 10,
164            keep_alive: Duration::from_secs(30),
165            enqueue_scheduled: Duration::from_secs(30),
166            reenqueue_orphaned_after: Duration::from_secs(300),
167            namespace: String::from("apalis_redis"),
168        }
169    }
170}
171
172impl Config {
173    /// Get the interval of polling
174    pub fn get_poll_interval(&self) -> &Duration {
175        &self.poll_interval
176    }
177
178    /// Get the number of jobs to fetch
179    pub fn get_buffer_size(&self) -> usize {
180        self.buffer_size
181    }
182
183    /// get the keep live rate
184    pub fn get_keep_alive(&self) -> &Duration {
185        &self.keep_alive
186    }
187
188    /// get the enqueued setting
189    pub fn get_enqueue_scheduled(&self) -> &Duration {
190        &self.enqueue_scheduled
191    }
192
193    /// get the namespace
194    pub fn get_namespace(&self) -> &String {
195        &self.namespace
196    }
197
198    /// get the poll interval
199    pub fn set_poll_interval(mut self, poll_interval: Duration) -> Self {
200        self.poll_interval = poll_interval;
201        self
202    }
203
204    /// set the buffer setting
205    pub fn set_buffer_size(mut self, buffer_size: usize) -> Self {
206        self.buffer_size = buffer_size;
207        self
208    }
209
210    /// set the keep-alive setting
211    pub fn set_keep_alive(mut self, keep_alive: Duration) -> Self {
212        self.keep_alive = keep_alive;
213        self
214    }
215
216    /// get the enqueued setting
217    pub fn set_enqueue_scheduled(mut self, enqueue_scheduled: Duration) -> Self {
218        self.enqueue_scheduled = enqueue_scheduled;
219        self
220    }
221
222    /// set the namespace for the Storage
223    pub fn set_namespace(mut self, namespace: &str) -> Self {
224        self.namespace = namespace.to_string();
225        self
226    }
227
228    /// Returns the Redis key for the list of pending jobs associated with the queue.
229    /// The key is dynamically generated using the namespace of the queue.
230    ///
231    /// # Returns
232    /// A `String` representing the Redis key for the pending jobs list.
233    pub fn active_jobs_list(&self) -> String {
234        ACTIVE_JOBS_LIST.replace("{queue}", &self.namespace)
235    }
236
237    /// Returns the Redis key for the set of consumers associated with the queue.
238    /// The key is dynamically generated using the namespace of the queue.
239    ///
240    /// # Returns
241    /// A `String` representing the Redis key for the consumers set.
242    pub fn consumers_set(&self) -> String {
243        CONSUMERS_SET.replace("{queue}", &self.namespace)
244    }
245
246    /// Returns the Redis key for the set of dead jobs associated with the queue.
247    /// The key is dynamically generated using the namespace of the queue.
248    ///
249    /// # Returns
250    /// A `String` representing the Redis key for the dead jobs set.
251    pub fn dead_jobs_set(&self) -> String {
252        DEAD_JOBS_SET.replace("{queue}", &self.namespace)
253    }
254
255    /// Returns the Redis key for the set of done jobs associated with the queue.
256    /// The key is dynamically generated using the namespace of the queue.
257    ///
258    /// # Returns
259    /// A `String` representing the Redis key for the done jobs set.
260    pub fn done_jobs_set(&self) -> String {
261        DONE_JOBS_SET.replace("{queue}", &self.namespace)
262    }
263
264    /// Returns the Redis key for the set of failed jobs associated with the queue.
265    /// The key is dynamically generated using the namespace of the queue.
266    ///
267    /// # Returns
268    /// A `String` representing the Redis key for the failed jobs set.
269    pub fn failed_jobs_set(&self) -> String {
270        FAILED_JOBS_SET.replace("{queue}", &self.namespace)
271    }
272
273    /// Returns the Redis key for the set of inflight jobs associated with the queue.
274    /// The key is dynamically generated using the namespace of the queue.
275    ///
276    /// # Returns
277    /// A `String` representing the Redis key for the inflight jobs set.
278    pub fn inflight_jobs_set(&self) -> String {
279        INFLIGHT_JOB_SET.replace("{queue}", &self.namespace)
280    }
281
282    /// Returns the Redis key for the hash storing job data associated with the queue.
283    /// The key is dynamically generated using the namespace of the queue.
284    ///
285    /// # Returns
286    /// A `String` representing the Redis key for the job data hash.
287    pub fn job_data_hash(&self) -> String {
288        JOB_DATA_HASH.replace("{queue}", &self.namespace)
289    }
290
291    /// Returns the Redis key for the set of scheduled jobs associated with the queue.
292    /// The key is dynamically generated using the namespace of the queue.
293    ///
294    /// # Returns
295    /// A `String` representing the Redis key for the scheduled jobs set.
296    pub fn scheduled_jobs_set(&self) -> String {
297        SCHEDULED_JOBS_SET.replace("{queue}", &self.namespace)
298    }
299
300    /// Returns the Redis key for the list of signals associated with the queue.
301    /// The key is dynamically generated using the namespace of the queue.
302    ///
303    /// # Returns
304    /// A `String` representing the Redis key for the signal list.
305    pub fn signal_list(&self) -> String {
306        SIGNAL_LIST.replace("{queue}", &self.namespace)
307    }
308
309    /// Gets the reenqueue_orphaned_after duration.
310    pub fn reenqueue_orphaned_after(&self) -> Duration {
311        self.reenqueue_orphaned_after
312    }
313
314    /// Gets a mutable reference to the reenqueue_orphaned_after.
315    pub fn reenqueue_orphaned_after_mut(&mut self) -> &mut Duration {
316        &mut self.reenqueue_orphaned_after
317    }
318
319    /// Occasionally some workers die, or abandon jobs because of panics.
320    /// This is the time a task takes before its back to the queue
321    ///
322    /// Defaults to 5 minutes
323    pub fn set_reenqueue_orphaned_after(mut self, after: Duration) -> Self {
324        self.reenqueue_orphaned_after = after;
325        self
326    }
327}
328
329/// Represents a [Storage] that uses Redis for storage.
330pub struct RedisStorage<T, Conn = ConnectionManager, C = JsonCodec<Vec<u8>>> {
331    conn: Conn,
332    job_type: PhantomData<T>,
333    pub(super) scripts: RedisScript,
334    controller: Controller,
335    config: Config,
336    codec: PhantomData<C>,
337}
338
339impl<T, Conn, C> fmt::Debug for RedisStorage<T, Conn, C> {
340    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341        f.debug_struct("RedisStorage")
342            .field("conn", &"ConnectionManager")
343            .field("job_type", &std::any::type_name::<T>())
344            .field("scripts", &self.scripts)
345            .field("config", &self.config)
346            .finish()
347    }
348}
349
350impl<T, Conn: Clone, C> Clone for RedisStorage<T, Conn, C> {
351    fn clone(&self) -> Self {
352        Self {
353            conn: self.conn.clone(),
354            job_type: PhantomData,
355            scripts: self.scripts.clone(),
356            controller: self.controller.clone(),
357            config: self.config.clone(),
358            codec: self.codec,
359        }
360    }
361}
362
363impl<T: Serialize + DeserializeOwned, Conn> RedisStorage<T, Conn, JsonCodec<Vec<u8>>> {
364    /// Start a new connection
365    pub fn new(conn: Conn) -> RedisStorage<T, Conn, JsonCodec<Vec<u8>>> {
366        Self::new_with_codec::<JsonCodec<Vec<u8>>>(
367            conn,
368            Config::default().set_namespace(type_name::<T>()),
369        )
370    }
371
372    /// Start a connection with a custom config
373    pub fn new_with_config(
374        conn: Conn,
375        config: Config,
376    ) -> RedisStorage<T, Conn, JsonCodec<Vec<u8>>> {
377        Self::new_with_codec::<JsonCodec<Vec<u8>>>(conn, config)
378    }
379
380    /// Start a new connection providing custom config and a codec
381    pub fn new_with_codec<K>(conn: Conn, config: Config) -> RedisStorage<T, Conn, K>
382    where
383        K: Codec + Sync + Send + 'static,
384    {
385        RedisStorage {
386            conn,
387            job_type: PhantomData,
388            controller: Controller::new(),
389            config,
390            codec: PhantomData::<K>,
391            scripts: RedisScript {
392                done_job: redis::Script::new(include_str!("../lua/done_job.lua")),
393                push_job: redis::Script::new(include_str!("../lua/push_job.lua")),
394                retry_job: redis::Script::new(include_str!("../lua/retry_job.lua")),
395                enqueue_scheduled: redis::Script::new(include_str!(
396                    "../lua/enqueue_scheduled_jobs.lua"
397                )),
398                get_jobs: redis::Script::new(include_str!("../lua/get_jobs.lua")),
399                register_consumer: redis::Script::new(include_str!("../lua/register_consumer.lua")),
400                kill_job: redis::Script::new(include_str!("../lua/kill_job.lua")),
401                reenqueue_active: redis::Script::new(include_str!(
402                    "../lua/reenqueue_active_jobs.lua"
403                )),
404                reenqueue_orphaned: redis::Script::new(include_str!(
405                    "../lua/reenqueue_orphaned_jobs.lua"
406                )),
407                schedule_job: redis::Script::new(include_str!("../lua/schedule_job.lua")),
408                vacuum: redis::Script::new(include_str!("../lua/vacuum.lua")),
409                stats: redis::Script::new(include_str!("../lua/stats.lua")),
410            },
411        }
412    }
413
414    /// Get current connection
415    pub fn get_connection(&self) -> &Conn {
416        &self.conn
417    }
418
419    /// Get the config used by the storage
420    pub fn get_config(&self) -> &Config {
421        &self.config
422    }
423}
424
425impl<T, Conn, C> RedisStorage<T, Conn, C> {
426    /// Get the underlying codec details
427    pub fn get_codec(&self) -> &PhantomData<C> {
428        &self.codec
429    }
430}
431
432impl<T, Conn, C> Backend<Request<T, RedisContext>> for RedisStorage<T, Conn, C>
433where
434    T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static,
435    Conn: ConnectionLike + Send + Sync + 'static,
436    C: Codec<Compact = Vec<u8>> + Send + 'static,
437{
438    type Stream = BackendStream<RequestStream<Request<T, RedisContext>>>;
439
440    type Layer = AckLayer<Sender<(RedisContext, Response<Vec<u8>>)>, T, RedisContext, C>;
441
442    type Codec = C;
443
444    fn poll(
445        mut self,
446        worker: &Worker<apalis_core::worker::Context>,
447    ) -> Poller<Self::Stream, Self::Layer> {
448        let (mut tx, rx) = mpsc::channel(self.config.buffer_size);
449        let (ack, ack_rx) =
450            mpsc::channel::<(RedisContext, Response<Vec<u8>>)>(self.config.buffer_size);
451        let layer = AckLayer::new(ack);
452        let controller = self.controller.clone();
453        let config = self.config.clone();
454        let stream: RequestStream<Request<T, RedisContext>> = Box::pin(rx);
455        let worker = worker.clone();
456        let heartbeat = async move {
457            // Lets reenqueue any jobs that belonged to this worker in case of a death
458            if let Err(e) = self
459                .reenqueue_orphaned((config.buffer_size * 10) as i32, Utc::now())
460                .await
461            {
462                worker.emit(Event::Error(Box::new(
463                    RedisPollError::ReenqueueOrphanedError(e),
464                )));
465            }
466
467            let mut reenqueue_orphaned_stm =
468                apalis_core::interval::interval(config.poll_interval).fuse();
469
470            let mut keep_alive_stm = apalis_core::interval::interval(config.keep_alive).fuse();
471
472            let mut enqueue_scheduled_stm =
473                apalis_core::interval::interval(config.enqueue_scheduled).fuse();
474
475            let mut poll_next_stm = apalis_core::interval::interval(config.poll_interval).fuse();
476
477            let mut ack_stream = ack_rx.fuse();
478
479            if let Err(e) = self.keep_alive(worker.id()).await {
480                worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e))));
481            }
482
483            loop {
484                select! {
485                    _ = keep_alive_stm.next() => {
486                        if let Err(e) = self.keep_alive(worker.id()).await {
487                            worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e))));
488                        }
489                    }
490                    _ = enqueue_scheduled_stm.next() => {
491                        if let Err(e) = self.enqueue_scheduled(config.buffer_size).await {
492                            worker.emit(Event::Error(Box::new(RedisPollError::EnqueueScheduledError(e))));
493                        }
494                    }
495                    _ = poll_next_stm.next() => {
496                        if worker.is_ready() {
497                            let res = self.fetch_next(worker.id()).await;
498                            match res {
499                                Err(e) => {
500                                    worker.emit(Event::Error(Box::new(RedisPollError::PollNextError(e))));
501                                }
502                                Ok(res) => {
503                                    for job in res {
504                                        if let Err(e) = tx.send(Ok(Some(job))).await {
505                                            worker.emit(Event::Error(Box::new(RedisPollError::EnqueueError(e))));
506                                        }
507                                    }
508                                }
509                            }
510                        } else {
511                            continue;
512                        }
513
514                    }
515                    id_to_ack = ack_stream.next() => {
516                        if let Some((ctx, res)) = id_to_ack {
517                            if let Err(e) = self.ack(&ctx, &res).await {
518                                worker.emit(Event::Error(Box::new(RedisPollError::AckError(e))));
519                            }
520                        }
521                    }
522                    _ = reenqueue_orphaned_stm.next() => {
523                        let dead_since = Utc::now()
524                            - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap();
525                        if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await {
526                            worker.emit(Event::Error(Box::new(RedisPollError::ReenqueueOrphanedError(e))));
527                        }
528                    }
529                };
530            }
531        };
532        Poller::new_with_layer(
533            BackendStream::new(stream, controller),
534            heartbeat.boxed(),
535            layer,
536        )
537    }
538}
539
540impl<T, Conn, C, Res> Ack<T, Res, C> for RedisStorage<T, Conn, C>
541where
542    T: Sync + Send + Serialize + DeserializeOwned + Unpin + 'static,
543    Conn: ConnectionLike + Send + Sync + 'static,
544    C: Codec<Compact = Vec<u8>> + Send + 'static,
545    Res: Serialize + Sync + Send + 'static,
546{
547    type Context = RedisContext;
548    type AckError = RedisError;
549    async fn ack(&mut self, ctx: &Self::Context, res: &Response<Res>) -> Result<(), RedisError> {
550        // Lets update the number of attempts
551        // TODO: move attempts to its own key
552        let mut task = self
553            .fetch_by_id(&res.task_id)
554            .await?
555            .expect("must be a valid task");
556        task.parts.attempt = res.attempt.clone();
557        self.update(task).await?;
558        // End of expensive update
559
560        let inflight_set = format!(
561            "{}:{}",
562            self.config.inflight_jobs_set(),
563            ctx.lock_by.clone().unwrap()
564        );
565
566        let now: i64 = Utc::now().timestamp();
567        let task_id = res.task_id.to_string();
568        match &res.inner {
569            Ok(success_res) => {
570                let done_job = self.scripts.done_job.clone();
571                let done_jobs_set = &self.config.done_jobs_set();
572                done_job
573                    .key(inflight_set)
574                    .key(done_jobs_set)
575                    .key(self.config.job_data_hash())
576                    .arg(task_id)
577                    .arg(now)
578                    .arg(C::encode(success_res).map_err(Into::into).unwrap())
579                    .invoke_async(&mut self.conn)
580                    .await
581            }
582            Err(e) => match e {
583                Error::Abort(e) => {
584                    let worker_id = ctx.lock_by.as_ref().unwrap();
585                    self.kill(worker_id, &res.task_id, &e).await
586                }
587                _ => {
588                    if ctx.max_attempts > res.attempt.current() {
589                        let worker_id = ctx.lock_by.as_ref().unwrap();
590                        self.retry(worker_id, &res.task_id).await.map(|_| ())
591                    } else {
592                        let worker_id = ctx.lock_by.as_ref().unwrap();
593
594                        self.kill(
595                            worker_id,
596                            &res.task_id,
597                            &(Box::new(io::Error::new(
598                                io::ErrorKind::Interrupted,
599                                format!("Max retries of {} exceeded", ctx.max_attempts),
600                            )) as BoxDynError),
601                        )
602                        .await
603                    }
604                }
605            },
606        }
607    }
608}
609
610impl<T, Conn, C> RedisStorage<T, Conn, C>
611where
612    T: DeserializeOwned + Send + Unpin + Send + Sync + 'static,
613    Conn: ConnectionLike + Send + Sync + 'static,
614    C: Codec<Compact = Vec<u8>>,
615{
616    async fn fetch_next(
617        &mut self,
618        worker_id: &WorkerId,
619    ) -> Result<Vec<Request<T, RedisContext>>, RedisError> {
620        let fetch_jobs = self.scripts.get_jobs.clone();
621        let consumers_set = self.config.consumers_set();
622        let active_jobs_list = self.config.active_jobs_list();
623        let job_data_hash = self.config.job_data_hash();
624        let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
625        let signal_list = self.config.signal_list();
626        let namespace = &self.config.namespace;
627
628        let result = fetch_jobs
629            .key(&consumers_set)
630            .key(&active_jobs_list)
631            .key(&inflight_set)
632            .key(&job_data_hash)
633            .key(&signal_list)
634            .arg(self.config.buffer_size) // No of jobs to fetch
635            .arg(&inflight_set)
636            .invoke_async::<Vec<Value>>(&mut self.conn)
637            .await;
638
639        match result {
640            Ok(jobs) => {
641                let mut processed = vec![];
642                for job in jobs {
643                    let bytes = deserialize_job(&job)?;
644                    let mut request: Request<T, RedisContext> =
645                        C::decode(bytes.clone()).map_err(|e| build_error(&e.into().to_string()))?;
646                    request.parts.context.lock_by = Some(worker_id.clone());
647                    request.parts.namespace = Some(Namespace(namespace.clone()));
648                    processed.push(request)
649                }
650                Ok(processed)
651            }
652            Err(e) => {
653                warn!("An error occurred during streaming jobs: {e}");
654                if matches!(e.kind(), ErrorKind::ResponseError)
655                    && e.to_string().contains("consumer not registered script")
656                {
657                    self.keep_alive(worker_id).await?;
658                }
659                Err(e)
660            }
661        }
662    }
663}
664
665fn build_error(message: &str) -> RedisError {
666    RedisError::from(io::Error::new(io::ErrorKind::InvalidData, message))
667}
668
669fn deserialize_job(job: &Value) -> Result<&Vec<u8>, RedisError> {
670    match job {
671        Value::BulkString(bytes) => Ok(bytes),
672        Value::Array(val) | Value::Set(val) => val
673            .first()
674            .and_then(|val| {
675                if let Value::BulkString(bytes) = val {
676                    Some(bytes)
677                } else {
678                    None
679                }
680            })
681            .ok_or(build_error("Value::Bulk: Invalid data returned by storage")),
682        _ => Err(build_error("unknown result type for next message")),
683    }
684}
685
686impl<T, Conn: ConnectionLike, C> RedisStorage<T, Conn, C> {
687    async fn keep_alive(&mut self, worker_id: &WorkerId) -> Result<(), RedisError> {
688        let register_consumer = self.scripts.register_consumer.clone();
689        let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
690        let consumers_set = self.config.consumers_set();
691
692        let now: i64 = Utc::now().timestamp();
693
694        register_consumer
695            .key(consumers_set)
696            .arg(now)
697            .arg(inflight_set)
698            .invoke_async(&mut self.conn)
699            .await
700    }
701}
702
703impl<T, Conn, C> Storage for RedisStorage<T, Conn, C>
704where
705    T: Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
706    Conn: ConnectionLike + Send + Sync + 'static,
707    C: Codec<Compact = Vec<u8>> + Send + 'static,
708{
709    type Job = T;
710    type Error = RedisError;
711    type Context = RedisContext;
712
713    type Compact = Vec<u8>;
714
715    async fn push_request(
716        &mut self,
717        req: Request<T, RedisContext>,
718    ) -> Result<Parts<Self::Context>, RedisError> {
719        let conn = &mut self.conn;
720        let push_job = self.scripts.push_job.clone();
721        let job_data_hash = self.config.job_data_hash();
722        let active_jobs_list = self.config.active_jobs_list();
723        let signal_list = self.config.signal_list();
724
725        let job = C::encode(&req)
726            .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
727        push_job
728            .key(job_data_hash)
729            .key(active_jobs_list)
730            .key(signal_list)
731            .arg(req.parts.task_id.to_string())
732            .arg(job)
733            .invoke_async(conn)
734            .await?;
735        Ok(req.parts)
736    }
737
738    async fn push_raw_request(
739        &mut self,
740        req: Request<Self::Compact, Self::Context>,
741    ) -> Result<Parts<Self::Context>, Self::Error> {
742        let conn = &mut self.conn;
743        let push_job = self.scripts.push_job.clone();
744        let job_data_hash = self.config.job_data_hash();
745        let active_jobs_list = self.config.active_jobs_list();
746        let signal_list = self.config.signal_list();
747
748        let job = C::encode(&req)
749            .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
750        push_job
751            .key(job_data_hash)
752            .key(active_jobs_list)
753            .key(signal_list)
754            .arg(req.parts.task_id.to_string())
755            .arg(job)
756            .invoke_async(conn)
757            .await?;
758        Ok(req.parts)
759    }
760
761    async fn schedule_request(
762        &mut self,
763        req: Request<Self::Job, RedisContext>,
764        on: i64,
765    ) -> Result<Parts<Self::Context>, RedisError> {
766        let schedule_job = self.scripts.schedule_job.clone();
767        let job_data_hash = self.config.job_data_hash();
768        let scheduled_jobs_set = self.config.scheduled_jobs_set();
769        let job = C::encode(&req)
770            .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
771        schedule_job
772            .key(job_data_hash)
773            .key(scheduled_jobs_set)
774            .arg(req.parts.task_id.to_string())
775            .arg(job)
776            .arg(on)
777            .invoke_async(&mut self.conn)
778            .await?;
779        Ok(req.parts)
780    }
781
782    async fn len(&mut self) -> Result<i64, RedisError> {
783        let pending_jobs: i64 = redis::cmd("LLEN")
784            .arg(self.config.active_jobs_list())
785            .query_async(&mut self.conn)
786            .await?;
787
788        Ok(pending_jobs)
789    }
790
791    async fn fetch_by_id(
792        &mut self,
793        job_id: &TaskId,
794    ) -> Result<Option<Request<Self::Job, RedisContext>>, RedisError> {
795        let data: Value = redis::cmd("HMGET")
796            .arg(self.config.job_data_hash())
797            .arg(job_id.to_string())
798            .query_async(&mut self.conn)
799            .await?;
800        let bytes = deserialize_job(&data)?;
801
802        let inner: Request<T, RedisContext> = C::decode(bytes.to_vec())
803            .map_err(|e| (ErrorKind::IoError, "Decode error", e.into().to_string()))?;
804        Ok(Some(inner))
805    }
806    async fn update(&mut self, job: Request<T, RedisContext>) -> Result<(), RedisError> {
807        let task_id = job.parts.task_id.to_string();
808        let bytes = C::encode(&job)
809            .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
810        let _: i64 = redis::cmd("HSET")
811            .arg(self.config.job_data_hash())
812            .arg(task_id)
813            .arg(bytes)
814            .query_async(&mut self.conn)
815            .await?;
816        Ok(())
817    }
818
819    async fn reschedule(
820        &mut self,
821        job: Request<T, RedisContext>,
822        wait: Duration,
823    ) -> Result<(), RedisError> {
824        let schedule_job = self.scripts.schedule_job.clone();
825        let job_id = &job.parts.task_id;
826        let worker_id = &job.parts.context.lock_by.clone().unwrap();
827        let job = C::encode(&job)
828            .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
829        let job_data_hash = self.config.job_data_hash();
830        let scheduled_jobs_set = self.config.scheduled_jobs_set();
831        let on: i64 = Utc::now().timestamp();
832        let wait: i64 = wait
833            .as_secs()
834            .try_into()
835            .map_err(|e: TryFromIntError| (ErrorKind::IoError, "Duration error", e.to_string()))?;
836        let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
837        let failed_jobs_set = self.config.failed_jobs_set();
838        redis::cmd("SREM")
839            .arg(inflight_set)
840            .arg(job_id.to_string())
841            .query_async(&mut self.conn)
842            .await?;
843        redis::cmd("ZADD")
844            .arg(failed_jobs_set)
845            .arg(on)
846            .arg(job_id.to_string())
847            .query_async(&mut self.conn)
848            .await?;
849        schedule_job
850            .key(job_data_hash)
851            .key(scheduled_jobs_set)
852            .arg(job_id.to_string())
853            .arg(job)
854            .arg(on + wait)
855            .invoke_async(&mut self.conn)
856            .await
857    }
858    async fn is_empty(&mut self) -> Result<bool, RedisError> {
859        self.len().map_ok(|res| res == 0).await
860    }
861
862    async fn vacuum(&mut self) -> Result<usize, RedisError> {
863        let vacuum_script = self.scripts.vacuum.clone();
864        vacuum_script
865            .key(self.config.done_jobs_set())
866            .key(self.config.job_data_hash())
867            .invoke_async(&mut self.conn)
868            .await
869    }
870}
871
872impl<T, Conn, C> RedisStorage<T, Conn, C>
873where
874    Conn: ConnectionLike + Send + Sync + 'static,
875    C: Codec<Compact = Vec<u8>> + Send + 'static,
876{
877    /// Attempt to retry a job
878    pub async fn retry(&mut self, worker_id: &WorkerId, task_id: &TaskId) -> Result<i32, RedisError>
879    where
880        T: Send + DeserializeOwned + Serialize + Unpin + Sync + 'static,
881    {
882        let retry_job = self.scripts.retry_job.clone();
883        let inflight_set = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
884        let scheduled_jobs_set = self.config.scheduled_jobs_set();
885        let job_data_hash = self.config.job_data_hash();
886        let job_fut = self.fetch_by_id(task_id);
887        let now: i64 = Utc::now().timestamp();
888        let res = job_fut.await?;
889        let conn = &mut self.conn;
890        match res {
891            Some(job) => {
892                let attempt = &job.parts.attempt;
893                let max_attempts = &job.parts.context.max_attempts;
894                if &attempt.current() >= &max_attempts {
895                    self.kill(
896                        worker_id,
897                        task_id,
898                        &(Box::new(io::Error::new(
899                            io::ErrorKind::Interrupted,
900                            format!("Max retries of {} exceeded", max_attempts),
901                        )) as BoxDynError),
902                    )
903                    .await?;
904                    return Ok(1);
905                }
906                let job = C::encode(job)
907                    .map_err(|e| (ErrorKind::IoError, "Encode error", e.into().to_string()))?;
908
909                let res: Result<i32, RedisError> = retry_job
910                    .key(inflight_set)
911                    .key(scheduled_jobs_set)
912                    .key(job_data_hash)
913                    .arg(task_id.to_string())
914                    .arg(now)
915                    .arg(job)
916                    .invoke_async(conn)
917                    .await;
918                match res {
919                    Ok(count) => Ok(count),
920                    Err(e) => Err(e),
921                }
922            }
923            None => Err(RedisError::from((ErrorKind::ResponseError, "Id not found"))),
924        }
925    }
926
927    /// Attempt to kill a job
928    pub async fn kill(
929        &mut self,
930        worker_id: &WorkerId,
931        task_id: &TaskId,
932        error: &BoxDynError,
933    ) -> Result<(), RedisError> {
934        let kill_job = self.scripts.kill_job.clone();
935        let current_worker_id = format!("{}:{}", self.config.inflight_jobs_set(), worker_id);
936        let job_data_hash = self.config.job_data_hash();
937        let dead_jobs_set = self.config.dead_jobs_set();
938        let now: i64 = Utc::now().timestamp();
939        kill_job
940            .key(current_worker_id)
941            .key(dead_jobs_set)
942            .key(job_data_hash)
943            .arg(task_id.to_string())
944            .arg(now)
945            .arg(error.to_string())
946            .invoke_async(&mut self.conn)
947            .await
948    }
949
950    /// Required to add scheduled jobs to the active set
951    pub async fn enqueue_scheduled(&mut self, count: usize) -> Result<usize, RedisError> {
952        let enqueue_jobs = self.scripts.enqueue_scheduled.clone();
953        let scheduled_jobs_set = self.config.scheduled_jobs_set();
954        let active_jobs_list = self.config.active_jobs_list();
955        let signal_list = self.config.signal_list();
956        let now: i64 = Utc::now().timestamp();
957        let res: Result<usize, _> = enqueue_jobs
958            .key(scheduled_jobs_set)
959            .key(active_jobs_list)
960            .key(signal_list)
961            .arg(now)
962            .arg(count)
963            .invoke_async(&mut self.conn)
964            .await;
965        match res {
966            Ok(count) => Ok(count),
967            Err(e) => Err(e),
968        }
969    }
970
971    /// Re-enqueue some jobs that might be abandoned.
972    pub async fn reenqueue_active(&mut self, job_ids: Vec<&TaskId>) -> Result<(), RedisError> {
973        let reenqueue_active = self.scripts.reenqueue_active.clone();
974        let inflight_set: String = self.config.inflight_jobs_set().to_string();
975        let active_jobs_list = self.config.active_jobs_list();
976        let signal_list = self.config.signal_list();
977
978        reenqueue_active
979            .key(inflight_set)
980            .key(active_jobs_list)
981            .key(signal_list)
982            .arg(
983                job_ids
984                    .into_iter()
985                    .map(|j| j.to_string())
986                    .collect::<Vec<String>>(),
987            )
988            .invoke_async(&mut self.conn)
989            .await
990    }
991    /// Re-enqueue some jobs that might be orphaned after a number of seconds
992    pub async fn reenqueue_orphaned(
993        &mut self,
994        count: i32,
995        dead_since: DateTime<Utc>,
996    ) -> Result<usize, RedisError> {
997        let reenqueue_orphaned = self.scripts.reenqueue_orphaned.clone();
998        let consumers_set = self.config.consumers_set();
999        let active_jobs_list = self.config.active_jobs_list();
1000        let signal_list = self.config.signal_list();
1001
1002        let dead_since = dead_since.timestamp();
1003
1004        let res: Result<usize, RedisError> = reenqueue_orphaned
1005            .key(consumers_set)
1006            .key(active_jobs_list)
1007            .key(signal_list)
1008            .arg(dead_since)
1009            .arg(count)
1010            .invoke_async(&mut self.conn)
1011            .await;
1012        match res {
1013            Ok(count) => Ok(count),
1014            Err(e) => Err(e),
1015        }
1016    }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use apalis_core::worker::Context;
1022    use apalis_core::{generic_storage_test, sleep};
1023    use email_service::Email;
1024
1025    use apalis_core::test_utils::apalis_test_service_fn;
1026    use apalis_core::test_utils::TestWrapper;
1027
1028    generic_storage_test!(setup);
1029
1030    use super::*;
1031
1032    /// migrate DB and return a storage instance.
1033    async fn setup<T: Serialize + DeserializeOwned>() -> RedisStorage<T> {
1034        let redis_url = std::env::var("REDIS_URL").expect("No REDIS_URL is specified");
1035        // Because connections cannot be shared across async runtime
1036        // (different runtimes are created for each test),
1037        // we don't share the storage and tests must be run sequentially.
1038        let conn = connect(redis_url).await.unwrap();
1039        let config = Config::default()
1040            .set_namespace("apalis::test")
1041            .set_enqueue_scheduled(Duration::from_millis(500)); // Instantly return jobs to the queue
1042        let mut storage = RedisStorage::new_with_config(conn, config);
1043        cleanup(&mut storage, &WorkerId::new("test-worker")).await;
1044        storage
1045    }
1046
1047    /// rollback DB changes made by tests.
1048    ///
1049    /// You should execute this function in the end of a test
1050    async fn cleanup<T>(storage: &mut RedisStorage<T>, _worker_id: &WorkerId) {
1051        let _resp: String = redis::cmd("FLUSHDB")
1052            .query_async(&mut storage.conn)
1053            .await
1054            .expect("failed to Flushdb");
1055    }
1056
1057    fn example_email() -> Email {
1058        Email {
1059            subject: "Test Subject".to_string(),
1060            to: "example@postgres".to_string(),
1061            text: "Some Text".to_string(),
1062        }
1063    }
1064
1065    async fn consume_one(
1066        storage: &mut RedisStorage<Email>,
1067        worker_id: &WorkerId,
1068    ) -> Request<Email, RedisContext> {
1069        let stream = storage.fetch_next(worker_id);
1070        stream
1071            .await
1072            .expect("failed to poll job")
1073            .first()
1074            .expect("stream is empty")
1075            .clone()
1076    }
1077
1078    async fn register_worker_at(storage: &mut RedisStorage<Email>) -> Worker<Context> {
1079        let worker = Worker::new(WorkerId::new("test-worker"), Context::default());
1080        worker.start();
1081        storage
1082            .keep_alive(&worker.id())
1083            .await
1084            .expect("failed to register worker");
1085        worker
1086    }
1087
1088    async fn register_worker(storage: &mut RedisStorage<Email>) -> Worker<Context> {
1089        register_worker_at(storage).await
1090    }
1091
1092    async fn push_email(storage: &mut RedisStorage<Email>, email: Email) {
1093        storage.push(email).await.expect("failed to push a job");
1094    }
1095
1096    async fn get_job(
1097        storage: &mut RedisStorage<Email>,
1098        job_id: &TaskId,
1099    ) -> Request<Email, RedisContext> {
1100        storage
1101            .fetch_by_id(job_id)
1102            .await
1103            .expect("failed to fetch job by id")
1104            .expect("no job found by id")
1105    }
1106
1107    #[tokio::test]
1108    async fn test_consume_last_pushed_job() {
1109        let mut storage = setup().await;
1110        push_email(&mut storage, example_email()).await;
1111
1112        let worker = register_worker(&mut storage).await;
1113
1114        let _job = consume_one(&mut storage, &worker.id()).await;
1115    }
1116
1117    #[tokio::test]
1118    async fn test_acknowledge_job() {
1119        let mut storage = setup().await;
1120        push_email(&mut storage, example_email()).await;
1121
1122        let worker = register_worker(&mut storage).await;
1123
1124        let job = consume_one(&mut storage, &worker.id()).await;
1125        let ctx = &job.parts.context;
1126        let res = 42usize;
1127        storage
1128            .ack(
1129                ctx,
1130                &Response::success(res, job.parts.task_id.clone(), job.parts.attempt.clone()),
1131            )
1132            .await
1133            .expect("failed to acknowledge the job");
1134
1135        let _job = get_job(&mut storage, &job.parts.task_id).await;
1136    }
1137
1138    #[tokio::test]
1139    async fn test_kill_job() {
1140        let mut storage = setup().await;
1141
1142        push_email(&mut storage, example_email()).await;
1143
1144        let worker = register_worker(&mut storage).await;
1145
1146        let job = consume_one(&mut storage, &worker.id()).await;
1147        let job_id = &job.parts.task_id;
1148
1149        storage
1150            .kill(
1151                &worker.id(),
1152                &job_id,
1153                &(Box::new(io::Error::new(
1154                    io::ErrorKind::Interrupted,
1155                    "Some unforeseen error occurred",
1156                )) as BoxDynError),
1157            )
1158            .await
1159            .expect("failed to kill job");
1160
1161        let _job = get_job(&mut storage, &job_id).await;
1162    }
1163
1164    #[tokio::test]
1165    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_1sec() {
1166        let mut storage = setup().await;
1167
1168        push_email(&mut storage, example_email()).await;
1169
1170        let worker = register_worker_at(&mut storage).await;
1171
1172        let job = consume_one(&mut storage, &worker.id()).await;
1173        sleep(Duration::from_millis(1000)).await;
1174        let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(1)).unwrap();
1175        let res = storage
1176            .reenqueue_orphaned(1, dead_since)
1177            .await
1178            .expect("failed to reenqueue_orphaned");
1179        // We expect 1 job to be re-enqueued
1180        assert_eq!(res, 1);
1181        let job = get_job(&mut storage, &job.parts.task_id).await;
1182        let ctx = &job.parts.context;
1183        // assert_eq!(*ctx.status(), State::Pending);
1184        // assert!(ctx.done_at().is_none());
1185        assert!(ctx.lock_by.is_none());
1186        // assert!(ctx.lock_at().is_none());
1187        // assert_eq!(*ctx.last_error(), Some("Job was abandoned".to_owned()));
1188        // TODO: Redis should store context aside
1189        // assert_eq!(job.parts.attempt.current(), 1);
1190    }
1191
1192    #[tokio::test]
1193    async fn test_heartbeat_renqueueorphaned_pulse_last_seen_5sec() {
1194        let mut storage = setup().await;
1195
1196        push_email(&mut storage, example_email()).await;
1197
1198        let worker = register_worker_at(&mut storage).await;
1199        sleep(Duration::from_millis(1100)).await;
1200        let job = consume_one(&mut storage, &worker.id()).await;
1201        let dead_since = Utc::now() - chrono::Duration::from_std(Duration::from_secs(5)).unwrap();
1202        let res = storage
1203            .reenqueue_orphaned(1, dead_since)
1204            .await
1205            .expect("failed to reenqueue_orphaned");
1206        // We expect 0 job to be re-enqueued
1207        assert_eq!(res, 0);
1208        let job = get_job(&mut storage, &job.parts.task_id).await;
1209        let _ctx = &job.parts.context;
1210        // assert_eq!(*ctx.status(), State::Running);
1211        // TODO: update redis context
1212        // assert_eq!(ctx.lock_by, Some(worker_id));
1213        // assert!(ctx.lock_at().is_some());
1214        // assert_eq!(*ctx.last_error(), None);
1215        assert_eq!(job.parts.attempt.current(), 0);
1216    }
1217
1218    #[tokio::test]
1219    async fn test_stats() {
1220        use apalis_core::backend::BackendExpose;
1221
1222        let mut storage = setup().await;
1223        let stats = storage.stats().await.expect("failed to get stats");
1224        assert_eq!(stats.pending, 0);
1225        assert_eq!(stats.running, 0);
1226        push_email(&mut storage, example_email()).await;
1227        let stats = storage.stats().await.expect("failed to get stats");
1228        assert_eq!(stats.pending, 1);
1229
1230        let worker = register_worker(&mut storage).await;
1231
1232        let _job = consume_one(&mut storage, &worker.id()).await;
1233
1234        let stats = storage.stats().await.expect("failed to get stats");
1235        assert_eq!(stats.pending, 0);
1236        assert_eq!(stats.running, 1);
1237    }
1238}