1#![doc = include_str!("../README.md")]
2#![warn(
3 missing_debug_implementations,
4 missing_docs,
5 rust_2018_idioms,
6 unreachable_pub,
7 bad_style,
8 dead_code,
9 improper_ctypes,
10 non_shorthand_field_patterns,
11 no_mangle_generic_items,
12 overflowing_literals,
13 path_statements,
14 patterns_in_fns_without_body,
15 unconditional_recursion,
16 unused,
17 unused_allocation,
18 unused_comparisons,
19 unused_parens,
20 while_true
21)]
22
23use std::{
24 fmt::Debug,
25 marker::PhantomData,
26 sync::{
27 Arc,
28 atomic::{AtomicUsize, Ordering},
29 },
30 time::Instant,
31};
32
33use apalis_core::{
34 backend::{
35 Backend, TaskStream,
36 codec::{Codec, json::JsonCodec},
37 poll_strategy::{PollContext, PollStrategyExt},
38 },
39 features_table,
40 task::{Task, attempt::Attempt, builder::TaskBuilder, task_id::TaskId},
41 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
42};
43use futures::{
44 StreamExt, future,
45 stream::{self, BoxStream},
46};
47use rsmq_async::{Rsmq, RsmqConnection, RsmqError};
48use serde::de::DeserializeOwned;
49use tracing::{error, trace, warn};
50
51use crate::sink::RsMqSink;
52
53mod ack;
54mod config;
55mod context;
56mod sink;
57
58pub use crate::config::Config;
59pub use crate::context::RedisMqContext;
60
61type RsMqTask<T> = Task<T, RedisMqContext, String>;
62
63pin_project_lite::pin_project! {
64 #[doc = features_table! {
67 setup = {
68 use apalis_rsmq::Config;
69 use apalis_rsmq::RedisMq;
70 use rsmq_async::RsmqConnection;
71
72 let mut conn = rsmq_async::Rsmq::new(Default::default()).await.unwrap();
73 let _ = conn.create_queue("test", None, None, None).await;
74 let mut config = Config::default();
75 config.set_namespace("test".to_owned());
76 let mut mq = RedisMq::new(conn, config);
77 mq
78 };,
79 TaskSink => supported("Ability to push new tasks"),
80 Serialization => supported("Serialization support for arguments. Accepts any bytes codec", false),
81 FetchById => not_implemented("Allow fetching a task by its ID"),
82 RegisterWorker => not_supported("Allow registering a worker with the backend"),
83 PipeExt => supported("Allow other backends to pipe to this backend", false),
84 MakeShared => supported("Share the same JSON storage across multiple workers", false),
85 Workflow => supported("Flexible enough to support workflows", false),
86 WaitForCompletion => supported("Wait for tasks to complete without blocking", false),
87 ResumeById => not_implemented("Resume a task by its ID"),
88 ResumeAbandoned => not_implemented("Resume abandoned tasks"),
89 ListWorkers => not_supported("List all workers registered with the backend"),
90 ListTasks => not_implemented("List all tasks in the backend"),
91 }]
92 #[derive(Debug)]
93 pub struct RedisMq<T, C = JsonCodec<Vec<u8>>> {
94 conn: Rsmq,
95 config: Config,
96 msg_type: PhantomData<T>,
97 codec: PhantomData<C>,
98 #[pin]
99 sink: RsMqSink<T, C>,
100 }
101}
102
103impl<T> RedisMq<T, JsonCodec<Vec<u8>>> {
104 pub fn new(conn: Rsmq, config: Config) -> RedisMq<T, JsonCodec<Vec<u8>>> {
106 RedisMq {
107 sink: RsMqSink::new(conn.clone(), config.clone()),
108 conn,
109 config,
110 msg_type: PhantomData,
111 codec: PhantomData,
112 }
113 }
114}
115
116impl<T, C> RedisMq<T, C> {
117 pub fn config(&self) -> &Config {
119 &self.config
120 }
121}
122
123impl<T, C> Clone for RedisMq<T, C> {
125 fn clone(&self) -> Self {
126 RedisMq {
127 conn: self.conn.clone(),
128 msg_type: PhantomData,
129 config: self.config.clone(),
130 codec: PhantomData,
131 sink: RsMqSink::new(self.conn.clone(), self.config.clone()),
132 }
133 }
134}
135
136impl<Args, C> Backend<Args> for RedisMq<Args, C>
137where
138 Args: Send + DeserializeOwned + 'static,
139 C: Codec<PrimitiveMessage<Args>, Compact = Vec<u8>>,
140 C::Error: std::error::Error + Send,
141{
142 type Stream = TaskStream<RsMqTask<Args>, RsmqError>;
143 type Layer = AcknowledgeLayer<Self>;
144 type Codec = C;
145 type Context = RedisMqContext;
146 type Error = RsmqError;
147 type Beat = BoxStream<'static, Result<(), RsmqError>>;
148 type IdType = String;
149
150 fn heartbeat(&self, _: &WorkerContext) -> Self::Beat {
151 stream::once(future::ready(Ok(()))).boxed()
152 }
153
154 fn middleware(&self) -> Self::Layer {
155 AcknowledgeLayer::new(self.clone())
156 }
157
158 fn poll(self, worker: &WorkerContext) -> Self::Stream {
159 let poll_strategy = self.config.poll_strategy().clone();
160 let prev_count = Arc::new(AtomicUsize::new(0));
161 let ctx = PollContext::new(worker.clone(), prev_count.clone());
162 let throttle = poll_strategy.build_stream(&ctx);
163
164 let stream = futures::stream::unfold(throttle, move |mut throttle| {
165 let mut conn = self.conn.clone();
166 let namespace = self.config.namespace().to_string();
167 let prev_count = prev_count.clone();
168 async move {
169 let instant = Instant::now();
170 throttle.next().await;
171 trace!("Polling new messages after {:?}", instant.elapsed());
172 match conn.receive_message(&namespace, None).await {
173 Ok(Some(r)) => match C::decode(&r.message) {
174 Ok(msg) => {
175 let task = TaskBuilder::new(msg.task)
176 .with_ctx(msg.context)
177 .with_task_id(TaskId::new(r.id))
178 .with_attempt(Attempt::new_with_value(r.rc as usize))
179 .build();
180 prev_count.store(1, Ordering::SeqCst);
181 Some((Ok(Some(task)), throttle))
182 }
183 Err(e) => {
184 error!("Failed to decode message: {:?}", e);
185 Some((Err(RsmqError::InvalidFormat(e.to_string())), throttle))
186 }
187 },
188 Ok(None) => {
189 prev_count.store(0, Ordering::SeqCst);
190 Some((Ok(None), throttle))
191 }
192 Err(e) => {
193 error!("Error receiving message: {:?}", e);
194 Some((Err(e), throttle))
195 }
196 }
197 }
198 });
199
200 stream.boxed()
201 }
202}
203
204#[derive(serde::Serialize, serde::Deserialize, Debug)]
205struct PrimitiveMessage<T> {
206 task: T,
207 context: RedisMqContext,
208}