use std::{borrow::Cow, collections::HashMap, fmt::Display, future::Future, time::SystemTime};
use anyhow::Context as _;
use async_channel::Receiver;
use chrono::{DateTime, Utc};
use grpc_build_core::NamedMessage;
use lapin::{
acker::Acker,
options::BasicQosOptions,
types::{AMQPValue, FieldTable},
Channel,
};
use opentelemetry::{
trace::{FutureExt, SpanKind, TraceContextExt, Tracer},
KeyValue,
};
use prost::Message as _;
use tokio::task::JoinHandle;
use crate::{ext::timestamp_to_utc, protogen::amqpsy::message::AmqpMessageWrapper, telemetry};
use super::{new_ampq_channel, AmqpConfig};
pub enum ConsumerGroup {
Exchange {
name: String,
},
Queue {
name: String,
exchange_name: String,
routing_key: String,
handler: Vec<JoinHandle<anyhow::Result<()>>>,
},
}
pub struct ConsumerGroups(Vec<ConsumerGroup>);
impl ConsumerGroups {
pub async fn run_until_shutdown(self) -> anyhow::Result<()> {
let mut handlers = Vec::new();
for resource in self.0 {
match resource {
ConsumerGroup::Exchange { name: _ } => {}
ConsumerGroup::Queue { handler, .. } => {
handlers.extend(handler);
}
}
}
futures::future::join_all(handlers).await;
Ok(())
}
}
pub struct ConsumerGroupBuilder<C: Clone + Send + 'static> {
items: Vec<ConsumerGroup>,
context: C,
config: AmqpConfig,
app_name: String,
}
impl ConsumerGroup {
pub fn builder<C: Send + Clone + 'static>(
app_name: &str,
config: AmqpConfig,
context: C,
) -> ConsumerGroupBuilder<C> {
ConsumerGroupBuilder {
items: <_>::default(),
context,
config,
app_name: app_name.to_string(),
}
}
}
struct Exchange {
_ty: ExchangeKind,
name: String,
}
impl Exchange {
pub fn name(&self) -> &str {
&self.name
}
pub fn name_dead_letter(&self) -> String {
format!("{}.dlx", self.name)
}
}
enum ExchangeKind {
Topic,
}
pub struct TopicExchangeBuilder<C: Send + Clone + 'static> {
factory: ConsumerGroupBuilder<C>,
exchange: Exchange,
items: Vec<ConsumerGroup>,
}
#[derive(Debug, Clone, Copy)]
pub struct ConsumerConfig {
pub prefetch_count: u16,
pub worker_count: u8,
}
impl ConsumerConfig {
pub fn with_prefetch_count(self, prefetch_count: u16) -> Self {
Self {
prefetch_count,
..self
}
}
pub fn with_worker_count(self, worker_count: u8) -> Self {
Self {
worker_count,
..self
}
}
}
impl Default for ConsumerConfig {
fn default() -> Self {
Self {
prefetch_count: 50,
worker_count: 3,
}
}
}
impl<C: Send + Clone + 'static> ConsumerGroupBuilder<C> {
pub fn for_topic_exchange(self, name: &str) -> TopicExchangeBuilder<C> {
TopicExchangeBuilder {
factory: self,
exchange: Exchange {
_ty: ExchangeKind::Topic,
name: name.to_string(),
},
items: Vec::new(),
}
}
pub fn done(self) -> ConsumerGroups {
ConsumerGroups(self.items)
}
}
impl<C: Send + Clone + 'static> TopicExchangeBuilder<C> {
pub async fn consume<Msg: prost::Message + NamedMessage + Default + Default + 'static>(
self,
handler: impl AmqpHandler<Message = Msg, Context = C> + Send + Clone + 'static,
) -> anyhow::Result<Self> {
self.consume_with_config(handler, ConsumerConfig::default())
.await
}
pub async fn consume_with_config<
Msg: prost::Message + NamedMessage + Default + Default + 'static,
>(
mut self,
handler: impl AmqpHandler<Message = Msg, Context = C> + Send + Clone + 'static,
consumer_config: ConsumerConfig,
) -> anyhow::Result<Self> {
let exchange_name = self.exchange.name.clone();
let routing_key = Msg::NAME.to_string();
let prefix = self.factory.app_name.clone();
let channel = self
.create_or_ensure_exchange()
.await
.context("create or ensure exchange")?;
let queue_name = generate_queue_name(&prefix, &exchange_name, &routing_key);
let queue_name_dead_letter = format!("{}-deadletter", queue_name);
let dnx_name = self.exchange.name_dead_letter();
queue_declare(&channel, &queue_name_dead_letter, Default::default()).await?;
channel
.queue_bind(
&queue_name_dead_letter,
&dnx_name,
&routing_key,
<_>::default(),
<_>::default(),
)
.await
.context("queue bind to dead letter exchange")?;
let mut args = FieldTable::default();
args.insert(
"x-dead-letter-exchange".into(),
AMQPValue::LongString(dnx_name.into()),
);
args.insert(
"x-dead-letter-routing-key".into(),
AMQPValue::LongString(routing_key.clone().into()),
);
queue_declare(&channel, &queue_name, args).await?;
tracing::info!(
"Queue binding: {} ({}) -> {}",
exchange_name,
routing_key,
queue_name,
);
channel
.queue_bind(
&queue_name,
&exchange_name,
&routing_key,
<_>::default(),
<_>::default(),
)
.await
.context("queue bind")?;
self.items.push(ConsumerGroup::Queue {
name: queue_name.to_string(),
exchange_name: exchange_name.to_string(),
routing_key: routing_key.to_string(),
handler: spawn_consumer(
self.factory.config.clone(),
self.factory.context.clone(),
handler,
queue_name.to_string(),
consumer_config,
)
.await?,
});
Ok(self)
}
pub async fn run_until_shutdown(self) -> anyhow::Result<()> {
self.done().run_until_shutdown().await
}
pub fn done(self) -> ConsumerGroups {
self.then().done()
}
pub fn then(mut self) -> ConsumerGroupBuilder<C> {
self.factory.items.push(ConsumerGroup::Exchange {
name: self.exchange.name,
});
self.factory.items.extend(self.items);
self.factory
}
async fn create_or_ensure_exchange(&self) -> anyhow::Result<Channel> {
tracing::debug!("Ensuring topic exchange: {}", self.exchange.name);
let channel = new_ampq_channel(&self.factory.config).await?;
channel
.exchange_declare(
self.exchange.name(),
lapin::ExchangeKind::Topic,
lapin::options::ExchangeDeclareOptions {
passive: false,
durable: true,
auto_delete: false,
internal: false,
nowait: false,
},
<_>::default(),
)
.await
.context("declare main exchange")?;
let dead_letter_exchange = self.exchange.name_dead_letter();
channel
.exchange_declare(
&dead_letter_exchange,
lapin::ExchangeKind::Direct,
lapin::options::ExchangeDeclareOptions {
passive: false,
durable: true,
auto_delete: false,
internal: false,
nowait: false,
},
<_>::default(),
)
.await
.context("declare dead letter exchange")?;
Ok(channel)
}
}
async fn queue_declare(
channel: &Channel,
queue_name: &str,
args: FieldTable,
) -> Result<(), anyhow::Error> {
channel
.queue_declare(
queue_name,
lapin::options::QueueDeclareOptions {
passive: false,
durable: true,
exclusive: false,
auto_delete: false,
nowait: false,
},
args,
)
.await
.context("declare queue")?;
Ok(())
}
fn generate_queue_name(prefix: &str, exchange_name: &str, routing_key: &str) -> String {
let input = format!("{}-{}-{}", prefix, exchange_name, routing_key);
input
}
pub struct Message<Msg> {
message_id: String,
exchange: String,
routing_key: String,
published_at: DateTime<Utc>,
delivery: Acker,
headers: HashMap<String, String>,
msg: Msg,
}
async fn spawn_consumer<
Msg: prost::Message + NamedMessage + Default + 'static,
Ctx: Send + Clone + 'static,
>(
config: AmqpConfig,
context: Ctx,
handler: impl AmqpHandler<Message = Msg, Context = Ctx> + Send + Clone + 'static,
queue_name: String,
consumer_config: ConsumerConfig,
) -> anyhow::Result<Vec<tokio::task::JoinHandle<anyhow::Result<()>>>> {
let (tx, rx) = async_channel::bounded::<Message<Msg>>(consumer_config.prefetch_count as _);
let mut handles = Vec::with_capacity(consumer_config.worker_count as usize + 1);
let consumer_tag = format!("cmr_{}", epoch());
for idx in 0..consumer_config.worker_count {
let handler = handler.clone();
let rx = rx.clone();
let queue_name = queue_name.clone();
let worker = spawn_worker(
rx,
context.clone(),
queue_name,
consumer_tag.clone(),
idx,
handler,
);
handles.push(worker);
}
let consumer = tokio::task::spawn(async move {
create_consumer(consumer_tag, &queue_name, &config, consumer_config, tx).await
});
handles.push(consumer);
Ok(handles)
}
fn spawn_worker<Msg, Ctx>(
rx: Receiver<Message<Msg>>,
context: Ctx,
queue_name: String,
consumer_tag: String,
worker_seq: u8,
handler: impl AmqpHandler<Message = Msg, Context = Ctx> + Send + 'static,
) -> tokio::task::JoinHandle<anyhow::Result<()>>
where
Msg: prost::Message + NamedMessage + Default + Send + 'static,
Ctx: Clone + Send + 'static,
{
tokio::task::spawn(async move {
tracing::debug!("Started worker for {}", Msg::NAME);
while let Ok(message) = rx.recv().await {
let now = chrono::Utc::now();
let handler_result = {
let parent_ctx = telemetry::extract_context_from_hash_map(&message.headers);
let tracer = opentelemetry::global::tracer(file!());
let span = tracer
.span_builder(Cow::Owned(format!(
"consume_{} (Queue: {})",
Msg::NAME,
queue_name.clone()
)))
.with_kind(SpanKind::Server)
.with_attributes(vec![
KeyValue::new("amqpsy.amqp.consumer.tag", consumer_tag.clone()),
KeyValue::new("amqpsy.amqp.consumer.worker_seq", worker_seq.to_string()),
KeyValue::new("amqpsy.amqp.message.id", message.message_id),
KeyValue::new("amqpsy.amqp.message.routing_key", message.routing_key),
KeyValue::new("amqpsy.amqp.message.exchange", message.exchange),
KeyValue::new("amqpsy.amqp.message.queue_name", queue_name.clone()),
KeyValue::new(
"amqpsy.amqp.message.first_published_at",
message.published_at.to_rfc3339(),
),
KeyValue::new("amqpsy.amqp.message.last_consumed_at", now.to_rfc3339()),
])
.start_with_context(&tracer, &parent_ctx);
let otel_context = opentelemetry::Context::default().with_span(span);
handler
.handle(context.clone(), message.msg)
.with_context(otel_context)
.await
};
match handler_result {
Ok(()) | Err(ConsumerError::Invalid(_)) => {
message
.delivery
.ack(lapin::options::BasicAckOptions { multiple: false })
.await?;
}
Err(ConsumerError::Fatal(e)) => {
tracing::error!({error.raw = ?e}, "Consumer fatal error: {}", e);
message
.delivery
.reject(lapin::options::BasicRejectOptions { requeue: false })
.await?;
}
Err(ConsumerError::Retry(e)) => {
tracing::error!({error.raw = ?e}, "Consumer transient error: {}", e);
message
.delivery
.nack(lapin::options::BasicNackOptions {
multiple: false,
requeue: true,
})
.await?;
}
}
}
tracing::info!("Worker finished for {}", Msg::NAME);
Ok(())
})
}
fn epoch() -> u128 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis()
}
async fn create_consumer<Msg: prost::Message + NamedMessage + Default + 'static>(
tag: String,
queue_name: &str,
config: &AmqpConfig,
consumer_config: ConsumerConfig,
sender: async_channel::Sender<Message<Msg>>,
) -> anyhow::Result<()> {
let channel = new_ampq_channel(config).await?;
channel
.basic_qos(consumer_config.prefetch_count, BasicQosOptions::default())
.await?;
let mut consumer = channel
.basic_consume(
queue_name,
&tag,
lapin::options::BasicConsumeOptions::default(),
<_>::default(),
)
.await?;
use futures_lite::StreamExt;
loop {
let item = consumer.next().await;
match item {
Some(Ok(delivery)) => match AmqpMessageWrapper::decode(delivery.data.as_slice()) {
Ok(wrapper) => {
let headers = wrapper.headers;
match Msg::decode(wrapper.payload.as_slice()) {
Ok(decoded_message) => {
tracing::info!(
"[{}] Consuming: {} ({}) -> {}",
Msg::NAME,
delivery.exchange,
delivery.routing_key,
queue_name
);
sender
.send(Message {
delivery: delivery.acker,
message_id: wrapper.id,
exchange: wrapper.exchange,
routing_key: wrapper.routing_key,
published_at: wrapper
.created_at
.map(timestamp_to_utc)
.unwrap_or_default(),
headers,
msg: decoded_message,
})
.await?;
}
Err(decode_error) => {
tracing::warn!({error.raw = ?decode_error}, "Failed to decode message from queue {}", queue_name);
}
}
}
Err(decoding_error) => {
tracing::warn!({error.raw = ?decoding_error}, "Invalid message in the queue {} - not a message of `AmqpMessageWrapper` (message published by other pubsub library?)", queue_name);
if let Err(e) = delivery
.reject(lapin::options::BasicRejectOptions { requeue: false })
.await
{
tracing::warn!("Error rejecting invalid message: {}", e);
}
}
},
Some(Err(e)) => {
tracing::warn!({error.raw = ?e}, "Failed to fetch message from queue {}", queue_name);
}
None => {
tracing::warn!("Consumer cancelled for queue {}", queue_name);
break;
}
}
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum ConsumerError {
#[error("Fatal error: {0}")]
Fatal(anyhow::Error),
#[error("Invalid request: {0}")]
Invalid(anyhow::Error),
#[error("Retry: {0}")]
Retry(anyhow::Error),
}
pub trait AmqpHandler {
type Message;
type Context;
fn handle(
&self,
context: Self::Context,
message: Self::Message,
) -> impl Future<Output = Result<(), ConsumerError>> + Send;
}
pub trait ConsumerErrorExt<T> {
fn or_transient_error(self) -> Result<T, ConsumerError>;
fn or_fatal_error(self) -> Result<T, ConsumerError>;
fn or_invalid_request(self) -> Result<T, ConsumerError>;
}
impl<T, E: Display> ConsumerErrorExt<T> for Result<T, E> {
fn or_transient_error(self) -> Result<T, ConsumerError> {
self.map_err(|e| ConsumerError::Retry(anyhow::anyhow!("{}", e)))
}
fn or_fatal_error(self) -> Result<T, ConsumerError> {
self.map_err(|e| ConsumerError::Fatal(anyhow::anyhow!("{}", e)))
}
fn or_invalid_request(self) -> Result<T, ConsumerError> {
self.map_err(|e| ConsumerError::Invalid(anyhow::anyhow!("{}", e)))
}
}