use std::{
collections::{BTreeMap, HashMap},
error::Error as StdError,
future::Future,
sync::Arc,
time::Duration,
};
use thiserror::Error;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::codec::Codec;
use crate::{Broker, Publisher, ServerSpec, Subscriber, SubscriptionSource};
use super::context::State;
use super::dispatch::{Delivery, Publishers};
use super::handler::Handler;
use super::lifecycle::{BoxError, BoxFuture, BrokerLifecycle};
use super::metadata::HandlerMetadata;
use super::middleware::{Identity, Layer, Stack};
use super::publish::{PublishLayer, PublishMiddleware, TypedPublisher};
use super::publisher_registry::ErasedPublisher;
use super::publishing::{PublishingDef, PublishingHandler};
use super::router::Router;
use super::subscriber_def::SubscriberDef;
use super::typed::{Typed, typed};
type Starter = Box<
dyn FnOnce(
Arc<State>,
CancellationToken,
) -> BoxFuture<'static, Result<JoinHandle<()>, BoxError>>
+ Send,
>;
type StartupHook = Box<dyn FnOnce(State) -> BoxFuture<'static, Result<State, BoxError>> + Send>;
type LifecycleHook = Box<dyn FnOnce(Arc<State>) -> BoxFuture<'static, Result<(), BoxError>> + Send>;
#[derive(Clone, Copy)]
enum LifecyclePhase {
AfterStartup,
OnShutdown,
AfterShutdown,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct AppInfo {
pub title: String,
pub version: String,
pub description: Option<String>,
}
impl AppInfo {
#[must_use]
pub fn new(title: impl Into<String>, version: impl Into<String>) -> Self {
Self {
title: title.into(),
version: version.into(),
description: None,
}
}
#[must_use]
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RustStreamError {
#[error("broker connect failed: {0}")]
Connect(#[source] BoxError),
#[error("startup hook failed: {0}")]
Startup(#[source] BoxError),
#[error("subscription failed: {0}")]
Subscribe(#[source] BoxError),
#[error("broker shutdown failed: {0}")]
Shutdown(#[source] BoxError),
#[error("dispatch task failed: {0}")]
Join(#[source] tokio::task::JoinError),
}
pub struct RustStream<L = Identity> {
info: AppInfo,
brokers: Vec<Arc<dyn BrokerLifecycle>>,
starters: Vec<Starter>,
handlers: Vec<HandlerMetadata>,
servers: BTreeMap<String, ServerSpec>,
publishers: Publishers,
publish_layers: Vec<Arc<dyn PublishMiddleware>>,
state: State,
on_startup: Vec<StartupHook>,
after_startup: Vec<LifecycleHook>,
on_shutdown: Vec<LifecycleHook>,
after_shutdown: Vec<LifecycleHook>,
shutdown_timeout: Option<Duration>,
global: L,
}
impl<L> std::fmt::Debug for RustStream<L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RustStream")
.field("info", &self.info)
.field("brokers", &self.brokers.len())
.field("handlers", &self.handlers.len())
.finish_non_exhaustive()
}
}
impl RustStream<Identity> {
#[must_use]
pub fn new(info: AppInfo) -> Self {
Self {
info,
brokers: Vec::new(),
starters: Vec::new(),
handlers: Vec::new(),
servers: BTreeMap::new(),
publishers: HashMap::new(),
publish_layers: Vec::new(),
state: State::default(),
on_startup: Vec::new(),
after_startup: Vec::new(),
on_shutdown: Vec::new(),
after_shutdown: Vec::new(),
shutdown_timeout: None,
global: Identity,
}
}
}
impl<L> RustStream<L> {
#[must_use]
pub fn layer<N>(self, layer: N) -> RustStream<Stack<N, L>> {
RustStream {
info: self.info,
brokers: self.brokers,
starters: self.starters,
handlers: self.handlers,
servers: self.servers,
publishers: self.publishers,
publish_layers: self.publish_layers,
state: self.state,
on_startup: self.on_startup,
after_startup: self.after_startup,
on_shutdown: self.on_shutdown,
after_shutdown: self.after_shutdown,
shutdown_timeout: self.shutdown_timeout,
global: Stack::new(layer, self.global),
}
}
#[must_use]
pub fn insert_state<T>(mut self, value: T) -> Self
where
T: std::any::Any + Send + Sync,
{
self.state.insert(value);
self
}
#[must_use]
pub fn on_startup<F, Fut, E>(mut self, hook: F) -> Self
where
F: FnOnce(State) -> Fut + Send + 'static,
Fut: Future<Output = Result<State, E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.on_startup.push(Box::new(move |state| {
Box::pin(async move { hook(state).await.map_err(|e| Box::new(e) as BoxError) })
}));
self
}
#[must_use]
pub fn after_startup<F, Fut, E>(self, hook: F) -> Self
where
F: FnOnce(Arc<State>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.push_lifecycle_hook(LifecyclePhase::AfterStartup, hook)
}
#[must_use]
pub fn on_shutdown<F, Fut, E>(self, hook: F) -> Self
where
F: FnOnce(Arc<State>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.push_lifecycle_hook(LifecyclePhase::OnShutdown, hook)
}
#[must_use]
pub fn after_shutdown<F, Fut, E>(self, hook: F) -> Self
where
F: FnOnce(Arc<State>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.push_lifecycle_hook(LifecyclePhase::AfterShutdown, hook)
}
fn push_lifecycle_hook<F, Fut, E>(mut self, phase: LifecyclePhase, hook: F) -> Self
where
F: FnOnce(Arc<State>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
let boxed: LifecycleHook = Box::new(move |state| {
Box::pin(async move { hook(state).await.map_err(|e| Box::new(e) as BoxError) })
});
match phase {
LifecyclePhase::AfterStartup => self.after_startup.push(boxed),
LifecyclePhase::OnShutdown => self.on_shutdown.push(boxed),
LifecyclePhase::AfterShutdown => self.after_shutdown.push(boxed),
}
self
}
#[must_use]
pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = Some(timeout);
self
}
#[must_use]
pub fn publisher<P>(mut self, name: impl Into<String>, publisher: P) -> Self
where
P: Publisher + 'static,
{
self.publishers.insert(name.into(), Arc::new(publisher));
self
}
#[must_use]
pub fn publish_layer<M>(mut self, middleware: M) -> Self
where
M: PublishMiddleware + 'static,
{
self.publish_layers.push(Arc::new(middleware));
self
}
#[must_use]
pub fn register_broker<B>(mut self, broker: B) -> Self
where
B: Broker + 'static,
{
self.brokers.push(Arc::new(broker));
self
}
#[must_use]
pub fn server(mut self, name: impl Into<String>, spec: ServerSpec) -> Self {
self.servers.insert(name.into(), spec);
self
}
#[must_use]
pub fn with_broker<B, F>(mut self, broker: B, build: F) -> Self
where
B: Broker + 'static,
L: Clone,
F: FnOnce(&mut BrokerScope<B, L>),
{
let broker = Arc::new(broker);
let mut scope = self.new_scope(&broker, ());
build(&mut scope);
self.collect_scope(&broker, scope);
self
}
#[must_use]
pub fn with_broker_codec<B, C, F>(mut self, broker: B, codec: C, build: F) -> Self
where
B: Broker + 'static,
C: Codec + Clone + 'static,
L: Clone,
F: FnOnce(&mut BrokerScope<B, L, C>),
{
let broker = Arc::new(broker);
let mut scope = self.new_scope(&broker, codec);
build(&mut scope);
self.collect_scope(&broker, scope);
self
}
fn new_scope<B, C>(&self, broker: &Arc<B>, codec: C) -> BrokerScope<B, L, C>
where
B: Broker + 'static,
L: Clone,
{
BrokerScope {
broker: broker.clone(),
router: Router::new(),
publishers: self.publishers.clone(),
pipeline: self.publish_layers.iter().cloned().collect(),
global: self.global.clone(),
codec,
}
}
fn collect_scope<B, C>(&mut self, broker: &Arc<B>, scope: BrokerScope<B, L, C>)
where
B: Broker + 'static,
{
let lifecycle: Arc<dyn BrokerLifecycle> = broker.clone();
let delivery = Arc::new(Delivery {
publishers: self.publishers.clone(),
pipeline: scope.pipeline.clone(),
});
let (starters, handlers) = scope.router.into_parts();
for (bound, meta) in starters.into_iter().zip(handlers) {
let broker = broker.clone();
let delivery = delivery.clone();
self.starters.push(Box::new(move |state, token| {
bound(broker, state, delivery, token)
}));
self.handlers.push(meta);
}
self.brokers.push(lifecycle);
}
#[must_use]
pub fn handlers(&self) -> &[HandlerMetadata] {
&self.handlers
}
#[must_use]
pub fn info(&self) -> &AppInfo {
&self.info
}
#[must_use]
pub fn servers(&self) -> &BTreeMap<String, ServerSpec> {
&self.servers
}
pub async fn run(self) -> Result<(), RustStreamError> {
self.run_until(wait_for_signal()).await
}
pub async fn run_until<F>(self, shutdown: F) -> Result<(), RustStreamError>
where
F: Future<Output = ()> + Send,
{
let Self {
info,
brokers,
starters,
handlers,
mut state,
on_startup,
after_startup,
on_shutdown,
after_shutdown,
shutdown_timeout,
..
} = self;
info!(
target: "ruststream::lifecycle",
service = %info.title,
version = %info.version,
brokers = brokers.len(),
subscribers = starters.len(),
"starting service",
);
if !on_startup.is_empty() {
debug!(target: "ruststream::lifecycle", count = on_startup.len(), "running on_startup hooks");
}
for hook in on_startup {
state = hook(state).await.map_err(RustStreamError::Startup)?;
}
let state = Arc::new(state);
for broker in &brokers {
broker.connect().await.map_err(RustStreamError::Connect)?;
info!(target: "ruststream::lifecycle", broker = broker.name(), "broker connected");
}
let token = CancellationToken::new();
let mut handles = Vec::with_capacity(starters.len());
for (starter, meta) in starters.into_iter().zip(handlers) {
let handle = starter(state.clone(), token.clone())
.await
.map_err(RustStreamError::Subscribe)?;
info!(
target: "ruststream::dispatch",
subscriber = %meta.name,
input = meta.input_type,
"subscriber started",
);
handles.push(handle);
}
if !after_startup.is_empty() {
debug!(target: "ruststream::lifecycle", count = after_startup.len(), "running after_startup hooks");
}
for hook in after_startup {
hook(Arc::clone(&state))
.await
.map_err(RustStreamError::Startup)?;
}
info!(target: "ruststream::lifecycle", subscribers = handles.len(), "service running");
shutdown.await;
info!(target: "ruststream::lifecycle", "shutdown signal received");
for hook in on_shutdown {
if let Err(err) = hook(Arc::clone(&state)).await {
warn!(target: "ruststream::lifecycle", error = %err, "on_shutdown hook failed");
}
}
token.cancel();
debug!(target: "ruststream::lifecycle", "draining in-flight handlers");
drain_handles(handles, shutdown_timeout).await?;
for broker in brokers.iter().rev() {
broker.shutdown().await.map_err(RustStreamError::Shutdown)?;
debug!(target: "ruststream::lifecycle", broker = broker.name(), "broker shut down");
}
for hook in after_shutdown {
if let Err(err) = hook(Arc::clone(&state)).await {
warn!(target: "ruststream::lifecycle", error = %err, "after_shutdown hook failed");
}
}
info!(target: "ruststream::lifecycle", "service stopped");
Ok(())
}
}
pub struct BrokerScope<B, L = Identity, C = ()> {
broker: Arc<B>,
router: Router<B>,
publishers: Publishers,
pipeline: Arc<[Arc<dyn PublishMiddleware>]>,
global: L,
codec: C,
}
impl<B: Broker + 'static, L, C> BrokerScope<B, L, C> {
#[must_use]
pub fn broker(&self) -> &B {
&self.broker
}
#[must_use]
pub fn publisher(&self, name: &str) -> Option<Arc<dyn ErasedPublisher>> {
self.publishers.get(name).cloned()
}
pub fn handle<S, H>(&mut self, subscriber: S, handler: H, meta: HandlerMetadata)
where
S: Subscriber + Send + 'static,
H: Handler<S::Message> + 'static,
L: Layer<H>,
L::Handler: Handler<S::Message> + 'static,
{
let handler = self.global.layer(handler);
self.router.handle(subscriber, handler, meta);
}
pub fn subscribe<S, H>(&mut self, source: S, handler: H, meta: HandlerMetadata)
where
S: SubscriptionSource<B> + Send + 'static,
S::Subscriber: Send + 'static,
H: Handler<<S::Subscriber as Subscriber>::Message> + 'static,
L: Layer<H>,
L::Handler: Handler<<S::Subscriber as Subscriber>::Message> + 'static,
{
let handler = self.global.layer(handler);
self.router.subscribe(source, handler, meta);
}
pub fn include_router(&mut self, router: Router<B>) {
self.router.merge(router);
}
}
impl<B: Broker + 'static, L> BrokerScope<B, L, ()> {
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
pub fn include<D>(&mut self, def: D)
where
D: SubscriberDef,
D::Source: SubscriptionSource<B> + Send + 'static,
<D::Source as SubscriptionSource<B>>::Subscriber: Send + 'static,
<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message: 'static,
D::Input: DeserializeOwned + Send + Sync + 'static,
D::Handler: 'static,
L: Layer<
Typed<
<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message,
D::Input,
crate::codec::DefaultCodec,
D::Handler,
>,
>,
L::Handler: Handler<<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message>
+ 'static,
{
let source = def.source();
self.include_on(source, def, crate::codec::DefaultCodec::default());
}
pub fn include_publishing<D, P, PC, PL>(&mut self, def: D, publisher: TypedPublisher<P, PC, PL>)
where
D: PublishingDef + 'static,
D::Source: SubscriptionSource<B> + Send + 'static,
<D::Source as SubscriptionSource<B>>::Subscriber: Send + 'static,
<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message: 'static,
D::Input: DeserializeOwned + Send + Sync + 'static,
D::Reply: Serialize + Send + Sync + 'static,
P: Publisher + 'static,
PC: Codec + Clone + 'static,
PL: PublishLayer + 'static,
L: Layer<PublishingHandler<D, PC, P, PC, PL>>,
L::Handler: Handler<<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message>
+ 'static,
{
let codec = publisher.codec().clone();
let source = def.source();
self.include_publishing_on(source, def, codec, publisher);
}
}
impl<B: Broker + 'static, L, C: Codec + Clone + 'static> BrokerScope<B, L, C> {
pub fn include<D>(&mut self, def: D)
where
D: SubscriberDef,
D::Source: SubscriptionSource<B> + Send + 'static,
<D::Source as SubscriptionSource<B>>::Subscriber: Send + 'static,
<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message: 'static,
D::Input: DeserializeOwned + Send + Sync + 'static,
D::Handler: 'static,
L: Layer<
Typed<
<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message,
D::Input,
C,
D::Handler,
>,
>,
L::Handler: Handler<<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message>
+ 'static,
{
let codec = self.codec.clone();
let source = def.source();
self.include_on(source, def, codec);
}
pub fn include_publishing<D, P, PC, PL>(&mut self, def: D, publisher: TypedPublisher<P, PC, PL>)
where
D: PublishingDef + 'static,
D::Source: SubscriptionSource<B> + Send + 'static,
<D::Source as SubscriptionSource<B>>::Subscriber: Send + 'static,
<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message: 'static,
D::Input: DeserializeOwned + Send + Sync + 'static,
D::Reply: Serialize + Send + Sync + 'static,
P: Publisher + 'static,
PC: Codec + 'static,
PL: PublishLayer + 'static,
L: Layer<PublishingHandler<D, C, P, PC, PL>>,
L::Handler: Handler<<<D::Source as SubscriptionSource<B>>::Subscriber as Subscriber>::Message>
+ 'static,
{
let codec = self.codec.clone();
let source = def.source();
self.include_publishing_on(source, def, codec, publisher);
}
}
impl<B: Broker + 'static, L, SC> BrokerScope<B, L, SC> {
pub fn include_on<S, D, C>(&mut self, source: S, def: D, codec: C)
where
S: SubscriptionSource<B> + Send + 'static,
S::Subscriber: Send + 'static,
<S::Subscriber as Subscriber>::Message: 'static,
D: SubscriberDef,
D::Input: DeserializeOwned + Send + Sync + 'static,
D::Handler: 'static,
C: Codec + 'static,
L: Layer<Typed<<S::Subscriber as Subscriber>::Message, D::Input, C, D::Handler>>,
L::Handler: Handler<<S::Subscriber as Subscriber>::Message> + 'static,
{
let mut meta = HandlerMetadata::typed::<D::Input>(source.name().to_owned());
if let Some(description) = def.description() {
meta = meta.with_description(description.to_owned());
}
if let Some(schema) = def.input_schema() {
meta = meta.with_payload_schema(schema);
}
let handler = typed(codec, def.into_handler());
self.subscribe(source, handler, meta);
}
pub fn include_publishing_on<S, D, C, P, PC, PL>(
&mut self,
source: S,
def: D,
codec: C,
publisher: TypedPublisher<P, PC, PL>,
) where
S: SubscriptionSource<B> + Send + 'static,
S::Subscriber: Send + 'static,
<S::Subscriber as Subscriber>::Message: 'static,
D: PublishingDef + 'static,
D::Input: DeserializeOwned + Send + Sync + 'static,
D::Reply: Serialize + Send + Sync + 'static,
C: Codec + 'static,
P: Publisher + 'static,
PC: Codec + 'static,
PL: PublishLayer + 'static,
L: Layer<PublishingHandler<D, C, P, PC, PL>>,
L::Handler: Handler<<S::Subscriber as Subscriber>::Message> + 'static,
{
let description = def.description().map(str::to_owned);
let mut meta = HandlerMetadata::typed::<D::Input>(source.name().to_owned())
.with_output_type(std::any::type_name::<D::Reply>());
if let Some(description) = description {
meta = meta.with_description(description);
}
if let Some(schema) = def.input_schema() {
meta = meta.with_payload_schema(schema);
}
let handler = PublishingHandler {
def,
codec,
publisher,
pipeline: self.pipeline.clone(),
};
self.subscribe(source, handler, meta);
}
}
impl<B, L, C> std::fmt::Debug for BrokerScope<B, L, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BrokerScope")
.field("router", &self.router)
.finish_non_exhaustive()
}
}
async fn drain_handles(
handles: Vec<JoinHandle<()>>,
timeout: Option<Duration>,
) -> Result<(), RustStreamError> {
let Some(timeout) = timeout else {
for handle in handles {
handle.await.map_err(RustStreamError::Join)?;
}
return Ok(());
};
let aborts: Vec<_> = handles.iter().map(JoinHandle::abort_handle).collect();
if tokio::time::timeout(timeout, futures::future::join_all(handles))
.await
.is_err()
{
warn!(
target: "ruststream::lifecycle",
"graceful shutdown timed out; aborting in-flight handlers",
);
for abort in aborts {
abort.abort();
}
}
Ok(())
}
async fn wait_for_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let Ok(mut term) = signal(SignalKind::terminate()) else {
let _ = tokio::signal::ctrl_c().await;
return;
};
tokio::select! {
_ = tokio::signal::ctrl_c() => {}
_ = term.recv() => {}
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
}