strev 0.5.0

Event-driven pub/sub messaging library with compile-time ack safety
Documentation
use std::time::Duration;

use tokio::select;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use tracing::error;

use crate::decorator::{PublisherDecorator, SubscriberDecorator};
use crate::error::RouterError;
use crate::handler::Handler;
use crate::message::{Message, Pending};
use crate::middleware::Middleware;
use crate::publisher::Publisher;
use crate::subscriber::Subscriber;
use crate::topic::Topic;

pub enum ShutdownSignal {
    Token(CancellationToken),
    CtrlC,
}

pub struct RouterConfig {
    pub close_timeout: Duration,
}

impl Default for RouterConfig {
    fn default() -> Self {
        Self {
            close_timeout: Duration::from_secs(30),
        }
    }
}

pub struct Router {
    config: RouterConfig,
    handlers: Vec<HandlerRegistration>,
    middlewares: Vec<Box<dyn Middleware>>,
    publisher_decorators: Vec<Box<dyn PublisherDecorator>>,
    subscriber_decorators: Vec<Box<dyn SubscriberDecorator>>,
}

struct HandlerRegistration {
    name: String,
    subscribe_topic: Topic,
    handler: Box<dyn Handler>,
    subscriber: Box<dyn Subscriber>,
    publisher: Option<Box<dyn Publisher>>,
    middlewares: Vec<Box<dyn Middleware>>,
}

pub struct HandlerBuilder<'r> {
    router: &'r mut Router,
    index: usize,
}

impl<'r> HandlerBuilder<'r> {
    pub fn with_middleware(self, middleware: impl Middleware + 'static) -> Self {
        self.router.handlers[self.index]
            .middlewares
            .push(Box::new(middleware));
        self
    }
}

impl Router {
    pub fn new() -> Self {
        Self::with_config(RouterConfig::default())
    }

    pub fn with_config(config: RouterConfig) -> Self {
        Self {
            config,
            handlers: Vec::new(),
            middlewares: Vec::new(),
            publisher_decorators: Vec::new(),
            subscriber_decorators: Vec::new(),
        }
    }

    pub fn is_empty(&self) -> bool {
        self.handlers.is_empty()
    }

    pub fn handler_names(&self) -> Vec<&str> {
        self.handlers.iter().map(|h| h.name.as_str()).collect()
    }

    pub fn add_middleware(&mut self, middleware: impl Middleware + 'static) -> &mut Self {
        self.middlewares.push(Box::new(middleware));
        self
    }

    pub fn add_publisher_decorator(
        &mut self,
        decorator: impl PublisherDecorator + 'static,
    ) -> &mut Self {
        self.publisher_decorators.push(Box::new(decorator));
        self
    }

    pub fn add_subscriber_decorator(
        &mut self,
        decorator: impl SubscriberDecorator + 'static,
    ) -> &mut Self {
        self.subscriber_decorators.push(Box::new(decorator));
        self
    }

    fn assert_unique_name(&self, name: &str) {
        assert!(
            !self.handlers.iter().any(|h| h.name == name),
            "duplicate handler name: {name}"
        );
    }

