use std::{future::Future, marker::PhantomData};
use serde::de::DeserializeOwned;
use tracing::{error, warn};
use crate::IncomingMessage;
use crate::codec::Codec;
use super::context::Context;
use super::dispatch::Workers;
use super::failure::{FailurePolicies, FailurePolicy};
use super::handler::{HandlerResult, Settle};
use super::metadata::HandlerMetadata;
#[derive(Debug)]
#[non_exhaustive]
pub enum BatchResult {
Uniform(HandlerResult),
PerElement(Vec<Settle>),
}
pub trait IntoBatchResult {
fn into_batch_result(self) -> BatchResult;
}
impl IntoBatchResult for BatchResult {
fn into_batch_result(self) -> BatchResult {
self
}
}
impl IntoBatchResult for HandlerResult {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(self)
}
}
impl IntoBatchResult for () {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(HandlerResult::Ack)
}
}
impl<E> IntoBatchResult for Result<(), E> {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(match self {
Ok(()) => HandlerResult::Ack,
Err(_) => HandlerResult::drop(),
})
}
}
impl<E> IntoBatchResult for Result<HandlerResult, E> {
fn into_batch_result(self) -> BatchResult {
BatchResult::Uniform(self.unwrap_or_else(|_| HandlerResult::drop()))
}
}
impl IntoBatchResult for Vec<Settle> {
fn into_batch_result(self) -> BatchResult {
BatchResult::PerElement(self)
}
}
impl IntoBatchResult for Vec<HandlerResult> {
fn into_batch_result(self) -> BatchResult {
BatchResult::PerElement(self.into_iter().map(Settle::from).collect())
}
}
pub trait SliceHandler<T>: Send + Sync {
fn handle_slice(
&self,
batch: &[T],
ctx: &mut Context,
) -> impl Future<Output = BatchResult> + Send;
}
impl<T, F, Fut> SliceHandler<T> for F
where
F: Fn(&[T], &mut Context) -> Fut + Send + Sync,
Fut: Future + Send,
Fut::Output: IntoBatchResult,
{
fn handle_slice(
&self,
batch: &[T],
ctx: &mut Context,
) -> impl Future<Output = BatchResult> + Send {
let fut = (self)(batch, ctx);
async move { fut.await.into_batch_result() }
}
}
pub trait BatchDef: Sized {
type Input;
type Handler: SliceHandler<Self::Input>;
type Source;
fn source(&self) -> Self::Source;
fn workers(&self) -> Workers {
Workers::sequential()
}
fn failure_policies(&self) -> FailurePolicies {
FailurePolicies::default()
}
fn description(&self) -> Option<&str> {
None
}
fn input_schema(&self) -> Option<String> {
None
}
fn message_name(&self) -> Option<&'static str> {
None
}
fn message_description(&self) -> Option<&'static str> {
None
}
fn into_handler(self) -> Self::Handler;
}
pub(crate) fn batch_metadata<D: BatchDef>(name: String, def: &D) -> HandlerMetadata {
HandlerMetadata::typed::<D::Input>(name).with_def_details(
def.description(),
def.input_schema(),
def.message_name(),
def.message_description(),
)
}
pub(crate) trait BatchHandler<M>: Send + Sync {
fn handle_batch(&self, batch: Vec<M>, ctx: &mut Context) -> impl Future<Output = ()> + Send;
}
pub(crate) fn typed_batch<M, T, C, H>(codec: C, inner: H) -> TypedBatch<M, T, C, H>
where
M: IncomingMessage,
T: DeserializeOwned + Send + Sync,
C: Codec,
H: SliceHandler<T>,
{
TypedBatch {
codec,
inner,
decode: FailurePolicy::Drop,
_phantom: PhantomData,
}
}
pub struct TypedBatch<M, T, C, H> {
codec: C,
inner: H,
decode: FailurePolicy,
_phantom: PhantomData<fn(M, T)>,
}
impl<M, T, C, H> TypedBatch<M, T, C, H> {
#[must_use]
pub(crate) fn with_decode(mut self, decode: FailurePolicy) -> Self {
self.decode = decode;
self
}
}
impl<M, T, C, H> std::fmt::Debug for TypedBatch<M, T, C, H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TypedBatch")
.field("decode", &self.decode)
.finish_non_exhaustive()
}
}
impl<M, T, C, H> BatchHandler<M> for TypedBatch<M, T, C, H>
where
M: IncomingMessage,
T: DeserializeOwned + Send + Sync,
C: Codec,
H: SliceHandler<T>,
{
async fn handle_batch(&self, batch: Vec<M>, ctx: &mut Context<'_>) {
let subscription = ctx.name().to_owned();
let (values, accepted) = decode_batch(batch, &self.codec, self.decode, ctx).await;
if accepted.is_empty() {
return;
}
let tasks = ctx.tasks().clone();
match self.inner.handle_slice(&values, ctx).await {
BatchResult::Uniform(result) => {
for msg in accepted {
settle(msg, result, &subscription).await;
}
}
BatchResult::PerElement(results) => {
if results.len() != accepted.len() {
error!(
target: "ruststream::dispatch",
subscription = %subscription,
expected = accepted.len(),
returned = results.len(),
"per-element outcome count does not match the batch; \
retrying the unmatched remainder",
);
}
let mut results = results.into_iter();
for msg in accepted {
let mut result = results
.next()
.unwrap_or_else(|| HandlerResult::retry().into());
let after = result.take_after();
settle(msg, result.outcome(), &subscription).await;
if let Some(after) = after {
tasks.spawn(after);
}
}
}
}
}
}
#[allow(clippy::needless_pass_by_ref_mut)]
pub(crate) async fn decode_batch<M, T, C>(
batch: Vec<M>,
codec: &C,
decode: FailurePolicy,
ctx: &mut Context<'_>,
) -> (Vec<T>, Vec<M>)
where
M: IncomingMessage,
T: DeserializeOwned,
C: Codec,
{
let subscription = ctx.name().to_owned();
let mut values = Vec::with_capacity(batch.len());
let mut accepted = Vec::with_capacity(batch.len());
for msg in batch {
match codec.decode::<T>(msg.payload()) {
Ok(value) => {
values.push(value);
accepted.push(msg);
}
Err(err) => {
warn!(
target: "ruststream::dispatch",
subscription = %subscription,
message_type = std::any::type_name::<T>(),
error = %err,
"codec decode failed",
);
let outcome = match decode {
FailurePolicy::FailFast => {
ctx.fail_fast(&format!("batch decode failed: {err}"));
HandlerResult::drop()
}
other => other.settlement().unwrap_or_else(HandlerResult::drop),
};
settle(msg, outcome, &subscription).await;
}
}
}
(values, accepted)
}
pub(crate) async fn settle<M: IncomingMessage>(msg: M, result: HandlerResult, subscription: &str) {
let ack_result = match result {
HandlerResult::Ack => msg.ack().await,
HandlerResult::Nack { requeue } => msg.nack(requeue).await,
HandlerResult::NackAfter { delay } => msg.nack_after(delay).await,
};
if let Err(err) = ack_result {
warn!(
target: "ruststream::dispatch",
subscription = %subscription,
error = %err,
"ack / nack failed",
);
}
}
#[cfg(all(test, feature = "memory", feature = "json"))]
mod tests {
use futures::StreamExt;
use super::super::context::State;
use super::super::dispatch::Delivery;
use super::*;
use crate::codec::JsonCodec;
use crate::memory::{MemoryBroker, MemoryMessage, MemorySubscriber};
use crate::{BatchSubscriber, Headers, OutgoingMessage, Publisher, Subscriber};
async fn publish_numbers(broker: &MemoryBroker, name: &str, numbers: &[u32]) {
let publisher = broker.publisher();
for n in numbers {
publisher
.publish(OutgoingMessage::new(name, &serde_json::to_vec(n).unwrap()))
.await
.unwrap();
}
}
async fn pull_batch(sub: &mut MemorySubscriber) -> Vec<MemoryMessage> {
let mut stream = std::pin::pin!(sub.batches());
stream.next().await.unwrap().unwrap()
}
#[tokio::test]
async fn per_element_outcomes_settle_individually() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("selective");
publish_numbers(&broker, "selective", &[0, 1, 2]).await;
let handler = typed_batch(JsonCodec, |batch: &[u32], _ctx: &mut Context| {
let outcomes: Vec<HandlerResult> = batch
.iter()
.map(|n| match n {
1 => HandlerResult::retry(),
2 => HandlerResult::drop(),
_ => HandlerResult::Ack,
})
.collect();
async move { outcomes }
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("selective", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
assert_eq!(batch.len(), 3);
handler.handle_batch(batch, &mut ctx).await;
let redelivered = pull_batch(&mut sub).await;
let payloads: Vec<&[u8]> = redelivered.iter().map(IncomingMessage::payload).collect();
assert_eq!(payloads, [b"1"]);
for msg in redelivered {
msg.ack().await.unwrap();
}
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn per_element_continuations_run_after_settle() {
use std::sync::Arc;
use tokio::sync::Notify;
use tokio_util::task::TaskTracker;
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("after-batch");
publish_numbers(&broker, "after-batch", &[0, 1]).await;
let ran = Arc::new(Notify::new());
let signal = Arc::clone(&ran);
let handler = typed_batch(JsonCodec, move |batch: &[u32], _ctx: &mut Context| {
let signal = Arc::clone(&signal);
let outcomes: Vec<Settle> = batch
.iter()
.map(|n| {
if *n == 0 {
let signal = Arc::clone(&signal);
HandlerResult::ack().and_after(async move { signal.notify_one() })
} else {
HandlerResult::retry().into()
}
})
.collect();
async move { outcomes }
});
let tasks = TaskTracker::new();
let state = State::default();
let delivery = Delivery::with_tasks(tasks.clone());
let headers = Headers::new();
let mut ctx = Context::new("after-batch", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
handler.handle_batch(batch, &mut ctx).await;
ran.notified().await;
tasks.close();
tasks.wait().await;
let redelivered = pull_batch(&mut sub).await;
let payloads: Vec<&[u8]> = redelivered.iter().map(IncomingMessage::payload).collect();
assert_eq!(payloads, [b"1"]);
for msg in redelivered {
msg.ack().await.unwrap();
}
}
#[tokio::test]
async fn unmatched_remainder_is_retried() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("short");
publish_numbers(&broker, "short", &[0, 1, 2]).await;
let handler = typed_batch(JsonCodec, |_batch: &[u32], _ctx: &mut Context| async {
vec![HandlerResult::Ack]
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("short", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
assert_eq!(batch.len(), 3);
handler.handle_batch(batch, &mut ctx).await;
let redelivered = pull_batch(&mut sub).await;
let payloads: Vec<&[u8]> = redelivered.iter().map(IncomingMessage::payload).collect();
assert_eq!(payloads, [b"1", b"2"]);
for msg in redelivered {
msg.ack().await.unwrap();
}
}
#[tokio::test(start_paused = true)]
async fn per_element_outcomes_carry_delays() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("delayed");
publish_numbers(&broker, "delayed", &[0, 1]).await;
let handler = typed_batch(JsonCodec, |batch: &[u32], _ctx: &mut Context| {
let outcomes: Vec<HandlerResult> = batch
.iter()
.map(|n| match n {
1 => HandlerResult::retry_after(std::time::Duration::from_secs(5)),
_ => HandlerResult::Ack,
})
.collect();
async move { outcomes }
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("delayed", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
handler.handle_batch(batch, &mut ctx).await;
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
tokio::time::advance(std::time::Duration::from_secs(5)).await;
tokio::task::yield_now().await;
let redelivered = stream.next().await.unwrap().unwrap();
assert_eq!(redelivered.payload(), b"1");
redelivered.ack().await.unwrap();
}
#[tokio::test]
async fn uniform_outcome_settles_the_whole_batch() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("uniform");
publish_numbers(&broker, "uniform", &[0, 1]).await;
let handler = typed_batch(JsonCodec, |_batch: &[u32], _ctx: &mut Context| async {
HandlerResult::retry()
});
let state = State::default();
let delivery = Delivery::empty();
let headers = Headers::new();
let mut ctx = Context::new("uniform", &headers, &state, &delivery);
let batch = pull_batch(&mut sub).await;
assert_eq!(batch.len(), 2);
handler.handle_batch(batch, &mut ctx).await;
let redelivered = pull_batch(&mut sub).await;
assert_eq!(redelivered.len(), 2);
for msg in redelivered {
msg.ack().await.unwrap();
}
}
}