1#![forbid(unsafe_code)]
60#![warn(
61 clippy::await_holding_lock,
62 clippy::cargo_common_metadata,
63 clippy::dbg_macro,
64 clippy::empty_enum,
65 clippy::enum_glob_use,
66 clippy::inefficient_to_string,
67 clippy::mem_forget,
68 clippy::mutex_integer,
69 clippy::needless_continue,
70 clippy::todo,
71 clippy::unimplemented,
72 clippy::wildcard_imports,
73 future_incompatible,
74 missing_docs,
75 missing_debug_implementations,
76 unreachable_pub
77)]
78
79mod ack;
80use apalis_core::{
81 backend::Backend,
82 layers::AckLayer,
83 mq::MessageQueue,
84 poller::Poller,
85 request::{Parts, Request, RequestStream},
86 worker::{Context, Worker},
87};
88use deadpool_lapin::{Manager, Pool};
89use futures::StreamExt;
90use lapin::{
91 options::{BasicConsumeOptions, BasicPublishOptions, QueueDeclareOptions},
92 types::FieldTable,
93 BasicProperties, Channel, ConnectionProperties, Error, Queue,
94};
95use serde::{de::DeserializeOwned, Serialize};
96use std::{
97 fmt::Debug,
98 io::{self, ErrorKind},
99 marker::PhantomData,
100 sync::Arc,
101};
102use utils::{AmqpContext, AmqpMessage, Config, DeliveryTag};
103
104pub mod utils;
106
107#[derive(Debug)]
108pub struct AmqpBackend<M> {
110 channel: Channel,
111 queue: Queue,
112 message_type: PhantomData<M>,
113 config: Config,
114}
115
116impl<M> Clone for AmqpBackend<M> {
117 fn clone(&self) -> Self {
118 Self {
119 channel: self.channel.clone(),
120 queue: self.queue.clone(),
121 message_type: PhantomData,
122 config: self.config.clone(),
123 }
124 }
125}
126
127impl<M: Serialize + DeserializeOwned + Send + Sync + 'static> MessageQueue<M> for AmqpBackend<M> {
128 type Error = Error;
129 async fn enqueue(&mut self, message: M) -> Result<(), Self::Error> {
134 let _confirmation = self
135 .channel
136 .basic_publish(
137 "",
138 self.config.namespace().as_str(),
139 BasicPublishOptions::default(),
140 &serde_json::to_vec(&AmqpMessage {
141 inner: message,
142 task_id: Default::default(),
143 attempt: Default::default(),
144 })
145 .map_err(|e| Error::IOError(Arc::new(io::Error::new(ErrorKind::InvalidData, e))))?,
146 BasicProperties::default(),
147 )
148 .await?
149 .await?;
150 Ok(())
151 }
152
153 async fn size(&mut self) -> Result<usize, Self::Error> {
154 Ok(self.queue.message_count() as usize)
155 }
156
157 async fn dequeue(&mut self) -> Result<Option<M>, Self::Error> {
158 Ok(None)
159 }
160}
161
162impl<M: DeserializeOwned + Send + 'static, Res> Backend<Request<M, AmqpContext>, Res>
163 for AmqpBackend<M>
164{
165 type Layer = AckLayer<Self, M, AmqpContext, Res>;
166 type Stream = RequestStream<Request<M, AmqpContext>>;
167
168 fn poll<Svc>(self, worker: &Worker<Context>) -> Poller<Self::Stream, Self::Layer> {
169 let channel = self.channel.clone();
170 let worker = worker.clone();
171 let config = self.config.clone();
172 let stream = async_stream::stream! {
173 let mut consumer = channel
174 .basic_consume(
175 config.namespace().as_str(),
176 &worker.id().to_string(),
177 BasicConsumeOptions::default(),
178 FieldTable::default(),
179 )
180 .await
181 .map_err(|e| apalis_core::error::Error::SourceError(Arc::new(e.into())))?;
182
183 while let Some(Ok(item)) = consumer.next().await {
184 let bytes = item.data;
185 let tag = item.delivery_tag;
186 let msg = serde_json::from_slice(&bytes)
187 .map_err(|e| apalis_core::error::Error::SourceError(Arc::new(e.into()))).map(|req: AmqpMessage<M>| {
188 let mut parts = Parts::default();
189 parts.task_id = req.task_id;
190 parts.context = AmqpContext::new(DeliveryTag::new(tag));
191 parts.attempt = req.attempt;
192 parts.namespace = Some(config.namespace().to_owned());
193 parts.data = Default::default();
194 Request::new_with_parts(req.inner, parts)
195 })?;
196 yield Ok(Some(msg));
197
198 }
199 };
200 Poller::new_with_layer(stream.boxed(), std::future::pending(), AckLayer::new(self))
201 }
202}
203
204impl<M: Serialize + DeserializeOwned + Send + 'static> AmqpBackend<M> {
205 pub fn new(channel: Channel, queue: Queue) -> Self {
207 Self {
208 channel,
209 message_type: PhantomData,
210 queue,
211 config: Config::new(std::any::type_name::<M>()),
212 }
213 }
214
215 pub fn new_with_config(channel: Channel, queue: Queue, config: Config) -> Self {
217 Self {
218 channel,
219 message_type: PhantomData,
220 queue,
221 config,
222 }
223 }
224
225 pub fn channel(&self) -> &Channel {
227 &self.channel
228 }
229
230 pub fn queue(&self) -> &Queue {
232 &self.queue
233 }
234
235 pub fn config(&self) -> &Config {
237 &self.config
238 }
239
240 pub async fn new_from_addr<S: AsRef<str>>(addr: S) -> Result<Self, lapin::Error> {
245 let manager = Manager::new(addr.as_ref(), ConnectionProperties::default());
246 let pool: Pool = deadpool::managed::Pool::builder(manager)
247 .max_size(10)
248 .build()
249 .map_err(|error| {
250 lapin::Error::IOError(Arc::new(io::Error::new(
251 io::ErrorKind::ConnectionAborted,
252 error,
253 )))
254 })?;
255 let amqp_conn = pool.get().await.map_err(|error| {
256 lapin::Error::IOError(Arc::new(io::Error::new(
257 io::ErrorKind::ConnectionRefused,
258 error,
259 )))
260 })?;
261 let channel = amqp_conn.create_channel().await?;
262 let queue = channel
263 .queue_declare(
264 std::any::type_name::<M>(),
265 QueueDeclareOptions::default(),
266 FieldTable::default(),
267 )
268 .await?;
269 Ok(Self::new(channel, queue))
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use apalis_core::{builder::WorkerBuilder, builder::WorkerFactoryFn};
277 use serde::Deserialize;
278
279 #[derive(Debug, Serialize, Deserialize)]
280 struct TestMessage;
281
282 async fn test_job(_job: TestMessage) {}
283
284 #[tokio::test]
285 async fn it_works() {
286 let env = std::env::var("AMQP_ADDR").unwrap();
287 let amqp_backend = AmqpBackend::new_from_addr(&env).await.unwrap();
288 let _worker = WorkerBuilder::new("rango-amigo")
289 .backend(amqp_backend)
290 .build_fn(test_job);
291 }
292}