use std::sync::Arc;
use crate::{IncomingMessage, Subscriber};
use futures::StreamExt;
use thiserror::Error;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{error, warn};
use super::{
handler::{Handler, HandlerResult},
metadata::HandlerMetadata,
};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RouterError {
#[error("subscriber task failed: {0}")]
Join(#[source] tokio::task::JoinError),
}
pub struct Router {
tasks: Vec<JoinHandle<()>>,
handlers: Vec<HandlerMetadata>,
shutdown: CancellationToken,
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for Router {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Router")
.field("handlers", &self.handlers.len())
.field(
"tasks_running",
&self.tasks.iter().filter(|t| !t.is_finished()).count(),
)
.finish_non_exhaustive()
}
}
impl Router {
#[must_use]
pub fn new() -> Self {
Self {
tasks: Vec::new(),
handlers: Vec::new(),
shutdown: CancellationToken::new(),
}
}
pub fn handle<S, H>(&mut self, mut subscriber: S, handler: H, metadata: HandlerMetadata)
where
S: Subscriber + Send + 'static,
H: Handler<S::Message> + 'static,
{
self.handlers.push(metadata);
let shutdown = self.shutdown.clone();
let handler = Arc::new(handler);
let task = tokio::spawn(async move {
let mut stream = std::pin::pin!(subscriber.stream());
loop {
tokio::select! {
() = shutdown.cancelled() => break,
next = stream.next() => match next {
Some(Ok(msg)) => dispatch(&*handler, msg).await,
Some(Err(err)) => {
error!(
target: "ruststream::dispatch",
error = %err,
"subscriber stream error",
);
}
None => break,
}
}
}
});
self.tasks.push(task);
}
#[must_use]
pub fn shutdown_handle(&self) -> CancellationToken {
self.shutdown.clone()
}
pub fn shutdown(&self) {
self.shutdown.cancel();
}
#[must_use]
pub fn handlers(&self) -> &[HandlerMetadata] {
&self.handlers
}
pub async fn run(self) -> Result<(), RouterError> {
let results = futures::future::join_all(self.tasks).await;
for result in results {
result.map_err(RouterError::Join)?;
}
Ok(())
}
}
async fn dispatch<H, M>(handler: &H, msg: M)
where
H: Handler<M>,
M: IncomingMessage,
{
let outcome = handler.handle(&msg).await;
let ack_result = match outcome {
HandlerResult::Ack => msg.ack().await,
HandlerResult::Nack { requeue } => msg.nack(requeue).await,
};
if let Err(err) = ack_result {
warn!(
target: "ruststream::dispatch",
error = %err,
"ack / nack failed",
);
}
}