use std::{future::Future, marker::PhantomData};
use serde::de::DeserializeOwned;
use tracing::{error, warn};
use crate::IncomingMessage;
use crate::codec::Codec;
use super::context::Context;
use super::dispatch::Workers;
use super::handler::HandlerResult;
use super::metadata::HandlerMetadata;
use super::typed::DecodeFailure;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum BatchResult {
Uniform(HandlerResult),
PerElement(Vec<HandlerResult>),
}
pub trait IntoBatchResult {
fn into_batch_result(self) -> BatchResult;
}
impl IntoBatchResult for BatchResult {
fn into_batch_result(self) -> BatchResult {
self
}
}
impl IntoBatchResult for HandlerResult {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(self)
}
}
impl IntoBatchResult for () {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(HandlerResult::Ack)
}
}
impl<E> IntoBatchResult for Result<(), E> {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(match self {
Ok(()) => HandlerResult::Ack,
Err(_) => HandlerResult::drop(),
})
}
}
impl<E> IntoBatchResult for Result<HandlerResult, E> {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(self.unwrap_or_else(|_| HandlerResult::drop()))
}
}
impl IntoBatchResult for Vec<HandlerResult> {
fn into_batch_result(self) -> BatchResult {
BatchResult::PerElement(self)
}
}
pub trait SliceHandler<T>: Send + Sync {
fn handle_slice(
&self,
batch: &[T],
ctx: &mut Context,
) -> impl Future<Output = BatchResult> + Send;
}
impl<T, F, Fut> SliceHandler<T> for F
where
F: Fn(&[T], &mut Context) -> Fut + Send + Sync,
Fut: Future + Send,
Fut::Output: IntoBatchResult,
{
fn handle_slice(
&self,
batch: &[T],
ctx: &mut Context,
) -> impl Future<Output = BatchResult> + Send {
let fut = (self)(batch, ctx);
async move { fut.await.into_batch_result() }
}
}
pub trait BatchDef: Sized {
type Input;
type Handler: SliceHandler<Self::Input>;
type Source;
fn source(&self) -> Self::Source;
fn workers(&self) -> Workers {
Workers::sequential()
}
fn description(&self) -> Option<&str> {
None
}
fn input_schema(&self) -> Option<String> {
None
}
fn message_name(&self) -> Option<&'static str> {
None
}
fn message_description(&self) -> Option<&'static str> {
None
}
fn into_handler(self) -> Self::Handler;
}
pub(crate) fn batch_metadata<D: BatchDef>(name: String, def: &D) -> HandlerMetadata {
HandlerMetadata::typed::<D::Input>(name).with_def_details(
def.description(),
def.input_schema(),
def.message_name(),
def.message_description(),
)
}
pub(crate) trait BatchHandler<M>: Send + Sync {
fn handle_batch(&self, batch: Vec<M>, ctx: &mut Context) -> impl Future<Output = ()> + Send;
}
pub(crate) fn typed_batch<M, T, C, H>(codec: C, inner: H) -> TypedBatch<M, T, C, H>
where
M: IncomingMessage,
T: DeserializeOwned + Send + Sync,
C: Codec,
H: SliceHandler<T>,
{
TypedBatch {
codec,
inner,
on_decode_failure: DecodeFailure::default(),
_phantom: PhantomData,
}
}
pub struct TypedBatch<M, T, C, H> {
codec: C,
inner: H,
on_decode_failure: DecodeFailure,
_phantom: PhantomData<fn(M, T)>,
}
impl<M, T, C, H> TypedBatch<M, T, C, H> {
#[must_use]
#[allow(dead_code)] pub(crate) fn on_decode_failure(mut self, mode: DecodeFailure) -> Self {
self.on_decode_failure = mode;
self
}
}
impl<M, T, C, H> std::fmt::Debug for TypedBatch<M, T, C, H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TypedBatch")
.field("on_decode_failure", &self.on_decode_failure)
.finish_non_exhaustive()
}
}
impl<M, T, C, H> BatchHandler<M> for TypedBatch<M, T, C, H>
where
M: IncomingMessage,
T: DeserializeOwned + Send + Sync,
C: Codec,
H: SliceHandler<T>,
{
async fn handle_batch(&self, batch: Vec<M>, ctx: &mut Context<'_>) {
let (values, accepted) = decode_batch(batch, &self.codec, self.on_decode_failure).await;
if accepted.is_empty() {
return;
}
match self.inner.handle_slice(&values, ctx).await {
BatchResult::Uniform(result) => {
for msg in accepted {
settle(msg, result).await;
}
}
BatchResult::PerElement(results) => {
if results.len() != accepted.len() {
error!(
target: "ruststream::dispatch",
expected = accepted.len(),
returned = results.len(),
"per-element outcome count does not match the batch; \
retrying the unmatched remainder",
);
}
let mut results = results.into_iter();
for msg in accepted {
let result = results.next().unwrap_or_else(HandlerResult::retry);
settle(msg, result).await;
}
}
}
}
}
pub(crate) async fn decode_batch<M, T, C>(
batch: Vec<M>,
codec: &C,
on_decode_failure: DecodeFailure,
) -> (Vec<T>, Vec<M>)
where
M: IncomingMessage,
T: DeserializeOwned,
C: Codec,
{
let mut values = Vec::with_capacity(batch.len());
let mut accepted = Vec::with_capacity(batch.len());
for msg in batch {
match codec.decode::<T>(msg.payload()) {
Ok(value) => {
values.push(value);
accepted.push(msg);
}
Err(err) => {
warn!(
target: "ruststream::dispatch",
error = %err,
"codec decode failed",
);
let requeue = matches!(on_decode_failure, DecodeFailure::Requeue);
if let Err(err) = msg.nack(requeue).await {
warn!(target: "ruststream::dispatch", error = %err, "nack failed");
}
}
}
}
(values, accepted)
}
pub(crate) async fn settle<M: IncomingMessage>(msg: M, result: HandlerResult) {
let ack_result = match result {
HandlerResult::Ack => msg.ack().await,
HandlerResult::Nack { requeue } => msg.nack(requeue).await,
HandlerResult::NackAfter { delay } => msg.nack_after(delay).await,
};
if let Err(err) = ack_result {
warn!(target: "ruststream::dispatch", error = %err, "ack / nack failed");
}
}
#[cfg(all(test, feature = "memory", feature = "json"))]
mod tests {
use futures::StreamExt;
use super::super::context::State;
use super::super::dispatch::Delivery;
use super::*;
use crate::codec::JsonCodec;
use crate::memory::{MemoryBroker, MemoryMessage, MemorySubscriber};
use crate::{BatchSubscriber, Headers, OutgoingMessage, Publisher, Subscriber};
async fn publish_numbers(broker: &MemoryBroker, name: &str, numbers: &[u32]) {
let publisher = broker.publisher();
for n in numbers {
publisher
.publish(OutgoingMessage::new(name, &serde_json::to_vec(n).unwrap()))
.await
.unwrap();
}
}
async fn pull_batch(sub: &mut MemorySubscriber) -> Vec<MemoryMessage> {
let mut stream = std::pin::pin!(sub.batches());
stream.next().await.unwrap().unwrap()
}
#[tokio::test]
async fn per_element_outcomes_settle_individually() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("selective");
publish_numbers(&broker, "selective", &[0, 1, 2]).await;
let handler = typed_batch(JsonCodec, |batch: &[u32], _ctx: &mut Context| {
let outcomes: Vec<HandlerResult> = batch
.iter()
.map(|n| match n {
1 => HandlerResult::retry(),
2 => HandlerResult::drop(),
_ => HandlerResult::Ack,
})
.collect();
async move { outcomes }
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("selective", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
assert_eq!(batch.len(), 3);
handler.handle_batch(batch, &mut ctx).await;
let redelivered = pull_batch(&mut sub).await;
let payloads: Vec<&[u8]> = redelivered.iter().map(IncomingMessage::payload).collect();
assert_eq!(payloads, [b"1"]);
for msg in redelivered {
msg.ack().await.unwrap();
}
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
}
#[tokio::test]
async fn unmatched_remainder_is_retried() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("short");
publish_numbers(&broker, "short", &[0, 1, 2]).await;
let handler = typed_batch(JsonCodec, |_batch: &[u32], _ctx: &mut Context| async {
vec![HandlerResult::Ack]
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("short", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
assert_eq!(batch.len(), 3);
handler.handle_batch(batch, &mut ctx).await;
let redelivered = pull_batch(&mut sub).await;
let payloads: Vec<&[u8]> = redelivered.iter().map(IncomingMessage::payload).collect();
assert_eq!(payloads, [b"1", b"2"]);
for msg in redelivered {
msg.ack().await.unwrap();
}
}
#[tokio::test(start_paused = true)]
async fn per_element_outcomes_carry_delays() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("delayed");
publish_numbers(&broker, "delayed", &[0, 1]).await;
let handler = typed_batch(JsonCodec, |batch: &[u32], _ctx: &mut Context| {
let outcomes: Vec<HandlerResult> = batch
.iter()
.map(|n| match n {
1 => HandlerResult::retry_after(std::time::Duration::from_secs(5)),
_ => HandlerResult::Ack,
})
.collect();
async move { outcomes }
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("delayed", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
handler.handle_batch(batch, &mut ctx).await;
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
tokio::time::advance(std::time::Duration::from_secs(5)).await;
tokio::task::yield_now().await;
let redelivered = stream.next().await.unwrap().unwrap();
assert_eq!(redelivered.payload(), b"1");
redelivered.ack().await.unwrap();
}
#[tokio::test]
async fn uniform_outcome_settles_the_whole_batch() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("uniform");
publish_numbers(&broker, "uniform", &[0, 1]).await;
let handler = typed_batch(JsonCodec, |_batch: &[u32], _ctx: &mut Context| async {
HandlerResult::retry()
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("uniform", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
assert_eq!(batch.len(), 2);
handler.handle_batch(batch, &mut ctx).await;
let redelivered = pull_batch(&mut sub).await;
assert_eq!(redelivered.len(), 2);
for msg in redelivered {
msg.ack().await.unwrap();
}
}
}