1#![allow(dead_code)]
3use super::{Broker, BrokerBuilder, DeliveryError, DeliveryStream};
4use crate::error::{BrokerError, ProtocolError};
5use crate::protocol::Delivery;
6use crate::protocol::DeliveryInfo;
7use crate::protocol::Message;
8use crate::protocol::TryDeserializeMessage;
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use deadpool_redis::{Config as PoolConfig, Pool, Runtime};
12use futures::{Future, Stream};
13use log::{debug, error, warn};
14use redis::RedisError;
15use std::clone::Clone;
16use std::collections::HashSet;
17use std::fmt;
18use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
19use std::sync::Arc;
20use std::task::{Poll, Waker};
21use tokio::sync::mpsc::{channel, Receiver, Sender};
22use uuid::Uuid;
23
24#[cfg(test)]
25use std::any::Any;
26
27struct Config {
28 broker_url: String,
29 prefetch_count: u16,
30 queues: HashSet<String>,
31 heartbeat: Option<u16>,
32}
33
34pub struct RedisBrokerBuilder {
35 config: Config,
36}
37
38#[async_trait]
39impl BrokerBuilder for RedisBrokerBuilder {
40 fn new(broker_url: &str) -> Self {
42 RedisBrokerBuilder {
43 config: Config {
44 broker_url: broker_url.into(),
45 prefetch_count: 10,
46 queues: HashSet::new(),
47 heartbeat: Some(60),
48 },
49 }
50 }
51
52 fn prefetch_count(mut self: Box<Self>, prefetch_count: u16) -> Box<dyn BrokerBuilder> {
54 self.config.prefetch_count = prefetch_count;
55 self
56 }
57
58 fn declare_queue(mut self: Box<Self>, name: &str) -> Box<dyn BrokerBuilder> {
60 self.config.queues.insert(name.into());
61 self
62 }
63
64 fn heartbeat(mut self: Box<Self>, heartbeat: Option<u16>) -> Box<dyn BrokerBuilder> {
66 if heartbeat.is_some() {
67 warn!("Setting heartbeat on redis broker has no effect on anything");
68 }
69 self.config.heartbeat = heartbeat;
70 self
71 }
72
73 async fn build(&self, _connection_timeout: u32) -> Result<Box<dyn Broker>, BrokerError> {
75 let mut queues: HashSet<String> = HashSet::new();
76 for queue_name in &self.config.queues {
77 queues.insert(queue_name.into());
78 }
79
80 log::info!("Creating deadpool-redis pool");
81 let pool_config = PoolConfig::from_url(&self.config.broker_url);
82 let pool = pool_config
83 .create_pool(Some(Runtime::Tokio1))
84 .map_err(|e| BrokerError::InvalidBrokerUrl(format!("Pool creation failed: {}", e)))?;
85
86 log::info!("Creating mpsc channel");
87 let (tx, rx) = channel(1);
88 log::info!("Creating broker with connection pool");
89 Ok(Box::new(RedisBroker {
90 uri: self.config.broker_url.clone(),
91 queues,
92 pool,
93 prefetch_count: Arc::new(AtomicU16::new(self.config.prefetch_count)),
94 pending_tasks: Arc::new(AtomicU16::new(0)),
95 waker_rx: tokio::sync::Mutex::new(rx),
96 waker_tx: tx,
97 is_closed: Arc::new(AtomicBool::new(false)),
98 delivery_info: DeliveryInfo::for_redis_default(),
99 }))
100 }
101}
102
103pub struct RedisBroker {
104 uri: String,
105 pool: Pool,
107 queues: HashSet<String>,
109
110 prefetch_count: Arc<AtomicU16>,
113 pending_tasks: Arc<AtomicU16>,
114 waker_rx: tokio::sync::Mutex<Receiver<Waker>>,
115 waker_tx: Sender<Waker>,
116 is_closed: Arc<AtomicBool>,
118
119 delivery_info: DeliveryInfo,
120}
121
122#[derive(Clone)]
123pub struct Channel {
124 pool: Pool,
125 queue_name: String,
126 delivery_info: DeliveryInfo,
127}
128
129impl fmt::Debug for Channel {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 write!(f, "Channel {{ {} }}", self.queue_name)
132 }
133}
134
135impl Channel {
136 fn new(pool: Pool, queue_name: String, delivery_info: DeliveryInfo) -> Self {
137 Self {
138 pool,
139 queue_name,
140 delivery_info,
141 }
142 }
143
144 fn process_map_name(&self) -> String {
145 format!("_celery.{}_process_map", self.queue_name)
146 }
147
148 async fn fetch_task(
149 self,
150 send_waker: Option<(Sender<Waker>, Waker)>,
151 ) -> Result<Delivery, BrokerError> {
152 if let Some((sender, waker)) = send_waker {
153 sender.send(waker).await.unwrap();
154 futures::pending!();
155 }
156
157 let mut conn = self.pool.get().await.map_err(|e| {
158 BrokerError::IoError(std::io::Error::new(
159 std::io::ErrorKind::ConnectionRefused,
160 e,
161 ))
162 })?;
163
164 loop {
165 let rez: Result<Option<String>, RedisError> = redis::cmd("RPOP")
166 .arg(&self.queue_name)
167 .query_async(&mut *conn)
168 .await;
169 match rez {
170 Ok(None) => tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await,
171 Ok(Some(rez)) => {
172 let delivery: Delivery = serde_json::from_str(&rez[..])?;
173 debug!(
174 "Received msg: {} / {}",
175 delivery.properties.delivery_tag, delivery.headers.task
176 );
177 let _set_rez: u32 = redis::cmd("HSET")
178 .arg(self.process_map_name())
179 .arg(&delivery.properties.correlation_id)
180 .arg(&rez)
181 .query_async(&mut *conn)
182 .await?;
183 break Ok(delivery);
184 }
185 Err(err) => break Err(err.into()),
186 }
187 }
188 }
189
190 async fn send_task(self, message: &Message) -> Result<(), BrokerError> {
191 let mut conn = self.pool.get().await.map_err(|e| {
192 BrokerError::IoError(std::io::Error::new(
193 std::io::ErrorKind::ConnectionRefused,
194 e,
195 ))
196 })?;
197 Ok(redis::cmd("LPUSH")
198 .arg(&self.queue_name)
199 .arg(message.json_serialized(Some(self.delivery_info))?)
200 .query_async(&mut *conn)
201 .await?)
202 }
203
204 async fn resend_task(&self, delivery: &Delivery) -> Result<(), BrokerError> {
205 let mut message = delivery.clone().try_deserialize_message()?;
206 let retries = message.headers.retries.unwrap_or_default();
207 message.headers.retries = Some(retries + 1);
208 self.clone().send_task(&message).await?;
209 Ok(())
210 }
211
212 async fn remove_task(&self, delivery: &Delivery) -> Result<(), BrokerError> {
213 let mut conn = self.pool.get().await.map_err(|e| {
214 BrokerError::IoError(std::io::Error::new(
215 std::io::ErrorKind::ConnectionRefused,
216 e,
217 ))
218 })?;
219 redis::cmd("HDEL")
220 .arg(self.process_map_name())
221 .arg(&delivery.properties.correlation_id)
222 .query_async::<()>(&mut *conn)
223 .await?;
224 Ok(())
225 }
226}
227
228type ConsumerOutput = Result<Delivery, BrokerError>;
229type ConsumerOutputFuture = Box<dyn Future<Output = ConsumerOutput> + Send>;
230
231pub struct Consumer {
232 channel: Channel,
233 error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
234 polled_pop: Option<std::pin::Pin<ConsumerOutputFuture>>,
235 pending_tasks: Arc<AtomicU16>,
236 waker_tx: Sender<Waker>,
237 prefetch_count: Arc<AtomicU16>,
238}
239
240impl DeliveryStream for Consumer {}
241
242#[async_trait]
243impl super::Delivery for (Channel, Delivery) {
244 async fn resend(
245 &self,
246 _broker: &dyn Broker,
247 _eta: Option<DateTime<Utc>>,
248 ) -> Result<(), BrokerError> {
249 self.0.resend_task(&self.1).await?;
250 Ok(())
251 }
252
253 async fn remove(&self) -> Result<(), BrokerError> {
254 self.0.remove_task(&self.1).await?;
255 Ok(())
256 }
257
258 async fn ack(&self) -> Result<(), BrokerError> {
259 todo!()
260 }
261}
262
263impl TryDeserializeMessage for (Channel, Delivery) {
264 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
265 self.1.try_deserialize_message()
266 }
267}
268
269impl Stream for Consumer {
270 type Item = Result<Box<dyn super::Delivery>, Box<dyn DeliveryError>>;
271 fn poll_next(
272 mut self: std::pin::Pin<&mut Self>,
273 cx: &mut std::task::Context<'_>,
274 ) -> Poll<std::option::Option<<Self as futures::Stream>::Item>> {
275 if self.pending_tasks.load(Ordering::SeqCst) >= self.prefetch_count.load(Ordering::SeqCst)
280 && self.prefetch_count.load(Ordering::SeqCst) > 0
281 {
282 debug!("Pending tasks limit reached");
283 return Poll::Pending;
284 }
285 let mut polled_pop = if self.polled_pop.is_none() {
286 Box::pin(self.channel.clone().fetch_task(None))
287 } else {
288 self.polled_pop.take().unwrap()
289 };
290 if let Poll::Ready(item) = Future::poll(polled_pop.as_mut(), cx) {
291 match item {
292 Ok(item) => {
293 self.pending_tasks.fetch_add(1, Ordering::SeqCst);
294 Poll::Ready(Some(Ok(Box::new((self.channel.clone(), item)))))
295 }
296 Err(err) => {
297 (self.error_handler)(err);
298 cx.waker().wake_by_ref();
299 Poll::Pending
300 }
301 }
302 } else {
303 self.polled_pop = Some(polled_pop);
304 Poll::Pending
305 }
306 }
307}
308
309#[async_trait]
310impl Broker for RedisBroker {
311 async fn consume(
319 &self,
320 queue: &str,
321 error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
322 ) -> Result<(String, Box<dyn DeliveryStream>), BrokerError> {
323 let consumer = Consumer {
324 channel: Channel {
325 pool: self.pool.clone(),
326 queue_name: queue.to_string(),
327 delivery_info: self.delivery_info.clone(),
328 },
329 error_handler,
330 polled_pop: None,
331 prefetch_count: Arc::clone(&self.prefetch_count),
332 pending_tasks: Arc::clone(&self.pending_tasks),
333 waker_tx: self.waker_tx.clone(),
334 };
335
336 let mut buffer = Uuid::encode_buffer();
338 let uuid = Uuid::new_v4().hyphenated().encode_lower(&mut buffer);
339 let consumer_tag = uuid.to_owned();
340
341 Ok((consumer_tag, Box::new(consumer)))
342 }
343
344 async fn cancel(&self, _consumer_tag: &str) -> Result<(), BrokerError> {
345 Ok(())
346 }
347
348 async fn ack(&self, delivery: &dyn super::Delivery) -> Result<(), BrokerError> {
350 self.pending_tasks.fetch_sub(1, Ordering::SeqCst);
351 delivery.remove().await?;
352 let mut waker_rx = self.waker_rx.lock().await;
353 let dummy_waker = futures::task::noop_waker_ref();
355 let mut dummy_ctx = std::task::Context::from_waker(dummy_waker);
356 if let Poll::Ready(Some(waker)) = waker_rx.poll_recv(&mut dummy_ctx) {
357 waker.wake();
358 }
359 Ok(())
360 }
361
362 async fn retry(
364 &self,
365 delivery: &dyn super::Delivery,
366 eta: Option<DateTime<Utc>>,
367 ) -> Result<(), BrokerError> {
368 delivery.resend(self, eta).await?;
369 Ok(())
371 }
372
373 async fn send(&self, message: &Message, queue: &str) -> Result<(), BrokerError> {
375 if self.is_closed.load(Ordering::SeqCst) {
377 return Err(BrokerError::NotConnected);
378 }
379
380 Channel::new(
381 self.pool.clone(),
382 queue.to_string(),
383 self.delivery_info.clone(),
384 )
385 .send_task(message)
386 .await?;
387 Ok(())
388 }
389
390 async fn increase_prefetch_count(&self) -> Result<(), BrokerError> {
393 self.prefetch_count.fetch_add(1, Ordering::SeqCst);
394 Ok(())
395 }
396
397 async fn decrease_prefetch_count(&self) -> Result<(), BrokerError> {
400 self.prefetch_count.fetch_sub(1, Ordering::SeqCst);
401 Ok(())
402 }
403
404 async fn close(&self) -> Result<(), BrokerError> {
406 self.is_closed.store(true, Ordering::SeqCst);
408 Ok(())
409 }
410
411 fn safe_url(&self) -> String {
412 let parsed_url = redis::parse_redis_url(&self.uri[..]);
413 match parsed_url {
414 Some(url) => format!(
415 "{}://{}:***@{}:{}/{}",
416 url.scheme(),
417 url.username(),
418 url.host_str().unwrap(),
419 url.port().unwrap(),
420 url.path(),
421 ),
422 None => {
423 error!("Invalid redis url.");
424 String::from("")
425 }
426 }
427 }
428
429 async fn reconnect(&self, _connection_timeout: u32) -> Result<(), BrokerError> {
430 let mut conn = self.pool.get().await.map_err(|e| {
433 BrokerError::IoError(std::io::Error::new(
434 std::io::ErrorKind::ConnectionRefused,
435 e,
436 ))
437 })?;
438
439 let _: String = redis::cmd("PING").query_async(&mut *conn).await?;
441
442 self.is_closed.store(false, Ordering::SeqCst);
444 Ok(())
445 }
446
447 #[cfg(test)]
448 fn into_any(self: Box<Self>) -> Box<dyn Any> {
449 self
450 }
451}