apalis_rsmq/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(
3    missing_debug_implementations,
4    missing_docs,
5    rust_2018_idioms,
6    unreachable_pub,
7    bad_style,
8    dead_code,
9    improper_ctypes,
10    non_shorthand_field_patterns,
11    no_mangle_generic_items,
12    overflowing_literals,
13    path_statements,
14    patterns_in_fns_without_body,
15    unconditional_recursion,
16    unused,
17    unused_allocation,
18    unused_comparisons,
19    unused_parens,
20    while_true
21)]
22
23use std::{
24    fmt::Debug,
25    marker::PhantomData,
26    sync::{
27        Arc,
28        atomic::{AtomicUsize, Ordering},
29    },
30    time::Instant,
31};
32
33use apalis_core::{
34    backend::{
35        Backend, TaskStream,
36        codec::{Codec, json::JsonCodec},
37        poll_strategy::{PollContext, PollStrategyExt},
38    },
39    features_table,
40    task::{Task, attempt::Attempt, builder::TaskBuilder, task_id::TaskId},
41    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
42};
43use futures::{
44    StreamExt, future,
45    stream::{self, BoxStream},
46};
47use rsmq_async::{Rsmq, RsmqConnection, RsmqError};
48use serde::de::DeserializeOwned;
49use tracing::{error, trace, warn};
50
51use crate::sink::RsMqSink;
52
53mod ack;
54mod config;
55mod context;
56mod sink;
57
58pub use crate::config::Config;
59pub use crate::context::RedisMqContext;
60
61type RsMqTask<T> = Task<T, RedisMqContext, String>;
62
63pin_project_lite::pin_project! {
64    /// Redis-backed message queue
65    ///
66    #[doc = features_table! {
67        setup = {
68            use apalis_rsmq::Config;
69            use apalis_rsmq::RedisMq;
70            use rsmq_async::RsmqConnection;
71
72            let mut conn = rsmq_async::Rsmq::new(Default::default()).await.unwrap();
73            let _ = conn.create_queue("test", None, None, None).await;
74            let mut config = Config::default();
75            config.set_namespace("test".to_owned());
76            let mut mq = RedisMq::new(conn, config);
77            mq
78        };,
79        TaskSink => supported("Ability to push new tasks"),
80        Serialization => supported("Serialization support for arguments. Accepts any bytes codec", false),
81        FetchById => not_implemented("Allow fetching a task by its ID"),
82        RegisterWorker => not_supported("Allow registering a worker with the backend"),
83        PipeExt => supported("Allow other backends to pipe to this backend", false),
84        MakeShared => supported("Share the same JSON storage across multiple workers", false),
85        Workflow => supported("Flexible enough to support workflows", false),
86        WaitForCompletion => supported("Wait for tasks to complete without blocking", false),
87        ResumeById => not_implemented("Resume a task by its ID"),
88        ResumeAbandoned => not_implemented("Resume abandoned tasks"),
89        ListWorkers => not_supported("List all workers registered with the backend"),
90        ListTasks => not_implemented("List all tasks in the backend"),
91    }]
92    #[derive(Debug)]
93    pub struct RedisMq<T, C = JsonCodec<Vec<u8>>> {
94        conn: Rsmq,
95        config: Config,
96        msg_type: PhantomData<T>,
97        codec: PhantomData<C>,
98        #[pin]
99        sink: RsMqSink<T, C>,
100    }
101}
102
103impl<T> RedisMq<T, JsonCodec<Vec<u8>>> {
104    /// Creates a new RedisMq instance
105    pub fn new(conn: Rsmq, config: Config) -> RedisMq<T, JsonCodec<Vec<u8>>> {
106        RedisMq {
107            sink: RsMqSink::new(conn.clone(), config.clone()),
108            conn,
109            config,
110            msg_type: PhantomData,
111            codec: PhantomData,
112        }
113    }
114}
115
116impl<T, C> RedisMq<T, C> {
117    /// Gets the configuration
118    pub fn config(&self) -> &Config {
119        &self.config
120    }
121}
122
123// Implement Clone manually
124impl<T, C> Clone for RedisMq<T, C> {
125    fn clone(&self) -> Self {
126        RedisMq {
127            conn: self.conn.clone(),
128            msg_type: PhantomData,
129            config: self.config.clone(),
130            codec: PhantomData,
131            sink: RsMqSink::new(self.conn.clone(), self.config.clone()),
132        }
133    }
134}
135
136impl<Args, C> Backend<Args> for RedisMq<Args, C>
137where
138    Args: Send + DeserializeOwned + 'static,
139    C: Codec<PrimitiveMessage<Args>, Compact = Vec<u8>>,
140    C::Error: std::error::Error + Send,
141{
142    type Stream = TaskStream<RsMqTask<Args>, RsmqError>;
143    type Layer = AcknowledgeLayer<Self>;
144    type Codec = C;
145    type Context = RedisMqContext;
146    type Error = RsmqError;
147    type Beat = BoxStream<'static, Result<(), RsmqError>>;
148    type IdType = String;
149
150    fn heartbeat(&self, _: &WorkerContext) -> Self::Beat {
151        stream::once(future::ready(Ok(()))).boxed()
152    }
153
154    fn middleware(&self) -> Self::Layer {
155        AcknowledgeLayer::new(self.clone())
156    }
157
158    fn poll(self, worker: &WorkerContext) -> Self::Stream {
159        let poll_strategy = self.config.poll_strategy().clone();
160        let prev_count = Arc::new(AtomicUsize::new(0));
161        let ctx = PollContext::new(worker.clone(), prev_count.clone());
162        let throttle = poll_strategy.build_stream(&ctx);
163
164        let stream = futures::stream::unfold(throttle, move |mut throttle| {
165            let mut conn = self.conn.clone();
166            let namespace = self.config.namespace().to_string();
167            let prev_count = prev_count.clone();
168            async move {
169                let instant = Instant::now();
170                throttle.next().await;
171                trace!("Polling new messages after {:?}", instant.elapsed());
172                match conn.receive_message(&namespace, None).await {
173                    Ok(Some(r)) => match C::decode(&r.message) {
174                        Ok(msg) => {
175                            let task = TaskBuilder::new(msg.task)
176                                .with_ctx(msg.context)
177                                .with_task_id(TaskId::new(r.id))
178                                .with_attempt(Attempt::new_with_value(r.rc as usize))
179                                .build();
180                            prev_count.store(1, Ordering::SeqCst);
181                            Some((Ok(Some(task)), throttle))
182                        }
183                        Err(e) => {
184                            error!("Failed to decode message: {:?}", e);
185                            Some((Err(RsmqError::InvalidFormat(e.to_string())), throttle))
186                        }
187                    },
188                    Ok(None) => {
189                        prev_count.store(0, Ordering::SeqCst);
190                        Some((Ok(None), throttle))
191                    }
192                    Err(e) => {
193                        error!("Error receiving message: {:?}", e);
194                        Some((Err(e), throttle))
195                    }
196                }
197            }
198        });
199
200        stream.boxed()
201    }
202}
203
204#[derive(serde::Serialize, serde::Deserialize, Debug)]
205struct PrimitiveMessage<T> {
206    task: T,
207    context: RedisMqContext,
208}