use std::{
collections::BTreeMap, error::Error as StdError, future::Future, sync::Arc, time::Duration,
};
use crate::codec::Codec;
use crate::{Broker, DescribeServer, ServerSpec};
use tokio_util::task::TaskTracker;
use crate::runtime::dispatch::Delivery;
use crate::runtime::lifecycle::{BoxError, BrokerLifecycle};
use crate::runtime::metadata::HandlerMetadata;
use crate::runtime::middleware::{Identity, Stack};
use crate::runtime::publish::{PublishIdentity, PublishLayer, PublishStack};
use crate::runtime::router::RouterSink;
#[cfg(feature = "testing")]
use crate::testing::coordinator::TestHooks;
use super::scope::BrokerScope;
use super::{AppInfo, LifecycleHook, LifecyclePhase, Starter, StateInit};
pub struct RustStream<L = Identity, St = (), PP = PublishIdentity> {
pub(super) info: AppInfo,
pub(super) brokers: Vec<RegisteredBroker>,
pub(super) starters: Vec<Starter<St>>,
pub(super) handlers: Vec<HandlerMetadata>,
pub(super) servers: BTreeMap<String, ServerSpec>,
pub(super) publish_pipeline: PP,
pub(super) state_init: StateInit<St>,
pub(super) after_startup: Vec<LifecycleHook<St>>,
pub(super) on_shutdown: Vec<LifecycleHook<St>>,
pub(super) after_shutdown: Vec<LifecycleHook<St>>,
pub(super) shutdown_timeout: Option<Duration>,
pub(super) continuations: TaskTracker,
#[cfg(feature = "testing")]
pub(super) test_hooks: Arc<TestHooks>,
pub(super) global: L,
}
pub(crate) struct RegisteredBroker {
pub(crate) lifecycle: Arc<dyn BrokerLifecycle>,
pub(crate) label: Option<String>,
}
#[cfg(feature = "testing")]
pub(crate) struct TestParts<St> {
pub(crate) brokers: Vec<RegisteredBroker>,
pub(crate) starters: Vec<Starter<St>>,
pub(crate) state_init: StateInit<St>,
pub(crate) after_startup: Vec<LifecycleHook<St>>,
pub(crate) shutdown_timeout: Option<Duration>,
pub(crate) continuations: TaskTracker,
pub(crate) test_hooks: Arc<TestHooks>,
}
impl<L, St, PP> std::fmt::Debug for RustStream<L, St, PP> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RustStream")
.field("info", &self.info)
.field("brokers", &self.brokers.len())
.field("handlers", &self.handlers.len())
.finish_non_exhaustive()
}
}
impl RustStream<Identity, (), PublishIdentity> {
#[must_use]
pub fn new(info: AppInfo) -> Self {
Self {
info,
brokers: Vec::new(),
starters: Vec::new(),
handlers: Vec::new(),
servers: BTreeMap::new(),
publish_pipeline: PublishIdentity,
state_init: Box::new(|| Box::pin(async { Ok(()) })),
after_startup: Vec::new(),
on_shutdown: Vec::new(),
after_shutdown: Vec::new(),
shutdown_timeout: None,
continuations: TaskTracker::new(),
#[cfg(feature = "testing")]
test_hooks: Arc::new(TestHooks::detached()),
global: Identity,
}
}
}
impl<L, St, PP> RustStream<L, St, PP> {
#[must_use]
pub fn layer<N>(self, layer: N) -> RustStream<Stack<N, L>, St, PP> {
RustStream {
info: self.info,
brokers: self.brokers,
starters: self.starters,
handlers: self.handlers,
servers: self.servers,
publish_pipeline: self.publish_pipeline,
state_init: self.state_init,
after_startup: self.after_startup,
on_shutdown: self.on_shutdown,
after_shutdown: self.after_shutdown,
shutdown_timeout: self.shutdown_timeout,
continuations: self.continuations,
#[cfg(feature = "testing")]
test_hooks: self.test_hooks,
global: Stack::new(layer, self.global),
}
}
#[must_use]
pub fn on_startup<F, Fut, St2, E>(self, hook: F) -> RustStream<L, St2, PP>
where
F: FnOnce(St) -> Fut + Send + 'static,
Fut: Future<Output = Result<St2, E>> + Send,
St: Send + 'static,
St2: Send + Sync + 'static,
E: StdError + Send + Sync + 'static,
{
let prev = self.state_init;
RustStream {
info: self.info,
brokers: self.brokers,
starters: Vec::new(),
handlers: self.handlers,
servers: self.servers,
publish_pipeline: self.publish_pipeline,
state_init: Box::new(move || {
Box::pin(async move {
let prev_state = prev().await?;
hook(prev_state).await.map_err(|e| Box::new(e) as BoxError)
})
}),
after_startup: Vec::new(),
on_shutdown: Vec::new(),
after_shutdown: Vec::new(),
shutdown_timeout: self.shutdown_timeout,
continuations: self.continuations,
#[cfg(feature = "testing")]
test_hooks: self.test_hooks,
global: self.global,
}
}
#[must_use]
pub fn after_startup<F, Fut, E>(self, hook: F) -> Self
where
St: Send + Sync + 'static,
F: FnOnce(Arc<St>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.push_lifecycle_hook(LifecyclePhase::AfterStartup, hook)
}
#[must_use]
pub fn on_shutdown<F, Fut, E>(self, hook: F) -> Self
where
St: Send + Sync + 'static,
F: FnOnce(Arc<St>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.push_lifecycle_hook(LifecyclePhase::OnShutdown, hook)
}
#[must_use]
pub fn after_shutdown<F, Fut, E>(self, hook: F) -> Self
where
St: Send + Sync + 'static,
F: FnOnce(Arc<St>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
self.push_lifecycle_hook(LifecyclePhase::AfterShutdown, hook)
}
fn push_lifecycle_hook<F, Fut, E>(mut self, phase: LifecyclePhase, hook: F) -> Self
where
St: Send + Sync + 'static,
F: FnOnce(Arc<St>) -> Fut + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send,
E: StdError + Send + Sync + 'static,
{
let boxed: LifecycleHook<St> = Box::new(move |state| {
Box::pin(async move { hook(state).await.map_err(|e| Box::new(e) as BoxError) })
});
match phase {
LifecyclePhase::AfterStartup => self.after_startup.push(boxed),
LifecyclePhase::OnShutdown => self.on_shutdown.push(boxed),
LifecyclePhase::AfterShutdown => self.after_shutdown.push(boxed),
}
self
}
#[must_use]
pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = Some(timeout);
self
}
#[must_use]
pub fn publish_layer<M>(self, middleware: M) -> RustStream<L, St, PublishStack<M, PP>>
where
M: PublishLayer + Clone + 'static,
{
RustStream {
info: self.info,
brokers: self.brokers,
starters: self.starters,
handlers: self.handlers,
servers: self.servers,
publish_pipeline: PublishStack::new(middleware, self.publish_pipeline),
state_init: self.state_init,
after_startup: self.after_startup,
on_shutdown: self.on_shutdown,
after_shutdown: self.after_shutdown,
shutdown_timeout: self.shutdown_timeout,
continuations: self.continuations,
#[cfg(feature = "testing")]
test_hooks: self.test_hooks,
global: self.global,
}
}
#[must_use]
pub fn register_broker<B>(mut self, broker: B) -> Self
where
B: Broker + 'static,
{
self.brokers.push(RegisteredBroker {
lifecycle: Arc::new(broker),
label: None,
});
self
}
#[must_use]
pub fn server(mut self, name: impl Into<String>, spec: ServerSpec) -> Self {
self.servers.insert(name.into(), spec);
self
}
#[must_use]
pub fn with_broker<B, F>(mut self, broker: B, build: F) -> Self
where
B: Broker + 'static,
L: Clone,
PP: Clone,
St: Send + Sync + 'static,
F: FnOnce(&mut BrokerScope<B, L, (), St, PP>),
{
let broker = Arc::new(broker);
let mut scope = self.new_scope(&broker, ());
build(&mut scope);
self.collect_scope(&broker, scope, None);
self
}
#[must_use]
pub fn with_broker_codec<B, C, F>(mut self, broker: B, codec: C, build: F) -> Self
where
B: Broker + 'static,
C: Codec + Clone + 'static,
L: Clone,
PP: Clone,
St: Send + Sync + 'static,
F: FnOnce(&mut BrokerScope<B, L, C, St, PP>),
{
let broker = Arc::new(broker);
let mut scope = self.new_scope(&broker, codec);
build(&mut scope);
self.collect_scope(&broker, scope, None);
self
}
#[must_use]
pub fn with_broker_labeled<B, F>(
mut self,
label: impl Into<String>,
broker: B,
build: F,
) -> Self
where
B: DescribeServer + 'static,
L: Clone,
PP: Clone,
St: Send + Sync + 'static,
F: FnOnce(&mut BrokerScope<B, L, (), St, PP>),
{
let label = self.record_server(label, &broker);
let broker = Arc::new(broker);
let mut scope = self.new_scope(&broker, ());
build(&mut scope);
self.collect_scope(&broker, scope, Some(label));
self
}
#[must_use]
pub fn with_broker_labeled_codec<B, C, F>(
mut self,
label: impl Into<String>,
broker: B,
codec: C,
build: F,
) -> Self
where
B: DescribeServer + 'static,
C: Codec + Clone + 'static,
L: Clone,
PP: Clone,
St: Send + Sync + 'static,
F: FnOnce(&mut BrokerScope<B, L, C, St, PP>),
{
let label = self.record_server(label, &broker);
let broker = Arc::new(broker);
let mut scope = self.new_scope(&broker, codec);
build(&mut scope);
self.collect_scope(&broker, scope, Some(label));
self
}
fn record_server<B: DescribeServer>(&mut self, label: impl Into<String>, broker: &B) -> String {
let label = label.into();
self.servers
.entry(label.clone())
.or_insert_with(|| broker.describe_server());
label
}
fn new_scope<B, C>(&self, broker: &Arc<B>, codec: C) -> BrokerScope<B, L, C, St, PP>
where
B: Broker + 'static,
L: Clone,
PP: Clone,
St: Send + Sync + 'static,
{
BrokerScope {
broker: broker.clone(),
sink: RouterSink::new(),
pipeline: self.publish_pipeline.clone(),
retry_publisher: None,
global: self.global.clone(),
codec,
}
}
fn collect_scope<B, C>(
&mut self,
broker: &Arc<B>,
scope: BrokerScope<B, L, C, St, PP>,
label: Option<String>,
) where
B: Broker + 'static,
St: Send + Sync + 'static,
{
let lifecycle: Arc<dyn BrokerLifecycle> = broker.clone();
#[cfg(feature = "testing")]
let delivery = Arc::new(Delivery::instrumented(
scope.retry_publisher.clone(),
self.continuations.clone(),
self.test_hooks.clone(),
self.brokers.len(),
));
#[cfg(not(feature = "testing"))]
let delivery = Arc::new(Delivery::detached(
scope.retry_publisher.clone(),
self.continuations.clone(),
));
let (starters, handlers) = scope.sink.into_parts();
for (bound, meta) in starters.into_iter().zip(handlers) {
let broker = broker.clone();
let delivery = delivery.clone();
self.starters.push(Box::new(move |state, shutdown, token| {
bound(broker, state, delivery, shutdown, token)
}));
self.handlers.push(meta);
}
self.brokers.push(RegisteredBroker { lifecycle, label });
}
#[must_use]
pub fn handlers(&self) -> &[HandlerMetadata] {
&self.handlers
}
#[must_use]
pub fn info(&self) -> &AppInfo {
&self.info
}
#[must_use]
pub fn servers(&self) -> &BTreeMap<String, ServerSpec> {
&self.servers
}
#[cfg(feature = "testing")]
pub(crate) fn into_test_parts(self) -> TestParts<St> {
TestParts {
brokers: self.brokers,
starters: self.starters,
state_init: self.state_init,
after_startup: self.after_startup,
shutdown_timeout: self.shutdown_timeout,
continuations: self.continuations,
test_hooks: self.test_hooks,
}
}
}