use std::{any::type_name, fmt::Display, marker::PhantomData, sync::Arc, task::Poll};
use async_trait::async_trait;
use dashmap::DashMap;
use futures::{Future, Stream, StreamExt};
use lapin::{options::BasicConsumeOptions, BasicProperties};
use serde::{Deserialize, Serialize};
use tokio::{
sync::mpsc,
task::{self, JoinHandle},
};
use tracing::{debug, error, warn};
use uuid::Uuid;
use crate::{delivery_uuid, error::Error, fmt_correlation_id, Bus, Connection, Delivery, Result};
use super::{direct::DirectBus, Channel, Consumer, Publisher};
mod comm;
pub use comm::*;
pub trait RpcBus: DirectBus {
type ReplyPayload;
fn serialize_reply(payload: &Self::ReplyPayload) -> Result<Vec<u8>>;
fn deserialize_reply(bytes: &[u8]) -> Result<Self::ReplyPayload>;
}
#[derive(Clone)]
pub struct RpcChannel {
inner: lapin::Channel,
pending_replies: Arc<DashMap<Uuid, mpsc::UnboundedSender<lapin::message::Delivery>>>,
}
#[derive(Debug)]
pub struct Reply<B> {
_marker: PhantomData<B>,
}
impl<B: RpcBus> Bus for Reply<B> {
type Chan = RpcChannel;
type PublishPayload = B::ReplyPayload;
fn serialize_payload(payload: &Self::PublishPayload) -> Result<Vec<u8>> {
B::serialize_reply(payload)
}
fn deserialize_payload(bytes: &[u8]) -> Result<Self::PublishPayload> {
B::deserialize_reply(bytes)
}
}
impl<B: RpcBus> DirectBus for Reply<B> {
type Args = B::Args;
fn queue(args: Self::Args) -> String {
B::queue(args)
}
}
impl<B: RpcBus> RpcBus for Reply<B> {
type ReplyPayload = B::PublishPayload;
fn serialize_reply(payload: &Self::ReplyPayload) -> Result<Vec<u8>> {
B::serialize_payload(payload)
}
fn deserialize_reply(bytes: &[u8]) -> Result<Self::ReplyPayload> {
B::deserialize_payload(bytes)
}
}
impl RpcChannel {
pub async fn new(connection: &Connection) -> Result<RpcChannel> {
let chan = connection.inner.create_channel().await?;
let pending_replies: DashMap<Uuid, mpsc::UnboundedSender<lapin::message::Delivery>> =
DashMap::new();
let pending_replies = Arc::new(pending_replies);
let reply_consumer = chan
.basic_consume(
"amq.rabbitmq.reply-to",
&Uuid::new_v4().to_string(),
BasicConsumeOptions {
no_ack: true,
..Default::default()
},
Default::default(),
)
.await?;
let handle_replies: JoinHandle<Result<()>> = task::spawn({
let mut reply_consumer = reply_consumer;
let pending_replies = pending_replies.clone();
async move {
while let Some(msg_res) = reply_consumer.next().await {
match msg_res {
Ok(msg) => {
let forward_reply: JoinHandle<()> = task::spawn_blocking({
let pending_replies = pending_replies.clone();
move || {
let reply_id = match delivery_uuid(&msg, 1) {
Some(Ok(i)) => i,
Some(Err(e)) => {
error!("Error parsing reply message correlation UUID: {e:?}. Dropping message.");
return;
}
None => {
error!("Received reply with nog correlation ID. Dropping message.");
return;
}
};
let forwarding_success =
if let Some(tx) = pending_replies.get(&reply_id) {
tx.send(msg).is_ok()
} else {
false
};
if !forwarding_success {
warn!("Received reply cannot be forwarded due to dropped Receiver. UUID: {}", reply_id);
}
}
});
drop(forward_reply);
}
Err(e) => error!("Error receiving reply message: {e:?}"),
}
}
panic!("Task handle_replies ended");
}
});
drop(handle_replies);
Ok(RpcChannel {
inner: chan,
pending_replies,
})
}
fn register_pending_reply<B: DirectBus>(
&self,
correlation_uuid: Uuid,
) -> impl Stream<Item = Delivery<B>> {
let (tx, rx) = mpsc::unbounded_channel();
debug!("Registering pending reply for correlation UUID {correlation_uuid}");
let rx = ReplyReceiver {
correlation_uuid,
inner: rx,
chan: Some(self.clone()),
_marker: PhantomData,
};
self.pending_replies.insert(correlation_uuid, tx);
rx
}
fn remove_pending_reply(&self, correlation_uuid: &Uuid) {
self.pending_replies.remove(correlation_uuid);
}
pub async fn consumer<B: RpcBus>(
&self,
args: B::Args,
consumer_tag: &str,
) -> Result<Consumer<B>> {
let queue = B::queue(args);
self.inner
.queue_declare(&queue, Default::default(), Default::default())
.await?;
let consumer = self
.inner
.basic_consume(&queue, consumer_tag, Default::default(), Default::default())
.await?;
debug!(
"Created consumer for RPC bus {} for queue {queue} with consumer tag {consumer_tag}",
type_name::<B>()
);
Ok(Consumer {
inner: consumer,
_marker: PhantomData,
})
}
pub fn publisher<B: RpcBus<Chan = Self>>(&self) -> Publisher<B> {
debug!("Created publisher for RPC bus {}", type_name::<B>());
Publisher { chan: self.clone() }
}
}
#[async_trait]
impl Channel for RpcChannel {
async fn publish_with_properties(
&self,
payload_bytes: &[u8],
routing_key: &str,
properties: lapin::BasicProperties,
correlation_uuid: Uuid,
reply_uuid: Option<Uuid>,
) -> Result<()> {
let correlation_id = fmt_correlation_id(correlation_uuid, reply_uuid);
debug!("Publishing message with correlation ID {correlation_id} an RPC channel with routing key {routing_key}");
let properties = properties.with_correlation_id(correlation_id.into());
self.inner
.basic_publish(
"",
routing_key,
Default::default(),
payload_bytes,
properties,
)
.await?;
Ok(())
}
}
impl<'r, 'p, B> Publisher<B>
where
B: RpcBus<Chan = RpcChannel>,
B::PublishPayload: Deserialize<'p> + Serialize,
B::ReplyPayload: Deserialize<'r> + Serialize,
{
pub async fn publish_recv_many(
&self,
args: B::Args,
payload: &B::PublishPayload,
) -> Result<impl Stream<Item = Delivery<Reply<B>>>> {
let correlation_uuid = Uuid::new_v4();
let rx = self.chan.register_pending_reply(correlation_uuid);
let properties = BasicProperties::default().with_reply_to("amq.rabbitmq.reply-to".into());
debug!("Publishing message with correlation UUID {correlation_uuid}, expecting one or more replies");
self.publish_with_properties(&B::queue(args), payload, properties, correlation_uuid, None)
.await?;
Ok(rx)
}
pub async fn publish_recv_one(
&'r self,
args: B::Args,
payload: &B::PublishPayload,
) -> Result<impl Future<Output = Option<Delivery<Reply<B>>>>> {
let rx = self.publish_recv_many(args, payload).await?;
Ok(async move { rx.take(1).next().await })
}
}
impl<'p, 'r, B> Delivery<B>
where
B: RpcBus,
B::PublishPayload: Deserialize<'p> + Serialize,
B::ReplyPayload: Deserialize<'r> + Serialize,
{
pub async fn reply(&self, reply_payload: &B::ReplyPayload, chan: &impl Channel) -> Result<()> {
let Some(correlation_uuid) = self.get_uuid() else {
return Err(Error::Reply(ReplyError::NoCorrelationUuid));
};
let Some(reply_to) = self.inner.properties.reply_to().as_ref().map(|r | r.as_str()) else {
return Err(Error::Reply(ReplyError::NoReplyToConfigured))
};
let reply_uuid = correlation_uuid?;
let bytes = B::serialize_reply(reply_payload)?;
debug!("Replying to message with correlation UUID {reply_uuid}");
let correlation_uuid = Uuid::new_v4();
chan.publish_with_properties(
&bytes,
reply_to,
Default::default(),
correlation_uuid,
Some(reply_uuid),
)
.await
}
}
struct ReplyReceiver<B> {
correlation_uuid: Uuid,
inner: mpsc::UnboundedReceiver<lapin::message::Delivery>,
chan: Option<RpcChannel>,
_marker: PhantomData<B>,
}
impl<B: Unpin> Stream for ReplyReceiver<B> {
type Item = Delivery<B>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.inner.poll_recv(cx).map(|msg| msg.map(|m| m.into()))
}
}
impl<B> Drop for ReplyReceiver<B> {
fn drop(&mut self) {
let chan = self.chan.take().unwrap();
let correlation_uuid = self.correlation_uuid;
debug!(
"Closed reply receiver for correlation UUID {correlation_uuid} and RPC bus {}",
type_name::<B>()
);
task::spawn_blocking(move || chan.remove_pending_reply(&correlation_uuid));
}
}
#[derive(Debug)]
pub enum ReplyError {
NoCorrelationUuid,
NoReplyToConfigured,
}
impl Display for ReplyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReplyError::NoCorrelationUuid => {
write!(f, "No correlation Uuid configured for the message")
}
ReplyError::NoReplyToConfigured => {
write!(f, "No value configured for the reply-to field")
}
}
}
}
impl std::error::Error for ReplyError {}
#[cfg(test)]
pub use tests::*;
#[cfg(test)]
mod tests {
use std::time::Duration;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::time::timeout;
use uuid::Uuid;
use crate::{
chan::tests::{FramePayload, RABBIT_MQ_URL},
rpc_bus, setup_test_logging, Connection, Consumer, Publisher, RpcChannel,
};
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum FrameSendError {
ClientDisconnected,
Other,
}
rpc_bus!(FrameBus, FramePayload, Result<(), FrameSendError>, u32, |args| format!(
"frame_{}",
args,
),
serde_json::to_vec,
serde_json::from_slice
);
#[tokio::test]
async fn publish_recv_many() -> crate::Result<()> {
let _log = setup_test_logging();
let connection: Connection = Connection::connect(RABBIT_MQ_URL).await?;
let uuid = Uuid::new_v4();
tokio::task::spawn({
let channel = RpcChannel::new(&connection).await?;
let mut consumer: Consumer<FrameBus> =
channel.consumer(3, &Uuid::new_v4().to_string()).await?;
async move {
let msg = consumer.next().await.unwrap().unwrap();
msg.ack(false).await.unwrap();
let payload = msg.get_payload().unwrap();
assert_eq!(payload.message, uuid.to_string());
for _ in 0..3 {
msg.reply(&Err(FrameSendError::ClientDisconnected), &channel)
.await
.unwrap();
}
}
});
let channel = RpcChannel::new(&connection).await?;
let publisher: Publisher<FrameBus> = channel.publisher();
let mut rx = publisher
.publish_recv_many(
3,
&FramePayload {
message: uuid.to_string(),
},
)
.await?;
for _ in 0..3 {
timeout(Duration::from_secs(1), rx.next()).await.unwrap();
}
Ok(())
}
#[tokio::test]
async fn publish_recv_one() -> Result<(), crate::Error> {
let _log = setup_test_logging();
let connection = Connection::connect(RABBIT_MQ_URL).await.unwrap();
let uuid = Uuid::new_v4();
tokio::task::spawn({
let channel = RpcChannel::new(&connection).await.unwrap();
let mut consumer: Consumer<FrameBus> =
channel.consumer(4, &Uuid::new_v4().to_string()).await?;
async move {
let msg = consumer.next().await.unwrap().unwrap();
msg.ack(false).await.unwrap();
let payload = msg.get_payload().unwrap();
assert_eq!(payload.message, uuid.to_string());
msg.reply(&Err(FrameSendError::ClientDisconnected), &channel)
.await
.unwrap();
}
});
let channel = RpcChannel::new(&connection).await.unwrap();
let publisher: Publisher<FrameBus> = channel.publisher();
let fut = publisher
.publish_recv_one(
4,
&FramePayload {
message: uuid.to_string(),
},
)
.await
.unwrap();
timeout(Duration::from_secs(1), fut).await.unwrap();
Ok(())
}
}
#[macro_export]
macro_rules! rpc_bus {
($doc:literal, $bus:ident, $publish_payload:ty, $reply_payload:ty, $args:ty, $queue:expr, $serialize:expr, $deserialize:expr) => {
$crate::bus!($doc, $bus);
$crate::bus_impl!(
$bus,
$crate::RpcChannel,
$publish_payload,
$serialize,
$deserialize
);
$crate::direct_bus_impl!($bus, $args, $queue);
$crate::rpc_bus_impl!(
$bus,
$reply_payload,
$serialize,
$deserialize
);
};
(doc = $doc:literal, bus = $bus:ident, publish = $publish_payload:ty, reply = $reply_payload:ty, args = $args:ty, queue = $queue:expr, serialize = $serialize:expr, deserialize = $deserialize:expr) => {
$crate::rpc_bus!(
$doc,
$bus,
$publish_payload,
$reply_payload,
$args,
$queue,
$serialize,
$deserialize
);
};
($bus:ident, $publish_payload:ty, $reply_payload:ty, $args:ty, $queue:expr, $serialize:expr, $deserialize:expr) => {
$crate::rpc_bus!(
"",
$bus,
$publish_payload,
$reply_payload,
$args,
$queue,
$serialize,
$deserialize
);
};
(bus = $bus:ident, publish = $publish_payload:ty, reply = $reply_payload:ty, args = $args:ty, queue = $queue:expr, serialize = $serialize:expr, deserialize = $deserialize:expr) => {
$crate::rpc_bus!(
$bus,
$publish_payload,
$reply_payload,
$args,
$queue,
$serialize,
$deserialize
);
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! rpc_bus_impl {
($bus:ident, $reply_payload:ty, $serialize:expr, $deserialize:expr) => {
impl $crate::RpcBus for $bus {
type ReplyPayload = $reply_payload;
fn serialize_reply(payload: &Self::ReplyPayload) -> $crate::Result<Vec<u8>> {
#[allow(clippy::redundant_closure_call)]
($serialize)(payload).map_err(|e| $crate::Error::Serde(Box::new(e)))
}
fn deserialize_reply(bytes: &[u8]) -> $crate::Result<Self::ReplyPayload> {
#[allow(clippy::redundant_closure_call)]
($deserialize)(bytes).map_err(|e| $crate::Error::Serde(Box::new(e)))
}
}
};
}