use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use futures::{FutureExt, StreamExt};
use tokio::sync::mpsc;
use tokio::task::{JoinHandle, JoinSet};
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{debug, error, warn};
use crate::{AckError, BatchSubscriber, Headers, IncomingMessage, Subscriber};
use super::batch::BatchHandler;
use super::context::{Context, State};
use super::failure::{DispatchFailure, FailurePolicy, panic_reason};
use super::handler::{Handler, HandlerResult};
use super::publish::PublishMiddleware;
use super::publisher_registry::ErasedPublisher;
pub(crate) type Publishers = HashMap<String, Arc<dyn ErasedPublisher>>;
pub const RETRY_COUNT_HEADER: &str = "x-ruststream-retry-count";
fn current_retry_count(headers: &Headers) -> u64 {
headers
.get_str(RETRY_COUNT_HEADER)
.and_then(|v| v.parse().ok())
.unwrap_or(0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Workers {
count: usize,
by_key: bool,
}
impl Workers {
#[must_use]
pub const fn sequential() -> Self {
Self {
count: 1,
by_key: false,
}
}
#[must_use]
pub const fn pool(count: usize) -> Self {
Self {
count: if count == 0 { 1 } else { count },
by_key: false,
}
}
#[must_use]
pub const fn keyed(count: usize) -> Self {
Self {
count: if count == 0 { 1 } else { count },
by_key: true,
}
}
pub(crate) const fn is_sequential(&self) -> bool {
self.count <= 1
}
}
impl Default for Workers {
fn default() -> Self {
Self::sequential()
}
}
pub(crate) struct Delivery {
pub(crate) publishers: Publishers,
pub(crate) pipeline: Arc<[Arc<dyn PublishMiddleware>]>,
pub(crate) retry_publisher: Option<Arc<dyn ErasedPublisher>>,
pub(crate) tasks: TaskTracker,
}
impl Delivery {
#[cfg(test)]
pub(crate) fn empty() -> Self {
Self::with_tasks(TaskTracker::new())
}
#[cfg(test)]
pub(crate) fn with_tasks(tasks: TaskTracker) -> Self {
Self {
publishers: HashMap::new(),
pipeline: Arc::from([]),
retry_publisher: None,
tasks,
}
}
}
impl std::fmt::Debug for Delivery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Delivery")
.field("publishers", &self.publishers.len())
.field("layers", &self.pipeline.len())
.field("retry_publisher", &self.retry_publisher.is_some())
.field("pending_continuations", &self.tasks.len())
.finish_non_exhaustive()
}
}
pub(crate) fn spawn_dispatch<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
failure: DispatchFailure,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let hooks = TaskTracker::new();
let mut stream = std::pin::pin!(subscriber.stream());
loop {
tokio::select! {
() = shutdown.cancelled() => break,
next = stream.next() => match next {
Some(Ok(msg)) => {
dispatch(&*handler, msg, &name, &state, &delivery, &hooks, &failure).await;
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
drain_hooks(hooks).await;
})
}
async fn drain_hooks(hooks: TaskTracker) {
hooks.close();
hooks.wait().await;
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn spawn_dispatch_workers<S, H>(
subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
failure: DispatchFailure,
workers: Workers,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
S::Message: Send + Sync + 'static,
H: Handler<S::Message> + 'static,
{
if workers.is_sequential() {
return spawn_dispatch(
subscriber, handler, shutdown, name, state, delivery, failure,
);
}
if workers.by_key {
spawn_dispatch_lanes(
subscriber, handler, shutdown, name, state, delivery, failure, workers,
)
} else {
spawn_dispatch_pool(
subscriber, handler, shutdown, name, state, delivery, failure, workers,
)
}
}
#[allow(clippy::too_many_arguments)] fn spawn_dispatch_pool<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
failure: DispatchFailure,
workers: Workers,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
S::Message: Send + Sync + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let hooks = TaskTracker::new();
let mut stream = std::pin::pin!(subscriber.stream());
let mut tasks = JoinSet::new();
loop {
tokio::select! {
() = shutdown.cancelled() => break,
Some(joined) = tasks.join_next(), if tasks.len() >= workers.count => {
log_worker_exit(joined);
}
next = stream.next(), if tasks.len() < workers.count => match next {
Some(Ok(msg)) => {
let handler = Arc::clone(&handler);
let name = Arc::clone(&name);
let state = Arc::clone(&state);
let delivery = Arc::clone(&delivery);
let hooks = hooks.clone();
let failure = failure.clone();
tasks.spawn(async move {
dispatch(&*handler, msg, &name, &state, &delivery, &hooks, &failure)
.await;
});
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
while let Some(joined) = tasks.join_next().await {
log_worker_exit(joined);
}
drain_hooks(hooks).await;
})
}
#[allow(clippy::too_many_arguments)] fn spawn_dispatch_lanes<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
failure: DispatchFailure,
workers: Workers,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
S::Message: Send + Sync + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let hooks = TaskTracker::new();
let mut lanes = Vec::with_capacity(workers.count);
let mut tasks = JoinSet::new();
for _ in 0..workers.count {
let (tx, mut rx) = mpsc::channel::<S::Message>(1);
let handler = Arc::clone(&handler);
let name = Arc::clone(&name);
let state = Arc::clone(&state);
let delivery = Arc::clone(&delivery);
let hooks = hooks.clone();
let failure = failure.clone();
tasks.spawn(async move {
while let Some(msg) = rx.recv().await {
dispatch(&*handler, msg, &name, &state, &delivery, &hooks, &failure).await;
}
});
lanes.push(tx);
}
let mut stream = std::pin::pin!(subscriber.stream());
let mut unkeyed_rotation = 0usize;
loop {
tokio::select! {
() = shutdown.cancelled() => break,
next = stream.next() => match next {
Some(Ok(msg)) => {
let lane = msg.partition_key().map_or_else(
|| {
unkeyed_rotation = (unkeyed_rotation + 1) % workers.count;
unkeyed_rotation
},
|key| lane_of(key, workers.count),
);
if lanes[lane].send(msg).await.is_err() {
error!(
target: "ruststream::dispatch",
subscriber = %name,
lane,
"worker lane terminated; stopping dispatch",
);
break;
}
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
drop(lanes);
while let Some(joined) = tasks.join_next().await {
log_worker_exit(joined);
}
drain_hooks(hooks).await;
})
}
fn lane_of(key: &[u8], lanes: usize) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
#[allow(clippy::cast_possible_truncation)]
{
(hasher.finish() % lanes as u64) as usize
}
}
fn log_worker_exit(joined: Result<(), tokio::task::JoinError>) {
if let Err(err) = joined {
error!(target: "ruststream::dispatch", error = %err, "worker task failed");
}
}
#[allow(clippy::too_many_arguments)] pub(crate) fn spawn_batch_dispatch<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
failure: DispatchFailure,
workers: Workers,
) -> JoinHandle<()>
where
S: BatchSubscriber + Send + 'static,
S::Message: Send + 'static,
H: BatchHandler<S::Message> + 'static,
{
tokio::spawn(async move {
let hooks = TaskTracker::new();
let mut stream = std::pin::pin!(subscriber.batches());
let mut tasks = JoinSet::new();
loop {
tokio::select! {
() = shutdown.cancelled() => break,
Some(joined) = tasks.join_next(), if tasks.len() >= workers.count => {
log_worker_exit(joined);
}
next = stream.next(), if tasks.len() < workers.count => match next {
Some(Ok(batch)) => {
let batch: Vec<S::Message> = batch.into_iter().collect();
if workers.is_sequential() {
run_batch(&*handler, batch, &name, &state, &delivery, &hooks, &failure)
.await;
} else {
let handler = Arc::clone(&handler);
let name = Arc::clone(&name);
let state = Arc::clone(&state);
let delivery = Arc::clone(&delivery);
let hooks = hooks.clone();
let failure = failure.clone();
tasks.spawn(async move {
run_batch(
&*handler, batch, &name, &state, &delivery, &hooks, &failure,
)
.await;
});
}
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
while let Some(joined) = tasks.join_next().await {
log_worker_exit(joined);
}
drain_hooks(hooks).await;
})
}
#[allow(clippy::too_many_arguments)] async fn dispatch<H, M>(
handler: &H,
msg: M,
name: &str,
state: &State,
delivery: &Delivery,
hooks: &TaskTracker,
failure: &DispatchFailure,
) where
H: Handler<M>,
M: IncomingMessage,
{
let extensions = msg.extensions();
let mut ctx = Context::with_extensions(name, msg.headers(), state, extensions, delivery)
.with_failfast(&failure.shutdown);
let result = AssertUnwindSafe(handler.handle(&msg, &mut ctx))
.catch_unwind()
.await;
let settle = match result {
Ok(s) => Some(s),
Err(payload) => {
let reason = panic_reason(payload.as_ref());
error!(
target: "ruststream::dispatch",
subscription = %name,
panic = %reason,
"handler panicked",
);
match failure.policies.panic {
FailurePolicy::FailFast => {
failure
.shutdown
.signal(name, &format!("handler panicked: {reason}"));
None
}
other => Some(
other
.settlement()
.unwrap_or_else(HandlerResult::drop)
.into(),
),
}
}
};
let continuations = settle
.as_ref()
.map_or_else(Vec::new, |s| ctx.take_hooks_for(s.outcome()));
drop(ctx);
if let Some(mut s) = settle {
settle_outcome(msg, s.outcome(), name, delivery).await;
if let Some(after) = s.take_after() {
delivery.tasks.spawn(after);
}
}
for fut in continuations {
hooks.spawn(fut);
}
}
#[allow(clippy::too_many_arguments)] async fn run_batch<H, M>(
handler: &H,
batch: Vec<M>,
name: &str,
state: &State,
delivery: &Delivery,
hooks: &TaskTracker,
failure: &DispatchFailure,
) where
H: BatchHandler<M>,
M: IncomingMessage,
{
let empty = Headers::new();
let mut ctx = Context::new(name, &empty, state, delivery).with_failfast(&failure.shutdown);
let result = AssertUnwindSafe(handler.handle_batch(batch, &mut ctx))
.catch_unwind()
.await;
match result {
Ok(()) => {
for fut in ctx.take_settle_hooks() {
hooks.spawn(fut);
}
}
Err(payload) => {
let reason = panic_reason(payload.as_ref());
error!(
target: "ruststream::dispatch",
subscription = %name,
panic = %reason,
"batch handler panicked",
);
if failure.policies.panic == FailurePolicy::FailFast {
failure
.shutdown
.signal(name, &format!("batch handler panicked: {reason}"));
}
}
}
}
async fn settle_outcome<M: IncomingMessage>(
msg: M,
outcome: HandlerResult,
name: &str,
delivery: &Delivery,
) {
let ack_result = match outcome {
HandlerResult::Ack => msg.ack().await,
HandlerResult::Nack { requeue } => msg.nack(requeue).await,
HandlerResult::NackAfter { delay } => settle_nack_after(msg, name, delay, delivery).await,
};
if let Err(err) = ack_result {
warn!(
target: "ruststream::dispatch",
subscription = %name,
error = %err,
"ack / nack failed",
);
}
}
async fn settle_nack_after<M>(
msg: M,
name: &str,
delay: Duration,
delivery: &Delivery,
) -> Result<(), AckError>
where
M: IncomingMessage,
{
if msg.supports_nack_after() {
return msg.nack_after(delay).await;
}
let Some(publisher) = delivery.retry_publisher.clone() else {
warn!(
target: "ruststream::dispatch",
subscription = %name,
"retry_after on a broker without native delayed redelivery and no retry publisher \
configured; requeuing immediately (the delay is dropped)",
);
return msg.nack(true).await;
};
let payload = Bytes::copy_from_slice(msg.payload());
let mut headers = msg.headers().clone();
let next_count = current_retry_count(&headers) + 1;
headers.insert(RETRY_COUNT_HEADER, next_count.to_string());
let subject = name.to_owned();
msg.nack(false).await?;
tokio::spawn(async move {
tokio::time::sleep(delay).await;
if let Err(err) = publisher
.publish_message(&subject, &payload, &headers)
.await
{
warn!(
target: "ruststream::dispatch",
subscription = %subject,
error = %err,
"deferred retry_after re-publish failed; message lost",
);
}
});
Ok(())
}
#[cfg(all(test, feature = "memory"))]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicU8, Ordering},
};
use futures::StreamExt;
use super::*;
use crate::memory::MemoryBroker;
use crate::{AckError, Headers, IncomingMessage, OutgoingMessage, Publisher};
struct PlainMessage {
payload: Bytes,
headers: Headers,
settled: Arc<AtomicU8>,
}
impl IncomingMessage for PlainMessage {
fn payload(&self) -> &[u8] {
&self.payload
}
fn headers(&self) -> &Headers {
&self.headers
}
async fn ack(self) -> Result<(), AckError> {
Ok(())
}
async fn nack(self, requeue: bool) -> Result<(), AckError> {
self.settled
.store(if requeue { 2 } else { 1 }, Ordering::SeqCst);
Ok(())
}
}
fn plain(name_headers: &[(&str, &str)], settled: &Arc<AtomicU8>) -> PlainMessage {
let mut headers = Headers::new();
for (k, v) in name_headers {
headers.insert((*k).to_owned(), Bytes::copy_from_slice(v.as_bytes()));
}
PlainMessage {
payload: Bytes::from_static(b"body"),
headers,
settled: Arc::clone(settled),
}
}
#[tokio::test(start_paused = true)]
async fn fallback_defers_republish_to_source_with_incremented_retry_count() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("orders");
let delivery = Delivery {
publishers: HashMap::new(),
pipeline: Arc::from([]),
retry_publisher: Some(Arc::new(broker.publisher())),
tasks: TaskTracker::new(),
};
let settled = Arc::new(AtomicU8::new(0));
let msg = plain(&[], &settled);
settle_nack_after(msg, "orders", Duration::from_secs(30), &delivery)
.await
.unwrap();
assert_eq!(settled.load(Ordering::SeqCst), 1);
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
tokio::time::advance(Duration::from_secs(30)).await;
tokio::task::yield_now().await;
let redelivered = stream.next().await.unwrap().unwrap();
assert_eq!(redelivered.payload(), b"body");
assert_eq!(
redelivered.headers().get_str(RETRY_COUNT_HEADER),
Some("1"),
"the first deferred republish must carry retry-count 1",
);
}
#[tokio::test(start_paused = true)]
async fn fallback_increments_an_existing_retry_count() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("orders");
let delivery = Delivery {
publishers: HashMap::new(),
pipeline: Arc::from([]),
retry_publisher: Some(Arc::new(broker.publisher())),
tasks: TaskTracker::new(),
};
let settled = Arc::new(AtomicU8::new(0));
let msg = plain(&[(RETRY_COUNT_HEADER, "4")], &settled);
settle_nack_after(msg, "orders", Duration::from_secs(1), &delivery)
.await
.unwrap();
tokio::time::advance(Duration::from_secs(1)).await;
tokio::task::yield_now().await;
let mut stream = std::pin::pin!(sub.stream());
let redelivered = stream.next().await.unwrap().unwrap();
assert_eq!(redelivered.headers().get_str(RETRY_COUNT_HEADER), Some("5"));
}
#[tokio::test]
async fn without_a_retry_publisher_the_fallback_requeues_immediately() {
let delivery = Delivery::empty();
let settled = Arc::new(AtomicU8::new(0));
let msg = plain(&[], &settled);
settle_nack_after(msg, "orders", Duration::from_secs(30), &delivery)
.await
.unwrap();
assert_eq!(settled.load(Ordering::SeqCst), 2);
}
#[tokio::test(start_paused = true)]
async fn native_support_defers_to_the_broker_nack_after() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("orders");
let publisher = broker.publisher();
publisher
.publish(OutgoingMessage::new("orders", b"native".as_slice()))
.await
.unwrap();
let other = MemoryBroker::new();
let delivery = Delivery {
publishers: HashMap::new(),
pipeline: Arc::from([]),
retry_publisher: Some(Arc::new(other.publisher())),
tasks: TaskTracker::new(),
};
let msg = {
let mut stream = std::pin::pin!(sub.stream());
stream.next().await.unwrap().unwrap()
};
assert!(msg.supports_nack_after());
settle_nack_after(msg, "orders", Duration::from_secs(5), &delivery)
.await
.unwrap();
tokio::time::advance(Duration::from_secs(5)).await;
tokio::task::yield_now().await;
let mut stream = std::pin::pin!(sub.stream());
let redelivered = stream.next().await.unwrap().unwrap();
assert_eq!(redelivered.payload(), b"native");
assert_eq!(redelivered.headers().get_str(RETRY_COUNT_HEADER), None);
}
}