    pub fn add_handler(
        &mut self,
        name: impl Into<String>,
        subscribe_topic: Topic,
        subscriber: impl Subscriber + 'static,
        publisher: impl Publisher + 'static,
        handler: impl Handler + 'static,
    ) -> HandlerBuilder<'_> {
        let name = name.into();
        self.assert_unique_name(&name);
        let index = self.handlers.len();
        self.handlers.push(HandlerRegistration {
            name,
            subscribe_topic,
            handler: Box::new(handler),
            subscriber: Box::new(subscriber),
            publisher: Some(Box::new(publisher)),
            middlewares: Vec::new(),
        });
        HandlerBuilder {
            router: self,
            index,
        }
    }

    pub fn add_consumer(
        &mut self,
        name: impl Into<String>,
        subscribe_topic: Topic,
        subscriber: impl Subscriber + 'static,
        handler: impl Handler + 'static,
    ) -> HandlerBuilder<'_> {
        let name = name.into();
        self.assert_unique_name(&name);
        let index = self.handlers.len();
        self.handlers.push(HandlerRegistration {
            name,
            subscribe_topic,
            handler: Box::new(handler),
            subscriber: Box::new(subscriber),
            publisher: None,
            middlewares: Vec::new(),
        });
        HandlerBuilder {
            router: self,
            index,
        }
    }

    pub async fn run(self, shutdown: ShutdownSignal) -> Result<(), RouterError> {
        let token = match shutdown {
            ShutdownSignal::Token(t) => t,
            ShutdownSignal::CtrlC => {
                let t = CancellationToken::new();
                let t2 = t.clone();
                tokio::spawn(async move {
                    let _ = tokio::signal::ctrl_c().await;
                    t2.cancel();
                });
                t
            }
        };

        let Self {
            config,
            handlers,
            middlewares,
            publisher_decorators,
            subscriber_decorators,
        } = self;
        let mut tasks = Vec::new();

        for mut reg in handlers {
            for dec in &subscriber_decorators {
                reg.subscriber = dec.decorate(reg.subscriber);
            }
            if let Some(pub_) = reg.publisher.take() {
                let mut decorated = pub_;
                for dec in &publisher_decorators {
                    decorated = dec.decorate(decorated);
                }
                reg.publisher = Some(decorated);
            }

            let mut stream = reg
                .subscriber
                .subscribe(&reg.subscribe_topic)
                .await
                .map_err(|source| RouterError::Subscribe {
                    handler: reg.name.clone(),
                    source,
                })?;

            let handler = Self::build_handler_chain(reg.handler, &middlewares, reg.middlewares);

            let name = reg.name;
            let publisher = reg.publisher;
            let cancel = token.clone();

            tasks.push(tokio::spawn(async move {
                loop {
                    select! {
                        _ = cancel.cancelled() => break,
                        maybe_msg = StreamExt::next(&mut stream) => {
                            match maybe_msg {
                                Some(msg) => {
                                    Self::process_message(
                                        &name,
                                        &*handler,
                                        msg,
                                        publisher.as_deref(),
                                    ).await;
                                }
                                None => break,
                            }
                        }
                    }
                }
            }));
        }

        let close_timeout = config.close_timeout;
        token.cancelled().await;

        let shutdown_result =
            tokio::time::timeout(close_timeout, futures::future::join_all(tasks)).await;

        if shutdown_result.is_err() {
            error!("router close timeout exceeded");
        }

        Ok(())
    }

    fn build_handler_chain(
        handler: Box<dyn Handler>,
        router_middlewares: &[Box<dyn Middleware>],
        handler_middlewares: Vec<Box<dyn Middleware>>,
    ) -> Box<dyn Handler> {
        let mut h = handler;
        for mw in handler_middlewares.into_iter().rev() {
            h = mw.wrap(h);
        }
        for mw in router_middlewares.iter().rev() {
            h = mw.wrap(h);
        }
        h
    }

    async fn process_message(
        handler_name: &str,
        handler: &dyn Handler,
        msg: Message<Pending>,
        publisher: Option<&dyn Publisher>,
    ) {
        match handler.handle(msg).await {
            Ok(result) => {
                if !result.produced().is_empty()
                    && let Some(pub_) = publisher
                {
                    let mut by_topic: std::collections::HashMap<&Topic, Vec<Message<Pending>>> =
                        std::collections::HashMap::new();

                    for pm in result.produced() {
                        by_topic
                            .entry(&pm.topic)
                            .or_default()
                            .push(Message::with_metadata(
                                pm.payload.clone(),
                                pm.metadata.clone(),
                            ));
                    }

                    for (topic, messages) in by_topic {
                        if let Err(e) = pub_.publish(topic, messages).await {
                            error!(handler = handler_name, topic = %topic, error = %e, "failed to publish produced messages");
                        }
                    }
                }
            }
            Err(e) => {
                error!(handler = handler_name, error = %e, "handler error");
            }
        }
    }
}

impl Default for Router {
    fn default() -> Self {
        Self::new()
    }
}