1use async_trait::async_trait;
4use chrono::{DateTime, SecondsFormat, Utc};
5use futures::Stream;
6use lapin::message::Delivery;
7use lapin::options::{
8 BasicAckOptions, BasicCancelOptions, BasicConsumeOptions, BasicPublishOptions, BasicQosOptions,
9 QueueDeclareOptions,
10};
11use lapin::types::{AMQPValue, FieldArray, FieldTable};
12use lapin::uri::{self, AMQPUri};
13use lapin::{BasicProperties, Channel, Connection, ConnectionProperties, Queue};
14use log::debug;
15use std::collections::HashMap;
16use std::str::FromStr;
17use std::task::Poll;
18use tokio::sync::{Mutex, RwLock};
19
20use super::{Broker, BrokerBuilder, DeliveryError, DeliveryStream};
21use crate::error::{BrokerError, ProtocolError};
22use crate::protocol::{Message, MessageHeaders, MessageProperties, TryDeserializeMessage};
23use tokio_executor_trait::Tokio as TokioExecutor;
24use tokio_reactor_trait::Tokio as TokioReactor;
25
26#[cfg(test)]
27use std::any::Any;
28
29struct Consumer {
30 wrapped: lapin::Consumer,
31}
32impl DeliveryStream for Consumer {}
33impl DeliveryError for lapin::Error {}
34
35#[async_trait]
36impl super::Delivery for Delivery {
37 async fn resend(
38 &self,
39 broker: &dyn Broker,
40 eta: Option<DateTime<Utc>>,
41 ) -> Result<(), BrokerError> {
42 let mut message = self.try_deserialize_message()?;
43 message.headers.eta = eta;
44 message.headers.retries = Some(message.headers.retries.map_or(1, |retry| retry + 1));
46 broker.send(&message, self.routing_key.as_str()).await
47 }
48 async fn remove(&self) -> Result<(), BrokerError> {
49 todo!()
50 }
51 async fn ack(&self) -> Result<(), BrokerError> {
52 lapin::acker::Acker::ack(self, BasicAckOptions::default()).await?;
53 Ok(())
54 }
55}
56
57impl Stream for Consumer {
58 type Item = Result<Box<dyn super::Delivery>, Box<dyn DeliveryError>>;
59
60 fn poll_next(
61 mut self: std::pin::Pin<&mut Self>,
62 cx: &mut std::task::Context<'_>,
63 ) -> std::task::Poll<std::option::Option<<Self as futures::Stream>::Item>> {
64 use futures_lite::stream::StreamExt;
65
66 if let Poll::Ready(ret) = self.wrapped.poll_next(cx) {
67 if let Some(result) = ret {
68 match result {
69 Ok(x) => Poll::Ready(Some(Ok(Box::new(x)))),
70 Err(x) => Poll::Ready(Some(Err(Box::new(x)))),
71 }
72 } else {
73 Poll::Ready(None)
74 }
75 } else {
76 Poll::Pending
77 }
78 }
79}
80
81struct Config {
82 broker_url: String,
83 prefetch_count: u16,
84 queues: HashMap<String, QueueDeclareOptions>,
85 heartbeat: Option<u16>,
86}
87
88pub struct AMQPBrokerBuilder {
90 config: Config,
91}
92
93fn create_connection_properties() -> ConnectionProperties {
94 ConnectionProperties::default()
96 .with_executor(TokioExecutor::current())
97 .with_reactor(TokioReactor)
98}
99
100#[async_trait]
101impl BrokerBuilder for AMQPBrokerBuilder {
102 fn new(broker_url: &str) -> Self {
104 Self {
105 config: Config {
106 broker_url: broker_url.into(),
107 prefetch_count: 10,
108 queues: HashMap::new(),
109 heartbeat: Some(60),
110 },
111 }
112 }
113
114 fn prefetch_count(mut self: Box<Self>, prefetch_count: u16) -> Box<dyn BrokerBuilder> {
117 self.config.prefetch_count = prefetch_count;
118 self
119 }
120
121 fn declare_queue(mut self: Box<Self>, name: &str) -> Box<dyn BrokerBuilder> {
123 self.config.queues.insert(
124 name.into(),
125 QueueDeclareOptions {
126 passive: false,
127 durable: true,
128 exclusive: false,
129 auto_delete: false,
130 nowait: false,
131 },
132 );
133 self
134 }
135
136 fn heartbeat(mut self: Box<Self>, heartbeat: Option<u16>) -> Box<dyn BrokerBuilder> {
138 self.config.heartbeat = heartbeat;
139 self
140 }
141
142 async fn build(&self, connection_timeout: u32) -> Result<Box<dyn Broker>, BrokerError> {
144 let mut uri = AMQPUri::from_str(&self.config.broker_url)
145 .map_err(|_| BrokerError::InvalidBrokerUrl(self.config.broker_url.clone()))?;
146 uri.query.heartbeat = self.config.heartbeat;
147 uri.query.connection_timeout = Some((connection_timeout as u64) * 1000);
148
149 let conn = Connection::connect_uri(uri.clone(), create_connection_properties()).await?;
150
151 let consume_channel = conn.create_channel().await?;
152 let produce_channel = conn.create_channel().await?;
153
154 let mut queues: HashMap<String, Queue> = HashMap::new();
155 for (queue_name, queue_options) in &self.config.queues {
156 let queue = consume_channel
157 .queue_declare(queue_name, *queue_options, FieldTable::default())
158 .await?;
159 queues.insert(queue_name.into(), queue);
160 }
161
162 let broker = AMQPBroker {
163 uri,
164 conn: Mutex::new(conn),
165 consume_channel: RwLock::new(consume_channel),
166 produce_channel: RwLock::new(produce_channel),
167 queues: RwLock::new(queues),
168 queue_declare_options: self.config.queues.clone(),
169 prefetch_count: Mutex::new(self.config.prefetch_count),
170 };
171 broker
172 .set_prefetch_count(self.config.prefetch_count)
173 .await?;
174 Ok(Box::new(broker))
175 }
176}
177
178pub struct AMQPBroker {
180 uri: AMQPUri,
181
182 conn: Mutex<Connection>,
186
187 consume_channel: RwLock<Channel>,
189
190 produce_channel: RwLock<Channel>,
194
195 queues: RwLock<HashMap<String, Queue>>,
199
200 queue_declare_options: HashMap<String, QueueDeclareOptions>,
201
202 prefetch_count: Mutex<u16>,
205}
206
207impl AMQPBroker {
208 async fn set_prefetch_count(&self, prefetch_count: u16) -> Result<(), BrokerError> {
209 debug!("Setting prefetch count to {}", prefetch_count);
210 self.consume_channel
211 .read()
212 .await
213 .basic_qos(prefetch_count, BasicQosOptions { global: true })
214 .await?;
215 Ok(())
216 }
217}
218
219#[async_trait]
220impl Broker for AMQPBroker {
221 fn safe_url(&self) -> String {
222 format!(
223 "{}://{}:***@{}:{}/{}",
224 match self.uri.scheme {
225 uri::AMQPScheme::AMQP => "amqp",
226 _ => "amqps",
227 },
228 self.uri.authority.userinfo.username,
229 self.uri.authority.host,
230 self.uri.authority.port,
231 self.uri.vhost,
232 )
233 }
234
235 async fn consume(
236 &self,
237 queue: &str,
238 _error_handler: Box<dyn Fn(BrokerError) + Send + Sync + 'static>,
239 ) -> Result<(String, Box<dyn DeliveryStream>), BrokerError> {
240 let queues = self.queues.read().await;
246 let queue = queues
247 .get(queue)
248 .ok_or_else::<BrokerError, _>(|| BrokerError::UnknownQueue(queue.into()))?;
249 let consumer = Consumer {
250 wrapped: self
251 .consume_channel
252 .read()
253 .await
254 .basic_consume(
255 queue.name().as_str(),
256 "",
257 BasicConsumeOptions::default(),
258 FieldTable::default(),
259 )
260 .await?,
261 };
262 Ok((consumer.wrapped.tag().to_string(), Box::new(consumer)))
263 }
264
265 async fn cancel(&self, consumer_tag: &str) -> Result<(), BrokerError> {
266 let consume_channel = self.consume_channel.write().await;
267 consume_channel
268 .basic_cancel(consumer_tag, BasicCancelOptions::default())
269 .await?;
270 Ok(())
271 }
272
273 async fn ack(&self, delivery: &dyn super::Delivery) -> Result<(), BrokerError> {
274 delivery.ack().await
275 }
276
277 async fn retry(
278 &self,
279 delivery: &dyn super::Delivery,
280 eta: Option<DateTime<Utc>>,
281 ) -> Result<(), BrokerError> {
282 delivery.resend(self, eta).await?;
283 Ok(())
284 }
285
286 async fn send(&self, message: &Message, queue: &str) -> Result<(), BrokerError> {
287 let properties = message.delivery_properties();
288 debug!("Sending AMQP message with: {:?}", properties);
289 self.produce_channel
290 .read()
291 .await
292 .basic_publish(
293 "",
294 queue,
295 BasicPublishOptions::default(),
296 &message.raw_body.clone()[..],
297 properties,
298 )
299 .await?;
300 Ok(())
301 }
302
303 async fn increase_prefetch_count(&self) -> Result<(), BrokerError> {
304 let new_count = {
305 let mut prefetch_count = self.prefetch_count.lock().await;
306 if *prefetch_count < u16::MAX {
307 let new_count = *prefetch_count + 1;
308 *prefetch_count = new_count;
309 new_count
310 } else {
311 u16::MAX
312 }
313 };
314 self.set_prefetch_count(new_count).await?;
315 Ok(())
316 }
317
318 async fn decrease_prefetch_count(&self) -> Result<(), BrokerError> {
319 let new_count = {
320 let mut prefetch_count = self.prefetch_count.lock().await;
321 if *prefetch_count > 1 {
322 let new_count = *prefetch_count - 1;
323 *prefetch_count = new_count;
324 new_count
325 } else {
326 0u16
327 }
328 };
329 if new_count > 0 {
330 self.set_prefetch_count(new_count).await?;
331 }
332 Ok(())
333 }
334
335 async fn close(&self) -> Result<(), BrokerError> {
336 let consume_channel = self.consume_channel.write().await;
337 let produce_channel = self.produce_channel.write().await;
338 let conn = self.conn.lock().await;
339
340 if consume_channel.status().connected() {
341 debug!("Closing consumer channel...");
342 consume_channel.close(200, "OK").await?;
343 }
344
345 if produce_channel.status().connected() {
346 debug!("Closing producer channel...");
347 produce_channel.close(200, "OK").await?;
348 }
349
350 if conn.status().connected() {
351 debug!("Closing connection...");
352 conn.close(200, "OK").await?;
353 }
354
355 Ok(())
356 }
357
358 async fn reconnect(&self, connection_timeout: u32) -> Result<(), BrokerError> {
360 let mut conn = self.conn.lock().await;
361 if !conn.status().connected() {
362 debug!("Attempting to reconnect to broker");
363 let mut uri = self.uri.clone();
364 uri.query.connection_timeout = Some(connection_timeout as u64);
365 *conn = Connection::connect_uri(uri, create_connection_properties()).await?;
366
367 let mut consume_channel = self.consume_channel.write().await;
368 let mut produce_channel = self.produce_channel.write().await;
369 let mut queues = self.queues.write().await;
370
371 *consume_channel = conn.create_channel().await?;
372 *produce_channel = conn.create_channel().await?;
373
374 queues.clear();
375 for (queue_name, queue_options) in &self.queue_declare_options {
376 let queue = consume_channel
377 .queue_declare(queue_name, *queue_options, FieldTable::default())
378 .await?;
379 queues.insert(queue_name.into(), queue);
380 }
381 }
382
383 Ok(())
384 }
385
386 #[cfg(test)]
387 fn into_any(self: Box<Self>) -> Box<dyn Any> {
388 self
389 }
390}
391
392impl Message {
393 fn delivery_properties(&self) -> BasicProperties {
394 let mut properties = BasicProperties::default()
395 .with_correlation_id(self.properties.correlation_id.clone().into())
396 .with_content_type(self.properties.content_type.clone().into())
397 .with_content_encoding(self.properties.content_encoding.clone().into())
398 .with_headers(self.delivery_headers())
399 .with_priority(0)
400 .with_delivery_mode(2);
401 if let Some(ref reply_to) = self.properties.reply_to {
402 properties = properties.with_reply_to(reply_to.clone().into());
403 }
404 properties
405 }
406
407 fn delivery_headers(&self) -> FieldTable {
408 let mut headers = FieldTable::default();
409 headers.insert(
410 "id".into(),
411 AMQPValue::LongString(self.headers.id.clone().into()),
412 );
413 headers.insert(
414 "task".into(),
415 AMQPValue::LongString(self.headers.task.clone().into()),
416 );
417 if let Some(ref lang) = self.headers.lang {
418 headers.insert("lang".into(), AMQPValue::LongString(lang.clone().into()));
419 }
420 if let Some(ref root_id) = self.headers.root_id {
421 headers.insert(
422 "root_id".into(),
423 AMQPValue::LongString(root_id.clone().into()),
424 );
425 }
426 if let Some(ref parent_id) = self.headers.parent_id {
427 headers.insert(
428 "parent_id".into(),
429 AMQPValue::LongString(parent_id.clone().into()),
430 );
431 }
432 if let Some(ref group) = self.headers.group {
433 headers.insert("group".into(), AMQPValue::LongString(group.clone().into()));
434 }
435 if let Some(ref meth) = self.headers.meth {
436 headers.insert("meth".into(), AMQPValue::LongString(meth.clone().into()));
437 }
438 if let Some(ref shadow) = self.headers.shadow {
439 headers.insert(
440 "shadow".into(),
441 AMQPValue::LongString(shadow.clone().into()),
442 );
443 }
444 if let Some(ref eta) = self.headers.eta {
445 headers.insert(
446 "eta".into(),
447 AMQPValue::LongString(eta.to_rfc3339_opts(SecondsFormat::Millis, false).into()),
448 );
449 }
450 if let Some(ref expires) = self.headers.expires {
451 headers.insert(
452 "expires".into(),
453 AMQPValue::LongString(expires.to_rfc3339_opts(SecondsFormat::Millis, false).into()),
454 );
455 }
456 if let Some(retries) = self.headers.retries {
457 headers.insert("retries".into(), AMQPValue::LongUInt(retries));
458 }
459 let mut timelimit = FieldArray::default();
460 if let Some(t) = self.headers.timelimit.0 {
461 timelimit.push(AMQPValue::LongUInt(t));
462 } else {
463 timelimit.push(AMQPValue::Void);
464 }
465 if let Some(t) = self.headers.timelimit.1 {
466 timelimit.push(AMQPValue::LongUInt(t));
467 } else {
468 timelimit.push(AMQPValue::Void);
469 }
470 headers.insert("timelimit".into(), AMQPValue::FieldArray(timelimit));
471 if let Some(ref argsrepr) = self.headers.argsrepr {
472 headers.insert(
473 "argsrepr".into(),
474 AMQPValue::LongString(argsrepr.clone().into()),
475 );
476 }
477 if let Some(ref kwargsrepr) = self.headers.kwargsrepr {
478 headers.insert(
479 "kwargsrepr".into(),
480 AMQPValue::LongString(kwargsrepr.clone().into()),
481 );
482 }
483 if let Some(ref origin) = self.headers.origin {
484 headers.insert(
485 "origin".into(),
486 AMQPValue::LongString(origin.clone().into()),
487 );
488 }
489 headers
490 }
491}
492
493impl TryDeserializeMessage for (Channel, Delivery) {
494 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
495 self.1.try_deserialize_message()
496 }
497}
498
499impl TryDeserializeMessage for Delivery {
500 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
501 let headers = self
502 .properties
503 .headers()
504 .as_ref()
505 .ok_or(ProtocolError::MissingHeaders)?;
506 Ok(Message {
507 properties: MessageProperties {
508 correlation_id: self
509 .properties
510 .correlation_id()
511 .as_ref()
512 .map(|v| v.to_string())
513 .ok_or_else(|| {
514 ProtocolError::MissingRequiredProperty("correlation_id".into())
515 })?,
516 content_type: self
517 .properties
518 .content_type()
519 .as_ref()
520 .map(|v| v.to_string())
521 .ok_or_else(|| ProtocolError::MissingRequiredProperty("content_type".into()))?,
522 content_encoding: self
523 .properties
524 .content_encoding()
525 .as_ref()
526 .map(|v| v.to_string())
527 .ok_or_else(|| {
528 ProtocolError::MissingRequiredProperty("content_encoding".into())
529 })?,
530 reply_to: self.properties.reply_to().as_ref().map(|v| v.to_string()),
531 delivery_info: None,
532 },
533 headers: MessageHeaders {
534 id: get_header_str_required(headers, "id")?,
535 task: get_header_str_required(headers, "task")?,
536 lang: get_header_str(headers, "lang"),
537 root_id: get_header_str(headers, "root_id"),
538 parent_id: get_header_str(headers, "parent_id"),
539 group: get_header_str(headers, "group"),
540 meth: get_header_str(headers, "meth"),
541 shadow: get_header_str(headers, "shadow"),
542 eta: get_header_dt(headers, "eta"),
543 expires: get_header_dt(headers, "expires"),
544 retries: get_header_u32(headers, "retries"),
545 timelimit: headers
546 .inner()
547 .get("timelimit")
548 .and_then(|v| match v {
549 AMQPValue::FieldArray(a) => {
550 let a = a.as_slice().to_vec();
551 if a.len() == 2 {
552 let soft = amqp_value_to_u32(&a[0]);
553 let hard = amqp_value_to_u32(&a[1]);
554 Some((soft, hard))
555 } else {
556 None
557 }
558 }
559 _ => None,
560 })
561 .unwrap_or((None, None)),
562 argsrepr: get_header_str(headers, "argsrepr"),
563 kwargsrepr: get_header_str(headers, "kwargsrepr"),
564 origin: get_header_str(headers, "origin"),
565 },
566 raw_body: self.data.clone(),
567 })
568 }
569}
570
571fn get_header_str(headers: &FieldTable, key: &str) -> Option<String> {
572 headers.inner().get(key).and_then(|v| match v {
573 AMQPValue::ShortString(s) => Some(s.to_string()),
574 AMQPValue::LongString(s) => Some(s.to_string()),
575 _ => None,
576 })
577}
578
579fn get_header_str_required(headers: &FieldTable, key: &str) -> Result<String, ProtocolError> {
580 get_header_str(headers, key).ok_or_else(|| ProtocolError::MissingRequiredHeader(key.into()))
581}
582
583fn get_header_dt(headers: &FieldTable, key: &str) -> Option<DateTime<Utc>> {
584 if let Some(s) = get_header_str(headers, key) {
585 match DateTime::parse_from_rfc3339(&s) {
586 Ok(dt) => Some(DateTime::<Utc>::from(dt)),
587 _ => None,
588 }
589 } else {
590 None
591 }
592}
593
594fn get_header_u32(headers: &FieldTable, key: &str) -> Option<u32> {
595 headers.inner().get(key).and_then(amqp_value_to_u32)
596}
597
598fn amqp_value_to_u32(v: &AMQPValue) -> Option<u32> {
599 match v {
600 AMQPValue::ShortShortInt(n) => Some(*n as u32),
601 AMQPValue::ShortShortUInt(n) => Some(*n as u32),
602 AMQPValue::ShortInt(n) => Some(*n as u32),
603 AMQPValue::ShortUInt(n) => Some(*n as u32),
604 AMQPValue::LongInt(n) => Some(*n as u32),
605 AMQPValue::LongUInt(n) => Some(*n),
606 AMQPValue::LongLongInt(n) => Some(*n as u32),
607 _ => None,
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use std::time::SystemTime;
615
616 #[test]
617 fn test_conversion() {
619 let now = DateTime::<Utc>::from(SystemTime::now());
620
621 let now_str = now.to_rfc3339_opts(SecondsFormat::Millis, false);
624 let now = DateTime::<Utc>::from(DateTime::parse_from_rfc3339(&now_str).unwrap());
625
626 let message = Message {
627 properties: MessageProperties {
628 correlation_id: "aaa".into(),
629 content_type: "application/json".into(),
630 content_encoding: "utf-8".into(),
631 reply_to: Some("bbb".into()),
632 delivery_info: None,
633 },
634 headers: MessageHeaders {
635 id: "aaa".into(),
636 task: "add".into(),
637 lang: Some("rust".into()),
638 root_id: Some("aaa".into()),
639 parent_id: Some("000".into()),
640 group: Some("A".into()),
641 meth: Some("method_name".into()),
642 shadow: Some("add-these".into()),
643 eta: Some(now),
644 expires: Some(now),
645 retries: Some(1),
646 timelimit: (Some(30), Some(60)),
647 argsrepr: Some("(1)".into()),
648 kwargsrepr: Some("{'y': 2}".into()),
649 origin: Some("gen123@piper".into()),
650 },
651 raw_body: vec![],
652 };
653
654 let serialized = message.json_serialized(None).unwrap();
660 let serialized_str = std::str::from_utf8(&serialized).unwrap();
661 let deserialized: crate::protocol::Delivery = serde_json::from_str(serialized_str).unwrap();
662 let message2 = deserialized.try_deserialize_message();
663 assert!(message2.is_ok());
664
665 let message2 = message2.unwrap();
666 assert_eq!(message, message2);
667 }
668}