use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio::task::{JoinHandle, JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::{BatchSubscriber, Headers, IncomingMessage, Subscriber};
use super::batch::BatchHandler;
use super::context::{Context, State};
use super::handler::{Handler, HandlerResult};
use super::publish::PublishMiddleware;
use super::publisher_registry::ErasedPublisher;
pub(crate) type Publishers = HashMap<String, Arc<dyn ErasedPublisher>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Workers {
count: usize,
by_key: bool,
}
impl Workers {
#[must_use]
pub const fn sequential() -> Self {
Self {
count: 1,
by_key: false,
}
}
#[must_use]
pub const fn pool(count: usize) -> Self {
Self {
count: if count == 0 { 1 } else { count },
by_key: false,
}
}
#[must_use]
pub const fn keyed(count: usize) -> Self {
Self {
count: if count == 0 { 1 } else { count },
by_key: true,
}
}
pub(crate) const fn is_sequential(&self) -> bool {
self.count <= 1
}
}
impl Default for Workers {
fn default() -> Self {
Self::sequential()
}
}
pub(crate) struct Delivery {
pub(crate) publishers: Publishers,
pub(crate) pipeline: Arc<[Arc<dyn PublishMiddleware>]>,
}
impl Delivery {
#[cfg(test)]
pub(crate) fn empty() -> Self {
Self {
publishers: HashMap::new(),
pipeline: Arc::from([]),
}
}
}
impl std::fmt::Debug for Delivery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Delivery")
.field("publishers", &self.publishers.len())
.field("layers", &self.pipeline.len())
.finish_non_exhaustive()
}
}
pub(crate) fn spawn_dispatch<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let mut stream = std::pin::pin!(subscriber.stream());
loop {
tokio::select! {
() = shutdown.cancelled() => break,
next = stream.next() => match next {
Some(Ok(msg)) => dispatch(&*handler, msg, &name, &state, &delivery).await,
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
})
}
pub(crate) fn spawn_dispatch_workers<S, H>(
subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
workers: Workers,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
S::Message: Send + Sync + 'static,
H: Handler<S::Message> + 'static,
{
if workers.is_sequential() {
return spawn_dispatch(subscriber, handler, shutdown, name, state, delivery);
}
if workers.by_key {
spawn_dispatch_lanes(
subscriber, handler, shutdown, name, state, delivery, workers,
)
} else {
spawn_dispatch_pool(
subscriber, handler, shutdown, name, state, delivery, workers,
)
}
}
fn spawn_dispatch_pool<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
workers: Workers,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
S::Message: Send + Sync + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let mut stream = std::pin::pin!(subscriber.stream());
let mut tasks = JoinSet::new();
loop {
tokio::select! {
() = shutdown.cancelled() => break,
Some(joined) = tasks.join_next(), if tasks.len() >= workers.count => {
log_worker_exit(joined);
}
next = stream.next(), if tasks.len() < workers.count => match next {
Some(Ok(msg)) => {
let handler = Arc::clone(&handler);
let name = Arc::clone(&name);
let state = Arc::clone(&state);
let delivery = Arc::clone(&delivery);
tasks.spawn(async move {
dispatch(&*handler, msg, &name, &state, &delivery).await;
});
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
while let Some(joined) = tasks.join_next().await {
log_worker_exit(joined);
}
})
}
fn spawn_dispatch_lanes<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
workers: Workers,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
S::Message: Send + Sync + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let mut lanes = Vec::with_capacity(workers.count);
let mut tasks = JoinSet::new();
for _ in 0..workers.count {
let (tx, mut rx) = mpsc::channel::<S::Message>(1);
let handler = Arc::clone(&handler);
let name = Arc::clone(&name);
let state = Arc::clone(&state);
let delivery = Arc::clone(&delivery);
tasks.spawn(async move {
while let Some(msg) = rx.recv().await {
dispatch(&*handler, msg, &name, &state, &delivery).await;
}
});
lanes.push(tx);
}
let mut stream = std::pin::pin!(subscriber.stream());
let mut unkeyed_rotation = 0usize;
loop {
tokio::select! {
() = shutdown.cancelled() => break,
next = stream.next() => match next {
Some(Ok(msg)) => {
let lane = msg.partition_key().map_or_else(
|| {
unkeyed_rotation = (unkeyed_rotation + 1) % workers.count;
unkeyed_rotation
},
|key| lane_of(key, workers.count),
);
if lanes[lane].send(msg).await.is_err() {
error!(
target: "ruststream::dispatch",
subscriber = %name,
lane,
"worker lane terminated; stopping dispatch",
);
break;
}
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
drop(lanes);
while let Some(joined) = tasks.join_next().await {
log_worker_exit(joined);
}
})
}
fn lane_of(key: &[u8], lanes: usize) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
#[allow(clippy::cast_possible_truncation)]
{
(hasher.finish() % lanes as u64) as usize
}
}
fn log_worker_exit(joined: Result<(), tokio::task::JoinError>) {
if let Err(err) = joined {
error!(target: "ruststream::dispatch", error = %err, "worker task failed");
}
}
pub(crate) fn spawn_batch_dispatch<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
workers: Workers,
) -> JoinHandle<()>
where
S: BatchSubscriber + Send + 'static,
S::Message: Send + 'static,
H: BatchHandler<S::Message> + 'static,
{
tokio::spawn(async move {
let mut stream = std::pin::pin!(subscriber.batches());
let mut tasks = JoinSet::new();
loop {
tokio::select! {
() = shutdown.cancelled() => break,
Some(joined) = tasks.join_next(), if tasks.len() >= workers.count => {
log_worker_exit(joined);
}
next = stream.next(), if tasks.len() < workers.count => match next {
Some(Ok(batch)) => {
let batch: Vec<S::Message> = batch.into_iter().collect();
if workers.is_sequential() {
let empty = Headers::new();
let mut ctx = Context::new(&name, &empty, &state, &delivery);
handler.handle_batch(batch, &mut ctx).await;
} else {
let handler = Arc::clone(&handler);
let name = Arc::clone(&name);
let state = Arc::clone(&state);
let delivery = Arc::clone(&delivery);
tasks.spawn(async move {
let empty = Headers::new();
let mut ctx = Context::new(&name, &empty, &state, &delivery);
handler.handle_batch(batch, &mut ctx).await;
});
}
}
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
while let Some(joined) = tasks.join_next().await {
log_worker_exit(joined);
}
})
}
async fn dispatch<H, M>(handler: &H, msg: M, name: &str, state: &State, delivery: &Delivery)
where
H: Handler<M>,
M: IncomingMessage,
{
let mut ctx = Context::new(name, msg.headers(), state, delivery);
let outcome = handler.handle(&msg, &mut ctx).await;
let ack_result = match outcome {
HandlerResult::Ack => msg.ack().await,
HandlerResult::Nack { requeue } => msg.nack(requeue).await,
HandlerResult::NackAfter { delay } => msg.nack_after(delay).await,
};
if let Err(err) = ack_result {
warn!(
target: "ruststream::dispatch",
error = %err,
"ack / nack failed",
);
}
}