1use async_trait::async_trait;
4use chrono::{DateTime, SecondsFormat, Utc};
5use futures::Stream;
6use lapin::message::Delivery;
7use lapin::options::{
8 BasicAckOptions, BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, BasicQosOptions,
9 QueueDeclareOptions,
10};
11use lapin::types::{AMQPValue, FieldArray, FieldTable};
12use lapin::uri::{self, AMQPUri};
13use lapin::{BasicProperties, Channel, Connection, ConnectionProperties, Queue};
14use log::debug;
15use std::collections::HashMap;
16use std::str::FromStr;
17use std::task::Poll;
18use tokio::sync::{Mutex, RwLock};
19
20use super::{Broker, BrokerBuilder, DeliveryError, DeliveryStream};
21use crate::error::{BrokerError, ProtocolError};
22use crate::protocol::{Message, MessageHeaders, MessageProperties, TryDeserializeMessage};
23use tokio_executor_trait::Tokio as TokioExecutor;
24
25#[cfg(test)]
26use std::any::Any;
27
28struct Consumer {
29 wrapped: lapin::Consumer,
30}
31impl DeliveryStream for Consumer {}
32impl DeliveryError for lapin::Error {}
33
34#[async_trait]
35impl super::Delivery for Delivery {
36 async fn resend(
37 &self,
38 broker: &dyn Broker,
39 eta: Option<DateTime<Utc>>,
40 ) -> Result<(), BrokerError> {
41 let mut message = self.try_deserialize_message()?;
42 message.headers.eta = eta;
43 message.headers.retries = Some(message.headers.retries.map_or(1, |retry| retry + 1));
45 broker.send(&message, self.routing_key.as_str()).await
46 }
47 async fn remove(&self) -> Result<(), BrokerError> {
48 todo!()
49 }
50 async fn ack(&self) -> Result<(), BrokerError> {
51 lapin::acker::Acker::ack(self, BasicAckOptions::default()).await?;
52 Ok(())
53 }
54}
55
56impl Stream for Consumer {
57 type Item = Result<Box<dyn super::Delivery>, Box<dyn DeliveryError>>;
58
59 fn poll_next(
60 mut self: std::pin::Pin<&mut Self>,
61 cx: &mut std::task::Context<'_>,
62 ) -> std::task::Poll<std::option::Option<<Self as futures::Stream>::Item>> {
63 use futures_lite::stream::StreamExt;
64
65 if let Poll::Ready(ret) = self.wrapped.poll_next(cx) {
66 if let Some(result) = ret {
67 match result {
68 Ok(x) => Poll::Ready(Some(Ok(Box::new(x)))),
69 Err(x) => Poll::Ready(Some(Err(Box::new(x)))),
70 }
71 } else {
72 Poll::Ready(None)
73 }
74 } else {
75 Poll::Pending
76 }
77 }
78}
79
80struct Config {
81 broker_url: String,
82 prefetch_count: u16,
83 queues: HashMap<String, QueueDeclareOptions>,
84 heartbeat: Option<u16>,
85}
86
87pub struct AMQPBrokerBuilder {
89 config: Config,
90}
91
92fn create_base_connection_properties() -> ConnectionProperties {
93 ConnectionProperties::default().with_executor(TokioExecutor::current())
95}
96
97#[cfg(unix)]
98fn create_connection_properties() -> ConnectionProperties {
99 create_base_connection_properties().with_reactor(tokio_reactor_trait::Tokio)
100}
101#[cfg(windows)]
102fn create_connection_properties() -> ConnectionProperties {
103 create_base_connection_properties()
104}
105
106#[async_trait]
107impl BrokerBuilder for AMQPBrokerBuilder {
108 fn new(broker_url: &str) -> Self {
110 Self {
111 config: Config {
112 broker_url: broker_url.into(),
113 prefetch_count: 10,
114 queues: HashMap::new(),
115 heartbeat: Some(60),
116 },
117 }
118 }
119
120 fn prefetch_count(mut self: Box<Self>, prefetch_count: u16) -> Box<dyn BrokerBuilder> {
123 self.config.prefetch_count = prefetch_count;
124 self
125 }
126
127 fn declare_queue(mut self: Box<Self>, name: &str) -> Box<dyn BrokerBuilder> {
129 self.config.queues.insert(
130 name.into(),
131 QueueDeclareOptions {
132 passive: false,
133 durable: true,
134 exclusive: false,
135 auto_delete: false,
136 nowait: false,
137 },
138 );
139 self
140 }
141
142 fn heartbeat(mut self: Box<Self>, heartbeat: Option<u16>) -> Box<dyn BrokerBuilder> {
144 self.config.heartbeat = heartbeat;
145 self
146 }
147
148 async fn build(&self, connection_timeout: u32) -> Result<Box<dyn Broker>, BrokerError> {
150 let mut uri = AMQPUri::from_str(&self.config.broker_url)
151 .map_err(|_| BrokerError::InvalidBrokerUrl(self.config.broker_url.clone()))?;
152 uri.query.heartbeat = self.config.heartbeat;
153 uri.query.connection_timeout = Some((connection_timeout as u64) * 1000);
154
155 let conn = Connection::connect_uri(uri.clone(), create_connection_properties()).await?;
156
157 let consume_channel = conn.create_channel().await?;
158 let produce_channel = conn.create_channel().await?;
159
160 let mut queues: HashMap<String, Queue> = HashMap::new();
161 for (queue_name, queue_options) in &self.config.queues {
162 let queue = consume_channel
163 .queue_declare(queue_name, *queue_options, FieldTable::default())
164 .await?;
165 queues.insert(queue_name.into(), queue);
166 }
167
168 let broker = AMQPBroker {
169 uri,
170 conn: Mutex::new(conn),
171 consume_channel: RwLock::new(consume_channel),
172 produce_channel: RwLock::new(produce_channel),
173 queues: RwLock::new(queues),
174 queue_declare_options: self.config.queues.clone(),
175 prefetch_count: Mutex::new(self.config.prefetch_count),
176 };
177 broker
178 .set_prefetch_count(self.config.prefetch_count)
179 .await?;
180 Ok(Box::new(broker))
181 }
182}
183
184pub struct AMQPBroker {
186 uri: AMQPUri,
187
188 conn: Mutex<Connection>,
192
193 consume_channel: RwLock<Channel>,
195
196 produce_channel: RwLock<Channel>,
200
201 queues: RwLock<HashMap<String, Queue>>,
205
206 queue_declare_options: HashMap<String, QueueDeclareOptions>,
207
208 prefetch_count: Mutex<u16>,
211}
212
213impl AMQPBroker {
214 async fn set_prefetch_count(&self, prefetch_count: u16) -> Result<(), BrokerError> {
215 debug!("Setting prefetch count to {}", prefetch_count);
216 self.consume_channel
217 .read()
218 .await
219 .basic_qos(prefetch_count, BasicQosOptions { global: true })
220 .await?;
221 Ok(())
222 }
223}
224
225#[async_trait]
226impl Broker for AMQPBroker {
227 fn safe_url(&self) -> String {
228 format!(
229 "{}://{}:***@{}:{}/{}",
230 match self.uri.scheme {
231 uri::AMQPScheme::AMQP => "amqp",
232 _ => "amqps",
233 },
234 self.uri.authority.userinfo.username,
235 self.uri.authority.host,
236 self.uri.authority.port,
237 self.uri.vhost,
238 )
239 }
240
241 async fn consume(
242 &self,
243 queue: &str,
244 error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
245 ) -> Result<(String, Box<dyn DeliveryStream>), BrokerError> {
246 self.conn
247 .lock()
248 .await
249 .on_error(move |e| error_handler(BrokerError::from(e)));
250 let queues = self.queues.read().await;
251 let queue = queues
252 .get(queue)
253 .ok_or_else::<BrokerError, _>(|| BrokerError::UnknownQueue(queue.into()))?;
254 let consumer = Consumer {
255 wrapped: self
256 .consume_channel
257 .read()
258 .await
259 .basic_consume(
260 queue.name().as_str(),
261 "",
262 BasicConsumeOptions::default(),
263 FieldTable::default(),
264 )
265 .await?,
266 };
267 Ok((consumer.wrapped.tag().to_string(), Box::new(consumer)))
268 }
269
270 async fn cancel(&self, consumer_tag: &str) -> Result<(), BrokerError> {
271 let consume_channel = self.consume_channel.write().await;
272 consume_channel
273 .basic_cancel(consumer_tag, BasicCancelOptions::default())
274 .await?;
275 Ok(())
276 }
277
278 async fn ack(&self, delivery: &dyn super::Delivery) -> Result<(), BrokerError> {
279 delivery.ack().await
280 }
281
282 async fn retry(
283 &self,
284 delivery: &dyn super::Delivery,
285 eta: Option<DateTime<Utc>>,
286 ) -> Result<(), BrokerError> {
287 delivery.resend(self, eta).await?;
288 Ok(())
289 }
290
291 async fn send(&self, message: &Message, queue: &str) -> Result<(), BrokerError> {
292 let properties = message.delivery_properties();
293 debug!("Sending AMQP message with: {:?}", properties);
294 self.produce_channel
295 .read()
296 .await
297 .basic_publish(
298 "",
299 queue,
300 BasicPublishOptions::default(),
301 &message.raw_body.clone()[..],
302 properties,
303 )
304 .await?;
305 Ok(())
306 }
307
308 async fn increase_prefetch_count(&self) -> Result<(), BrokerError> {
309 let new_count = {
310 let mut prefetch_count = self.prefetch_count.lock().await;
311 if *prefetch_count < std::u16::MAX {
312 let new_count = *prefetch_count + 1;
313 *prefetch_count = new_count;
314 new_count
315 } else {
316 std::u16::MAX
317 }
318 };
319 self.set_prefetch_count(new_count).await?;
320 Ok(())
321 }
322
323 async fn decrease_prefetch_count(&self) -> Result<(), BrokerError> {
324 let new_count = {
325 let mut prefetch_count = self.prefetch_count.lock().await;
326 if *prefetch_count > 1 {
327 let new_count = *prefetch_count - 1;
328 *prefetch_count = new_count;
329 new_count
330 } else {
331 0u16
332 }
333 };
334 if new_count > 0 {
335 self.set_prefetch_count(new_count).await?;
336 }
337 Ok(())
338 }
339
340 async fn close(&self) -> Result<(), BrokerError> {
341 let consume_channel = self.consume_channel.write().await;
342 let produce_channel = self.produce_channel.write().await;
343 let conn = self.conn.lock().await;
344
345 if consume_channel.status().connected() {
346 debug!("Closing consumer channel...");
347 consume_channel.close(200, "OK").await?;
348 }
349
350 if produce_channel.status().connected() {
351 debug!("Closing producer channel...");
352 produce_channel.close(200, "OK").await?;
353 }
354
355 if conn.status().connected() {
356 debug!("Closing connection...");
357 conn.close(200, "OK").await?;
358 }
359
360 Ok(())
361 }
362
363 async fn reconnect(&self, connection_timeout: u32) -> Result<(), BrokerError> {
365 let mut conn = self.conn.lock().await;
366 if !conn.status().connected() {
367 debug!("Attempting to reconnect to broker");
368 let mut uri = self.uri.clone();
369 uri.query.connection_timeout = Some(connection_timeout as u64);
370 *conn = Connection::connect_uri(uri, create_connection_properties()).await?;
371
372 let mut consume_channel = self.consume_channel.write().await;
373 let mut produce_channel = self.produce_channel.write().await;
374 let mut queues = self.queues.write().await;
375
376 *consume_channel = conn.create_channel().await?;
377 *produce_channel = conn.create_channel().await?;
378
379 queues.clear();
380 for (queue_name, queue_options) in &self.queue_declare_options {
381 let queue = consume_channel
382 .queue_declare(queue_name, *queue_options, FieldTable::default())
383 .await?;
384 queues.insert(queue_name.into(), queue);
385 }
386 }
387
388 Ok(())
389 }
390
391 #[cfg(test)]
392 fn into_any(self: Box<Self>) -> Box<dyn Any> {
393 self
394 }
395}
396
397impl Message {
398 fn delivery_properties(&self) -> BasicProperties {
399 let mut properties = BasicProperties::default()
400 .with_correlation_id(self.properties.correlation_id.clone().into())
401 .with_content_type(self.properties.content_type.clone().into())
402 .with_content_encoding(self.properties.content_encoding.clone().into())
403 .with_headers(self.delivery_headers())
404 .with_priority(0)
405 .with_delivery_mode(2);
406 if let Some(ref reply_to) = self.properties.reply_to {
407 properties = properties.with_reply_to(reply_to.clone().into());
408 }
409 properties
410 }
411
412 fn delivery_headers(&self) -> FieldTable {
413 let mut headers = FieldTable::default();
414 headers.insert(
415 "id".into(),
416 AMQPValue::LongString(self.headers.id.clone().into()),
417 );
418 headers.insert(
419 "task".into(),
420 AMQPValue::LongString(self.headers.task.clone().into()),
421 );
422 if let Some(ref lang) = self.headers.lang {
423 headers.insert("lang".into(), AMQPValue::LongString(lang.clone().into()));
424 }
425 if let Some(ref root_id) = self.headers.root_id {
426 headers.insert(
427 "root_id".into(),
428 AMQPValue::LongString(root_id.clone().into()),
429 );
430 }
431 if let Some(ref parent_id) = self.headers.parent_id {
432 headers.insert(
433 "parent_id".into(),
434 AMQPValue::LongString(parent_id.clone().into()),
435 );
436 }
437 if let Some(ref group) = self.headers.group {
438 headers.insert("group".into(), AMQPValue::LongString(group.clone().into()));
439 }
440 if let Some(ref meth) = self.headers.meth {
441 headers.insert("meth".into(), AMQPValue::LongString(meth.clone().into()));
442 }
443 if let Some(ref shadow) = self.headers.shadow {
444 headers.insert(
445 "shadow".into(),
446 AMQPValue::LongString(shadow.clone().into()),
447 );
448 }
449 if let Some(ref eta) = self.headers.eta {
450 headers.insert(
451 "eta".into(),
452 AMQPValue::LongString(eta.to_rfc3339_opts(SecondsFormat::Millis, false).into()),
453 );
454 }
455 if let Some(ref expires) = self.headers.expires {
456 headers.insert(
457 "expires".into(),
458 AMQPValue::LongString(expires.to_rfc3339_opts(SecondsFormat::Millis, false).into()),
459 );
460 }
461 if let Some(retries) = self.headers.retries {
462 headers.insert("retries".into(), AMQPValue::LongUInt(retries));
463 }
464 let mut timelimit = FieldArray::default();
465 if let Some(t) = self.headers.timelimit.0 {
466 timelimit.push(AMQPValue::LongUInt(t));
467 } else {
468 timelimit.push(AMQPValue::Void);
469 }
470 if let Some(t) = self.headers.timelimit.1 {
471 timelimit.push(AMQPValue::LongUInt(t));
472 } else {
473 timelimit.push(AMQPValue::Void);
474 }
475 headers.insert("timelimit".into(), AMQPValue::FieldArray(timelimit));
476 if let Some(ref argsrepr) = self.headers.argsrepr {
477 headers.insert(
478 "argsrepr".into(),
479 AMQPValue::LongString(argsrepr.clone().into()),
480 );
481 }
482 if let Some(ref kwargsrepr) = self.headers.kwargsrepr {
483 headers.insert(
484 "kwargsrepr".into(),
485 AMQPValue::LongString(kwargsrepr.clone().into()),
486 );
487 }
488 if let Some(ref origin) = self.headers.origin {
489 headers.insert(
490 "origin".into(),
491 AMQPValue::LongString(origin.clone().into()),
492 );
493 }
494 headers
495 }
496}
497
498impl TryDeserializeMessage for (Channel, Delivery) {
499 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
500 self.1.try_deserialize_message()
501 }
502}
503
504impl TryDeserializeMessage for Delivery {
505 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
506 let headers = self
507 .properties
508 .headers()
509 .as_ref()
510 .ok_or(ProtocolError::MissingHeaders)?;
511 Ok(Message {
512 properties: MessageProperties {
513 correlation_id: self
514 .properties
515 .correlation_id()
516 .as_ref()
517 .map(|v| v.to_string())
518 .ok_or_else(|| {
519 ProtocolError::MissingRequiredProperty("correlation_id".into())
520 })?,
521 content_type: self
522 .properties
523 .content_type()
524 .as_ref()
525 .map(|v| v.to_string())
526 .ok_or_else(|| ProtocolError::MissingRequiredProperty("content_type".into()))?,
527 content_encoding: self
528 .properties
529 .content_encoding()
530 .as_ref()
531 .map(|v| v.to_string())
532 .ok_or_else(|| {
533 ProtocolError::MissingRequiredProperty("content_encoding".into())
534 })?,
535 reply_to: self.properties.reply_to().as_ref().map(|v| v.to_string()),
536 },
537 headers: MessageHeaders {
538 id: get_header_str_required(headers, "id")?,
539 task: get_header_str_required(headers, "task")?,
540 lang: get_header_str(headers, "lang"),
541 root_id: get_header_str(headers, "root_id"),
542 parent_id: get_header_str(headers, "parent_id"),
543 group: get_header_str(headers, "group"),
544 meth: get_header_str(headers, "meth"),
545 shadow: get_header_str(headers, "shadow"),
546 eta: get_header_dt(headers, "eta"),
547 expires: get_header_dt(headers, "expires"),
548 retries: get_header_u32(headers, "retries"),
549 timelimit: headers
550 .inner()
551 .get("timelimit")
552 .and_then(|v| match v {
553 AMQPValue::FieldArray(a) => {
554 let a = a.as_slice().to_vec();
555 if a.len() == 2 {
556 let soft = amqp_value_to_u32(&a[0]);
557 let hard = amqp_value_to_u32(&a[1]);
558 Some((soft, hard))
559 } else {
560 None
561 }
562 }
563 _ => None,
564 })
565 .unwrap_or((None, None)),
566 argsrepr: get_header_str(headers, "argsrepr"),
567 kwargsrepr: get_header_str(headers, "kwargsrepr"),
568 origin: get_header_str(headers, "origin"),
569 },
570 raw_body: self.data.clone(),
571 })
572 }
573}
574
575fn get_header_str(headers: &FieldTable, key: &str) -> Option<String> {
576 headers.inner().get(key).and_then(|v| match v {
577 AMQPValue::ShortString(s) => Some(s.to_string()),
578 AMQPValue::LongString(s) => Some(s.to_string()),
579 _ => None,
580 })
581}
582
583fn get_header_str_required(headers: &FieldTable, key: &str) -> Result<String, ProtocolError> {
584 get_header_str(headers, key).ok_or_else(|| ProtocolError::MissingRequiredHeader(key.into()))
585}
586
587fn get_header_dt(headers: &FieldTable, key: &str) -> Option<DateTime<Utc>> {
588 if let Some(s) = get_header_str(headers, key) {
589 match DateTime::parse_from_rfc3339(&s) {
590 Ok(dt) => Some(DateTime::<Utc>::from(dt)),
591 _ => None,
592 }
593 } else {
594 None
595 }
596}
597
598fn get_header_u32(headers: &FieldTable, key: &str) -> Option<u32> {
599 headers.inner().get(key).and_then(amqp_value_to_u32)
600}
601
602fn amqp_value_to_u32(v: &AMQPValue) -> Option<u32> {
603 match v {
604 AMQPValue::ShortShortInt(n) => Some(*n as u32),
605 AMQPValue::ShortShortUInt(n) => Some(*n as u32),
606 AMQPValue::ShortInt(n) => Some(*n as u32),
607 AMQPValue::ShortUInt(n) => Some(*n as u32),
608 AMQPValue::LongInt(n) => Some(*n as u32),
609 AMQPValue::LongUInt(n) => Some(*n),
610 AMQPValue::LongLongInt(n) => Some(*n as u32),
611 _ => None,
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use lapin::types::ShortString;
619 use std::time::SystemTime;
620
621 #[test]
622 fn test_conversion() {
624 let now = DateTime::<Utc>::from(SystemTime::now());
625
626 let now_str = now.to_rfc3339_opts(SecondsFormat::Millis, false);
629 let now = DateTime::<Utc>::from(DateTime::parse_from_rfc3339(&now_str).unwrap());
630
631 let message = Message {
632 properties: MessageProperties {
633 correlation_id: "aaa".into(),
634 content_type: "application/json".into(),
635 content_encoding: "utf-8".into(),
636 reply_to: Some("bbb".into()),
637 },
638 headers: MessageHeaders {
639 id: "aaa".into(),
640 task: "add".into(),
641 lang: Some("rust".into()),
642 root_id: Some("aaa".into()),
643 parent_id: Some("000".into()),
644 group: Some("A".into()),
645 meth: Some("method_name".into()),
646 shadow: Some("add-these".into()),
647 eta: Some(now),
648 expires: Some(now),
649 retries: Some(1),
650 timelimit: (Some(30), Some(60)),
651 argsrepr: Some("(1)".into()),
652 kwargsrepr: Some("{'y': 2}".into()),
653 origin: Some("gen123@piper".into()),
654 },
655 raw_body: vec![],
656 };
657
658 let delivery = Delivery {
659 delivery_tag: 0,
660 exchange: ShortString::from(""),
661 routing_key: ShortString::from("celery"),
662 redelivered: false,
663 properties: message.delivery_properties(),
664 data: vec![],
665 acker: Default::default(),
666 };
667
668 let message2 = delivery.try_deserialize_message();
669 assert!(message2.is_ok());
670
671 let message2 = message2.unwrap();
672 assert_eq!(message, message2);
673 }
674}