use super::connection::AMQPConnection;
use super::{AMQPHandler, AMQPResult, Context};
use crate::amqp::{Error, create_delivery_context};
use amqprs::channel::{
BasicAckArguments, BasicConsumeArguments, BasicNackArguments, BasicQosArguments,
ConsumerMessage,
};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tracing::Instrument;
pub struct Consumer<H, P, T, Err> {
context: Arc<Context>,
connection: AMQPConnection,
consume_args: BasicConsumeArguments,
qos_args: BasicQosArguments,
handler: H,
notifier: Arc<Notify>,
_phantom: std::marker::PhantomData<(P, T, Err)>,
}
impl<H, P, T, Err> Consumer<H, P, T, Err>
where
H: AMQPHandler<P, T, Err>,
T: AMQPResult,
{
pub fn new(
context: Arc<Context>,
connection: AMQPConnection,
options: BasicConsumeArguments,
qos_args: BasicQosArguments,
handler: H,
notifier: Arc<Notify>,
) -> Self {
Self {
context,
connection,
consume_args: options,
handler,
qos_args,
notifier,
_phantom: std::marker::PhantomData,
}
}
fn handle_delivery(
&self,
delivery: ConsumerMessage,
channel: &amqprs::channel::Channel,
skip_ack: bool,
tasks: &mut JoinSet<()>,
) {
let (delivery_context, payload) = create_delivery_context(delivery, &self.context, channel);
let handler = self.handler.clone();
#[cfg(feature = "telemetry")]
let span = super::telemetry::make_span_from_delivery_context(&delivery_context);
#[cfg(not(feature = "telemetry"))]
let span = {
let delivery_tag = delivery_context.delivery_tag;
tracing::span!(tracing::Level::INFO, "streameroo::consumer", delivery_tag)
};
let fut = async move {
match handler.call(payload, &delivery_context).await {
Ok(ret) => match ret.handle_result(&delivery_context).await {
Ok(_) => {
if skip_ack {
return;
}
if let Err(e) = delivery_context
.channel
.basic_ack(BasicAckArguments {
delivery_tag: delivery_context.delivery_tag,
multiple: false,
})
.await
{
tracing::error!(?e, "Error acking delivery");
}
}
Err(e) => {
tracing::error!(?e, "Error processing AMQPResult. Nacking delivery");
let nack_args = BasicNackArguments {
delivery_tag: delivery_context.delivery_tag,
multiple: false,
requeue: true,
};
if let Err(e) = delivery_context.channel.basic_nack(nack_args).await {
tracing::error!(?e, "Error nacking delivery");
}
}
},
Err(Error::Event(e)) => {
tracing::error!(?e, "Error decoding event. Nacking without requeue");
let nack_args = BasicNackArguments {
delivery_tag: delivery_context.delivery_tag,
multiple: false,
requeue: false,
};
if let Err(e) = delivery_context.channel.basic_nack(nack_args).await {
tracing::error!(?e, "Error acking delivery");
}
}
Err(e) => {
tracing::error!(?e, "Error calling AMQP handler");
let nack_args = BasicNackArguments {
delivery_tag: delivery_context.delivery_tag,
multiple: false,
requeue: true,
};
tracing::info!(?nack_args, "Nacking delivery");
if let Err(e) = delivery_context.channel.basic_nack(nack_args).await {
tracing::error!(?e, "Error nacking delivery");
}
}
}
};
tasks.spawn(fut.instrument(span));
}
pub async fn consume(self) {
let notified = self.notifier.notified();
tokio::pin!(notified);
let mut tasks = JoinSet::new();
let skip_ack = T::manual() || self.consume_args.no_ack;
let mut channel;
'outer: loop {
tracing::info!("Creating channel for consumer");
channel = match self.connection.open_channel().await {
Ok(channel) => channel,
Err(e) => {
tracing::error!(?e, "Failed to create channel for consumer");
continue;
}
};
if let Err(e) = channel.basic_qos(self.qos_args.clone()).await {
tracing::error!(?e, "Failed to set qos for consumer");
continue;
}
let (_, mut consumer_rx) =
match channel.basic_consume_rx(self.consume_args.clone()).await {
Ok(consume) => consume,
Err(e) => {
tracing::error!(?e, "Failed to start consuming");
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
loop {
tokio::select! {
biased;
_ = &mut notified => {
break 'outer;
},
delivery = consumer_rx.recv() => {
if let Some(delivery) = delivery {
self.handle_delivery(delivery, &channel, skip_ack, &mut tasks);
} else {
tracing::warn!("Consumer closed unexpectedly");
break
}
}
}
while tasks.try_join_next().is_some() {}
}
}
tasks.join_all().await;
if let Err(e) = channel.close().await {
tracing::error!(?e, "Failed to close channel in consumer");
}
}
}