use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::OutgoingMessage;
use crate::runtime::{
BrokerLifecycle, ErrorShutdown, LifecycleHook, RegisteredBroker, RustStream, RustStreamError,
Starter, TestParts,
};
use super::assertions::{PublishedAssertions, SubscriberAssertions};
use super::broker::{TestableBroker, TestableRegistration};
use super::coordinator::Coordinator;
const DEFAULT_MAX_STEPS: usize = 10_000;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TestError {
#[error("startup hook failed: {0}")]
Startup(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("subscription failed: {0}")]
Subscribe(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("the reaction did not settle within {processed} dispatched deliveries")]
NotQuiescent {
processed: usize,
},
#[error("publish after the service shut down")]
ShutDown,
#[error("more than one broker is registered; address one with broker::<B>() or broker_named()")]
Ambiguous,
#[error("broker {0} has no in-process test transport")]
NoTransport(String),
#[error("failed to encode the message: {0}")]
Encode(String),
}
struct BrokerEntry {
label: Option<String>,
lifecycle: Arc<dyn BrokerLifecycle>,
registration: Option<&'static TestableRegistration>,
}
impl BrokerEntry {
fn testable(&self) -> Option<&dyn TestableBroker> {
self.registration
.and_then(|registration| registration.resolve(self.lifecycle.as_any()))
}
fn display(&self) -> String {
self.label
.clone()
.unwrap_or_else(|| self.lifecycle.name().to_owned())
}
}
fn recover_testable(
lifecycle: &Arc<dyn BrokerLifecycle>,
coordinator: &Coordinator,
) -> Option<&'static TestableRegistration> {
let any = lifecycle.as_any();
for registration in inventory::iter::<TestableRegistration> {
if let Some(broker) = registration.resolve(any) {
broker.install_coordinator(coordinator.clone());
return Some(registration);
}
}
None
}
pub struct TestBrokers<'a> {
entries: &'a [BrokerEntry],
}
impl std::fmt::Debug for TestBrokers<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestBrokers")
.field("brokers", &self.entries.len())
.finish_non_exhaustive()
}
}
impl TestBrokers<'_> {
#[must_use]
pub fn broker<B: crate::Broker + 'static>(&self) -> &B {
let mut found = self
.entries
.iter()
.filter_map(|e| e.lifecycle.as_any().downcast_ref::<B>());
let first = found.next().unwrap_or_else(|| {
panic!(
"no registered broker of type {}",
std::any::type_name::<B>()
)
});
assert!(
found.next().is_none(),
"more than one broker of type {} is registered",
std::any::type_name::<B>(),
);
first
}
}
pub struct TestApp<St> {
entries: Vec<BrokerEntry>,
coordinator: Coordinator,
#[allow(dead_code)]
state: Arc<St>,
error_shutdown: ErrorShutdown,
token: CancellationToken,
handles: Vec<JoinHandle<()>>,
continuations: TaskTracker,
shutdown_timeout: Option<Duration>,
}
impl<St> std::fmt::Debug for TestApp<St> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestApp")
.field("brokers", &self.entries.len())
.field("subscribers", &self.handles.len())
.finish_non_exhaustive()
}
}
impl<St: Send + Sync + 'static> TestApp<St> {
pub async fn start<L>(app: RustStream<L, St>) -> Result<Self, TestError> {
let (coordinator, entries, parts) = Self::setup(app);
let TestParts {
starters,
state_init,
after_startup,
shutdown_timeout,
continuations,
..
} = parts;
let state = state_init().await.map_err(TestError::Startup)?;
Self::spawn(SpawnArgs {
coordinator,
entries,
starters,
after_startup,
continuations,
shutdown_timeout,
state: Arc::new(state),
})
.await
}
pub async fn with_state<L, F>(app: RustStream<L, St>, build: F) -> Result<Self, TestError>
where
F: FnOnce(&TestBrokers<'_>) -> St,
{
let (coordinator, entries, parts) = Self::setup(app);
let TestParts {
starters,
after_startup,
shutdown_timeout,
continuations,
..
} = parts;
let state = build(&TestBrokers { entries: &entries });
Self::spawn(SpawnArgs {
coordinator,
entries,
starters,
after_startup,
continuations,
shutdown_timeout,
state: Arc::new(state),
})
.await
}
fn setup<L>(app: RustStream<L, St>) -> (Coordinator, Vec<BrokerEntry>, TestParts<St>) {
let mut parts = app.into_test_parts();
let coordinator = Coordinator::new(DEFAULT_MAX_STEPS);
parts.test_hooks.install(coordinator.clone());
let entries = std::mem::take(&mut parts.brokers)
.into_iter()
.map(|RegisteredBroker { lifecycle, label }| {
let registration = recover_testable(&lifecycle, &coordinator);
BrokerEntry {
label,
lifecycle,
registration,
}
})
.collect();
(coordinator, entries, parts)
}
async fn spawn(args: SpawnArgs<St>) -> Result<Self, TestError> {
let SpawnArgs {
coordinator,
entries,
starters,
after_startup,
continuations,
shutdown_timeout,
state,
} = args;
let token = CancellationToken::new();
let error_shutdown = ErrorShutdown::new(token.clone());
let mut handles = Vec::with_capacity(starters.len());
for starter in starters {
let handle = starter(state.clone(), error_shutdown.clone(), token.clone())
.await
.map_err(TestError::Subscribe)?;
handles.push(handle);
}
for hook in after_startup {
hook(state.clone()).await.map_err(TestError::Startup)?;
}
Ok(Self {
entries,
coordinator,
state,
error_shutdown,
token,
handles,
continuations,
shutdown_timeout,
})
}
#[must_use]
pub fn broker<B: crate::Broker + 'static>(&self) -> BrokerHandle<'_> {
let mut matches = self
.entries
.iter()
.filter(|e| e.lifecycle.as_any().downcast_ref::<B>().is_some());
let first = matches.next().unwrap_or_else(|| {
panic!(
"no registered broker of type {}",
std::any::type_name::<B>()
)
});
assert!(
matches.next().is_none(),
"more than one broker of type {} is registered; address one with broker_named(label)",
std::any::type_name::<B>(),
);
self.handle(first)
}
#[must_use]
pub fn broker_named(&self, label: &str) -> BrokerHandle<'_> {
let entry = self
.entries
.iter()
.find(|e| e.label.as_deref() == Some(label))
.unwrap_or_else(|| panic!("no broker labeled {label:?}"));
self.handle(entry)
}
fn handle<'a>(&'a self, entry: &'a BrokerEntry) -> BrokerHandle<'a> {
let scope_id = self
.entries
.iter()
.position(|e| std::ptr::eq(e, entry))
.expect("entry belongs to this app");
BrokerHandle {
scope_id,
coordinator: &self.coordinator,
testable: entry.testable(),
token: &self.token,
label: entry.display(),
}
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
pub async fn publish<T: serde::Serialize + Sync>(
&self,
name: &str,
value: &T,
) -> Result<(), TestError> {
if self.entries.len() != 1 {
return Err(TestError::Ambiguous);
}
self.handle(&self.entries[0]).publish(name, value).await
}
pub async fn settle(&self) -> Result<(), TestError> {
self.coordinator.drive().await
}
pub async fn advance(&self, by: Duration) -> Result<(), TestError> {
tokio::time::advance(by).await;
self.coordinator.fire_due_timers().await;
self.coordinator.drive().await
}
pub async fn drain(&self) {
while !self.continuations.is_empty() {
tokio::task::yield_now().await;
}
}
pub fn run_result(&self) -> Result<(), RustStreamError> {
self.error_shutdown
.peek_failure()
.map_or(Ok(()), |reason| Err(RustStreamError::Dispatch(reason)))
}
pub fn assert_running(&self) {
assert!(
!self.token.is_cancelled(),
"expected the service to be running, but it was shut down: {:?}",
self.error_shutdown.peek_failure(),
);
}
pub fn assert_shut_down(&self) {
assert!(
self.token.is_cancelled(),
"expected the service to be shut down, but it was still running",
);
}
pub async fn shutdown(self) -> Result<(), RustStreamError> {
self.token.cancel();
match self.shutdown_timeout {
Some(timeout) => {
for handle in self.handles {
let _ = tokio::time::timeout(timeout, handle).await;
}
}
None => {
for handle in self.handles {
let _ = handle.await;
}
}
}
self.continuations.close();
self.continuations.wait().await;
self.error_shutdown
.taken_failure()
.map_or(Ok(()), |reason| Err(RustStreamError::Dispatch(reason)))
}
}
struct SpawnArgs<St> {
coordinator: Coordinator,
entries: Vec<BrokerEntry>,
starters: Vec<Starter<St>>,
after_startup: Vec<LifecycleHook<St>>,
continuations: TaskTracker,
shutdown_timeout: Option<Duration>,
state: Arc<St>,
}
pub struct BrokerHandle<'a> {
scope_id: usize,
coordinator: &'a Coordinator,
testable: Option<&'a dyn TestableBroker>,
token: &'a CancellationToken,
label: String,
}
impl std::fmt::Debug for BrokerHandle<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BrokerHandle")
.field("broker", &self.label)
.field("testable", &self.testable.is_some())
.finish_non_exhaustive()
}
}
impl BrokerHandle<'_> {
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
pub async fn publish<T: serde::Serialize + Sync>(
&self,
name: &str,
value: &T,
) -> Result<(), TestError> {
use crate::codec::Codec;
let bytes = crate::codec::DefaultCodec::default()
.encode(value)
.map_err(|err| TestError::Encode(err.to_string()))?;
self.publish_raw(name, &bytes).await
}
pub async fn publish_raw(&self, name: &str, payload: &[u8]) -> Result<(), TestError> {
if self.token.is_cancelled() {
return Err(TestError::ShutDown);
}
let transport = self
.testable
.ok_or_else(|| TestError::NoTransport(self.label.clone()))?;
transport.inject(OutgoingMessage::new(name, payload));
self.coordinator.drive().await
}
#[must_use]
pub fn subscriber(&self, name: &str) -> SubscriberAssertions<'_> {
SubscriberAssertions::new(self.coordinator, self.scope_id, name.to_owned())
}
#[must_use]
pub fn published<T>(&self, name: &str) -> PublishedAssertions<T> {
let messages = self.testable.map(|t| t.published(name)).unwrap_or_default();
PublishedAssertions::new(name.to_owned(), messages)
}
}