use std::collections::HashMap;
use std::sync::Arc;
use futures::StreamExt;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::{IncomingMessage, Subscriber};
use super::context::{Context, State};
use super::handler::{Handler, HandlerResult};
use super::publish::PublishMiddleware;
use super::publisher_registry::ErasedPublisher;
pub(crate) type Publishers = HashMap<String, Arc<dyn ErasedPublisher>>;
pub(crate) struct Delivery {
pub(crate) publishers: Publishers,
pub(crate) pipeline: Arc<[Arc<dyn PublishMiddleware>]>,
}
impl Delivery {
#[cfg(test)]
pub(crate) fn empty() -> Self {
Self {
publishers: HashMap::new(),
pipeline: Arc::from([]),
}
}
}
impl std::fmt::Debug for Delivery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Delivery")
.field("publishers", &self.publishers.len())
.field("layers", &self.pipeline.len())
.finish_non_exhaustive()
}
}
pub(crate) fn spawn_dispatch<S, H>(
mut subscriber: S,
handler: Arc<H>,
shutdown: CancellationToken,
name: Arc<str>,
state: Arc<State>,
delivery: Arc<Delivery>,
) -> JoinHandle<()>
where
S: Subscriber + Send + 'static,
H: Handler<S::Message> + 'static,
{
tokio::spawn(async move {
let mut stream = std::pin::pin!(subscriber.stream());
loop {
tokio::select! {
() = shutdown.cancelled() => break,
next = stream.next() => match next {
Some(Ok(msg)) => dispatch(&*handler, msg, &name, &state, &delivery).await,
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => {
debug!(
target: "ruststream::dispatch",
subscriber = %name,
"subscriber stream ended",
);
break;
}
}
}
}
})
}
async fn dispatch<H, M>(handler: &H, msg: M, name: &str, state: &State, delivery: &Delivery)
where
H: Handler<M>,
M: IncomingMessage,
{
let mut ctx = Context::new(name, msg.headers().clone(), state, delivery);
let outcome = handler.handle(&msg, &mut ctx).await;
let ack_result = match outcome {
HandlerResult::Ack => msg.ack().await,
HandlerResult::Nack { requeue } => msg.nack(requeue).await,
};
if let Err(err) = ack_result {
warn!(
target: "ruststream::dispatch",
error = %err,
"ack / nack failed",
);
}
}