celery/broker/
redis.rs

1//! Redis broker using deadpool-redis for better multithreading support.
2#![allow(dead_code)]
3use super::{Broker, BrokerBuilder, DeliveryError, DeliveryStream};
4use crate::error::{BrokerError, ProtocolError};
5use crate::protocol::Delivery;
6use crate::protocol::DeliveryInfo;
7use crate::protocol::Message;
8use crate::protocol::TryDeserializeMessage;
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use deadpool_redis::{Config as PoolConfig, Pool, Runtime};
12use futures::{Future, Stream};
13use log::{debug, error, warn};
14use redis::RedisError;
15use std::clone::Clone;
16use std::collections::HashSet;
17use std::fmt;
18use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
19use std::sync::Arc;
20use std::task::{Poll, Waker};
21use tokio::sync::mpsc::{channel, Receiver, Sender};
22use uuid::Uuid;
23
24#[cfg(test)]
25use std::any::Any;
26
27struct Config {
28    broker_url: String,
29    prefetch_count: u16,
30    queues: HashSet<String>,
31    heartbeat: Option<u16>,
32}
33
34pub struct RedisBrokerBuilder {
35    config: Config,
36}
37
38#[async_trait]
39impl BrokerBuilder for RedisBrokerBuilder {
40    /// Create a new `BrokerBuilder`.
41    fn new(broker_url: &str) -> Self {
42        RedisBrokerBuilder {
43            config: Config {
44                broker_url: broker_url.into(),
45                prefetch_count: 10,
46                queues: HashSet::new(),
47                heartbeat: Some(60),
48            },
49        }
50    }
51
52    /// Set the prefetch count.
53    fn prefetch_count(mut self: Box<Self>, prefetch_count: u16) -> Box<dyn BrokerBuilder> {
54        self.config.prefetch_count = prefetch_count;
55        self
56    }
57
58    /// Declare a queue.
59    fn declare_queue(mut self: Box<Self>, name: &str) -> Box<dyn BrokerBuilder> {
60        self.config.queues.insert(name.into());
61        self
62    }
63
64    /// Set the heartbeat.
65    fn heartbeat(mut self: Box<Self>, heartbeat: Option<u16>) -> Box<dyn BrokerBuilder> {
66        if heartbeat.is_some() {
67            warn!("Setting heartbeat on redis broker has no effect on anything");
68        }
69        self.config.heartbeat = heartbeat;
70        self
71    }
72
73    /// Construct the `Broker` with the given configuration using deadpool-redis.
74    async fn build(&self, _connection_timeout: u32) -> Result<Box<dyn Broker>, BrokerError> {
75        let mut queues: HashSet<String> = HashSet::new();
76        for queue_name in &self.config.queues {
77            queues.insert(queue_name.into());
78        }
79
80        log::info!("Creating deadpool-redis pool");
81        let pool_config = PoolConfig::from_url(&self.config.broker_url);
82        let pool = pool_config
83            .create_pool(Some(Runtime::Tokio1))
84            .map_err(|e| BrokerError::InvalidBrokerUrl(format!("Pool creation failed: {}", e)))?;
85
86        log::info!("Creating mpsc channel");
87        let (tx, rx) = channel(1);
88        log::info!("Creating broker with connection pool");
89        Ok(Box::new(RedisBroker {
90            uri: self.config.broker_url.clone(),
91            queues,
92            pool,
93            prefetch_count: Arc::new(AtomicU16::new(self.config.prefetch_count)),
94            pending_tasks: Arc::new(AtomicU16::new(0)),
95            waker_rx: tokio::sync::Mutex::new(rx),
96            waker_tx: tx,
97            is_closed: Arc::new(AtomicBool::new(false)),
98            delivery_info: DeliveryInfo::for_redis_default(),
99        }))
100    }
101}
102
103pub struct RedisBroker {
104    uri: String,
105    /// Redis connection pool for multithreading
106    pool: Pool,
107    /// Mapping of queue name to Queue struct.
108    queues: HashSet<String>,
109
110    /// Need to keep track of prefetch count. We put this behind a mutex to get interior
111    /// mutability.
112    prefetch_count: Arc<AtomicU16>,
113    pending_tasks: Arc<AtomicU16>,
114    waker_rx: tokio::sync::Mutex<Receiver<Waker>>,
115    waker_tx: Sender<Waker>,
116    /// Track whether the broker is closed
117    is_closed: Arc<AtomicBool>,
118
119    delivery_info: DeliveryInfo,
120}
121
122#[derive(Clone)]
123pub struct Channel {
124    pool: Pool,
125    queue_name: String,
126    delivery_info: DeliveryInfo,
127}
128
129impl fmt::Debug for Channel {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        write!(f, "Channel {{ {} }}", self.queue_name)
132    }
133}
134
135impl Channel {
136    fn new(pool: Pool, queue_name: String, delivery_info: DeliveryInfo) -> Self {
137        Self {
138            pool,
139            queue_name,
140            delivery_info,
141        }
142    }
143
144    fn process_map_name(&self) -> String {
145        format!("_celery.{}_process_map", self.queue_name)
146    }
147
148    async fn fetch_task(
149        self,
150        send_waker: Option<(Sender<Waker>, Waker)>,
151    ) -> Result<Delivery, BrokerError> {
152        if let Some((sender, waker)) = send_waker {
153            sender.send(waker).await.unwrap();
154            futures::pending!();
155        }
156
157        let mut conn = self.pool.get().await.map_err(|e| {
158            BrokerError::IoError(std::io::Error::new(
159                std::io::ErrorKind::ConnectionRefused,
160                e,
161            ))
162        })?;
163
164        loop {
165            let rez: Result<Option<String>, RedisError> = redis::cmd("RPOP")
166                .arg(&self.queue_name)
167                .query_async(&mut *conn)
168                .await;
169            match rez {
170                Ok(None) => tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await,
171                Ok(Some(rez)) => {
172                    let delivery: Delivery = serde_json::from_str(&rez[..])?;
173                    debug!(
174                        "Received msg: {} / {}",
175                        delivery.properties.delivery_tag, delivery.headers.task
176                    );
177                    let _set_rez: u32 = redis::cmd("HSET")
178                        .arg(self.process_map_name())
179                        .arg(&delivery.properties.correlation_id)
180                        .arg(&rez)
181                        .query_async(&mut *conn)
182                        .await?;
183                    break Ok(delivery);
184                }
185                Err(err) => break Err(err.into()),
186            }
187        }
188    }
189
190    async fn send_task(self, message: &Message) -> Result<(), BrokerError> {
191        let mut conn = self.pool.get().await.map_err(|e| {
192            BrokerError::IoError(std::io::Error::new(
193                std::io::ErrorKind::ConnectionRefused,
194                e,
195            ))
196        })?;
197        Ok(redis::cmd("LPUSH")
198            .arg(&self.queue_name)
199            .arg(message.json_serialized(Some(self.delivery_info))?)
200            .query_async(&mut *conn)
201            .await?)
202    }
203
204    async fn resend_task(&self, delivery: &Delivery) -> Result<(), BrokerError> {
205        let mut message = delivery.clone().try_deserialize_message()?;
206        let retries = message.headers.retries.unwrap_or_default();
207        message.headers.retries = Some(retries + 1);
208        self.clone().send_task(&message).await?;
209        Ok(())
210    }
211
212    async fn remove_task(&self, delivery: &Delivery) -> Result<(), BrokerError> {
213        let mut conn = self.pool.get().await.map_err(|e| {
214            BrokerError::IoError(std::io::Error::new(
215                std::io::ErrorKind::ConnectionRefused,
216                e,
217            ))
218        })?;
219        redis::cmd("HDEL")
220            .arg(self.process_map_name())
221            .arg(&delivery.properties.correlation_id)
222            .query_async::<()>(&mut *conn)
223            .await?;
224        Ok(())
225    }
226}
227
228type ConsumerOutput = Result<Delivery, BrokerError>;
229type ConsumerOutputFuture = Box<dyn Future<Output = ConsumerOutput> + Send>;
230
231pub struct Consumer {
232    channel: Channel,
233    error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
234    polled_pop: Option<std::pin::Pin<ConsumerOutputFuture>>,
235    pending_tasks: Arc<AtomicU16>,
236    waker_tx: Sender<Waker>,
237    prefetch_count: Arc<AtomicU16>,
238}
239
240impl DeliveryStream for Consumer {}
241
242#[async_trait]
243impl super::Delivery for (Channel, Delivery) {
244    async fn resend(
245        &self,
246        _broker: &dyn Broker,
247        _eta: Option<DateTime<Utc>>,
248    ) -> Result<(), BrokerError> {
249        self.0.resend_task(&self.1).await?;
250        Ok(())
251    }
252
253    async fn remove(&self) -> Result<(), BrokerError> {
254        self.0.remove_task(&self.1).await?;
255        Ok(())
256    }
257
258    async fn ack(&self) -> Result<(), BrokerError> {
259        todo!()
260    }
261}
262
263impl TryDeserializeMessage for (Channel, Delivery) {
264    fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
265        self.1.try_deserialize_message()
266    }
267}
268
269impl Stream for Consumer {
270    type Item = Result<Box<dyn super::Delivery>, Box<dyn DeliveryError>>;
271    fn poll_next(
272        mut self: std::pin::Pin<&mut Self>,
273        cx: &mut std::task::Context<'_>,
274    ) -> Poll<std::option::Option<<Self as futures::Stream>::Item>> {
275        // execute pipeline
276        // - get from queue
277        // - add delivery tag in processing unacked_index_key sortedlist
278        // - add delivery tag, msg in processing hashset unacked_key
279        if self.pending_tasks.load(Ordering::SeqCst) >= self.prefetch_count.load(Ordering::SeqCst)
280            && self.prefetch_count.load(Ordering::SeqCst) > 0
281        {
282            debug!("Pending tasks limit reached");
283            return Poll::Pending;
284        }
285        let mut polled_pop = if self.polled_pop.is_none() {
286            Box::pin(self.channel.clone().fetch_task(None))
287        } else {
288            self.polled_pop.take().unwrap()
289        };
290        if let Poll::Ready(item) = Future::poll(polled_pop.as_mut(), cx) {
291            match item {
292                Ok(item) => {
293                    self.pending_tasks.fetch_add(1, Ordering::SeqCst);
294                    Poll::Ready(Some(Ok(Box::new((self.channel.clone(), item)))))
295                }
296                Err(err) => {
297                    (self.error_handler)(err);
298                    cx.waker().wake_by_ref();
299                    Poll::Pending
300                }
301            }
302        } else {
303            self.polled_pop = Some(polled_pop);
304            Poll::Pending
305        }
306    }
307}
308
309#[async_trait]
310impl Broker for RedisBroker {
311    /// Consume messages from a queue.
312    ///
313    /// If the connection is successful, this should return a future stream of `Result`s where an `Ok`
314    /// value is a [`Self::Delivery`](trait.Broker.html#associatedtype.Delivery)
315    /// type that can be coerced into a [`Message`](protocol/struct.Message.html)
316    /// and an `Err` value is a
317    /// [`Self::DeliveryError`](trait.Broker.html#associatedtype.DeliveryError) type.
318    async fn consume(
319        &self,
320        queue: &str,
321        error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
322    ) -> Result<(String, Box<dyn DeliveryStream>), BrokerError> {
323        let consumer = Consumer {
324            channel: Channel {
325                pool: self.pool.clone(),
326                queue_name: queue.to_string(),
327                delivery_info: self.delivery_info.clone(),
328            },
329            error_handler,
330            polled_pop: None,
331            prefetch_count: Arc::clone(&self.prefetch_count),
332            pending_tasks: Arc::clone(&self.pending_tasks),
333            waker_tx: self.waker_tx.clone(),
334        };
335
336        // Create unique consumer tag.
337        let mut buffer = Uuid::encode_buffer();
338        let uuid = Uuid::new_v4().hyphenated().encode_lower(&mut buffer);
339        let consumer_tag = uuid.to_owned();
340
341        Ok((consumer_tag, Box::new(consumer)))
342    }
343
344    async fn cancel(&self, _consumer_tag: &str) -> Result<(), BrokerError> {
345        Ok(())
346    }
347
348    /// Acknowledge a [`Delivery`](trait.Broker.html#associatedtype.Delivery) for deletion.
349    async fn ack(&self, delivery: &dyn super::Delivery) -> Result<(), BrokerError> {
350        self.pending_tasks.fetch_sub(1, Ordering::SeqCst);
351        delivery.remove().await?;
352        let mut waker_rx = self.waker_rx.lock().await;
353        // work around for try_recv. We do not care if a waker is available after this check.
354        let dummy_waker = futures::task::noop_waker_ref();
355        let mut dummy_ctx = std::task::Context::from_waker(dummy_waker);
356        if let Poll::Ready(Some(waker)) = waker_rx.poll_recv(&mut dummy_ctx) {
357            waker.wake();
358        }
359        Ok(())
360    }
361
362    /// Retry a delivery.
363    async fn retry(
364        &self,
365        delivery: &dyn super::Delivery,
366        eta: Option<DateTime<Utc>>,
367    ) -> Result<(), BrokerError> {
368        delivery.resend(self, eta).await?;
369        // self.ack(delivery).await?;
370        Ok(())
371    }
372
373    /// Send a [`Message`](protocol/struct.Message.html) into a queue.
374    async fn send(&self, message: &Message, queue: &str) -> Result<(), BrokerError> {
375        // Check if the broker is closed
376        if self.is_closed.load(Ordering::SeqCst) {
377            return Err(BrokerError::NotConnected);
378        }
379
380        Channel::new(
381            self.pool.clone(),
382            queue.to_string(),
383            self.delivery_info.clone(),
384        )
385        .send_task(message)
386        .await?;
387        Ok(())
388    }
389
390    /// Increase the `prefetch_count`. This has to be done when a task with a future
391    /// ETA is consumed.
392    async fn increase_prefetch_count(&self) -> Result<(), BrokerError> {
393        self.prefetch_count.fetch_add(1, Ordering::SeqCst);
394        Ok(())
395    }
396
397    /// Decrease the `prefetch_count`. This has to be done after a task with a future
398    /// ETA is executed.
399    async fn decrease_prefetch_count(&self) -> Result<(), BrokerError> {
400        self.prefetch_count.fetch_sub(1, Ordering::SeqCst);
401        Ok(())
402    }
403
404    /// Close connection pool.
405    async fn close(&self) -> Result<(), BrokerError> {
406        // Mark the broker as closed
407        self.is_closed.store(true, Ordering::SeqCst);
408        Ok(())
409    }
410
411    fn safe_url(&self) -> String {
412        let parsed_url = redis::parse_redis_url(&self.uri[..]);
413        match parsed_url {
414            Some(url) => format!(
415                "{}://{}:***@{}:{}/{}",
416                url.scheme(),
417                url.username(),
418                url.host_str().unwrap(),
419                url.port().unwrap(),
420                url.path(),
421            ),
422            None => {
423                error!("Invalid redis url.");
424                String::from("")
425            }
426        }
427    }
428
429    async fn reconnect(&self, _connection_timeout: u32) -> Result<(), BrokerError> {
430        // With deadpool-redis, reconnection is handled automatically by the pool
431        // Just test if we can get a connection
432        let mut conn = self.pool.get().await.map_err(|e| {
433            BrokerError::IoError(std::io::Error::new(
434                std::io::ErrorKind::ConnectionRefused,
435                e,
436            ))
437        })?;
438
439        // Test the connection with a PING
440        let _: String = redis::cmd("PING").query_async(&mut *conn).await?;
441
442        // Mark the broker as reconnected
443        self.is_closed.store(false, Ordering::SeqCst);
444        Ok(())
445    }
446
447    #[cfg(test)]
448    fn into_any(self: Box<Self>) -> Box<dyn Any> {
449        self
450    }
451}