use crate::endpoints::{create_consumer_from_route, create_publisher_from_route};
use crate::errors::ProcessingError;
pub use crate::models::Route;
use crate::models::{Endpoint, EndpointType, RouteOptions};
use crate::traits::{
BatchCommitFunc, ConsumerError, Handler, HandlerError, MessageConsumer, MessageDisposition,
MessagePublisher, PublisherError, SentBatch,
};
use async_channel::{bounded, Sender};
use serde::de::DeserializeOwned;
use std::collections::{BTreeMap, HashMap};
use std::sync::{Arc, OnceLock, RwLock};
use tokio::{
select,
task::{JoinHandle, JoinSet},
};
use tracing::{debug, error, info, trace, warn};
pub use crate::extensions::{
get_endpoint_factory, get_middleware_factory, register_endpoint_factory,
register_middleware_factory,
};
#[derive(Debug)]
pub struct RouteHandle((JoinHandle<()>, Sender<()>));
impl RouteHandle {
pub async fn stop(&self) {
let _ = self.0 .1.send(()).await;
self.0 .1.close();
}
pub async fn join(self) -> Result<(), tokio::task::JoinError> {
self.0 .0.await
}
}
async fn run_publisher_connect_hook(
route_name: &str,
publisher: &Arc<dyn MessagePublisher>,
) -> anyhow::Result<()> {
if let Some(hook) = publisher.on_connect_hook() {
hook.await.map_err(|err| {
anyhow::anyhow!(
"Publisher on_connect hook failed for route '{}': {}",
route_name,
err
)
})?;
}
Ok(())
}
async fn run_consumer_connect_hook(
route_name: &str,
consumer: &dyn MessageConsumer,
) -> anyhow::Result<()> {
if let Some(hook) = consumer.on_connect_hook() {
hook.await.map_err(|err| {
anyhow::anyhow!(
"Consumer on_connect hook failed for route '{}': {}",
route_name,
err
)
})?;
}
Ok(())
}
async fn run_publisher_disconnect_hook(route_name: &str, publisher: &Arc<dyn MessagePublisher>) {
if let Some(hook) = publisher.on_disconnect_hook() {
if let Err(err) = hook.await {
warn!(
"Publisher on_disconnect hook failed for route '{}': {}",
route_name, err
);
}
}
}
async fn run_consumer_disconnect_hook(route_name: &str, consumer: &dyn MessageConsumer) {
if let Some(hook) = consumer.on_disconnect_hook() {
if let Err(err) = hook.await {
warn!(
"Consumer on_disconnect hook failed for route '{}': {}",
route_name, err
);
}
}
}
impl From<(JoinHandle<()>, Sender<()>)> for RouteHandle {
fn from(tuple: (JoinHandle<()>, Sender<()>)) -> Self {
RouteHandle(tuple)
}
}
struct ActiveRoute {
route: Route,
handle: RouteHandle,
}
static ROUTE_REGISTRY: OnceLock<RwLock<HashMap<String, ActiveRoute>>> = OnceLock::new();
static ENDPOINT_REF_REGISTRY: OnceLock<RwLock<HashMap<String, Endpoint>>> = OnceLock::new();
pub fn register_endpoint(name: &str, endpoint: Endpoint) {
let registry = ENDPOINT_REF_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let mut writer = registry
.write()
.expect("Named endpoint registry lock poisoned");
if writer.insert(name.to_string(), endpoint).is_some() {
debug!("Overwriting a registered endpoint named '{}'", name);
}
}
pub fn get_endpoint(name: &str) -> Option<Endpoint> {
let registry = ENDPOINT_REF_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let reader = registry
.read()
.expect("Named endpoint registry lock poisoned");
reader.get(name).cloned()
}
impl Route {
pub fn new(input: Endpoint, output: Endpoint) -> Self {
Self {
input,
output,
..Default::default()
}
}
pub fn get(name: &str) -> Option<Self> {
let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let map = registry.read().expect("Route registry lock poisoned");
map.get(name).map(|active| active.route.clone())
}
pub fn list() -> Vec<String> {
let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let map = registry.read().expect("Route registry lock poisoned");
map.keys().cloned().collect()
}
pub fn is_ref(&self) -> bool {
matches!(self.input.endpoint_type, EndpointType::Ref(_))
&& !matches!(self.output.endpoint_type, EndpointType::Ref(_))
}
pub fn register_output_endpoint(&self, name: Option<&str>) -> Result<(), anyhow::Error> {
match name {
Some(name) => {
register_endpoint(name, self.output.clone());
}
None => {
if let EndpointType::Ref(name) = &self.input.endpoint_type {
register_endpoint(name, self.output.clone());
} else {
return Err(anyhow::anyhow!(
"No name and input is not a reference endpoint"
));
}
}
};
Ok(())
}
pub async fn deploy(&self, name: &str) -> anyhow::Result<()> {
Self::stop(name).await;
let handle = self.run(name).await?;
let active = ActiveRoute {
route: self.clone(),
handle,
};
let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let mut map = registry.write().expect("Route registry lock poisoned");
map.insert(name.to_string(), active);
Ok(())
}
pub async fn stop(name: &str) -> bool {
let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let active_opt = {
let mut map = registry.write().expect("Route registry lock poisoned");
map.remove(name)
};
if let Some(active) = active_opt {
let handle = active.handle;
let _ = handle.0 .1.send(()).await;
handle.0 .1.close();
let mut join_handle = handle.0 .0;
tokio::select! {
res = &mut join_handle => {
let _ = res;
}
_ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
join_handle.abort();
let _ = join_handle.await;
}
}
true
} else {
false
}
}
pub async fn create_publisher(&self) -> anyhow::Result<crate::Publisher> {
crate::Publisher::new(self.output.clone()).await
}
pub async fn connect_to_output(
&self,
name: &str,
) -> anyhow::Result<Box<dyn crate::traits::MessageConsumer>> {
create_consumer_from_route(name, &self.output).await
}
pub fn check(
&self,
name: &str,
allowed_endpoints: Option<&[&str]>,
) -> anyhow::Result<Vec<String>> {
let mut warnings = Vec::new();
warnings.extend(crate::endpoints::check_consumer(
name,
&self.input,
allowed_endpoints,
)?);
warnings.extend(crate::endpoints::check_publisher(
name,
&self.output,
allowed_endpoints,
)?);
Ok(warnings)
}
pub async fn run(&self, name_str: &str) -> anyhow::Result<RouteHandle> {
let warnings = self.check(name_str, None)?;
for warning in warnings {
tracing::warn!(route = name_str, "Configuration warning: {}", warning);
}
let (shutdown_tx, shutdown_rx) = bounded(1);
let (ready_tx, ready_rx) = bounded(1);
let route = Arc::new(self.clone());
let name = Arc::new(name_str.to_string());
let handle = tokio::spawn(async move {
loop {
let route_arc = Arc::clone(&route);
let name_arc = Arc::clone(&name);
let (internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
let ready_tx_clone = ready_tx.clone();
let mut run_task = tokio::spawn(async move {
route_arc
.run_until_err(&name_arc, Some(internal_shutdown_rx), Some(ready_tx_clone))
.await
});
select! {
_ = shutdown_rx.recv() => {
info!("Shutdown signal received for route '{}'.", name);
let _ = internal_shutdown_tx.send(()).await;
let _ = run_task.await;
break;
}
res = &mut run_task => {
match res {
Ok(Ok(should_continue)) if !should_continue => {
info!("Route '{}' completed gracefully. Shutting down.", name);
break;
}
Ok(Err(e)) => {
let is_permanent =
e.downcast_ref::<ProcessingError>().is_some_and(|pe| matches!(pe, ProcessingError::NonRetryable(_)))
|| e.downcast_ref::<ConsumerError>().is_some_and(|ce| matches!(ce, ConsumerError::EndOfStream));
if is_permanent {
error!("Route '{}' failed with a permanent error: {}. Shutting down.", name, e);
break;
}
warn!("Route '{}' failed: {}. Reconnecting in 5 seconds...", name, e);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
Err(e) => {
error!("Route '{}' task panicked: {}. Reconnecting in 5 seconds...", name, e);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
}
_ => {} }
}
}
}
});
match tokio::time::timeout(std::time::Duration::from_secs(5), ready_rx.recv()).await {
Ok(Ok(_)) => Ok(RouteHandle((handle, shutdown_tx))),
_ => {
handle.abort();
Err(anyhow::anyhow!(
"Route '{}' failed to start within 5 seconds or encountered an error",
name_str
))
}
}
}
pub async fn run_until_err(
&self,
name: &str,
shutdown_rx: Option<async_channel::Receiver<()>>,
ready_tx: Option<Sender<()>>,
) -> anyhow::Result<bool> {
let (_internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
let shutdown_rx = shutdown_rx.unwrap_or(internal_shutdown_rx);
if self.options.concurrency == 1 {
self.run_sequentially(name, shutdown_rx, ready_tx).await
} else {
self.run_concurrently(name, shutdown_rx, ready_tx).await
}
}
async fn run_sequentially(
&self,
name: &str,
shutdown_rx: async_channel::Receiver<()>,
ready_tx: Option<Sender<()>>,
) -> anyhow::Result<bool> {
let publisher = create_publisher_from_route(name, &self.output).await?;
let mut consumer = create_consumer_from_route(name, &self.input).await?;
if let Err(err) = run_publisher_connect_hook(name, &publisher).await {
run_publisher_disconnect_hook(name, &publisher).await;
return Err(err);
}
if let Err(err) = run_consumer_connect_hook(name, consumer.as_ref()).await {
run_consumer_disconnect_hook(name, consumer.as_ref()).await;
run_publisher_disconnect_hook(name, &publisher).await;
return Err(err);
}
let (err_tx, err_rx) = bounded(1);
let mut commit_tasks = JoinSet::new();
let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
let mut seq_counter = 0u64;
if let Some(tx) = ready_tx {
let _ = tx.send(()).await;
}
let mut message_ids = Vec::with_capacity(self.options.batch_size);
let has_retry_middleware = self.output.has_retry_middleware();
let run_result = loop {
select! {
Ok(err) = err_rx.recv() => break Err(err),
_ = shutdown_rx.recv() => {
info!("Shutdown signal received in sequential runner for route '{}'.", name);
break Ok(true); }
res = consumer.receive_batch(self.options.batch_size) => {
let received_batch = match res {
Ok(batch) => {
if batch.messages.is_empty() {
continue; }
batch
}
Err(ConsumerError::EndOfStream) => {
info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
break Ok(false); }
Err(ConsumerError::Connection(e)) => {
break Err(e);
},
Err(ConsumerError::Gap { requested, base }) => {
break Err(anyhow::anyhow!("Consumer gap: requested offset {requested} but earliest available is {base}"));
}
};
debug!("Received a batch of {} messages sequentially", received_batch.messages.len());
let seq = seq_counter;
seq_counter += 1;
let mut commit_opt = Some(wrap_commit(received_batch.commit, seq, seq_tx.clone()));
let batch_len = received_batch.messages.len();
message_ids.clear();
message_ids.extend(received_batch.messages.iter().map(|m| m.message_id));
let request_ids: std::collections::HashSet<u128> = received_batch
.messages
.iter()
.filter(|m| m.metadata.contains_key("reply_to"))
.map(|m| m.message_id)
.collect();
match publisher.send_batch(received_batch.messages).await {
Ok(SentBatch::Ack) => {
for id in &message_ids {
if request_ids.contains(id) {
warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop broken.", id);
}
}
let commit = commit_opt.take().expect("Commit already used");
let err_tx = err_tx.clone();
commit_tasks.spawn(async move {
if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
error!("Commit failed: {}", e);
match err_tx.try_send(e) {
Ok(_) => trace!("Reported commit error to main task"),
Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
}
}
});
}
Ok(SentBatch::Partial { responses, failed }) => {
let has_transient = failed.iter().any(|(_, e)| {
matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_))
});
if has_transient {
let (_, first_err) = failed
.iter()
.find(|(_, e)| matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_)))
.expect("has_transient is true");
let err = anyhow::anyhow!(
"Transient error in batch send ({} messages failed). First error: {}",
failed.len(),
first_err
);
let commit = commit_opt.take().expect("Commit already used");
let dispositions =
map_responses_to_dispositions(&message_ids, responses, &failed, &request_ids);
if let Err(commit_err) = commit(dispositions).await {
warn!("Commit after transient failure also failed: {}", commit_err);
}
if !has_retry_middleware {
break Err(err);
}
warn!("Transient error in batch, message(s) Nack'ed for re-delivery: {}", err);
tokio::task::yield_now().await;
continue;
}
for (msg, e) in &failed {
error!("Dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
}
let commit = commit_opt.take().expect("Commit already used");
let err_tx = err_tx.clone();
let ids = std::mem::take(&mut message_ids);
let req_ids = request_ids;
commit_tasks.spawn(async move {
let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &req_ids);
if let Err(e) = commit(dispositions).await {
error!("Commit failed: {}", e);
match err_tx.try_send(e) {
Ok(_) => trace!("Reported commit error to main task"),
Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
}
}
});
}
Err(e) => {
warn!("Publisher error, sending {} Nacks to commit", batch_len);
let commit = commit_opt.take().expect("Commit already used");
let nack_result = commit(vec![MessageDisposition::Nack; batch_len]).await;
debug!("Nack commit result: {:?}", nack_result);
break Err(e.into());
}
}
tokio::task::yield_now().await;
}
}
};
drop(seq_tx);
loop {
select! {
res = err_rx.recv() => {
if let Ok(err) = res {
error!("Error reported during shutdown: {}", err);
}
}
res = commit_tasks.join_next() => {
if res.is_none() {
break;
}
}
}
}
drop(err_rx);
let _ = sequencer_handle.await;
run_consumer_disconnect_hook(name, consumer.as_ref()).await;
run_publisher_disconnect_hook(name, &publisher).await;
run_result
}
async fn run_concurrently(
&self,
name: &str,
shutdown_rx: async_channel::Receiver<()>,
ready_tx: Option<Sender<()>>,
) -> anyhow::Result<bool> {
let publisher = create_publisher_from_route(name, &self.output).await?;
let mut consumer = create_consumer_from_route(name, &self.input).await?;
if let Err(err) = run_publisher_connect_hook(name, &publisher).await {
run_publisher_disconnect_hook(name, &publisher).await;
return Err(err);
}
if let Err(err) = run_consumer_connect_hook(name, consumer.as_ref()).await {
run_consumer_disconnect_hook(name, consumer.as_ref()).await;
run_publisher_disconnect_hook(name, &publisher).await;
return Err(err);
}
if let Some(tx) = ready_tx {
let _ = tx.send(()).await;
}
let (err_tx, err_rx) = bounded(1); let work_capacity = self
.options
.concurrency
.saturating_mul(self.options.batch_size);
let (work_tx, work_rx) =
bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(work_capacity);
let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
let batch_size = self.options.batch_size;
let mut join_set = JoinSet::new();
for i in 0..self.options.concurrency {
let work_rx_clone = work_rx.clone();
let publisher = Arc::clone(&publisher);
let err_tx = err_tx.clone();
let mut commit_tasks = JoinSet::new();
let has_retry_middleware = self.output.has_retry_middleware();
join_set.spawn(async move {
debug!("Starting worker {}", i);
let mut message_ids = Vec::with_capacity(batch_size);
while let Ok((messages, commit_func)) = work_rx_clone.recv().await {
let mut commit_opt = Some(commit_func);
let batch_len = messages.len();
message_ids.clear();
message_ids.extend(messages.iter().map(|m| m.message_id));
let request_ids: std::collections::HashSet<u128> = messages
.iter()
.filter(|m| m.metadata.contains_key("reply_to"))
.map(|m| m.message_id)
.collect();
match publisher.send_batch(messages).await {
Ok(SentBatch::Ack) => {
for id in &message_ids {
if request_ids.contains(id) {
warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop broken.", id);
}
}
let commit = commit_opt.take().expect("Commit already used");
let err_tx = err_tx.clone();
commit_tasks.spawn(async move {
if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
error!("Commit failed: {}", e);
match err_tx.try_send(e) {
Ok(_) => trace!("Reported commit error to main task"),
Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
}
}
});
}
Ok(SentBatch::Partial { responses, failed }) => {
let has_transient = failed.iter().any(|(_, e)| {
matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_))
});
if has_transient {
let (_, first_err) = failed
.iter()
.find(|(_, e)| matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_)))
.expect("has_transient is true");
let e = anyhow::anyhow!(
"Transient error in batch send ({} messages failed). First error: {}",
failed.len(),
first_err
);
let commit = commit_opt.take().expect("Commit already used");
let dispositions =
map_responses_to_dispositions(&message_ids, responses, &failed, &request_ids);
if let Err(commit_err) = commit(dispositions).await {
warn!("Commit after transient failure also failed: {}", commit_err);
}
if !has_retry_middleware {
match err_tx.try_send(e) {
Ok(_) => trace!("Reported error to main task"),
Err(err_send) => warn!(error=?err_send, "Could not send error to main task, it might be down or busy."),
}
break;
}
warn!("Transient error in batch, message(s) Nack'ed for re-delivery: {}", e);
tokio::task::yield_now().await;
continue;
}
for (msg, e) in &failed {
error!("Worker dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
}
let commit = commit_opt.take().expect("Commit already used");
let err_tx = err_tx.clone();
let ids = std::mem::take(&mut message_ids);
let req_ids = request_ids;
commit_tasks.spawn(async move {
let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &req_ids);
if let Err(e) = commit(dispositions).await {
error!("Commit failed: {}", e);
match err_tx.try_send(e) {
Ok(_) => trace!("Reported commit error to main task"),
Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
}
}
});
}
Err(e) => {
error!("Worker failed to send message batch: {}", e);
let commit = commit_opt.take().expect("Commit already used");
let nack_result = commit(vec![MessageDisposition::Nack; batch_len]).await;
debug!("Nack commit result: {:?}", nack_result);
match err_tx.try_send(e.into()) {
Ok(_) => trace!("Reported error to main task"),
Err(err_send) => warn!(error=?err_send, "Could not send error to main task, it might be down or busy."),
}
break;
}
}
}
while commit_tasks.join_next().await.is_some() {}
});
}
let mut seq_counter = 0u64;
let mut loop_error: Option<anyhow::Error> = None;
loop {
select! {
biased;
Ok(err) = err_rx.recv() => {
error!("A worker reported a critical error. Shutting down route.");
loop_error = Some(err);
break;
}
Some(res) = join_set.join_next() => {
match res {
Ok(_) => {
error!("A worker task finished unexpectedly. Shutting down route.");
loop_error = Some(anyhow::anyhow!("Worker task finished unexpectedly"));
}
Err(e) => {
error!("A worker task panicked: {}. Shutting down route.", e);
loop_error = Some(e.into());
}
}
break;
}
_ = shutdown_rx.recv() => {
info!("Shutdown signal received in concurrent runner for route '{}'.", name);
break;
}
res = consumer.receive_batch(self.options.batch_size) => {
let (messages, commit) = match res {
Ok(batch) => {
if batch.messages.is_empty() {
continue; }
(batch.messages, batch.commit)
}
Err(ConsumerError::EndOfStream) => {
info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
break; }
Err(ConsumerError::Connection(e)) => {
loop_error = Some(e);
break;
}
Err(ConsumerError::Gap { requested, base }) => {
loop_error = Some(ConsumerError::Gap { requested, base }.into());
break;
}
};
debug!("Received a batch of {} messages concurrently", messages.len());
let seq = seq_counter;
let wrapped_commit = wrap_commit(commit, seq, seq_tx.clone());
match work_tx.send((messages, wrapped_commit)).await {
Ok(()) => {
seq_counter += 1;
}
Err(e) => {
warn!("Work channel closed, cannot process more messages concurrently. Shutting down.");
let (msgs_back, wrapped_commit_back) = e.into_inner();
let _ = (wrapped_commit_back)(vec![crate::traits::MessageDisposition::Nack; msgs_back.len()]).await;
break;
}
}
tokio::task::yield_now().await;
}
}
}
drop(work_tx);
while join_set.join_next().await.is_some() {}
drop(seq_tx);
let _ = sequencer_handle.await;
run_consumer_disconnect_hook(name, consumer.as_ref()).await;
run_publisher_disconnect_hook(name, &publisher).await;
if let Some(err) = loop_error {
return Err(err);
}
if let Ok(err) = err_rx.try_recv() {
return Err(err);
}
Ok(shutdown_rx.is_empty())
}
pub fn with_options(mut self, options: RouteOptions) -> Self {
self.options = options;
self
}
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.options.concurrency = concurrency.max(1);
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.options.batch_size = batch_size.max(1);
self
}
pub fn with_commit_concurrency_limit(mut self, limit: usize) -> Self {
self.options.commit_concurrency_limit = limit.max(1);
self
}
pub fn with_handler(mut self, handler: impl Handler + 'static) -> Self {
self.output.handler = Some(Arc::new(handler));
self
}
pub fn add_handler<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
where
T: DeserializeOwned + Send + Sync + 'static,
H: crate::type_handler::IntoTypedHandler<T, Args>,
Args: Send + Sync + 'static,
{
let handler = Arc::new(handler);
let wrapper = move |msg: crate::CanonicalMessage| {
let handler = handler.clone();
async move {
let data = msg.parse::<T>().map_err(|e| {
HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
})?;
let ctx = crate::MessageContext::from(msg);
handler.call(data, ctx).await
}
};
let wrapper = Arc::new(wrapper);
let prev_handler = self.output.handler.take();
let new_handler = if let Some(h) = prev_handler {
if let Some(extended) = h.register_handler(type_name, wrapper.clone()) {
extended
} else {
Arc::new(
crate::type_handler::TypeHandler::new()
.with_fallback(h)
.add_handler(type_name, wrapper),
)
}
} else {
Arc::new(crate::type_handler::TypeHandler::new().add_handler(type_name, wrapper))
};
self.output.handler = Some(new_handler);
self
}
pub fn add_handlers<T, H, Args>(mut self, handlers: HashMap<&str, H>) -> Self
where
T: DeserializeOwned + Send + Sync + 'static,
H: crate::type_handler::IntoTypedHandler<T, Args>,
Args: Send + Sync + 'static,
{
for (type_name, handler) in handlers {
self = self.add_handler(type_name, handler);
}
self
}
}
type SequencerItem = (
Vec<MessageDisposition>,
BatchCommitFunc,
tokio::sync::oneshot::Sender<anyhow::Result<()>>,
);
fn spawn_sequencer(buffer_size: usize) -> (Sender<(u64, SequencerItem)>, JoinHandle<()>) {
let (seq_tx, seq_rx) = bounded::<(u64, SequencerItem)>(buffer_size);
let sequencer_handle = tokio::spawn(async move {
let mut buffer: BTreeMap<u64, SequencerItem> = BTreeMap::new();
let mut next_seq = 0u64;
loop {
if let Some((dispositions, commit_func, notify)) = buffer.remove(&next_seq) {
let result = commit_func(dispositions).await;
let _ = notify.send(result);
next_seq += 1;
tokio::task::yield_now().await;
continue;
}
match seq_rx.recv().await {
Ok((seq, item)) => {
if seq < next_seq {
let (_, _, notify) = item;
trace!(
seq,
next_seq,
"Sequencer received late item (seq < next_seq)"
);
let _ = notify.send(Err(anyhow::anyhow!(
"Sequencer received late item (seq {} < next_seq {})",
seq,
next_seq
)));
} else {
buffer.insert(seq, item);
}
}
Err(_) => {
for (_, (_, _, notify)) in buffer {
let _ = notify.send(Err(anyhow::anyhow!("Sequencer is shutting down")));
}
break;
}
}
}
});
(seq_tx, sequencer_handle)
}
fn wrap_commit(
commit: BatchCommitFunc,
seq: u64,
seq_tx: Sender<(u64, SequencerItem)>,
) -> BatchCommitFunc {
Box::new(move |dispositions| {
Box::pin(async move {
let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
if seq_tx
.send((seq, (dispositions, commit, notify_tx)))
.await
.is_ok()
{
match notify_rx.await {
Ok(res) => res,
Err(_) => Err(anyhow::anyhow!(
"Sequencer dropped the commit channel unexpectedly"
)),
}
} else {
Err(anyhow::anyhow!(
"Failed to send commit to sequencer, route is likely shutting down"
))
}
})
})
}
fn map_responses_to_dispositions(
message_ids: &[u128],
responses: Option<Vec<crate::CanonicalMessage>>,
failed: &[(crate::CanonicalMessage, PublisherError)],
request_ids: &std::collections::HashSet<u128>,
) -> Vec<MessageDisposition> {
let mut dispositions = Vec::with_capacity(message_ids.len());
let failed_ids: std::collections::HashSet<u128> =
failed.iter().map(|(m, _)| m.message_id).collect();
let mut response_map: std::collections::HashMap<u128, crate::CanonicalMessage> = responses
.unwrap_or_default()
.into_iter()
.map(|r| (r.message_id, r))
.collect();
for id in message_ids {
if failed_ids.contains(id) {
dispositions.push(MessageDisposition::Nack);
} else if let Some(resp) = response_map.remove(id) {
dispositions.push(MessageDisposition::Reply(resp));
} else if request_ids.contains(id) {
error!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Nacking to avoid committing a lost response.", id);
dispositions.push(MessageDisposition::Nack);
} else {
dispositions.push(MessageDisposition::Ack);
}
}
dispositions
}
#[cfg(test)]
fn test_map_responses_to_dispositions_logic() {
use crate::{traits::PublisherError, CanonicalMessage};
use anyhow::anyhow;
let ids = vec![1, 2, 3, 4];
let mut resp1 = CanonicalMessage::from("resp1");
resp1.message_id = 1;
let mut resp4 = CanonicalMessage::from("resp4");
resp4.message_id = 4;
let responses = Some(vec![
resp1, resp4, ]);
let mut msg2 = CanonicalMessage::from("msg2");
msg2.message_id = 2;
let failed = vec![(msg2, PublisherError::NonRetryable(anyhow!("failed")))];
let mut request_ids = std::collections::HashSet::new();
request_ids.insert(3); let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &request_ids);
assert_eq!(dispositions.len(), 4);
assert!(matches!(dispositions[0], MessageDisposition::Reply(_))); assert!(matches!(dispositions[1], MessageDisposition::Nack)); assert!(matches!(dispositions[2], MessageDisposition::Nack)); assert!(matches!(dispositions[3], MessageDisposition::Reply(_))); }
pub fn get_route(name: &str) -> Option<Route> {
Route::get(name)
}
pub fn list_routes() -> Vec<String> {
Route::list()
}
pub async fn stop_route(name: &str) -> bool {
Route::stop(name).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{Endpoint, EndpointType, FaultMode, Middleware, RandomPanicMiddleware};
use crate::traits::{
CustomMiddlewareFactory, MessageConsumer, MessagePublisher, ReceivedBatch,
};
use crate::CanonicalMessage;
use std::any::Any;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Default)]
struct CommitObservation {
completed: Mutex<Vec<u64>>,
active: std::sync::atomic::AtomicUsize,
max_active: std::sync::atomic::AtomicUsize,
}
#[derive(Debug)]
struct CommitTrackingMiddlewareFactory {
observation: Arc<CommitObservation>,
}
#[derive(Debug)]
struct ReorderingPublisherMiddlewareFactory;
struct CommitTrackingConsumer {
inner: Box<dyn MessageConsumer>,
observation: Arc<CommitObservation>,
}
struct ReorderingPublisher {
inner: Box<dyn MessagePublisher>,
}
#[async_trait::async_trait]
impl CustomMiddlewareFactory for CommitTrackingMiddlewareFactory {
async fn apply_consumer(
&self,
consumer: Box<dyn MessageConsumer>,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessageConsumer>> {
Ok(Box::new(CommitTrackingConsumer {
inner: consumer,
observation: Arc::clone(&self.observation),
}))
}
}
#[async_trait::async_trait]
impl CustomMiddlewareFactory for ReorderingPublisherMiddlewareFactory {
async fn apply_publisher(
&self,
publisher: Box<dyn MessagePublisher>,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessagePublisher>> {
Ok(Box::new(ReorderingPublisher { inner: publisher }))
}
}
#[async_trait::async_trait]
impl MessageConsumer for CommitTrackingConsumer {
async fn receive_batch(
&mut self,
max_messages: usize,
) -> Result<ReceivedBatch, ConsumerError> {
let mut batch = self.inner.receive_batch(max_messages).await?;
let seq = batch
.messages
.first()
.and_then(|message| message.get_payload_str().parse::<u64>().ok())
.expect("tracking test expects numeric payloads");
let original_commit = batch.commit;
let observation = Arc::clone(&self.observation);
batch.commit = Box::new(move |dispositions| {
let observation = Arc::clone(&observation);
Box::pin(async move {
let active_now = observation.active.fetch_add(1, Ordering::SeqCst) + 1;
let _ = observation.max_active.fetch_update(
Ordering::SeqCst,
Ordering::SeqCst,
|current| (active_now > current).then_some(active_now),
);
tokio::time::sleep(Duration::from_millis(20)).await;
let result = original_commit(dispositions).await;
observation.completed.lock().unwrap().push(seq);
observation.active.fetch_sub(1, Ordering::SeqCst);
result
})
});
Ok(batch)
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[async_trait::async_trait]
impl MessagePublisher for ReorderingPublisher {
async fn send_batch(
&self,
messages: Vec<crate::CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
let seq = messages
.first()
.and_then(|message| message.get_payload_str().parse::<u64>().ok())
.expect("tracking test expects numeric payloads");
let delay_ms = 10 * (6u64.saturating_sub(seq.min(6)));
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
self.inner.send_batch(messages).await
}
async fn send(&self, msg: crate::CanonicalMessage) -> Result<Sent, PublisherError> {
self.inner.send(msg).await
}
async fn flush(&self) -> anyhow::Result<()> {
self.inner.flush().await
}
fn as_any(&self) -> &dyn Any {
self
}
}
async fn assert_route_commits_are_ordered_and_non_overlapping(concurrency: usize) {
let unique_id = fast_uuid_v7::gen_id().to_string();
let tracking_name = format!("track_commit_{}", unique_id);
let reorder_name = format!("reorder_publish_{}", unique_id);
let in_topic = format!("ordered_commit_in_{}", unique_id);
let observation = Arc::new(CommitObservation::default());
register_middleware_factory(
&tracking_name,
Arc::new(CommitTrackingMiddlewareFactory {
observation: Arc::clone(&observation),
}),
);
register_middleware_factory(
&reorder_name,
Arc::new(ReorderingPublisherMiddlewareFactory),
);
let input = Endpoint::new_memory(&in_topic, 32).add_middleware(Middleware::Custom {
name: tracking_name,
config: serde_json::Value::Null,
});
let output = Endpoint::new(EndpointType::Null).add_middleware(Middleware::Custom {
name: reorder_name,
config: serde_json::Value::Null,
});
let route = Route::new(input.clone(), output)
.with_concurrency(concurrency)
.with_batch_size(1)
.with_commit_concurrency_limit(1);
let input_channel = input.channel().unwrap();
let messages = (0..6)
.map(|seq| crate::CanonicalMessage::from(seq.to_string()))
.collect();
input_channel.fill_messages(messages).await.unwrap();
input_channel.close();
tokio::time::timeout(
std::time::Duration::from_secs(5),
route.run_until_err("ordered_commit_regression", None, None),
)
.await
.expect("Route should not hang while draining finite input")
.expect("Route should complete without commit errors");
assert_eq!(
*observation.completed.lock().unwrap(),
vec![0, 1, 2, 3, 4, 5],
"Commit execution must follow receive order",
);
assert_eq!(
observation.max_active.load(Ordering::SeqCst),
1,
"Broker-facing commit functions must never overlap",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sequential_route_commits_are_ordered_and_non_overlapping() {
assert_route_commits_are_ordered_and_non_overlapping(1).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_concurrent_route_commits_are_ordered_and_non_overlapping() {
assert_route_commits_are_ordered_and_non_overlapping(4).await;
}
async fn run_consumer_fault_test(
mode: FaultMode,
expected_payload: &str,
route_should_restart: bool,
concurrency: usize,
) {
let unique_suffix = fast_uuid_v7::gen_id().to_string();
let in_topic = format!("fault_in_{}_{}_{}", mode, concurrency, unique_suffix);
let out_topic = format!("fault_out_{}_{}_{}", mode, concurrency, unique_suffix);
let fault_config = RandomPanicMiddleware {
mode,
trigger_on_message: Some(1), enabled: true,
..Default::default()
};
let input = Endpoint::new_memory(&in_topic, 10)
.add_middleware(Middleware::RandomPanic(fault_config));
let output = Endpoint::new_memory(&out_topic, 10);
let route_name = format!("fault_test_{}_{}", mode, concurrency);
let route = Route::new(input.clone(), output.clone()).with_concurrency(concurrency);
route
.deploy(&route_name)
.await
.expect("Failed to deploy route");
let input_ch = input.channel().unwrap();
input_ch
.send_message("persistent_msg".into())
.await
.unwrap();
if route_should_restart {
tokio::time::sleep(std::time::Duration::from_secs(6)).await;
} else {
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
let mut verifier = route.connect_to_output("verifier").await.unwrap();
let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
.await
.expect("Timed out waiting for message after fault")
.expect("Stream closed while waiting for message");
assert_eq!(received.message.get_payload_str(), expected_payload);
(received.commit)(MessageDisposition::Ack).await.unwrap();
Route::stop(&route_name).await;
}
async fn run_publisher_fault_test(
mode: FaultMode,
expected_payload: &str,
route_should_restart: bool,
) {
let unique_suffix = fast_uuid_v7::gen_id().to_string();
let in_topic = format!("pub_fault_in_{}_{}", mode, unique_suffix);
let out_topic = format!("pub_fault_out_{}_{}", mode, unique_suffix);
let fault_config = RandomPanicMiddleware {
mode,
trigger_on_message: Some(1), enabled: true,
..Default::default()
};
let mut input = Endpoint::new_memory(&in_topic, 10);
if let EndpointType::Memory(ref mut cfg) = input.endpoint_type {
cfg.enable_nack = true;
}
let output = Endpoint::new_memory(&out_topic, 10)
.add_middleware(Middleware::RandomPanic(fault_config));
let route_name = format!("pub_fault_test_{}", mode);
let route = Route::new(input.clone(), output.clone());
route
.deploy(&route_name)
.await
.expect("Failed to deploy route");
let input_ch = input.channel().unwrap();
input_ch
.send_message(expected_payload.into())
.await
.unwrap();
if route_should_restart {
tokio::time::sleep(std::time::Duration::from_secs(6)).await;
} else {
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
let mut verifier = route.connect_to_output("verifier").await.unwrap();
let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
.await
.expect("Timed out waiting for message after publisher fault")
.expect("Stream closed");
assert_eq!(received.message.get_payload_str(), expected_payload);
(received.commit)(MessageDisposition::Ack).await.unwrap();
Route::stop(&route_name).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "Takes too much time for regular tests"]
async fn test_route_recovery_from_faults() {
let original_payload = "persistent_msg";
run_consumer_fault_test(FaultMode::Panic, original_payload, true, 2).await;
run_consumer_fault_test(FaultMode::Disconnect, original_payload, true, 2).await;
run_consumer_fault_test(FaultMode::Timeout, original_payload, true, 2).await;
run_consumer_fault_test(FaultMode::Nack, original_payload, true, 2).await;
run_consumer_fault_test(FaultMode::JsonFormatError, "{invalid json}", false, 2).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "Takes too much time for regular tests"]
async fn test_route_recovery_from_faults_sequential() {
let original_payload = "persistent_msg";
run_consumer_fault_test(FaultMode::Panic, original_payload, true, 1).await;
run_consumer_fault_test(FaultMode::Disconnect, original_payload, true, 1).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore = "Takes too much time for regular tests"]
async fn test_publisher_recovery_from_faults() {
let original_payload = "persistent_msg";
run_publisher_fault_test(FaultMode::Disconnect, original_payload, true).await;
run_publisher_fault_test(FaultMode::Timeout, original_payload, true).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_route_sequencer_deadlock_fix() {
let unique_id = fast_uuid_v7::gen_id().to_string();
let factory_name = format!("fail_factory_{}", unique_id);
let in_topic = format!("deadlock_in_{}", unique_id);
let out_topic = format!("deadlock_out_{}", unique_id);
#[derive(Debug)]
struct FailingMiddlewareFactory {
fail_flag: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl CustomMiddlewareFactory for FailingMiddlewareFactory {
async fn apply_publisher(
&self,
publisher: Box<dyn MessagePublisher>,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessagePublisher>> {
Ok(Box::new(FailingPublisher {
inner: publisher,
fail_flag: self.fail_flag.clone(),
}))
}
async fn apply_consumer(
&self,
consumer: Box<dyn MessageConsumer>,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessageConsumer>> {
Ok(consumer)
}
}
struct FailingPublisher {
inner: Box<dyn MessagePublisher>,
fail_flag: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl MessagePublisher for FailingPublisher {
async fn send_batch(
&self,
messages: Vec<crate::CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
if self
.fail_flag
.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
return Err(PublisherError::Retryable(anyhow::anyhow!(
"Simulated failure"
)));
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
self.inner.send_batch(messages).await
}
async fn send(
&self,
msg: crate::CanonicalMessage,
) -> Result<crate::traits::Sent, PublisherError> {
self.inner.send(msg).await
}
async fn flush(&self) -> anyhow::Result<()> {
self.inner.flush().await
}
fn as_any(&self) -> &dyn Any {
self
}
}
let fail_flag = Arc::new(AtomicBool::new(true));
register_middleware_factory(
&factory_name,
Arc::new(FailingMiddlewareFactory {
fail_flag: fail_flag.clone(),
}),
);
let input = Endpoint::new_memory(&in_topic, 100);
let output = Endpoint::new_memory(&out_topic, 100).add_middleware(Middleware::Custom {
name: factory_name,
config: serde_json::Value::Null,
});
let route = Route::new(input.clone(), output.clone())
.with_concurrency(2)
.with_batch_size(1);
let input_ch = input.channel().unwrap();
input_ch.send_message("msg1".into()).await.unwrap();
input_ch.send_message("msg2".into()).await.unwrap();
input_ch.send_message("msg3".into()).await.unwrap();
let run_fut = async {
let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1);
route
.run_until_err("deadlock_test", Some(shutdown_rx), None)
.await
};
let result = tokio::time::timeout(std::time::Duration::from_secs(5), run_fut).await;
match result {
Ok(res) => {
assert!(
res.is_err(),
"Route should have failed with simulated error"
);
}
Err(_) => {
panic!("Route deadlocked! The sequencer likely didn't receive the Nack for the failed batch.");
}
}
}
#[tokio::test]
async fn test_sequencer_ordered_commits() {
use std::time::Duration;
use tokio::time::timeout;
let (seq_tx, sequencer_handle) = spawn_sequencer(16);
let processed: Arc<Mutex<Vec<u64>>> = Arc::new(Mutex::new(Vec::new()));
let seqs = [2u64, 0u64, 1u64, 3u64];
let mut receivers = Vec::new();
for seq in seqs.iter().cloned() {
let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
let processed_clone = processed.clone();
let commit: BatchCommitFunc = Box::new(move |_dispositions| {
let processed = processed_clone.clone();
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(10 * seq)).await;
processed.lock().unwrap().push(seq);
Ok(())
})
});
seq_tx
.send((seq, (Vec::new(), commit, notify_tx)))
.await
.unwrap();
receivers.push(notify_rx);
}
for rx in receivers {
let res = timeout(Duration::from_secs(2), rx)
.await
.expect("Sequencer notify timed out");
assert!(res.is_ok(), "Sequencer reported an error on commit");
assert!(res.unwrap().is_ok(), "Commit returned an error");
}
drop(seq_tx);
let _ = sequencer_handle.await;
let result = processed.lock().unwrap().clone();
assert_eq!(
result,
vec![0u64, 1u64, 2u64, 3u64],
"Sequencer must process commits in order"
);
}
#[tokio::test]
async fn test_sequencer_shutdown_notifies_pending() {
use std::time::Duration;
use tokio::time::timeout;
let (seq_tx, sequencer_handle) = spawn_sequencer(8);
let (notify_tx1, notify_rx1) = tokio::sync::oneshot::channel();
let (notify_tx2, notify_rx2) = tokio::sync::oneshot::channel();
let commit1: BatchCommitFunc = Box::new(|_dispositions| {
Box::pin(async move {
panic!("Commit should not be executed during shutdown drain");
#[allow(unreachable_code)]
Ok(())
})
});
let commit2: BatchCommitFunc = Box::new(|_dispositions| {
Box::pin(async move {
panic!("Commit should not be executed during shutdown drain");
#[allow(unreachable_code)]
Ok(())
})
});
seq_tx
.send((1u64, (Vec::new(), commit1, notify_tx1)))
.await
.unwrap();
seq_tx
.send((2u64, (Vec::new(), commit2, notify_tx2)))
.await
.unwrap();
drop(seq_tx);
let r1 = timeout(Duration::from_secs(1), notify_rx1)
.await
.expect("Timeout waiting for notify_rx1")
.expect("Sequencer closed notify channel");
assert!(
r1.is_err(),
"Pending commit should receive Err on sequencer shutdown"
);
let r2 = timeout(Duration::from_secs(1), notify_rx2)
.await
.expect("Timeout waiting for notify_rx2")
.expect("Sequencer closed notify channel");
assert!(
r2.is_err(),
"Pending commit should receive Err on sequencer shutdown"
);
let _ = sequencer_handle.await;
}
use crate::traits::{BoxFuture, CustomEndpointFactory, Sent};
use std::sync::Mutex;
type ConsumerBehavior =
Arc<Mutex<dyn FnMut() -> Result<Box<dyn MessageConsumer>, anyhow::Error> + Send + Sync>>;
type PublisherBehavior =
Arc<Mutex<dyn FnMut() -> Result<Box<dyn MessagePublisher>, anyhow::Error> + Send + Sync>>;
struct MockEndpointFactory {
create_consumer_fail: bool,
consumer_behavior: ConsumerBehavior,
publisher_behavior: PublisherBehavior,
}
impl std::fmt::Debug for MockEndpointFactory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockEndpointFactory")
.field("create_consumer_fail", &self.create_consumer_fail)
.finish()
}
}
impl MockEndpointFactory {
fn new() -> Self {
Self {
create_consumer_fail: false,
consumer_behavior: Arc::new(Mutex::new(|| Err(anyhow::anyhow!("Not implemented")))),
publisher_behavior: Arc::new(Mutex::new(|| {
Ok(Box::new(NoOpPublisher) as Box<dyn MessagePublisher>)
})),
}
}
}
#[derive(Clone)]
struct NoOpPublisher;
#[async_trait::async_trait]
impl MessagePublisher for NoOpPublisher {
async fn send_batch(
&self,
_: Vec<crate::CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
Ok(SentBatch::Ack)
}
async fn send(&self, _: crate::CanonicalMessage) -> Result<Sent, PublisherError> {
Ok(Sent::Ack)
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[async_trait::async_trait]
impl CustomEndpointFactory for MockEndpointFactory {
async fn create_consumer(
&self,
_: &str,
_: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessageConsumer>> {
if self.create_consumer_fail {
return Err(anyhow::anyhow!("Endpoint unavailable"));
}
(self.consumer_behavior.lock().unwrap())()
}
async fn create_publisher(
&self,
_: &str,
_: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessagePublisher>> {
(self.publisher_behavior.lock().unwrap())()
}
}
#[derive(Clone, Default)]
struct HookState {
consumer_connects: Arc<AtomicUsize>,
consumer_disconnects: Arc<AtomicUsize>,
publisher_connects: Arc<AtomicUsize>,
publisher_disconnects: Arc<AtomicUsize>,
shared_mutations: Arc<AtomicUsize>,
fail_consumer_connect: Arc<AtomicBool>,
fail_consumer_disconnect: Arc<AtomicBool>,
fail_publisher_disconnect: Arc<AtomicBool>,
}
struct HookConsumer {
state: HookState,
}
struct HookPublisher {
state: HookState,
}
#[async_trait::async_trait]
impl MessageConsumer for HookConsumer {
fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
Some(Box::pin(async move {
self.state.consumer_connects.fetch_add(1, Ordering::SeqCst);
self.state.shared_mutations.fetch_add(1, Ordering::SeqCst);
if self.state.fail_consumer_connect.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!("consumer hook failed"));
}
Ok(())
}))
}
fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
Some(Box::pin(async move {
self.state
.consumer_disconnects
.fetch_add(1, Ordering::SeqCst);
if self.state.fail_consumer_disconnect.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!("consumer disconnect hook failed"));
}
Ok(())
}))
}
async fn receive_batch(&mut self, _max: usize) -> Result<ReceivedBatch, ConsumerError> {
Err(ConsumerError::EndOfStream)
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[async_trait::async_trait]
impl MessagePublisher for HookPublisher {
fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
Some(Box::pin(async move {
self.state.publisher_connects.fetch_add(1, Ordering::SeqCst);
self.state.shared_mutations.fetch_add(1, Ordering::SeqCst);
Ok(())
}))
}
fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
Some(Box::pin(async move {
self.state
.publisher_disconnects
.fetch_add(1, Ordering::SeqCst);
if self.state.fail_publisher_disconnect.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!("publisher disconnect hook failed"));
}
Ok(())
}))
}
async fn send_batch(
&self,
_: Vec<crate::CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
Ok(SentBatch::Ack)
}
fn as_any(&self) -> &dyn Any {
self
}
}
fn hook_route(state: HookState, concurrency: usize) -> Route {
let unique_id = fast_uuid_v7::gen_id().to_string();
let factory_name = format!("hooks_{}", unique_id);
let mut factory = MockEndpointFactory::new();
let consumer_state = state.clone();
factory.consumer_behavior = Arc::new(Mutex::new(move || {
Ok(Box::new(HookConsumer {
state: consumer_state.clone(),
}) as Box<dyn MessageConsumer>)
}));
let publisher_state = state;
factory.publisher_behavior = Arc::new(Mutex::new(move || {
Ok(Box::new(HookPublisher {
state: publisher_state.clone(),
}) as Box<dyn MessagePublisher>)
}));
register_endpoint_factory(&factory_name, Arc::new(factory));
let input = Endpoint {
endpoint_type: EndpointType::Custom {
name: factory_name.clone(),
config: serde_json::Value::Null,
},
middlewares: vec![],
handler: None,
};
let output = Endpoint {
endpoint_type: EndpointType::Custom {
name: factory_name,
config: serde_json::Value::Null,
},
middlewares: vec![],
handler: None,
};
Route::new(input, output).with_concurrency(concurrency)
}
#[tokio::test]
async fn test_lifecycle_hooks_called_once_sequentially() {
let state = HookState::default();
let route = hook_route(state.clone(), 1);
let stopped_by_shutdown = route
.run_until_err("test_lifecycle_sequential", None, None)
.await
.unwrap();
assert!(!stopped_by_shutdown);
assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
assert_eq!(state.shared_mutations.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_lifecycle_hooks_called_once_concurrently() {
let state = HookState::default();
let route = hook_route(state.clone(), 4);
route
.run_until_err("test_lifecycle_concurrent", None, None)
.await
.unwrap();
assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_lifecycle_on_connect_failure_stops_route() {
let state = HookState::default();
state.fail_consumer_connect.store(true, Ordering::SeqCst);
let route = hook_route(state.clone(), 1);
let err = route
.run_until_err("test_lifecycle_connect_failure", None, None)
.await
.unwrap_err();
assert!(err.to_string().contains("on_connect hook failed"));
assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_lifecycle_on_disconnect_failure_does_not_stop_route() {
let state = HookState::default();
state.fail_consumer_disconnect.store(true, Ordering::SeqCst);
state
.fail_publisher_disconnect
.store(true, Ordering::SeqCst);
let route = hook_route(state.clone(), 1);
let stopped_by_shutdown = route
.run_until_err("test_lifecycle_disconnect_failure", None, None)
.await
.unwrap();
assert!(!stopped_by_shutdown);
assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_start_fails_on_unavailable_endpoint() {
let unique_id = fast_uuid_v7::gen_id().to_string();
let factory_name = format!("unavailable_{}", unique_id);
let factory = Arc::new(MockEndpointFactory {
create_consumer_fail: true,
..MockEndpointFactory::new()
});
register_endpoint_factory(&factory_name, factory);
let input = Endpoint {
endpoint_type: EndpointType::Custom {
name: factory_name,
config: serde_json::Value::Null,
},
middlewares: vec![],
handler: None,
};
let output = Endpoint::new_memory("out", 10);
let route = Route::new(input, output);
let result = route.run("test_start_fail").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("failed to start"));
}
#[tokio::test]
async fn test_reconnect_on_consumer_error() {
let unique_id = fast_uuid_v7::gen_id().to_string();
let factory_name = format!("reconnect_{}", unique_id);
let connection_attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = connection_attempts.clone();
let consumer_logic = move || -> Result<Box<dyn MessageConsumer>, anyhow::Error> {
let attempt = attempts_clone.fetch_add(1, Ordering::SeqCst);
struct FlakyConsumer {
attempt: usize,
}
#[async_trait::async_trait]
impl MessageConsumer for FlakyConsumer {
async fn receive_batch(
&mut self,
_max: usize,
) -> Result<ReceivedBatch, ConsumerError> {
if self.attempt == 0 {
self.attempt = 999; Ok(ReceivedBatch {
messages: vec![crate::CanonicalMessage::from("msg1")],
commit: Box::new(|_| Box::pin(async { Ok(()) })),
})
} else if self.attempt == 999 {
Err(ConsumerError::Connection(anyhow::anyhow!(
"Connection dropped"
)))
} else {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(ReceivedBatch {
messages: vec![crate::CanonicalMessage::from("msg2")],
commit: Box::new(|_| Box::pin(async { Ok(()) })),
})
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
Ok(Box::new(FlakyConsumer { attempt }))
};
let mut factory = MockEndpointFactory::new();
factory.consumer_behavior = Arc::new(Mutex::new(consumer_logic));
register_endpoint_factory(&factory_name, Arc::new(factory));
let input = Endpoint {
endpoint_type: EndpointType::Custom {
name: factory_name,
config: serde_json::Value::Null,
},
middlewares: vec![],
handler: None,
};
let output = Endpoint::new_memory(&format!("out_{}", unique_id), 10);
let route = Route::new(input, output.clone());
route.deploy("test_reconnect").await.unwrap();
let mut verifier = create_consumer_from_route("verifier", &output)
.await
.unwrap();
let msg1 = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
.await
.expect("Timed out waiting for msg1")
.unwrap();
assert_eq!(msg1.message.get_payload_str(), "msg1");
let msg2 = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
.await
.expect("Timed out waiting for msg2")
.unwrap();
assert_eq!(msg2.message.get_payload_str(), "msg2");
assert!(connection_attempts.load(Ordering::SeqCst) >= 2);
Route::stop("test_reconnect").await;
}
#[tokio::test]
async fn test_non_retryable_handler_error_does_not_crash_route() {
let unique_id = fast_uuid_v7::gen_id().to_string();
let in_topic = format!("bad_input_in_{}", unique_id);
let out_topic = format!("bad_input_out_{}", unique_id);
let input = Endpoint::new_memory(&in_topic, 10);
let output = Endpoint::new_memory(&out_topic, 10);
let handler = |msg: crate::CanonicalMessage| async move {
if msg.get_payload_str() == "poison" {
Err(HandlerError::NonRetryable(anyhow::anyhow!("Invalid input")))
} else {
Ok(crate::Handled::Publish(msg))
}
};
let route = Route::new(input.clone(), output).with_handler(handler);
route.deploy("test_invalid_input").await.unwrap();
let input_ch = input.channel().unwrap();
let out_channel = route.output.channel().unwrap();
input_ch.send_message("poison".into()).await.unwrap();
input_ch.send_message("valid".into()).await.unwrap();
let received = tokio::time::timeout(std::time::Duration::from_secs(5), async {
loop {
if let Some(msg) = out_channel.drain_messages().pop() {
return msg;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
})
.await
.expect("Timed out waiting for valid message to be processed");
assert_eq!(received.get_payload_str(), "valid");
Route::stop("test_invalid_input").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_dlq_and_retry_batch_integration() {
use crate::models::{DeadLetterQueueMiddleware, Middleware, RetryMiddleware};
use crate::traits::{MessagePublisher, PublisherError, SentBatch};
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Clone)]
struct PartialFailPublisher {
attempts: Arc<Mutex<HashMap<u128, usize>>>,
}
#[async_trait::async_trait]
impl MessagePublisher for PartialFailPublisher {
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
let mut failed = Vec::new();
let mut attempts = self.attempts.lock().unwrap();
for msg in messages {
let msg_num: u32 = serde_json::from_slice::<serde_json::Value>(&msg.payload)
.unwrap()["id"]
.as_u64()
.unwrap() as u32;
let attempt_count = attempts.entry(msg.message_id).or_insert(0);
*attempt_count += 1;
if msg_num % 2 == 0 {
failed.push((
msg,
PublisherError::Retryable(anyhow::anyhow!("simulated failure")),
));
}
}
if failed.is_empty() {
Ok(SentBatch::Ack)
} else {
Ok(SentBatch::Partial {
responses: None,
failed,
})
}
}
async fn send(
&self,
_msg: CanonicalMessage,
) -> Result<crate::traits::Sent, PublisherError> {
unimplemented!()
}
fn as_any(&self) -> &dyn Any {
self
}
}
let in_topic = "batch_retry_dlq_in";
let out_topic = "batch_retry_dlq_out";
let dlq_topic = "batch_retry_dlq_dlq";
let input = Endpoint::new_memory(in_topic, 10);
let dlq_endpoint = Endpoint::new_memory(dlq_topic, 10);
let mock_publisher = PartialFailPublisher {
attempts: Arc::new(Mutex::new(HashMap::new())),
};
let mut output_with_middlewares = Endpoint::new_memory(out_topic, 10);
output_with_middlewares.middlewares = vec![
Middleware::Retry(RetryMiddleware {
max_attempts: 2,
initial_interval_ms: 1,
..Default::default()
}),
Middleware::Dlq(Box::new(DeadLetterQueueMiddleware {
endpoint: dlq_endpoint.clone(),
})),
];
let route = Route::new(input.clone(), output_with_middlewares).with_batch_size(4);
let final_publisher = crate::middleware::apply_middlewares_to_publisher(
Box::new(mock_publisher.clone()),
&route.output,
"test_route",
)
.await
.unwrap();
let (work_tx, work_rx) =
async_channel::bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(1);
let (seq_tx, _sequencer_handle) = spawn_sequencer(1);
tokio::spawn(async move {
if let Ok((messages, commit)) = work_rx.recv().await {
let batch_len = messages.len();
match final_publisher.send_batch(messages).await {
Ok(SentBatch::Ack) => {
let _ = commit(vec![MessageDisposition::Ack; batch_len]).await;
}
Ok(SentBatch::Partial { failed, .. }) => {
let dispositions = if failed.is_empty() {
vec![MessageDisposition::Ack; batch_len]
} else {
vec![MessageDisposition::Nack; batch_len]
};
let _ = commit(dispositions).await;
}
Err(_) => {
let _ = commit(vec![MessageDisposition::Nack; batch_len]).await;
}
}
}
});
let mut messages = Vec::new();
for i in 1..=4 {
messages.push(CanonicalMessage::from_json(serde_json::json!({"id": i})).unwrap());
}
let commit = wrap_commit(Box::new(|_| Box::pin(async { Ok(()) })), 0, seq_tx.clone());
work_tx.send((messages, commit)).await.unwrap();
let dlq_channel = dlq_endpoint.channel().unwrap();
let start = std::time::Instant::now();
while dlq_channel.len() < 2 {
if start.elapsed() > std::time::Duration::from_secs(5) {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
let dlq_msgs = dlq_channel.drain_messages();
assert_eq!(dlq_msgs.len(), 2, "Expected 2 messages to go to DLQ");
let dlq_ids: std::collections::HashSet<u32> = dlq_msgs
.iter()
.map(|m| {
serde_json::from_slice::<serde_json::Value>(&m.payload).unwrap()["id"]
.as_u64()
.unwrap() as u32
})
.collect();
assert!(dlq_ids.contains(&2));
assert!(dlq_ids.contains(&4));
let attempts = mock_publisher.attempts.lock().unwrap();
assert_eq!(attempts.values().filter(|&&c| c == 2).count(), 2);
assert_eq!(attempts.values().filter(|&&c| c == 1).count(), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_route_dlq_integration() {
let unique_id = fast_uuid_v7::gen_id().to_string();
let in_topic = format!("dlq_in_{}", unique_id);
let out_topic = format!("dlq_out_{}", unique_id);
let dlq_topic = format!("dlq_target_{}", unique_id);
let input = Endpoint::new_memory(&in_topic, 10);
let dlq_endpoint = Endpoint::new_memory(&dlq_topic, 10);
let mut output = Endpoint::new_memory(&out_topic, 10);
output.middlewares = vec![
Middleware::RandomPanic(RandomPanicMiddleware {
mode: FaultMode::Timeout, trigger_on_message: None, enabled: true,
..Default::default()
}),
Middleware::Retry(crate::models::RetryMiddleware {
max_attempts: 2,
initial_interval_ms: 10,
max_interval_ms: 100,
multiplier: 1.0,
}),
Middleware::Dlq(Box::new(crate::models::DeadLetterQueueMiddleware {
endpoint: dlq_endpoint.clone(),
})),
];
let route = Route::new(input.clone(), output);
route.deploy("test_dlq_integration").await.unwrap();
let input_ch = input.channel().unwrap();
input_ch.send_message("fail_msg".into()).await.unwrap();
let dlq_ch = dlq_endpoint.channel().unwrap();
let received = tokio::time::timeout(std::time::Duration::from_secs(5), async {
loop {
let batch = dlq_ch.drain_messages();
if !batch.is_empty() {
return batch[0].clone();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
})
.await
.expect("Timed out waiting for DLQ");
assert_eq!(received.get_payload_str(), "fail_msg");
let out_ch_target = mq_bridge::endpoints::memory::get_or_create_channel(
&mq_bridge::models::MemoryConfig::new(&out_topic, None),
);
assert!(out_ch_target.is_empty(), "Message should not reach target");
Route::stop("test_dlq_integration").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_large_message_handling() {
let unique_id = fast_uuid_v7::gen_id().to_string();
let in_topic = format!("large_in_{}", unique_id);
let out_topic = format!("large_out_{}", unique_id);
let input = Endpoint::new_memory(&in_topic, 5); let output = Endpoint::new_memory(&out_topic, 5);
let route = Route::new(input.clone(), output.clone());
route.deploy("test_large_msg").await.unwrap();
let large_payload = vec![b'x'; 5 * 1024 * 1024]; let input_ch = input.channel().unwrap();
input_ch
.send_message(large_payload.clone().into())
.await
.unwrap();
let mut verifier = route.connect_to_output("verifier").await.unwrap();
let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
.await
.expect("Timed out receiving large message")
.unwrap();
assert_eq!(received.message.payload.len(), large_payload.len());
assert_eq!(received.message.payload, large_payload.as_slice());
Route::stop("test_large_msg").await;
}
#[test]
fn test_map_responses_to_dispositions_unit() {
test_map_responses_to_dispositions_logic();
}
}