celery/broker/
amqp.rs

1//! AMQP broker.
2
3use async_trait::async_trait;
4use chrono::{DateTime, SecondsFormat, Utc};
5use futures::Stream;
6use lapin::message::Delivery;
7use lapin::options::{
8    BasicAckOptions, BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, BasicQosOptions,
9    QueueDeclareOptions,
10};
11use lapin::types::{AMQPValue, FieldArray, FieldTable};
12use lapin::uri::{self, AMQPUri};
13use lapin::{BasicProperties, Channel, Connection, ConnectionProperties, Queue};
14use log::debug;
15use std::collections::HashMap;
16use std::str::FromStr;
17use std::task::Poll;
18use tokio::sync::{Mutex, RwLock};
19
20use super::{Broker, BrokerBuilder, DeliveryError, DeliveryStream};
21use crate::error::{BrokerError, ProtocolError};
22use crate::protocol::{Message, MessageHeaders, MessageProperties, TryDeserializeMessage};
23use tokio_executor_trait::Tokio as TokioExecutor;
24
25#[cfg(test)]
26use std::any::Any;
27
28struct Consumer {
29    wrapped: lapin::Consumer,
30}
31impl DeliveryStream for Consumer {}
32impl DeliveryError for lapin::Error {}
33
34#[async_trait]
35impl super::Delivery for Delivery {
36    async fn resend(
37        &self,
38        broker: &dyn Broker,
39        eta: Option<DateTime<Utc>>,
40    ) -> Result<(), BrokerError> {
41        let mut message = self.try_deserialize_message()?;
42        message.headers.eta = eta;
43        // Increment the number of retries.
44        message.headers.retries = Some(message.headers.retries.map_or(1, |retry| retry + 1));
45        broker.send(&message, self.routing_key.as_str()).await
46    }
47    async fn remove(&self) -> Result<(), BrokerError> {
48        todo!()
49    }
50    async fn ack(&self) -> Result<(), BrokerError> {
51        lapin::acker::Acker::ack(self, BasicAckOptions::default()).await?;
52        Ok(())
53    }
54}
55
56impl Stream for Consumer {
57    type Item = Result<Box<dyn super::Delivery>, Box<dyn DeliveryError>>;
58
59    fn poll_next(
60        mut self: std::pin::Pin<&mut Self>,
61        cx: &mut std::task::Context<'_>,
62    ) -> std::task::Poll<std::option::Option<<Self as futures::Stream>::Item>> {
63        use futures_lite::stream::StreamExt;
64
65        if let Poll::Ready(ret) = self.wrapped.poll_next(cx) {
66            if let Some(result) = ret {
67                match result {
68                    Ok(x) => Poll::Ready(Some(Ok(Box::new(x)))),
69                    Err(x) => Poll::Ready(Some(Err(Box::new(x)))),
70                }
71            } else {
72                Poll::Ready(None)
73            }
74        } else {
75            Poll::Pending
76        }
77    }
78}
79
80struct Config {
81    broker_url: String,
82    prefetch_count: u16,
83    queues: HashMap<String, QueueDeclareOptions>,
84    heartbeat: Option<u16>,
85}
86
87/// Builds an [`AMQPBroker`] with a custom configuration.
88pub struct AMQPBrokerBuilder {
89    config: Config,
90}
91
92fn create_base_connection_properties() -> ConnectionProperties {
93    // See https://github.com/amqp-rs/reactor-trait/issues/1#issuecomment-1033473197
94    ConnectionProperties::default().with_executor(TokioExecutor::current())
95}
96
97#[cfg(unix)]
98fn create_connection_properties() -> ConnectionProperties {
99    create_base_connection_properties().with_reactor(tokio_reactor_trait::Tokio)
100}
101#[cfg(windows)]
102fn create_connection_properties() -> ConnectionProperties {
103    create_base_connection_properties()
104}
105
106#[async_trait]
107impl BrokerBuilder for AMQPBrokerBuilder {
108    /// Create a new `AMQPBrokerBuilder`.
109    fn new(broker_url: &str) -> Self {
110        Self {
111            config: Config {
112                broker_url: broker_url.into(),
113                prefetch_count: 10,
114                queues: HashMap::new(),
115                heartbeat: Some(60),
116            },
117        }
118    }
119
120    /// Set the worker [prefetch
121    /// count](https://www.rabbitmq.com/confirms.html#channel-qos-prefetch).
122    fn prefetch_count(mut self: Box<Self>, prefetch_count: u16) -> Box<dyn BrokerBuilder> {
123        self.config.prefetch_count = prefetch_count;
124        self
125    }
126
127    /// Declare a queue.
128    fn declare_queue(mut self: Box<Self>, name: &str) -> Box<dyn BrokerBuilder> {
129        self.config.queues.insert(
130            name.into(),
131            QueueDeclareOptions {
132                passive: false,
133                durable: true,
134                exclusive: false,
135                auto_delete: false,
136                nowait: false,
137            },
138        );
139        self
140    }
141
142    /// Set the heartbeat.
143    fn heartbeat(mut self: Box<Self>, heartbeat: Option<u16>) -> Box<dyn BrokerBuilder> {
144        self.config.heartbeat = heartbeat;
145        self
146    }
147
148    /// Build an `AMQPBroker`.
149    async fn build(&self, connection_timeout: u32) -> Result<Box<dyn Broker>, BrokerError> {
150        let mut uri = AMQPUri::from_str(&self.config.broker_url)
151            .map_err(|_| BrokerError::InvalidBrokerUrl(self.config.broker_url.clone()))?;
152        uri.query.heartbeat = self.config.heartbeat;
153        uri.query.connection_timeout = Some((connection_timeout as u64) * 1000);
154
155        let conn = Connection::connect_uri(uri.clone(), create_connection_properties()).await?;
156
157        let consume_channel = conn.create_channel().await?;
158        let produce_channel = conn.create_channel().await?;
159
160        let mut queues: HashMap<String, Queue> = HashMap::new();
161        for (queue_name, queue_options) in &self.config.queues {
162            let queue = consume_channel
163                .queue_declare(queue_name, *queue_options, FieldTable::default())
164                .await?;
165            queues.insert(queue_name.into(), queue);
166        }
167
168        let broker = AMQPBroker {
169            uri,
170            conn: Mutex::new(conn),
171            consume_channel: RwLock::new(consume_channel),
172            produce_channel: RwLock::new(produce_channel),
173            queues: RwLock::new(queues),
174            queue_declare_options: self.config.queues.clone(),
175            prefetch_count: Mutex::new(self.config.prefetch_count),
176        };
177        broker
178            .set_prefetch_count(self.config.prefetch_count)
179            .await?;
180        Ok(Box::new(broker))
181    }
182}
183
184/// An AMQP broker.
185pub struct AMQPBroker {
186    uri: AMQPUri,
187
188    /// Broker connection.
189    ///
190    /// This is only wrapped in a Mutex for interior mutability.
191    conn: Mutex<Connection>,
192
193    /// Channel to consume messages from.
194    consume_channel: RwLock<Channel>,
195
196    /// Channel to produce messages from.
197    ///
198    /// This is only wrapped in RwLock for interior mutability.
199    produce_channel: RwLock<Channel>,
200
201    /// Mapping of queue name to Queue struct.
202    ///
203    /// This is only wrapped in RwLock for interior mutability.
204    queues: RwLock<HashMap<String, Queue>>,
205
206    queue_declare_options: HashMap<String, QueueDeclareOptions>,
207
208    /// Need to keep track of prefetch count. We put this behind a mutex to get interior
209    /// mutability.
210    prefetch_count: Mutex<u16>,
211}
212
213impl AMQPBroker {
214    async fn set_prefetch_count(&self, prefetch_count: u16) -> Result<(), BrokerError> {
215        debug!("Setting prefetch count to {}", prefetch_count);
216        self.consume_channel
217            .read()
218            .await
219            .basic_qos(prefetch_count, BasicQosOptions { global: true })
220            .await?;
221        Ok(())
222    }
223}
224
225#[async_trait]
226impl Broker for AMQPBroker {
227    fn safe_url(&self) -> String {
228        format!(
229            "{}://{}:***@{}:{}/{}",
230            match self.uri.scheme {
231                uri::AMQPScheme::AMQP => "amqp",
232                _ => "amqps",
233            },
234            self.uri.authority.userinfo.username,
235            self.uri.authority.host,
236            self.uri.authority.port,
237            self.uri.vhost,
238        )
239    }
240
241    async fn consume(
242        &self,
243        queue: &str,
244        error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
245    ) -> Result<(String, Box<dyn DeliveryStream>), BrokerError> {
246        self.conn
247            .lock()
248            .await
249            .on_error(move |e| error_handler(BrokerError::from(e)));
250        let queues = self.queues.read().await;
251        let queue = queues
252            .get(queue)
253            .ok_or_else::<BrokerError, _>(|| BrokerError::UnknownQueue(queue.into()))?;
254        let consumer = Consumer {
255            wrapped: self
256                .consume_channel
257                .read()
258                .await
259                .basic_consume(
260                    queue.name().as_str(),
261                    "",
262                    BasicConsumeOptions::default(),
263                    FieldTable::default(),
264                )
265                .await?,
266        };
267        Ok((consumer.wrapped.tag().to_string(), Box::new(consumer)))
268    }
269
270    async fn cancel(&self, consumer_tag: &str) -> Result<(), BrokerError> {
271        let consume_channel = self.consume_channel.write().await;
272        consume_channel
273            .basic_cancel(consumer_tag, BasicCancelOptions::default())
274            .await?;
275        Ok(())
276    }
277
278    async fn ack(&self, delivery: &dyn super::Delivery) -> Result<(), BrokerError> {
279        delivery.ack().await
280    }
281
282    async fn retry(
283        &self,
284        delivery: &dyn super::Delivery,
285        eta: Option<DateTime<Utc>>,
286    ) -> Result<(), BrokerError> {
287        delivery.resend(self, eta).await?;
288        Ok(())
289    }
290
291    async fn send(&self, message: &Message, queue: &str) -> Result<(), BrokerError> {
292        let properties = message.delivery_properties();
293        debug!("Sending AMQP message with: {:?}", properties);
294        self.produce_channel
295            .read()
296            .await
297            .basic_publish(
298                "",
299                queue,
300                BasicPublishOptions::default(),
301                &message.raw_body.clone()[..],
302                properties,
303            )
304            .await?;
305        Ok(())
306    }
307
308    async fn increase_prefetch_count(&self) -> Result<(), BrokerError> {
309        let new_count = {
310            let mut prefetch_count = self.prefetch_count.lock().await;
311            if *prefetch_count < std::u16::MAX {
312                let new_count = *prefetch_count + 1;
313                *prefetch_count = new_count;
314                new_count
315            } else {
316                std::u16::MAX
317            }
318        };
319        self.set_prefetch_count(new_count).await?;
320        Ok(())
321    }
322
323    async fn decrease_prefetch_count(&self) -> Result<(), BrokerError> {
324        let new_count = {
325            let mut prefetch_count = self.prefetch_count.lock().await;
326            if *prefetch_count > 1 {
327                let new_count = *prefetch_count - 1;
328                *prefetch_count = new_count;
329                new_count
330            } else {
331                0u16
332            }
333        };
334        if new_count > 0 {
335            self.set_prefetch_count(new_count).await?;
336        }
337        Ok(())
338    }
339
340    async fn close(&self) -> Result<(), BrokerError> {
341        let consume_channel = self.consume_channel.write().await;
342        let produce_channel = self.produce_channel.write().await;
343        let conn = self.conn.lock().await;
344
345        if consume_channel.status().connected() {
346            debug!("Closing consumer channel...");
347            consume_channel.close(200, "OK").await?;
348        }
349
350        if produce_channel.status().connected() {
351            debug!("Closing producer channel...");
352            produce_channel.close(200, "OK").await?;
353        }
354
355        if conn.status().connected() {
356            debug!("Closing connection...");
357            conn.close(200, "OK").await?;
358        }
359
360        Ok(())
361    }
362
363    /// Try reconnecting in the event of some sort of connection error.
364    async fn reconnect(&self, connection_timeout: u32) -> Result<(), BrokerError> {
365        let mut conn = self.conn.lock().await;
366        if !conn.status().connected() {
367            debug!("Attempting to reconnect to broker");
368            let mut uri = self.uri.clone();
369            uri.query.connection_timeout = Some(connection_timeout as u64);
370            *conn = Connection::connect_uri(uri, create_connection_properties()).await?;
371
372            let mut consume_channel = self.consume_channel.write().await;
373            let mut produce_channel = self.produce_channel.write().await;
374            let mut queues = self.queues.write().await;
375
376            *consume_channel = conn.create_channel().await?;
377            *produce_channel = conn.create_channel().await?;
378
379            queues.clear();
380            for (queue_name, queue_options) in &self.queue_declare_options {
381                let queue = consume_channel
382                    .queue_declare(queue_name, *queue_options, FieldTable::default())
383                    .await?;
384                queues.insert(queue_name.into(), queue);
385            }
386        }
387
388        Ok(())
389    }
390
391    #[cfg(test)]
392    fn into_any(self: Box<Self>) -> Box<dyn Any> {
393        self
394    }
395}
396
397impl Message {
398    fn delivery_properties(&self) -> BasicProperties {
399        let mut properties = BasicProperties::default()
400            .with_correlation_id(self.properties.correlation_id.clone().into())
401            .with_content_type(self.properties.content_type.clone().into())
402            .with_content_encoding(self.properties.content_encoding.clone().into())
403            .with_headers(self.delivery_headers())
404            .with_priority(0)
405            .with_delivery_mode(2);
406        if let Some(ref reply_to) = self.properties.reply_to {
407            properties = properties.with_reply_to(reply_to.clone().into());
408        }
409        properties
410    }
411
412    fn delivery_headers(&self) -> FieldTable {
413        let mut headers = FieldTable::default();
414        headers.insert(
415            "id".into(),
416            AMQPValue::LongString(self.headers.id.clone().into()),
417        );
418        headers.insert(
419            "task".into(),
420            AMQPValue::LongString(self.headers.task.clone().into()),
421        );
422        if let Some(ref lang) = self.headers.lang {
423            headers.insert("lang".into(), AMQPValue::LongString(lang.clone().into()));
424        }
425        if let Some(ref root_id) = self.headers.root_id {
426            headers.insert(
427                "root_id".into(),
428                AMQPValue::LongString(root_id.clone().into()),
429            );
430        }
431        if let Some(ref parent_id) = self.headers.parent_id {
432            headers.insert(
433                "parent_id".into(),
434                AMQPValue::LongString(parent_id.clone().into()),
435            );
436        }
437        if let Some(ref group) = self.headers.group {
438            headers.insert("group".into(), AMQPValue::LongString(group.clone().into()));
439        }
440        if let Some(ref meth) = self.headers.meth {
441            headers.insert("meth".into(), AMQPValue::LongString(meth.clone().into()));
442        }
443        if let Some(ref shadow) = self.headers.shadow {
444            headers.insert(
445                "shadow".into(),
446                AMQPValue::LongString(shadow.clone().into()),
447            );
448        }
449        if let Some(ref eta) = self.headers.eta {
450            headers.insert(
451                "eta".into(),
452                AMQPValue::LongString(eta.to_rfc3339_opts(SecondsFormat::Millis, false).into()),
453            );
454        }
455        if let Some(ref expires) = self.headers.expires {
456            headers.insert(
457                "expires".into(),
458                AMQPValue::LongString(expires.to_rfc3339_opts(SecondsFormat::Millis, false).into()),
459            );
460        }
461        if let Some(retries) = self.headers.retries {
462            headers.insert("retries".into(), AMQPValue::LongUInt(retries));
463        }
464        let mut timelimit = FieldArray::default();
465        if let Some(t) = self.headers.timelimit.0 {
466            timelimit.push(AMQPValue::LongUInt(t));
467        } else {
468            timelimit.push(AMQPValue::Void);
469        }
470        if let Some(t) = self.headers.timelimit.1 {
471            timelimit.push(AMQPValue::LongUInt(t));
472        } else {
473            timelimit.push(AMQPValue::Void);
474        }
475        headers.insert("timelimit".into(), AMQPValue::FieldArray(timelimit));
476        if let Some(ref argsrepr) = self.headers.argsrepr {
477            headers.insert(
478                "argsrepr".into(),
479                AMQPValue::LongString(argsrepr.clone().into()),
480            );
481        }
482        if let Some(ref kwargsrepr) = self.headers.kwargsrepr {
483            headers.insert(
484                "kwargsrepr".into(),
485                AMQPValue::LongString(kwargsrepr.clone().into()),
486            );
487        }
488        if let Some(ref origin) = self.headers.origin {
489            headers.insert(
490                "origin".into(),
491                AMQPValue::LongString(origin.clone().into()),
492            );
493        }
494        headers
495    }
496}
497
498impl TryDeserializeMessage for (Channel, Delivery) {
499    fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
500        self.1.try_deserialize_message()
501    }
502}
503
504impl TryDeserializeMessage for Delivery {
505    fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
506        let headers = self
507            .properties
508            .headers()
509            .as_ref()
510            .ok_or(ProtocolError::MissingHeaders)?;
511        Ok(Message {
512            properties: MessageProperties {
513                correlation_id: self
514                    .properties
515                    .correlation_id()
516                    .as_ref()
517                    .map(|v| v.to_string())
518                    .ok_or_else(|| {
519                        ProtocolError::MissingRequiredProperty("correlation_id".into())
520                    })?,
521                content_type: self
522                    .properties
523                    .content_type()
524                    .as_ref()
525                    .map(|v| v.to_string())
526                    .ok_or_else(|| ProtocolError::MissingRequiredProperty("content_type".into()))?,
527                content_encoding: self
528                    .properties
529                    .content_encoding()
530                    .as_ref()
531                    .map(|v| v.to_string())
532                    .ok_or_else(|| {
533                        ProtocolError::MissingRequiredProperty("content_encoding".into())
534                    })?,
535                reply_to: self.properties.reply_to().as_ref().map(|v| v.to_string()),
536            },
537            headers: MessageHeaders {
538                id: get_header_str_required(headers, "id")?,
539                task: get_header_str_required(headers, "task")?,
540                lang: get_header_str(headers, "lang"),
541                root_id: get_header_str(headers, "root_id"),
542                parent_id: get_header_str(headers, "parent_id"),
543                group: get_header_str(headers, "group"),
544                meth: get_header_str(headers, "meth"),
545                shadow: get_header_str(headers, "shadow"),
546                eta: get_header_dt(headers, "eta"),
547                expires: get_header_dt(headers, "expires"),
548                retries: get_header_u32(headers, "retries"),
549                timelimit: headers
550                    .inner()
551                    .get("timelimit")
552                    .and_then(|v| match v {
553                        AMQPValue::FieldArray(a) => {
554                            let a = a.as_slice().to_vec();
555                            if a.len() == 2 {
556                                let soft = amqp_value_to_u32(&a[0]);
557                                let hard = amqp_value_to_u32(&a[1]);
558                                Some((soft, hard))
559                            } else {
560                                None
561                            }
562                        }
563                        _ => None,
564                    })
565                    .unwrap_or((None, None)),
566                argsrepr: get_header_str(headers, "argsrepr"),
567                kwargsrepr: get_header_str(headers, "kwargsrepr"),
568                origin: get_header_str(headers, "origin"),
569            },
570            raw_body: self.data.clone(),
571        })
572    }
573}
574
575fn get_header_str(headers: &FieldTable, key: &str) -> Option<String> {
576    headers.inner().get(key).and_then(|v| match v {
577        AMQPValue::ShortString(s) => Some(s.to_string()),
578        AMQPValue::LongString(s) => Some(s.to_string()),
579        _ => None,
580    })
581}
582
583fn get_header_str_required(headers: &FieldTable, key: &str) -> Result<String, ProtocolError> {
584    get_header_str(headers, key).ok_or_else(|| ProtocolError::MissingRequiredHeader(key.into()))
585}
586
587fn get_header_dt(headers: &FieldTable, key: &str) -> Option<DateTime<Utc>> {
588    if let Some(s) = get_header_str(headers, key) {
589        match DateTime::parse_from_rfc3339(&s) {
590            Ok(dt) => Some(DateTime::<Utc>::from(dt)),
591            _ => None,
592        }
593    } else {
594        None
595    }
596}
597
598fn get_header_u32(headers: &FieldTable, key: &str) -> Option<u32> {
599    headers.inner().get(key).and_then(amqp_value_to_u32)
600}
601
602fn amqp_value_to_u32(v: &AMQPValue) -> Option<u32> {
603    match v {
604        AMQPValue::ShortShortInt(n) => Some(*n as u32),
605        AMQPValue::ShortShortUInt(n) => Some(*n as u32),
606        AMQPValue::ShortInt(n) => Some(*n as u32),
607        AMQPValue::ShortUInt(n) => Some(*n as u32),
608        AMQPValue::LongInt(n) => Some(*n as u32),
609        AMQPValue::LongUInt(n) => Some(*n),
610        AMQPValue::LongLongInt(n) => Some(*n as u32),
611        _ => None,
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use lapin::types::ShortString;
619    use std::time::SystemTime;
620
621    #[test]
622    /// Tests conversion between Message -> Delivery -> Message.
623    fn test_conversion() {
624        let now = DateTime::<Utc>::from(SystemTime::now());
625
626        // HACK: round this to milliseconds because that will happen during conversion
627        // from message -> delivery.
628        let now_str = now.to_rfc3339_opts(SecondsFormat::Millis, false);
629        let now = DateTime::<Utc>::from(DateTime::parse_from_rfc3339(&now_str).unwrap());
630
631        let message = Message {
632            properties: MessageProperties {
633                correlation_id: "aaa".into(),
634                content_type: "application/json".into(),
635                content_encoding: "utf-8".into(),
636                reply_to: Some("bbb".into()),
637            },
638            headers: MessageHeaders {
639                id: "aaa".into(),
640                task: "add".into(),
641                lang: Some("rust".into()),
642                root_id: Some("aaa".into()),
643                parent_id: Some("000".into()),
644                group: Some("A".into()),
645                meth: Some("method_name".into()),
646                shadow: Some("add-these".into()),
647                eta: Some(now),
648                expires: Some(now),
649                retries: Some(1),
650                timelimit: (Some(30), Some(60)),
651                argsrepr: Some("(1)".into()),
652                kwargsrepr: Some("{'y': 2}".into()),
653                origin: Some("gen123@piper".into()),
654            },
655            raw_body: vec![],
656        };
657
658        let delivery = Delivery {
659            delivery_tag: 0,
660            exchange: ShortString::from(""),
661            routing_key: ShortString::from("celery"),
662            redelivered: false,
663            properties: message.delivery_properties(),
664            data: vec![],
665            acker: Default::default(),
666        };
667
668        let message2 = delivery.try_deserialize_message();
669        assert!(message2.is_ok());
670
671        let message2 = message2.unwrap();
672        assert_eq!(message, message2);
673    }
674}