use std::{future::Future, sync::Arc};
use serde::{Serialize, de::DeserializeOwned};
use tracing::warn;
use crate::IncomingMessage;
use crate::codec::Codec;
use super::batch::{BatchHandler, decode_batch, settle};
use super::context::Context;
use super::dispatch::Workers;
use super::failure::{FailurePolicies, FailurePolicy};
use super::handler::HandlerResult;
use super::metadata::HandlerMetadata;
use super::publish::{PublishMiddleware, ReplyPublisher};
pub trait BatchPublishingDef: Send + Sync {
type Input;
type Reply;
type Source;
fn source(&self) -> Self::Source;
fn reply_name(&self) -> &str;
fn workers(&self) -> Workers {
Workers::sequential()
}
fn failure_policies(&self) -> FailurePolicies {
FailurePolicies::default()
}
fn description(&self) -> Option<&str> {
None
}
fn input_schema(&self) -> Option<String> {
None
}
fn message_name(&self) -> Option<&'static str> {
None
}
fn message_description(&self) -> Option<&'static str> {
None
}
fn call(
&self,
batch: &[Self::Input],
ctx: &mut Context<'_>,
) -> impl Future<Output = Result<Vec<Self::Reply>, HandlerResult>> + Send;
}
pub(crate) fn batch_publishing_metadata<D: BatchPublishingDef>(
name: String,
def: &D,
) -> HandlerMetadata {
HandlerMetadata::typed::<D::Input>(name)
.with_output_type(std::any::type_name::<D::Reply>())
.with_def_details(
def.description(),
def.input_schema(),
def.message_name(),
def.message_description(),
)
}
pub struct BatchPublishingHandler<D, C, R> {
pub(crate) def: D,
pub(crate) codec: C,
pub(crate) publisher: R,
pub(crate) pipeline: Arc<[Arc<dyn PublishMiddleware>]>,
pub(crate) decode: FailurePolicy,
}
impl<D, C, R> std::fmt::Debug for BatchPublishingHandler<D, C, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchPublishingHandler")
.finish_non_exhaustive()
}
}
impl<M, D, C, R> BatchHandler<M> for BatchPublishingHandler<D, C, R>
where
M: IncomingMessage,
D: BatchPublishingDef,
D::Input: DeserializeOwned + Send + Sync,
D::Reply: Serialize + Send + Sync,
C: Codec,
R: ReplyPublisher,
{
async fn handle_batch(&self, batch: Vec<M>, ctx: &mut Context<'_>) {
let subscription = ctx.name().to_owned();
let (values, accepted) =
decode_batch::<M, D::Input, C>(batch, &self.codec, self.decode, ctx).await;
if accepted.is_empty() {
return;
}
let outcome = match self.def.call(&values, ctx).await {
Ok(replies) => {
let name = self.def.reply_name();
match self
.publisher
.publish_batch(name, &replies, &self.pipeline, ctx.extensions())
.await
{
Ok(()) => HandlerResult::Ack,
Err(err) => {
warn!(
target: "ruststream::dispatch",
subscription = %subscription,
reply = %name,
reply_type = std::any::type_name::<D::Reply>(),
error = %err,
"batch reply publish failed",
);
HandlerResult::retry()
}
}
}
Err(result) => result,
};
for msg in accepted {
settle(msg, outcome, &subscription).await;
}
}
}
#[cfg(all(test, feature = "memory", feature = "json"))]
mod tests {
use futures::StreamExt;
use super::super::context::State;
use super::super::dispatch::Delivery;
use super::super::publish::TypedPublisher;
use super::*;
use crate::codec::JsonCodec;
use crate::memory::{MemoryBroker, MemoryMessage, MemorySubscriber};
use crate::{BatchSubscriber, Headers, OutgoingMessage, Publisher, Subscriber};
struct Confirm {
reply_to: &'static str,
fail_with: Option<HandlerResult>,
}
impl BatchPublishingDef for Confirm {
type Input = u32;
type Reply = u32;
type Source = crate::Name;
fn source(&self) -> Self::Source {
crate::Name::new("orders")
}
fn reply_name(&self) -> &str {
self.reply_to
}
async fn call(
&self,
batch: &[u32],
_ctx: &mut Context<'_>,
) -> Result<Vec<u32>, HandlerResult> {
if let Some(result) = self.fail_with {
return Err(result);
}
Ok(batch.iter().map(|n| n * 10).collect())
}
}
async fn publish_numbers(broker: &MemoryBroker, name: &str, numbers: &[u32]) {
let publisher = broker.publisher();
for n in numbers {
publisher
.publish(OutgoingMessage::new(name, &serde_json::to_vec(n).unwrap()))
.await
.unwrap();
}
}
async fn pull_batch(sub: &mut MemorySubscriber) -> Vec<MemoryMessage> {
let mut stream = std::pin::pin!(sub.batches());
stream.next().await.unwrap().unwrap()
}
#[tokio::test]
async fn transactional_replies_publish_atomically_then_ack() {
let broker = MemoryBroker::new();
let mut input = broker.subscribe("orders");
let mut replies = broker.subscribe("confirmations");
let handler = BatchPublishingHandler {
def: Confirm {
reply_to: "confirmations",
fail_with: None,
},
codec: JsonCodec,
publisher: TypedPublisher::with_codec(broker.publisher(), JsonCodec).transactional(),
pipeline: Arc::from([]),
decode: FailurePolicy::Drop,
};
publish_numbers(&broker, "orders", &[1, 2]).await;
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("orders", &headers, &state, &delivery);
let batch = pull_batch(&mut input).await;
handler.handle_batch(batch, &mut ctx).await;
let confirmed = pull_batch(&mut replies).await;
let payloads: Vec<&[u8]> = confirmed.iter().map(IncomingMessage::payload).collect();
assert_eq!(payloads, [b"10", b"20"]);
for msg in confirmed {
msg.ack().await.unwrap();
}
let mut stream = std::pin::pin!(input.stream());
assert!(futures::poll!(stream.next()).is_pending());
}
#[tokio::test]
async fn handler_error_publishes_nothing_and_settles_the_batch() {
let broker = MemoryBroker::new();
let mut input = broker.subscribe("orders");
let mut replies = broker.subscribe("confirmations");
let handler = BatchPublishingHandler {
def: Confirm {
reply_to: "confirmations",
fail_with: Some(HandlerResult::retry()),
},
codec: JsonCodec,
publisher: TypedPublisher::with_codec(broker.publisher(), JsonCodec).transactional(),
pipeline: Arc::from([]),
decode: FailurePolicy::Drop,
};
publish_numbers(&broker, "orders", &[1, 2]).await;
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("orders", &headers, &state, &delivery);
let batch = pull_batch(&mut input).await;
handler.handle_batch(batch, &mut ctx).await;
let mut reply_stream = std::pin::pin!(replies.stream());
assert!(futures::poll!(reply_stream.next()).is_pending());
let redelivered = pull_batch(&mut input).await;
assert_eq!(redelivered.len(), 2);
for msg in redelivered {
msg.ack().await.unwrap();
}
}
}