1use crate::endpoints::{create_consumer_from_route, create_publisher_from_route};
7use crate::errors::ProcessingError;
8pub use crate::models::Route;
9use crate::models::{Endpoint, EndpointType, RouteOptions};
10use crate::traits::{
11 BatchCommitFunc, ConsumerError, Handler, HandlerError, MessageConsumer, MessageDisposition,
12 MessagePublisher, PublisherError, SentBatch,
13};
14use async_channel::{bounded, Sender};
15use serde::de::DeserializeOwned;
16use std::collections::{BTreeMap, HashMap};
17use std::sync::{Arc, OnceLock, RwLock};
18use tokio::{
19 select,
20 task::{JoinHandle, JoinSet},
21};
22use tracing::{debug, error, info, trace, warn};
23
24pub use crate::extensions::{
26 get_endpoint_factory, get_middleware_factory, register_endpoint_factory,
27 register_middleware_factory,
28};
29
30#[derive(Debug)]
31pub struct RouteHandle((JoinHandle<()>, Sender<()>));
32
33impl RouteHandle {
34 pub async fn stop(&self) {
35 let _ = self.0 .1.send(()).await;
36 self.0 .1.close();
37 }
38
39 pub async fn join(self) -> Result<(), tokio::task::JoinError> {
40 self.0 .0.await
41 }
42}
43
44async fn run_publisher_connect_hook(
45 route_name: &str,
46 publisher: &Arc<dyn MessagePublisher>,
47) -> anyhow::Result<()> {
48 if let Some(hook) = publisher.on_connect_hook() {
49 hook.await.map_err(|err| {
50 anyhow::anyhow!(
51 "Publisher on_connect hook failed for route '{}': {}",
52 route_name,
53 err
54 )
55 })?;
56 }
57 Ok(())
58}
59
60async fn run_consumer_connect_hook(
61 route_name: &str,
62 consumer: &dyn MessageConsumer,
63) -> anyhow::Result<()> {
64 if let Some(hook) = consumer.on_connect_hook() {
65 hook.await.map_err(|err| {
66 anyhow::anyhow!(
67 "Consumer on_connect hook failed for route '{}': {}",
68 route_name,
69 err
70 )
71 })?;
72 }
73 Ok(())
74}
75
76async fn run_publisher_disconnect_hook(route_name: &str, publisher: &Arc<dyn MessagePublisher>) {
77 if let Some(hook) = publisher.on_disconnect_hook() {
78 if let Err(err) = hook.await {
79 warn!(
80 "Publisher on_disconnect hook failed for route '{}': {}",
81 route_name, err
82 );
83 }
84 }
85}
86
87async fn run_consumer_disconnect_hook(route_name: &str, consumer: &dyn MessageConsumer) {
88 if let Some(hook) = consumer.on_disconnect_hook() {
89 if let Err(err) = hook.await {
90 warn!(
91 "Consumer on_disconnect hook failed for route '{}': {}",
92 route_name, err
93 );
94 }
95 }
96}
97
98impl From<(JoinHandle<()>, Sender<()>)> for RouteHandle {
99 fn from(tuple: (JoinHandle<()>, Sender<()>)) -> Self {
100 RouteHandle(tuple)
101 }
102}
103
104struct ActiveRoute {
105 route: Route,
106 handle: RouteHandle,
107}
108
109static ROUTE_REGISTRY: OnceLock<RwLock<HashMap<String, ActiveRoute>>> = OnceLock::new();
110static ENDPOINT_REF_REGISTRY: OnceLock<RwLock<HashMap<String, Endpoint>>> = OnceLock::new();
111
112pub fn register_endpoint(name: &str, endpoint: Endpoint) {
115 let registry = ENDPOINT_REF_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
116 let mut writer = registry
117 .write()
118 .expect("Named endpoint registry lock poisoned");
119 if writer.insert(name.to_string(), endpoint).is_some() {
120 debug!("Overwriting a registered endpoint named '{}'", name);
121 }
122}
123
124pub fn get_endpoint(name: &str) -> Option<Endpoint> {
126 let registry = ENDPOINT_REF_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
127 let reader = registry
128 .read()
129 .expect("Named endpoint registry lock poisoned");
130 reader.get(name).cloned()
131}
132
133impl Route {
134 pub fn new(input: Endpoint, output: Endpoint) -> Self {
140 Self {
141 input,
142 output,
143 ..Default::default()
144 }
145 }
146
147 pub fn get(name: &str) -> Option<Self> {
149 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
150 let map = registry.read().expect("Route registry lock poisoned");
151 map.get(name).map(|active| active.route.clone())
152 }
153
154 pub fn list() -> Vec<String> {
156 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
157 let map = registry.read().expect("Route registry lock poisoned");
158 map.keys().cloned().collect()
159 }
160
161 pub fn is_ref(&self) -> bool {
163 matches!(self.input.endpoint_type, EndpointType::Ref(_))
164 && !matches!(self.output.endpoint_type, EndpointType::Ref(_))
165 }
166
167 pub fn register_output_endpoint(&self, name: Option<&str>) -> Result<(), anyhow::Error> {
170 match name {
171 Some(name) => {
172 register_endpoint(name, self.output.clone());
173 }
174 None => {
175 if let EndpointType::Ref(name) = &self.input.endpoint_type {
176 register_endpoint(name, self.output.clone());
177 } else {
178 return Err(anyhow::anyhow!(
179 "No name and input is not a reference endpoint"
180 ));
181 }
182 }
183 };
184 Ok(())
185 }
186
187 pub async fn deploy(&self, name: &str) -> anyhow::Result<()> {
201 Self::stop(name).await;
202
203 let handle = self.run(name).await?;
204 let active = ActiveRoute {
205 route: self.clone(),
206 handle,
207 };
208
209 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
210 let mut map = registry.write().expect("Route registry lock poisoned");
211 map.insert(name.to_string(), active);
212 Ok(())
213 }
214
215 pub async fn stop(name: &str) -> bool {
220 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
221 let active_opt = {
222 let mut map = registry.write().expect("Route registry lock poisoned");
223 map.remove(name)
224 };
225
226 if let Some(active) = active_opt {
227 let handle = active.handle;
229
230 let _ = handle.0 .1.send(()).await;
232 handle.0 .1.close();
233
234 let mut join_handle = handle.0 .0;
236 tokio::select! {
237 res = &mut join_handle => {
238 let _ = res;
240 }
241 _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
242 join_handle.abort();
244 let _ = join_handle.await;
246 }
247 }
248
249 true
250 } else {
251 false
252 }
253 }
254
255 pub async fn create_publisher(&self) -> anyhow::Result<crate::Publisher> {
270 crate::Publisher::new(self.output.clone()).await
271 }
272
273 pub async fn connect_to_output(
276 &self,
277 name: &str,
278 ) -> anyhow::Result<Box<dyn crate::traits::MessageConsumer>> {
279 create_consumer_from_route(name, &self.output).await
280 }
281
282 pub fn check(
288 &self,
289 name: &str,
290 allowed_endpoints: Option<&[&str]>,
291 ) -> anyhow::Result<Vec<String>> {
292 let mut warnings = Vec::new();
293 warnings.extend(crate::endpoints::check_consumer(
294 name,
295 &self.input,
296 allowed_endpoints,
297 )?);
298 warnings.extend(crate::endpoints::check_publisher(
299 name,
300 &self.output,
301 allowed_endpoints,
302 )?);
303 Ok(warnings)
304 }
305
306 pub async fn run(&self, name_str: &str) -> anyhow::Result<RouteHandle> {
334 let warnings = self.check(name_str, None)?;
335 for warning in warnings {
336 tracing::warn!(route = name_str, "Configuration warning: {}", warning);
337 }
338 let (shutdown_tx, shutdown_rx) = bounded(1);
339 let (ready_tx, ready_rx) = bounded(1);
340 let route = Arc::new(self.clone());
342 let name = Arc::new(name_str.to_string());
343
344 let handle = tokio::spawn(async move {
345 loop {
346 let route_arc = Arc::clone(&route);
347 let name_arc = Arc::clone(&name);
348 let (internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
352 let ready_tx_clone = ready_tx.clone();
353
354 let mut run_task = tokio::spawn(async move {
356 route_arc
357 .run_until_err(&name_arc, Some(internal_shutdown_rx), Some(ready_tx_clone))
358 .await
359 });
360
361 select! {
362 _ = shutdown_rx.recv() => {
363 info!("Shutdown signal received for route '{}'.", name);
364 let _ = internal_shutdown_tx.send(()).await;
366 let _ = run_task.await;
368 break;
369 }
370 res = &mut run_task => {
371 match res {
372 Ok(Ok(should_continue)) if !should_continue => {
373 info!("Route '{}' completed gracefully. Shutting down.", name);
374 break;
375 }
376 Ok(Err(e)) => {
377 let is_permanent =
378 e.downcast_ref::<ProcessingError>().is_some_and(|pe| matches!(pe, ProcessingError::NonRetryable(_)))
379 || e.downcast_ref::<ConsumerError>().is_some_and(|ce| matches!(ce, ConsumerError::EndOfStream));
380
381 if is_permanent {
382 error!("Route '{}' failed with a permanent error: {}. Shutting down.", name, e);
383 break;
384 }
385
386 warn!("Route '{}' failed: {}. Reconnecting in 5 seconds...", name, e);
387 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
388 }
389 Err(e) => {
390 error!("Route '{}' task panicked: {}. Reconnecting in 5 seconds...", name, e);
391 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
392 }
393 _ => {} }
395 }
396 }
397 }
398 });
399
400 match tokio::time::timeout(std::time::Duration::from_secs(5), ready_rx.recv()).await {
401 Ok(Ok(_)) => Ok(RouteHandle((handle, shutdown_tx))),
402 _ => {
403 handle.abort();
404 Err(anyhow::anyhow!(
405 "Route '{}' failed to start within 5 seconds or encountered an error",
406 name_str
407 ))
408 }
409 }
410 }
411
412 pub async fn run_until_err(
414 &self,
415 name: &str,
416 shutdown_rx: Option<async_channel::Receiver<()>>,
417 ready_tx: Option<Sender<()>>,
418 ) -> anyhow::Result<bool> {
419 let (_internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
420 let shutdown_rx = shutdown_rx.unwrap_or(internal_shutdown_rx);
421 if self.options.concurrency == 1 {
422 self.run_sequentially(name, shutdown_rx, ready_tx).await
423 } else {
424 self.run_concurrently(name, shutdown_rx, ready_tx).await
425 }
426 }
427
428 async fn run_sequentially(
430 &self,
431 name: &str,
432 shutdown_rx: async_channel::Receiver<()>,
433 ready_tx: Option<Sender<()>>,
434 ) -> anyhow::Result<bool> {
435 let publisher = create_publisher_from_route(name, &self.output).await?;
436 let mut consumer = create_consumer_from_route(name, &self.input).await?;
437 if let Err(err) = run_publisher_connect_hook(name, &publisher).await {
438 run_publisher_disconnect_hook(name, &publisher).await;
439 return Err(err);
440 }
441 if let Err(err) = run_consumer_connect_hook(name, consumer.as_ref()).await {
442 run_consumer_disconnect_hook(name, consumer.as_ref()).await;
443 run_publisher_disconnect_hook(name, &publisher).await;
444 return Err(err);
445 }
446 let (err_tx, err_rx) = bounded(1);
447 let mut commit_tasks = JoinSet::new();
448
449 let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
451 let mut seq_counter = 0u64;
452
453 if let Some(tx) = ready_tx {
454 let _ = tx.send(()).await;
455 }
456 let mut message_ids = Vec::with_capacity(self.options.batch_size);
457 let has_retry_middleware = self.output.has_retry_middleware();
459 let run_result = loop {
460 select! {
461 Ok(err) = err_rx.recv() => break Err(err),
462
463 _ = shutdown_rx.recv() => {
464 info!("Shutdown signal received in sequential runner for route '{}'.", name);
465 break Ok(true); }
467 res = consumer.receive_batch(self.options.batch_size) => {
468 let received_batch = match res {
469 Ok(batch) => {
470 if batch.messages.is_empty() {
471 continue; }
473 batch
474 }
475 Err(ConsumerError::EndOfStream) => {
476 info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
477 break Ok(false); }
479 Err(ConsumerError::Connection(e)) => {
480 break Err(e);
482 },
483 Err(ConsumerError::Gap { requested, base }) => {
484 break Err(anyhow::anyhow!("Consumer gap: requested offset {requested} but earliest available is {base}"));
486 }
487 };
488 debug!("Received a batch of {} messages sequentially", received_batch.messages.len());
489
490 let seq = seq_counter;
492 seq_counter += 1;
493 let mut commit_opt = Some(wrap_commit(received_batch.commit, seq, seq_tx.clone()));
494 let batch_len = received_batch.messages.len();
495 message_ids.clear();
496 message_ids.extend(received_batch.messages.iter().map(|m| m.message_id));
497 let request_ids: std::collections::HashSet<u128> = received_batch
498 .messages
499 .iter()
500 .filter(|m| m.metadata.contains_key("reply_to"))
501 .map(|m| m.message_id)
502 .collect();
503
504 match publisher.send_batch(received_batch.messages).await {
505 Ok(SentBatch::Ack) => {
506 for id in &message_ids {
507 if request_ids.contains(id) {
508 warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop broken.", id);
509 }
510 }
511 let commit = commit_opt.take().expect("Commit already used");
512 let err_tx = err_tx.clone();
513 commit_tasks.spawn(async move {
514 if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
515 error!("Commit failed: {}", e);
516 match err_tx.try_send(e) {
517 Ok(_) => trace!("Reported commit error to main task"),
518 Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
519 }
520 }
521 });
522 }
523 Ok(SentBatch::Partial { responses, failed }) => {
524 let has_transient = failed.iter().any(|(_, e)| {
528 matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_))
529 });
530 if has_transient {
531 let (_, first_err) = failed
532 .iter()
533 .find(|(_, e)| matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_)))
534 .expect("has_transient is true");
535 let err = anyhow::anyhow!(
536 "Transient error in batch send ({} messages failed). First error: {}",
537 failed.len(),
538 first_err
539 );
540 let commit = commit_opt.take().expect("Commit already used");
542 let dispositions =
543 map_responses_to_dispositions(&message_ids, responses, &failed, &request_ids);
544 if let Err(commit_err) = commit(dispositions).await {
545 warn!("Commit after transient failure also failed: {}", commit_err);
546 }
547 if !has_retry_middleware {
548 break Err(err);
549 }
550 warn!("Transient error in batch, message(s) Nack'ed for re-delivery: {}", err);
551 tokio::task::yield_now().await;
552 continue;
553 }
554 for (msg, e) in &failed {
556 error!("Dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
557 }
558 let commit = commit_opt.take().expect("Commit already used");
559 let err_tx = err_tx.clone();
560 let ids = std::mem::take(&mut message_ids);
561 let req_ids = request_ids;
562 commit_tasks.spawn(async move {
563 let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &req_ids);
564 if let Err(e) = commit(dispositions).await {
565 error!("Commit failed: {}", e);
566 match err_tx.try_send(e) {
567 Ok(_) => trace!("Reported commit error to main task"),
568 Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
569 }
570 }
571 });
572 }
573 Err(e) => {
574 warn!("Publisher error, sending {} Nacks to commit", batch_len);
576 let commit = commit_opt.take().expect("Commit already used");
577 let nack_result = commit(vec![MessageDisposition::Nack; batch_len]).await;
578 debug!("Nack commit result: {:?}", nack_result);
579 break Err(e.into());
580 }
581 }
582
583 tokio::task::yield_now().await;
584 }
585 }
586 };
587
588 drop(seq_tx);
589 loop {
591 select! {
592 res = err_rx.recv() => {
593 if let Ok(err) = res {
594 error!("Error reported during shutdown: {}", err);
595 }
596 }
597 res = commit_tasks.join_next() => {
598 if res.is_none() {
599 break;
600 }
601 }
602 }
603 }
604 drop(err_rx);
605 let _ = sequencer_handle.await;
606 run_consumer_disconnect_hook(name, consumer.as_ref()).await;
607 run_publisher_disconnect_hook(name, &publisher).await;
608 run_result
609 }
610
611 async fn run_concurrently(
613 &self,
614 name: &str,
615 shutdown_rx: async_channel::Receiver<()>,
616 ready_tx: Option<Sender<()>>,
617 ) -> anyhow::Result<bool> {
618 let publisher = create_publisher_from_route(name, &self.output).await?;
619 let mut consumer = create_consumer_from_route(name, &self.input).await?;
620 if let Err(err) = run_publisher_connect_hook(name, &publisher).await {
621 run_publisher_disconnect_hook(name, &publisher).await;
622 return Err(err);
623 }
624 if let Err(err) = run_consumer_connect_hook(name, consumer.as_ref()).await {
625 run_consumer_disconnect_hook(name, consumer.as_ref()).await;
626 run_publisher_disconnect_hook(name, &publisher).await;
627 return Err(err);
628 }
629 if let Some(tx) = ready_tx {
630 let _ = tx.send(()).await;
631 }
632 let (err_tx, err_rx) = bounded(1); let work_capacity = self
635 .options
636 .concurrency
637 .saturating_mul(self.options.batch_size);
638 let (work_tx, work_rx) =
639 bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(work_capacity);
640 let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
644
645 let batch_size = self.options.batch_size;
647 let mut join_set = JoinSet::new();
648 for i in 0..self.options.concurrency {
649 let work_rx_clone = work_rx.clone();
650 let publisher = Arc::clone(&publisher);
651 let err_tx = err_tx.clone();
652 let mut commit_tasks = JoinSet::new();
653 let has_retry_middleware = self.output.has_retry_middleware();
654 join_set.spawn(async move {
655 debug!("Starting worker {}", i);
656 let mut message_ids = Vec::with_capacity(batch_size);
657 while let Ok((messages, commit_func)) = work_rx_clone.recv().await {
658 let mut commit_opt = Some(commit_func);
659 let batch_len = messages.len();
660 message_ids.clear();
661 message_ids.extend(messages.iter().map(|m| m.message_id));
662 let request_ids: std::collections::HashSet<u128> = messages
663 .iter()
664 .filter(|m| m.metadata.contains_key("reply_to"))
665 .map(|m| m.message_id)
666 .collect();
667
668 match publisher.send_batch(messages).await {
669 Ok(SentBatch::Ack) => {
670 for id in &message_ids {
671 if request_ids.contains(id) {
672 warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop broken.", id);
673 }
674 }
675 let commit = commit_opt.take().expect("Commit already used");
676 let err_tx = err_tx.clone();
677 commit_tasks.spawn(async move {
678 if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
679 error!("Commit failed: {}", e);
680 match err_tx.try_send(e) {
681 Ok(_) => trace!("Reported commit error to main task"),
682 Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
683 }
684 }
685 });
686 }
687 Ok(SentBatch::Partial { responses, failed }) => {
688 let has_transient = failed.iter().any(|(_, e)| {
690 matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_))
691 });
692 if has_transient {
693 let (_, first_err) = failed
694 .iter()
695 .find(|(_, e)| matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_)))
696 .expect("has_transient is true");
697 let e = anyhow::anyhow!(
698 "Transient error in batch send ({} messages failed). First error: {}",
699 failed.len(),
700 first_err
701 );
702 let commit = commit_opt.take().expect("Commit already used");
703 let dispositions =
705 map_responses_to_dispositions(&message_ids, responses, &failed, &request_ids);
706 if let Err(commit_err) = commit(dispositions).await {
707 warn!("Commit after transient failure also failed: {}", commit_err);
708 }
709 if !has_retry_middleware {
710 match err_tx.try_send(e) {
711 Ok(_) => trace!("Reported error to main task"),
712 Err(err_send) => warn!(error=?err_send, "Could not send error to main task, it might be down or busy."),
713 }
714 break;
715 }
716 warn!("Transient error in batch, message(s) Nack'ed for re-delivery: {}", e);
717 tokio::task::yield_now().await;
718 continue;
719 }
720 for (msg, e) in &failed {
722 error!("Worker dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
723 }
724 let commit = commit_opt.take().expect("Commit already used");
725 let err_tx = err_tx.clone();
726 let ids = std::mem::take(&mut message_ids);
727 let req_ids = request_ids;
728 commit_tasks.spawn(async move {
729 let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &req_ids);
730 if let Err(e) = commit(dispositions).await {
731 error!("Commit failed: {}", e);
732 match err_tx.try_send(e) {
733 Ok(_) => trace!("Reported commit error to main task"),
734 Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
735 }
736 }
737 });
738 }
739 Err(e) => {
740 error!("Worker failed to send message batch: {}", e);
741 let commit = commit_opt.take().expect("Commit already used");
742 let nack_result = commit(vec![MessageDisposition::Nack; batch_len]).await;
744 debug!("Nack commit result: {:?}", nack_result);
745 match err_tx.try_send(e.into()) {
747 Ok(_) => trace!("Reported error to main task"),
748 Err(err_send) => warn!(error=?err_send, "Could not send error to main task, it might be down or busy."),
749 }
750 break;
751 }
752 }
753 }
754 while commit_tasks.join_next().await.is_some() {}
756 });
757 }
758
759 let mut seq_counter = 0u64;
760 let mut loop_error: Option<anyhow::Error> = None;
762 loop {
763 select! {
764 biased; Ok(err) = err_rx.recv() => {
767 error!("A worker reported a critical error. Shutting down route.");
768 loop_error = Some(err);
769 break;
770 }
771
772 Some(res) = join_set.join_next() => {
773 match res {
774 Ok(_) => {
775 error!("A worker task finished unexpectedly. Shutting down route.");
776 loop_error = Some(anyhow::anyhow!("Worker task finished unexpectedly"));
777 }
778 Err(e) => {
779 error!("A worker task panicked: {}. Shutting down route.", e);
780 loop_error = Some(e.into());
781 }
782 }
783 break;
784 }
785
786 _ = shutdown_rx.recv() => {
787 info!("Shutdown signal received in concurrent runner for route '{}'.", name);
788 break;
789 }
790
791 res = consumer.receive_batch(self.options.batch_size) => {
792 let (messages, commit) = match res {
793 Ok(batch) => {
794 if batch.messages.is_empty() {
795 continue; }
797 (batch.messages, batch.commit)
798 }
799 Err(ConsumerError::EndOfStream) => {
800 info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
801 break; }
803 Err(ConsumerError::Connection(e)) => {
804 loop_error = Some(e);
806 break;
807 }
808 Err(ConsumerError::Gap { requested, base }) => {
809 loop_error = Some(ConsumerError::Gap { requested, base }.into());
811 break;
812 }
813 };
814 debug!("Received a batch of {} messages concurrently", messages.len());
815
816 let seq = seq_counter;
821 let wrapped_commit = wrap_commit(commit, seq, seq_tx.clone());
822
823 match work_tx.send((messages, wrapped_commit)).await {
824 Ok(()) => {
825 seq_counter += 1;
826 }
827 Err(e) => {
828 warn!("Work channel closed, cannot process more messages concurrently. Shutting down.");
829 let (msgs_back, wrapped_commit_back) = e.into_inner();
832 let _ = (wrapped_commit_back)(vec![crate::traits::MessageDisposition::Nack; msgs_back.len()]).await;
833 break;
834 }
835 }
836
837 tokio::task::yield_now().await;
838 }
839 }
840 }
841
842 drop(work_tx);
847 while join_set.join_next().await.is_some() {}
849
850 drop(seq_tx);
852 let _ = sequencer_handle.await;
853 run_consumer_disconnect_hook(name, consumer.as_ref()).await;
854 run_publisher_disconnect_hook(name, &publisher).await;
855
856 if let Some(err) = loop_error {
857 return Err(err);
858 }
859
860 if let Ok(err) = err_rx.try_recv() {
861 return Err(err);
862 }
863
864 Ok(shutdown_rx.is_empty())
867 }
868
869 pub fn with_options(mut self, options: RouteOptions) -> Self {
870 self.options = options;
871 self
872 }
873 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
874 self.options.concurrency = concurrency.max(1);
875 self
876 }
877
878 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
879 self.options.batch_size = batch_size.max(1);
880 self
881 }
882 pub fn with_commit_concurrency_limit(mut self, limit: usize) -> Self {
883 self.options.commit_concurrency_limit = limit.max(1);
884 self
885 }
886
887 pub fn with_handler(mut self, handler: impl Handler + 'static) -> Self {
888 self.output.handler = Some(Arc::new(handler));
889 self
890 }
891
892 pub fn add_handler<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
915 where
916 T: DeserializeOwned + Send + Sync + 'static,
917 H: crate::type_handler::IntoTypedHandler<T, Args>,
918 Args: Send + Sync + 'static,
919 {
920 let handler = Arc::new(handler);
922 let wrapper = move |msg: crate::CanonicalMessage| {
923 let handler = handler.clone();
924 async move {
925 let data = msg.parse::<T>().map_err(|e| {
926 HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
927 })?;
928 let ctx = crate::MessageContext::from(msg);
929 handler.call(data, ctx).await
930 }
931 };
932 let wrapper = Arc::new(wrapper);
933
934 let prev_handler = self.output.handler.take();
935
936 let new_handler = if let Some(h) = prev_handler {
937 if let Some(extended) = h.register_handler(type_name, wrapper.clone()) {
938 extended
939 } else {
940 Arc::new(
941 crate::type_handler::TypeHandler::new()
942 .with_fallback(h)
943 .add_handler(type_name, wrapper),
944 )
945 }
946 } else {
947 Arc::new(crate::type_handler::TypeHandler::new().add_handler(type_name, wrapper))
948 };
949
950 self.output.handler = Some(new_handler);
951 self
952 }
953 pub fn add_handlers<T, H, Args>(mut self, handlers: HashMap<&str, H>) -> Self
954 where
955 T: DeserializeOwned + Send + Sync + 'static,
956 H: crate::type_handler::IntoTypedHandler<T, Args>,
957 Args: Send + Sync + 'static,
958 {
959 for (type_name, handler) in handlers {
960 self = self.add_handler(type_name, handler);
961 }
962 self
963 }
964}
965
966type SequencerItem = (
967 Vec<MessageDisposition>,
968 BatchCommitFunc,
969 tokio::sync::oneshot::Sender<anyhow::Result<()>>,
970);
971
972fn spawn_sequencer(buffer_size: usize) -> (Sender<(u64, SequencerItem)>, JoinHandle<()>) {
973 let (seq_tx, seq_rx) = bounded::<(u64, SequencerItem)>(buffer_size);
974 let sequencer_handle = tokio::spawn(async move {
975 let mut buffer: BTreeMap<u64, SequencerItem> = BTreeMap::new();
976 let mut next_seq = 0u64;
977
978 loop {
979 if let Some((dispositions, commit_func, notify)) = buffer.remove(&next_seq) {
985 let result = commit_func(dispositions).await;
986 let _ = notify.send(result);
987 next_seq += 1;
988 tokio::task::yield_now().await;
990 continue;
991 }
992
993 match seq_rx.recv().await {
995 Ok((seq, item)) => {
996 if seq < next_seq {
997 let (_, _, notify) = item;
998 trace!(
999 seq,
1000 next_seq,
1001 "Sequencer received late item (seq < next_seq)"
1002 );
1003 let _ = notify.send(Err(anyhow::anyhow!(
1004 "Sequencer received late item (seq {} < next_seq {})",
1005 seq,
1006 next_seq
1007 )));
1008 } else {
1009 buffer.insert(seq, item);
1010 }
1011 }
1012 Err(_) => {
1013 for (_, (_, _, notify)) in buffer {
1015 let _ = notify.send(Err(anyhow::anyhow!("Sequencer is shutting down")));
1016 }
1017 break;
1018 }
1019 }
1020 }
1021 });
1022 (seq_tx, sequencer_handle)
1023}
1024
1025fn wrap_commit(
1026 commit: BatchCommitFunc,
1027 seq: u64,
1028 seq_tx: Sender<(u64, SequencerItem)>,
1029) -> BatchCommitFunc {
1030 Box::new(move |dispositions| {
1031 Box::pin(async move {
1032 let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
1033 if seq_tx
1034 .send((seq, (dispositions, commit, notify_tx)))
1035 .await
1036 .is_ok()
1037 {
1038 match notify_rx.await {
1039 Ok(res) => res,
1040 Err(_) => Err(anyhow::anyhow!(
1041 "Sequencer dropped the commit channel unexpectedly"
1042 )),
1043 }
1044 } else {
1045 Err(anyhow::anyhow!(
1046 "Failed to send commit to sequencer, route is likely shutting down"
1047 ))
1048 }
1049 })
1050 })
1051}
1052
1053fn map_responses_to_dispositions(
1054 message_ids: &[u128],
1055 responses: Option<Vec<crate::CanonicalMessage>>,
1056 failed: &[(crate::CanonicalMessage, PublisherError)],
1057 request_ids: &std::collections::HashSet<u128>,
1058) -> Vec<MessageDisposition> {
1059 let mut dispositions = Vec::with_capacity(message_ids.len());
1060 let failed_ids: std::collections::HashSet<u128> =
1061 failed.iter().map(|(m, _)| m.message_id).collect();
1062
1063 let mut response_map: std::collections::HashMap<u128, crate::CanonicalMessage> = responses
1065 .unwrap_or_default()
1066 .into_iter()
1067 .map(|r| (r.message_id, r))
1068 .collect();
1069
1070 for id in message_ids {
1071 if failed_ids.contains(id) {
1072 dispositions.push(MessageDisposition::Nack);
1073 } else if let Some(resp) = response_map.remove(id) {
1074 dispositions.push(MessageDisposition::Reply(resp));
1076 } else if request_ids.contains(id) {
1077 error!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Nacking to avoid committing a lost response.", id);
1078 dispositions.push(MessageDisposition::Nack);
1079 } else {
1080 dispositions.push(MessageDisposition::Ack);
1082 }
1083 }
1084 dispositions
1085}
1086
1087#[cfg(test)]
1088fn test_map_responses_to_dispositions_logic() {
1089 use crate::{traits::PublisherError, CanonicalMessage};
1090 use anyhow::anyhow;
1091
1092 let ids = vec![1, 2, 3, 4];
1093
1094 let mut resp1 = CanonicalMessage::from("resp1");
1095 resp1.message_id = 1;
1096 let mut resp4 = CanonicalMessage::from("resp4");
1097 resp4.message_id = 4;
1098
1099 let responses = Some(vec![
1100 resp1, resp4, ]);
1103
1104 let mut msg2 = CanonicalMessage::from("msg2");
1105 msg2.message_id = 2;
1106 let failed = vec![(msg2, PublisherError::NonRetryable(anyhow!("failed")))];
1107
1108 let mut request_ids = std::collections::HashSet::new();
1109 request_ids.insert(3); let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &request_ids);
1111
1112 assert_eq!(dispositions.len(), 4);
1113 assert!(matches!(dispositions[0], MessageDisposition::Reply(_))); assert!(matches!(dispositions[1], MessageDisposition::Nack)); assert!(matches!(dispositions[2], MessageDisposition::Nack)); assert!(matches!(dispositions[3], MessageDisposition::Reply(_))); }
1118
1119pub fn get_route(name: &str) -> Option<Route> {
1120 Route::get(name)
1121}
1122
1123pub fn list_routes() -> Vec<String> {
1124 Route::list()
1125}
1126
1127pub async fn stop_route(name: &str) -> bool {
1128 Route::stop(name).await
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134 use crate::models::{Endpoint, EndpointType, FaultMode, Middleware, RandomPanicMiddleware};
1135 use crate::traits::{
1136 CustomMiddlewareFactory, MessageConsumer, MessagePublisher, ReceivedBatch,
1137 };
1138 use crate::CanonicalMessage;
1139 use std::any::Any;
1140 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
1141 use std::sync::Arc;
1142 use std::time::Duration;
1143
1144 #[derive(Debug, Default)]
1145 struct CommitObservation {
1146 completed: Mutex<Vec<u64>>,
1147 active: std::sync::atomic::AtomicUsize,
1148 max_active: std::sync::atomic::AtomicUsize,
1149 }
1150
1151 #[derive(Debug)]
1152 struct CommitTrackingMiddlewareFactory {
1153 observation: Arc<CommitObservation>,
1154 }
1155
1156 #[derive(Debug)]
1157 struct ReorderingPublisherMiddlewareFactory;
1158
1159 struct CommitTrackingConsumer {
1160 inner: Box<dyn MessageConsumer>,
1161 observation: Arc<CommitObservation>,
1162 }
1163
1164 struct ReorderingPublisher {
1165 inner: Box<dyn MessagePublisher>,
1166 }
1167
1168 #[async_trait::async_trait]
1169 impl CustomMiddlewareFactory for CommitTrackingMiddlewareFactory {
1170 async fn apply_consumer(
1171 &self,
1172 consumer: Box<dyn MessageConsumer>,
1173 _route_name: &str,
1174 _config: &serde_json::Value,
1175 ) -> anyhow::Result<Box<dyn MessageConsumer>> {
1176 Ok(Box::new(CommitTrackingConsumer {
1177 inner: consumer,
1178 observation: Arc::clone(&self.observation),
1179 }))
1180 }
1181 }
1182
1183 #[async_trait::async_trait]
1184 impl CustomMiddlewareFactory for ReorderingPublisherMiddlewareFactory {
1185 async fn apply_publisher(
1186 &self,
1187 publisher: Box<dyn MessagePublisher>,
1188 _route_name: &str,
1189 _config: &serde_json::Value,
1190 ) -> anyhow::Result<Box<dyn MessagePublisher>> {
1191 Ok(Box::new(ReorderingPublisher { inner: publisher }))
1192 }
1193 }
1194
1195 #[async_trait::async_trait]
1196 impl MessageConsumer for CommitTrackingConsumer {
1197 async fn receive_batch(
1198 &mut self,
1199 max_messages: usize,
1200 ) -> Result<ReceivedBatch, ConsumerError> {
1201 let mut batch = self.inner.receive_batch(max_messages).await?;
1202 let seq = batch
1203 .messages
1204 .first()
1205 .and_then(|message| message.get_payload_str().parse::<u64>().ok())
1206 .expect("tracking test expects numeric payloads");
1207 let original_commit = batch.commit;
1208 let observation = Arc::clone(&self.observation);
1209 batch.commit = Box::new(move |dispositions| {
1210 let observation = Arc::clone(&observation);
1211 Box::pin(async move {
1212 let active_now = observation.active.fetch_add(1, Ordering::SeqCst) + 1;
1213 let _ = observation.max_active.fetch_update(
1214 Ordering::SeqCst,
1215 Ordering::SeqCst,
1216 |current| (active_now > current).then_some(active_now),
1217 );
1218
1219 tokio::time::sleep(Duration::from_millis(20)).await;
1220 let result = original_commit(dispositions).await;
1221 observation.completed.lock().unwrap().push(seq);
1222 observation.active.fetch_sub(1, Ordering::SeqCst);
1223 result
1224 })
1225 });
1226 Ok(batch)
1227 }
1228
1229 fn as_any(&self) -> &dyn Any {
1230 self
1231 }
1232 }
1233
1234 #[async_trait::async_trait]
1235 impl MessagePublisher for ReorderingPublisher {
1236 async fn send_batch(
1237 &self,
1238 messages: Vec<crate::CanonicalMessage>,
1239 ) -> Result<SentBatch, PublisherError> {
1240 let seq = messages
1241 .first()
1242 .and_then(|message| message.get_payload_str().parse::<u64>().ok())
1243 .expect("tracking test expects numeric payloads");
1244 let delay_ms = 10 * (6u64.saturating_sub(seq.min(6)));
1245 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
1246 self.inner.send_batch(messages).await
1247 }
1248
1249 async fn send(&self, msg: crate::CanonicalMessage) -> Result<Sent, PublisherError> {
1250 self.inner.send(msg).await
1251 }
1252
1253 async fn flush(&self) -> anyhow::Result<()> {
1254 self.inner.flush().await
1255 }
1256
1257 fn as_any(&self) -> &dyn Any {
1258 self
1259 }
1260 }
1261
1262 async fn assert_route_commits_are_ordered_and_non_overlapping(concurrency: usize) {
1263 let unique_id = fast_uuid_v7::gen_id().to_string();
1264 let tracking_name = format!("track_commit_{}", unique_id);
1265 let reorder_name = format!("reorder_publish_{}", unique_id);
1266 let in_topic = format!("ordered_commit_in_{}", unique_id);
1267 let observation = Arc::new(CommitObservation::default());
1268
1269 register_middleware_factory(
1270 &tracking_name,
1271 Arc::new(CommitTrackingMiddlewareFactory {
1272 observation: Arc::clone(&observation),
1273 }),
1274 );
1275 register_middleware_factory(
1276 &reorder_name,
1277 Arc::new(ReorderingPublisherMiddlewareFactory),
1278 );
1279
1280 let input = Endpoint::new_memory(&in_topic, 32).add_middleware(Middleware::Custom {
1281 name: tracking_name,
1282 config: serde_json::Value::Null,
1283 });
1284 let output = Endpoint::new(EndpointType::Null).add_middleware(Middleware::Custom {
1285 name: reorder_name,
1286 config: serde_json::Value::Null,
1287 });
1288
1289 let route = Route::new(input.clone(), output)
1290 .with_concurrency(concurrency)
1291 .with_batch_size(1)
1292 .with_commit_concurrency_limit(1);
1293
1294 let input_channel = input.channel().unwrap();
1295 let messages = (0..6)
1296 .map(|seq| crate::CanonicalMessage::from(seq.to_string()))
1297 .collect();
1298 input_channel.fill_messages(messages).await.unwrap();
1299 input_channel.close();
1300
1301 tokio::time::timeout(
1302 std::time::Duration::from_secs(5),
1303 route.run_until_err("ordered_commit_regression", None, None),
1304 )
1305 .await
1306 .expect("Route should not hang while draining finite input")
1307 .expect("Route should complete without commit errors");
1308 assert_eq!(
1309 *observation.completed.lock().unwrap(),
1310 vec![0, 1, 2, 3, 4, 5],
1311 "Commit execution must follow receive order",
1312 );
1313 assert_eq!(
1314 observation.max_active.load(Ordering::SeqCst),
1315 1,
1316 "Broker-facing commit functions must never overlap",
1317 );
1318 }
1319
1320 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1321 async fn test_sequential_route_commits_are_ordered_and_non_overlapping() {
1322 assert_route_commits_are_ordered_and_non_overlapping(1).await;
1323 }
1324
1325 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1326 async fn test_concurrent_route_commits_are_ordered_and_non_overlapping() {
1327 assert_route_commits_are_ordered_and_non_overlapping(4).await;
1328 }
1329
1330 async fn run_consumer_fault_test(
1332 mode: FaultMode,
1333 expected_payload: &str,
1334 route_should_restart: bool,
1335 concurrency: usize,
1336 ) {
1337 let unique_suffix = fast_uuid_v7::gen_id().to_string();
1338 let in_topic = format!("fault_in_{}_{}_{}", mode, concurrency, unique_suffix);
1339 let out_topic = format!("fault_out_{}_{}_{}", mode, concurrency, unique_suffix);
1340
1341 let fault_config = RandomPanicMiddleware {
1342 mode,
1343 trigger_on_message: Some(1), enabled: true,
1345 ..Default::default()
1346 };
1347
1348 let input = Endpoint::new_memory(&in_topic, 10)
1349 .add_middleware(Middleware::RandomPanic(fault_config));
1350 let output = Endpoint::new_memory(&out_topic, 10);
1351
1352 let route_name = format!("fault_test_{}_{}", mode, concurrency);
1353 let route = Route::new(input.clone(), output.clone()).with_concurrency(concurrency);
1354
1355 route
1357 .deploy(&route_name)
1358 .await
1359 .expect("Failed to deploy route");
1360 let input_ch = input.channel().unwrap();
1362 input_ch
1363 .send_message("persistent_msg".into())
1364 .await
1365 .unwrap();
1366
1367 if route_should_restart {
1368 tokio::time::sleep(std::time::Duration::from_secs(6)).await;
1371 } else {
1372 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1374 }
1375
1376 let mut verifier = route.connect_to_output("verifier").await.unwrap();
1378 let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
1379 .await
1380 .expect("Timed out waiting for message after fault")
1381 .expect("Stream closed while waiting for message");
1382
1383 assert_eq!(received.message.get_payload_str(), expected_payload);
1384 (received.commit)(MessageDisposition::Ack).await.unwrap();
1385
1386 Route::stop(&route_name).await;
1388 }
1389
1390 async fn run_publisher_fault_test(
1392 mode: FaultMode,
1393 expected_payload: &str,
1394 route_should_restart: bool,
1395 ) {
1396 let unique_suffix = fast_uuid_v7::gen_id().to_string();
1397 let in_topic = format!("pub_fault_in_{}_{}", mode, unique_suffix);
1398 let out_topic = format!("pub_fault_out_{}_{}", mode, unique_suffix);
1399
1400 let fault_config = RandomPanicMiddleware {
1401 mode,
1402 trigger_on_message: Some(1), enabled: true,
1404 ..Default::default()
1405 };
1406
1407 let mut input = Endpoint::new_memory(&in_topic, 10);
1408 if let EndpointType::Memory(ref mut cfg) = input.endpoint_type {
1410 cfg.enable_nack = true;
1411 }
1412 let output = Endpoint::new_memory(&out_topic, 10)
1414 .add_middleware(Middleware::RandomPanic(fault_config));
1415
1416 let route_name = format!("pub_fault_test_{}", mode);
1417 let route = Route::new(input.clone(), output.clone());
1418
1419 route
1420 .deploy(&route_name)
1421 .await
1422 .expect("Failed to deploy route");
1423
1424 let input_ch = input.channel().unwrap();
1425 input_ch
1426 .send_message(expected_payload.into())
1427 .await
1428 .unwrap();
1429
1430 if route_should_restart {
1431 tokio::time::sleep(std::time::Duration::from_secs(6)).await;
1432 } else {
1433 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1434 }
1435
1436 let mut verifier = route.connect_to_output("verifier").await.unwrap();
1437 let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
1438 .await
1439 .expect("Timed out waiting for message after publisher fault")
1440 .expect("Stream closed");
1441
1442 assert_eq!(received.message.get_payload_str(), expected_payload);
1443 (received.commit)(MessageDisposition::Ack).await.unwrap();
1444
1445 Route::stop(&route_name).await;
1446 }
1447
1448 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1449 #[ignore = "Takes too much time for regular tests"]
1450 async fn test_route_recovery_from_faults() {
1451 let original_payload = "persistent_msg";
1452
1453 run_consumer_fault_test(FaultMode::Panic, original_payload, true, 2).await;
1455 run_consumer_fault_test(FaultMode::Disconnect, original_payload, true, 2).await;
1456 run_consumer_fault_test(FaultMode::Timeout, original_payload, true, 2).await;
1457 run_consumer_fault_test(FaultMode::Nack, original_payload, true, 2).await;
1458
1459 run_consumer_fault_test(FaultMode::JsonFormatError, "{invalid json}", false, 2).await;
1461 }
1462
1463 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1464 #[ignore = "Takes too much time for regular tests"]
1465 async fn test_route_recovery_from_faults_sequential() {
1466 let original_payload = "persistent_msg";
1467
1468 run_consumer_fault_test(FaultMode::Panic, original_payload, true, 1).await;
1470 run_consumer_fault_test(FaultMode::Disconnect, original_payload, true, 1).await;
1471 }
1472
1473 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1474 #[ignore = "Takes too much time for regular tests"]
1475 async fn test_publisher_recovery_from_faults() {
1476 let original_payload = "persistent_msg";
1477 run_publisher_fault_test(FaultMode::Disconnect, original_payload, true).await;
1482 run_publisher_fault_test(FaultMode::Timeout, original_payload, true).await;
1483 }
1484
1485 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1486 async fn test_route_sequencer_deadlock_fix() {
1487 let unique_id = fast_uuid_v7::gen_id().to_string();
1492 let factory_name = format!("fail_factory_{}", unique_id);
1493 let in_topic = format!("deadlock_in_{}", unique_id);
1494 let out_topic = format!("deadlock_out_{}", unique_id);
1495
1496 #[derive(Debug)]
1497 struct FailingMiddlewareFactory {
1498 fail_flag: Arc<AtomicBool>,
1499 }
1500
1501 #[async_trait::async_trait]
1502 impl CustomMiddlewareFactory for FailingMiddlewareFactory {
1503 async fn apply_publisher(
1504 &self,
1505 publisher: Box<dyn MessagePublisher>,
1506 _route_name: &str,
1507 _config: &serde_json::Value,
1508 ) -> anyhow::Result<Box<dyn MessagePublisher>> {
1509 Ok(Box::new(FailingPublisher {
1510 inner: publisher,
1511 fail_flag: self.fail_flag.clone(),
1512 }))
1513 }
1514 async fn apply_consumer(
1515 &self,
1516 consumer: Box<dyn MessageConsumer>,
1517 _route_name: &str,
1518 _config: &serde_json::Value,
1519 ) -> anyhow::Result<Box<dyn MessageConsumer>> {
1520 Ok(consumer)
1521 }
1522 }
1523
1524 struct FailingPublisher {
1525 inner: Box<dyn MessagePublisher>,
1526 fail_flag: Arc<AtomicBool>,
1527 }
1528
1529 #[async_trait::async_trait]
1530 impl MessagePublisher for FailingPublisher {
1531 async fn send_batch(
1532 &self,
1533 messages: Vec<crate::CanonicalMessage>,
1534 ) -> Result<SentBatch, PublisherError> {
1535 if self
1538 .fail_flag
1539 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
1540 .is_ok()
1541 {
1542 return Err(PublisherError::Retryable(anyhow::anyhow!(
1543 "Simulated failure"
1544 )));
1545 }
1546 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1549 self.inner.send_batch(messages).await
1550 }
1551 async fn send(
1552 &self,
1553 msg: crate::CanonicalMessage,
1554 ) -> Result<crate::traits::Sent, PublisherError> {
1555 self.inner.send(msg).await
1556 }
1557 async fn flush(&self) -> anyhow::Result<()> {
1558 self.inner.flush().await
1559 }
1560 fn as_any(&self) -> &dyn Any {
1561 self
1562 }
1563 }
1564
1565 let fail_flag = Arc::new(AtomicBool::new(true));
1566 register_middleware_factory(
1567 &factory_name,
1568 Arc::new(FailingMiddlewareFactory {
1569 fail_flag: fail_flag.clone(),
1570 }),
1571 );
1572
1573 let input = Endpoint::new_memory(&in_topic, 100);
1574 let output = Endpoint::new_memory(&out_topic, 100).add_middleware(Middleware::Custom {
1575 name: factory_name,
1576 config: serde_json::Value::Null,
1577 });
1578
1579 let route = Route::new(input.clone(), output.clone())
1581 .with_concurrency(2)
1582 .with_batch_size(1);
1583
1584 let input_ch = input.channel().unwrap();
1586 input_ch.send_message("msg1".into()).await.unwrap();
1587 input_ch.send_message("msg2".into()).await.unwrap();
1588 input_ch.send_message("msg3".into()).await.unwrap();
1589
1590 let run_fut = async {
1593 let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1);
1594 route
1595 .run_until_err("deadlock_test", Some(shutdown_rx), None)
1596 .await
1597 };
1598
1599 let result = tokio::time::timeout(std::time::Duration::from_secs(5), run_fut).await;
1601
1602 match result {
1603 Ok(res) => {
1604 assert!(
1606 res.is_err(),
1607 "Route should have failed with simulated error"
1608 );
1609 }
1610 Err(_) => {
1611 panic!("Route deadlocked! The sequencer likely didn't receive the Nack for the failed batch.");
1612 }
1613 }
1614 }
1615
1616 #[tokio::test]
1617 async fn test_sequencer_ordered_commits() {
1618 use std::time::Duration;
1619 use tokio::time::timeout;
1620
1621 let (seq_tx, sequencer_handle) = spawn_sequencer(16);
1622 let processed: Arc<Mutex<Vec<u64>>> = Arc::new(Mutex::new(Vec::new()));
1623
1624 let seqs = [2u64, 0u64, 1u64, 3u64];
1626 let mut receivers = Vec::new();
1627
1628 for seq in seqs.iter().cloned() {
1629 let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
1630 let processed_clone = processed.clone();
1631 let commit: BatchCommitFunc = Box::new(move |_dispositions| {
1632 let processed = processed_clone.clone();
1633 Box::pin(async move {
1634 tokio::time::sleep(Duration::from_millis(10 * seq)).await;
1636 processed.lock().unwrap().push(seq);
1637 Ok(())
1638 })
1639 });
1640 seq_tx
1641 .send((seq, (Vec::new(), commit, notify_tx)))
1642 .await
1643 .unwrap();
1644 receivers.push(notify_rx);
1645 }
1646
1647 for rx in receivers {
1649 let res = timeout(Duration::from_secs(2), rx)
1650 .await
1651 .expect("Sequencer notify timed out");
1652 assert!(res.is_ok(), "Sequencer reported an error on commit");
1653 assert!(res.unwrap().is_ok(), "Commit returned an error");
1654 }
1655
1656 drop(seq_tx);
1658 let _ = sequencer_handle.await;
1659
1660 let result = processed.lock().unwrap().clone();
1661 assert_eq!(
1662 result,
1663 vec![0u64, 1u64, 2u64, 3u64],
1664 "Sequencer must process commits in order"
1665 );
1666 }
1667
1668 #[tokio::test]
1669 async fn test_sequencer_shutdown_notifies_pending() {
1670 use std::time::Duration;
1671 use tokio::time::timeout;
1672
1673 let (seq_tx, sequencer_handle) = spawn_sequencer(8);
1674
1675 let (notify_tx1, notify_rx1) = tokio::sync::oneshot::channel();
1677 let (notify_tx2, notify_rx2) = tokio::sync::oneshot::channel();
1678
1679 let commit1: BatchCommitFunc = Box::new(|_dispositions| {
1680 Box::pin(async move {
1681 panic!("Commit should not be executed during shutdown drain");
1683 #[allow(unreachable_code)]
1684 Ok(())
1685 })
1686 });
1687
1688 let commit2: BatchCommitFunc = Box::new(|_dispositions| {
1689 Box::pin(async move {
1690 panic!("Commit should not be executed during shutdown drain");
1691 #[allow(unreachable_code)]
1692 Ok(())
1693 })
1694 });
1695
1696 seq_tx
1697 .send((1u64, (Vec::new(), commit1, notify_tx1)))
1698 .await
1699 .unwrap();
1700 seq_tx
1701 .send((2u64, (Vec::new(), commit2, notify_tx2)))
1702 .await
1703 .unwrap();
1704
1705 drop(seq_tx);
1707
1708 let r1 = timeout(Duration::from_secs(1), notify_rx1)
1710 .await
1711 .expect("Timeout waiting for notify_rx1")
1712 .expect("Sequencer closed notify channel");
1713 assert!(
1714 r1.is_err(),
1715 "Pending commit should receive Err on sequencer shutdown"
1716 );
1717
1718 let r2 = timeout(Duration::from_secs(1), notify_rx2)
1719 .await
1720 .expect("Timeout waiting for notify_rx2")
1721 .expect("Sequencer closed notify channel");
1722 assert!(
1723 r2.is_err(),
1724 "Pending commit should receive Err on sequencer shutdown"
1725 );
1726
1727 let _ = sequencer_handle.await;
1728 }
1729
1730 use crate::traits::{BoxFuture, CustomEndpointFactory, Sent};
1731 use std::sync::Mutex;
1732
1733 type ConsumerBehavior =
1734 Arc<Mutex<dyn FnMut() -> Result<Box<dyn MessageConsumer>, anyhow::Error> + Send + Sync>>;
1735 type PublisherBehavior =
1736 Arc<Mutex<dyn FnMut() -> Result<Box<dyn MessagePublisher>, anyhow::Error> + Send + Sync>>;
1737
1738 struct MockEndpointFactory {
1739 create_consumer_fail: bool,
1740 consumer_behavior: ConsumerBehavior,
1741 publisher_behavior: PublisherBehavior,
1742 }
1743
1744 impl std::fmt::Debug for MockEndpointFactory {
1745 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1746 f.debug_struct("MockEndpointFactory")
1747 .field("create_consumer_fail", &self.create_consumer_fail)
1748 .finish()
1749 }
1750 }
1751
1752 impl MockEndpointFactory {
1753 fn new() -> Self {
1754 Self {
1755 create_consumer_fail: false,
1756 consumer_behavior: Arc::new(Mutex::new(|| Err(anyhow::anyhow!("Not implemented")))),
1757 publisher_behavior: Arc::new(Mutex::new(|| {
1758 Ok(Box::new(NoOpPublisher) as Box<dyn MessagePublisher>)
1759 })),
1760 }
1761 }
1762 }
1763
1764 #[derive(Clone)]
1765 struct NoOpPublisher;
1766 #[async_trait::async_trait]
1767 impl MessagePublisher for NoOpPublisher {
1768 async fn send_batch(
1769 &self,
1770 _: Vec<crate::CanonicalMessage>,
1771 ) -> Result<SentBatch, PublisherError> {
1772 Ok(SentBatch::Ack)
1773 }
1774 async fn send(&self, _: crate::CanonicalMessage) -> Result<Sent, PublisherError> {
1775 Ok(Sent::Ack)
1776 }
1777 fn as_any(&self) -> &dyn Any {
1778 self
1779 }
1780 }
1781
1782 #[async_trait::async_trait]
1783 impl CustomEndpointFactory for MockEndpointFactory {
1784 async fn create_consumer(
1785 &self,
1786 _: &str,
1787 _: &serde_json::Value,
1788 ) -> anyhow::Result<Box<dyn MessageConsumer>> {
1789 if self.create_consumer_fail {
1790 return Err(anyhow::anyhow!("Endpoint unavailable"));
1791 }
1792 (self.consumer_behavior.lock().unwrap())()
1793 }
1794 async fn create_publisher(
1795 &self,
1796 _: &str,
1797 _: &serde_json::Value,
1798 ) -> anyhow::Result<Box<dyn MessagePublisher>> {
1799 (self.publisher_behavior.lock().unwrap())()
1800 }
1801 }
1802
1803 #[derive(Clone, Default)]
1804 struct HookState {
1805 consumer_connects: Arc<AtomicUsize>,
1806 consumer_disconnects: Arc<AtomicUsize>,
1807 publisher_connects: Arc<AtomicUsize>,
1808 publisher_disconnects: Arc<AtomicUsize>,
1809 shared_mutations: Arc<AtomicUsize>,
1810 fail_consumer_connect: Arc<AtomicBool>,
1811 fail_consumer_disconnect: Arc<AtomicBool>,
1812 fail_publisher_disconnect: Arc<AtomicBool>,
1813 }
1814
1815 struct HookConsumer {
1816 state: HookState,
1817 }
1818
1819 struct HookPublisher {
1820 state: HookState,
1821 }
1822
1823 #[async_trait::async_trait]
1824 impl MessageConsumer for HookConsumer {
1825 fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1826 Some(Box::pin(async move {
1827 self.state.consumer_connects.fetch_add(1, Ordering::SeqCst);
1828 self.state.shared_mutations.fetch_add(1, Ordering::SeqCst);
1829 if self.state.fail_consumer_connect.load(Ordering::SeqCst) {
1830 return Err(anyhow::anyhow!("consumer hook failed"));
1831 }
1832 Ok(())
1833 }))
1834 }
1835
1836 fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1837 Some(Box::pin(async move {
1838 self.state
1839 .consumer_disconnects
1840 .fetch_add(1, Ordering::SeqCst);
1841 if self.state.fail_consumer_disconnect.load(Ordering::SeqCst) {
1842 return Err(anyhow::anyhow!("consumer disconnect hook failed"));
1843 }
1844 Ok(())
1845 }))
1846 }
1847
1848 async fn receive_batch(&mut self, _max: usize) -> Result<ReceivedBatch, ConsumerError> {
1849 Err(ConsumerError::EndOfStream)
1850 }
1851
1852 fn as_any(&self) -> &dyn Any {
1853 self
1854 }
1855 }
1856
1857 #[async_trait::async_trait]
1858 impl MessagePublisher for HookPublisher {
1859 fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1860 Some(Box::pin(async move {
1861 self.state.publisher_connects.fetch_add(1, Ordering::SeqCst);
1862 self.state.shared_mutations.fetch_add(1, Ordering::SeqCst);
1863 Ok(())
1864 }))
1865 }
1866
1867 fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1868 Some(Box::pin(async move {
1869 self.state
1870 .publisher_disconnects
1871 .fetch_add(1, Ordering::SeqCst);
1872 if self.state.fail_publisher_disconnect.load(Ordering::SeqCst) {
1873 return Err(anyhow::anyhow!("publisher disconnect hook failed"));
1874 }
1875 Ok(())
1876 }))
1877 }
1878
1879 async fn send_batch(
1880 &self,
1881 _: Vec<crate::CanonicalMessage>,
1882 ) -> Result<SentBatch, PublisherError> {
1883 Ok(SentBatch::Ack)
1884 }
1885
1886 fn as_any(&self) -> &dyn Any {
1887 self
1888 }
1889 }
1890
1891 fn hook_route(state: HookState, concurrency: usize) -> Route {
1892 let unique_id = fast_uuid_v7::gen_id().to_string();
1893 let factory_name = format!("hooks_{}", unique_id);
1894 let mut factory = MockEndpointFactory::new();
1895
1896 let consumer_state = state.clone();
1897 factory.consumer_behavior = Arc::new(Mutex::new(move || {
1898 Ok(Box::new(HookConsumer {
1899 state: consumer_state.clone(),
1900 }) as Box<dyn MessageConsumer>)
1901 }));
1902
1903 let publisher_state = state;
1904 factory.publisher_behavior = Arc::new(Mutex::new(move || {
1905 Ok(Box::new(HookPublisher {
1906 state: publisher_state.clone(),
1907 }) as Box<dyn MessagePublisher>)
1908 }));
1909
1910 register_endpoint_factory(&factory_name, Arc::new(factory));
1911
1912 let input = Endpoint {
1913 endpoint_type: EndpointType::Custom {
1914 name: factory_name.clone(),
1915 config: serde_json::Value::Null,
1916 },
1917 middlewares: vec![],
1918 handler: None,
1919 };
1920 let output = Endpoint {
1921 endpoint_type: EndpointType::Custom {
1922 name: factory_name,
1923 config: serde_json::Value::Null,
1924 },
1925 middlewares: vec![],
1926 handler: None,
1927 };
1928 Route::new(input, output).with_concurrency(concurrency)
1929 }
1930
1931 #[tokio::test]
1932 async fn test_lifecycle_hooks_called_once_sequentially() {
1933 let state = HookState::default();
1934 let route = hook_route(state.clone(), 1);
1935
1936 let stopped_by_shutdown = route
1937 .run_until_err("test_lifecycle_sequential", None, None)
1938 .await
1939 .unwrap();
1940
1941 assert!(!stopped_by_shutdown);
1942 assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
1943 assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
1944 assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
1945 assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
1946 assert_eq!(state.shared_mutations.load(Ordering::SeqCst), 2);
1947 }
1948
1949 #[tokio::test]
1950 async fn test_lifecycle_hooks_called_once_concurrently() {
1951 let state = HookState::default();
1952 let route = hook_route(state.clone(), 4);
1953
1954 route
1955 .run_until_err("test_lifecycle_concurrent", None, None)
1956 .await
1957 .unwrap();
1958
1959 assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
1960 assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
1961 assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
1962 assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
1963 }
1964
1965 #[tokio::test]
1966 async fn test_lifecycle_on_connect_failure_stops_route() {
1967 let state = HookState::default();
1968 state.fail_consumer_connect.store(true, Ordering::SeqCst);
1969 let route = hook_route(state.clone(), 1);
1970
1971 let err = route
1972 .run_until_err("test_lifecycle_connect_failure", None, None)
1973 .await
1974 .unwrap_err();
1975
1976 assert!(err.to_string().contains("on_connect hook failed"));
1977 assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
1978 assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
1979 }
1980
1981 #[tokio::test]
1982 async fn test_lifecycle_on_disconnect_failure_does_not_stop_route() {
1983 let state = HookState::default();
1984 state.fail_consumer_disconnect.store(true, Ordering::SeqCst);
1985 state
1986 .fail_publisher_disconnect
1987 .store(true, Ordering::SeqCst);
1988 let route = hook_route(state.clone(), 1);
1989
1990 let stopped_by_shutdown = route
1991 .run_until_err("test_lifecycle_disconnect_failure", None, None)
1992 .await
1993 .unwrap();
1994
1995 assert!(!stopped_by_shutdown);
1996 assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
1997 assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
1998 }
1999
2000 #[tokio::test]
2001 async fn test_start_fails_on_unavailable_endpoint() {
2002 let unique_id = fast_uuid_v7::gen_id().to_string();
2004 let factory_name = format!("unavailable_{}", unique_id);
2005
2006 let factory = Arc::new(MockEndpointFactory {
2007 create_consumer_fail: true,
2008 ..MockEndpointFactory::new()
2009 });
2010 register_endpoint_factory(&factory_name, factory);
2011
2012 let input = Endpoint {
2013 endpoint_type: EndpointType::Custom {
2014 name: factory_name,
2015 config: serde_json::Value::Null,
2016 },
2017 middlewares: vec![],
2018 handler: None,
2019 };
2020 let output = Endpoint::new_memory("out", 10);
2021 let route = Route::new(input, output);
2022
2023 let result = route.run("test_start_fail").await;
2026 assert!(result.is_err());
2027 assert!(result.unwrap_err().to_string().contains("failed to start"));
2028 }
2029
2030 #[tokio::test]
2031 async fn test_reconnect_on_consumer_error() {
2032 let unique_id = fast_uuid_v7::gen_id().to_string();
2034 let factory_name = format!("reconnect_{}", unique_id);
2035
2036 let connection_attempts = Arc::new(AtomicUsize::new(0));
2038 let attempts_clone = connection_attempts.clone();
2039
2040 let consumer_logic = move || -> Result<Box<dyn MessageConsumer>, anyhow::Error> {
2041 let attempt = attempts_clone.fetch_add(1, Ordering::SeqCst);
2042
2043 struct FlakyConsumer {
2044 attempt: usize,
2045 }
2046 #[async_trait::async_trait]
2047 impl MessageConsumer for FlakyConsumer {
2048 async fn receive_batch(
2049 &mut self,
2050 _max: usize,
2051 ) -> Result<ReceivedBatch, ConsumerError> {
2052 if self.attempt == 0 {
2053 self.attempt = 999; Ok(ReceivedBatch {
2056 messages: vec![crate::CanonicalMessage::from("msg1")],
2057 commit: Box::new(|_| Box::pin(async { Ok(()) })),
2058 })
2059 } else if self.attempt == 999 {
2060 Err(ConsumerError::Connection(anyhow::anyhow!(
2062 "Connection dropped"
2063 )))
2064 } else {
2065 tokio::time::sleep(Duration::from_millis(100)).await;
2068 Ok(ReceivedBatch {
2069 messages: vec![crate::CanonicalMessage::from("msg2")],
2070 commit: Box::new(|_| Box::pin(async { Ok(()) })),
2071 })
2072 }
2073 }
2074 fn as_any(&self) -> &dyn Any {
2075 self
2076 }
2077 }
2078 Ok(Box::new(FlakyConsumer { attempt }))
2079 };
2080
2081 let mut factory = MockEndpointFactory::new();
2082 factory.consumer_behavior = Arc::new(Mutex::new(consumer_logic));
2083 register_endpoint_factory(&factory_name, Arc::new(factory));
2084
2085 let input = Endpoint {
2086 endpoint_type: EndpointType::Custom {
2087 name: factory_name,
2088 config: serde_json::Value::Null,
2089 },
2090 middlewares: vec![],
2091 handler: None,
2092 };
2093 let output = Endpoint::new_memory(&format!("out_{}", unique_id), 10);
2094 let route = Route::new(input, output.clone());
2095
2096 route.deploy("test_reconnect").await.unwrap();
2097
2098 let mut verifier = create_consumer_from_route("verifier", &output)
2100 .await
2101 .unwrap();
2102
2103 let msg1 = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
2105 .await
2106 .expect("Timed out waiting for msg1")
2107 .unwrap();
2108 assert_eq!(msg1.message.get_payload_str(), "msg1");
2109
2110 let msg2 = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
2113 .await
2114 .expect("Timed out waiting for msg2")
2115 .unwrap();
2116 assert_eq!(msg2.message.get_payload_str(), "msg2");
2117
2118 assert!(connection_attempts.load(Ordering::SeqCst) >= 2);
2119 Route::stop("test_reconnect").await;
2120 }
2121
2122 #[tokio::test]
2123 async fn test_non_retryable_handler_error_does_not_crash_route() {
2124 let unique_id = fast_uuid_v7::gen_id().to_string();
2125 let in_topic = format!("bad_input_in_{}", unique_id);
2126 let out_topic = format!("bad_input_out_{}", unique_id); let input = Endpoint::new_memory(&in_topic, 10);
2129 let output = Endpoint::new_memory(&out_topic, 10);
2130
2131 let handler = |msg: crate::CanonicalMessage| async move {
2133 if msg.get_payload_str() == "poison" {
2134 Err(HandlerError::NonRetryable(anyhow::anyhow!("Invalid input")))
2135 } else {
2136 Ok(crate::Handled::Publish(msg))
2137 }
2138 };
2139
2140 let route = Route::new(input.clone(), output).with_handler(handler);
2141 route.deploy("test_invalid_input").await.unwrap();
2142
2143 let input_ch = input.channel().unwrap();
2144 let out_channel = route.output.channel().unwrap();
2145
2146 input_ch.send_message("poison".into()).await.unwrap();
2148
2149 input_ch.send_message("valid".into()).await.unwrap();
2151
2152 let received = tokio::time::timeout(std::time::Duration::from_secs(5), async {
2154 loop {
2155 if let Some(msg) = out_channel.drain_messages().pop() {
2156 return msg;
2157 }
2158 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2159 }
2160 })
2161 .await
2162 .expect("Timed out waiting for valid message to be processed");
2163 assert_eq!(received.get_payload_str(), "valid");
2164 Route::stop("test_invalid_input").await;
2165 }
2166
2167 #[tokio::test(flavor = "multi_thread")]
2168 async fn test_dlq_and_retry_batch_integration() {
2169 use crate::models::{DeadLetterQueueMiddleware, Middleware, RetryMiddleware};
2170 use crate::traits::{MessagePublisher, PublisherError, SentBatch};
2171 use std::collections::HashMap;
2172 use std::sync::Mutex;
2173
2174 #[derive(Clone)]
2176 struct PartialFailPublisher {
2177 attempts: Arc<Mutex<HashMap<u128, usize>>>,
2178 }
2179
2180 #[async_trait::async_trait]
2181 impl MessagePublisher for PartialFailPublisher {
2182 async fn send_batch(
2183 &self,
2184 messages: Vec<CanonicalMessage>,
2185 ) -> Result<SentBatch, PublisherError> {
2186 let mut failed = Vec::new();
2187 let mut attempts = self.attempts.lock().unwrap();
2188
2189 for msg in messages {
2190 let msg_num: u32 = serde_json::from_slice::<serde_json::Value>(&msg.payload)
2191 .unwrap()["id"]
2192 .as_u64()
2193 .unwrap() as u32;
2194
2195 let attempt_count = attempts.entry(msg.message_id).or_insert(0);
2196 *attempt_count += 1;
2197
2198 if msg_num % 2 == 0 {
2199 failed.push((
2201 msg,
2202 PublisherError::Retryable(anyhow::anyhow!("simulated failure")),
2203 ));
2204 }
2205 }
2207
2208 if failed.is_empty() {
2209 Ok(SentBatch::Ack)
2210 } else {
2211 Ok(SentBatch::Partial {
2212 responses: None,
2213 failed,
2214 })
2215 }
2216 }
2217 async fn send(
2218 &self,
2219 _msg: CanonicalMessage,
2220 ) -> Result<crate::traits::Sent, PublisherError> {
2221 unimplemented!()
2222 }
2223 fn as_any(&self) -> &dyn Any {
2224 self
2225 }
2226 }
2227
2228 let in_topic = "batch_retry_dlq_in";
2230 let out_topic = "batch_retry_dlq_out";
2231 let dlq_topic = "batch_retry_dlq_dlq";
2232
2233 let input = Endpoint::new_memory(in_topic, 10);
2234 let dlq_endpoint = Endpoint::new_memory(dlq_topic, 10);
2235
2236 let mock_publisher = PartialFailPublisher {
2237 attempts: Arc::new(Mutex::new(HashMap::new())),
2238 };
2239
2240 let mut output_with_middlewares = Endpoint::new_memory(out_topic, 10);
2241 output_with_middlewares.middlewares = vec![
2242 Middleware::Retry(RetryMiddleware {
2243 max_attempts: 2,
2244 initial_interval_ms: 1,
2245 ..Default::default()
2246 }),
2247 Middleware::Dlq(Box::new(DeadLetterQueueMiddleware {
2248 endpoint: dlq_endpoint.clone(),
2249 })),
2250 ];
2251
2252 let route = Route::new(input.clone(), output_with_middlewares).with_batch_size(4);
2253 let final_publisher = crate::middleware::apply_middlewares_to_publisher(
2255 Box::new(mock_publisher.clone()),
2256 &route.output,
2257 "test_route",
2258 )
2259 .await
2260 .unwrap();
2261
2262 let (work_tx, work_rx) =
2265 async_channel::bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(1);
2266 let (seq_tx, _sequencer_handle) = spawn_sequencer(1);
2267
2268 tokio::spawn(async move {
2270 if let Ok((messages, commit)) = work_rx.recv().await {
2271 let batch_len = messages.len();
2272 match final_publisher.send_batch(messages).await {
2273 Ok(SentBatch::Ack) => {
2274 let _ = commit(vec![MessageDisposition::Ack; batch_len]).await;
2275 }
2276 Ok(SentBatch::Partial { failed, .. }) => {
2277 let dispositions = if failed.is_empty() {
2279 vec![MessageDisposition::Ack; batch_len]
2280 } else {
2281 vec![MessageDisposition::Nack; batch_len]
2284 };
2285 let _ = commit(dispositions).await;
2286 }
2287 Err(_) => {
2288 let _ = commit(vec![MessageDisposition::Nack; batch_len]).await;
2289 }
2290 }
2291 }
2292 });
2293
2294 let mut messages = Vec::new();
2296 for i in 1..=4 {
2297 messages.push(CanonicalMessage::from_json(serde_json::json!({"id": i})).unwrap());
2299 }
2300 let commit = wrap_commit(Box::new(|_| Box::pin(async { Ok(()) })), 0, seq_tx.clone());
2301 work_tx.send((messages, commit)).await.unwrap();
2302
2303 let dlq_channel = dlq_endpoint.channel().unwrap();
2305
2306 let start = std::time::Instant::now();
2307 while dlq_channel.len() < 2 {
2308 if start.elapsed() > std::time::Duration::from_secs(5) {
2309 break;
2310 }
2311 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2312 }
2313
2314 let dlq_msgs = dlq_channel.drain_messages();
2315
2316 assert_eq!(dlq_msgs.len(), 2, "Expected 2 messages to go to DLQ");
2317
2318 let dlq_ids: std::collections::HashSet<u32> = dlq_msgs
2319 .iter()
2320 .map(|m| {
2321 serde_json::from_slice::<serde_json::Value>(&m.payload).unwrap()["id"]
2322 .as_u64()
2323 .unwrap() as u32
2324 })
2325 .collect();
2326
2327 assert!(dlq_ids.contains(&2));
2328 assert!(dlq_ids.contains(&4));
2329
2330 let attempts = mock_publisher.attempts.lock().unwrap();
2332 assert_eq!(attempts.values().filter(|&&c| c == 2).count(), 2);
2334 assert_eq!(attempts.values().filter(|&&c| c == 1).count(), 2);
2336 }
2337
2338 #[tokio::test(flavor = "multi_thread")]
2339 async fn test_route_dlq_integration() {
2340 let unique_id = fast_uuid_v7::gen_id().to_string();
2347 let in_topic = format!("dlq_in_{}", unique_id);
2348 let out_topic = format!("dlq_out_{}", unique_id);
2349 let dlq_topic = format!("dlq_target_{}", unique_id);
2350 let input = Endpoint::new_memory(&in_topic, 10);
2351 let dlq_endpoint = Endpoint::new_memory(&dlq_topic, 10);
2352
2353 let mut output = Endpoint::new_memory(&out_topic, 10);
2354 output.middlewares = vec![
2355 Middleware::RandomPanic(RandomPanicMiddleware {
2357 mode: FaultMode::Timeout, trigger_on_message: None, enabled: true,
2360 ..Default::default()
2361 }),
2362 Middleware::Retry(crate::models::RetryMiddleware {
2364 max_attempts: 2,
2365 initial_interval_ms: 10,
2366 max_interval_ms: 100,
2367 multiplier: 1.0,
2368 }),
2369 Middleware::Dlq(Box::new(crate::models::DeadLetterQueueMiddleware {
2371 endpoint: dlq_endpoint.clone(),
2372 })),
2373 ];
2374
2375 let route = Route::new(input.clone(), output);
2376 route.deploy("test_dlq_integration").await.unwrap();
2377
2378 let input_ch = input.channel().unwrap();
2380 input_ch.send_message("fail_msg".into()).await.unwrap();
2381
2382 let dlq_ch = dlq_endpoint.channel().unwrap();
2387
2388 let received = tokio::time::timeout(std::time::Duration::from_secs(5), async {
2390 loop {
2391 let batch = dlq_ch.drain_messages();
2392 if !batch.is_empty() {
2393 return batch[0].clone();
2394 }
2395 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2396 }
2397 })
2398 .await
2399 .expect("Timed out waiting for DLQ");
2400
2401 assert_eq!(received.get_payload_str(), "fail_msg");
2402
2403 let out_ch_target = mq_bridge::endpoints::memory::get_or_create_channel(
2404 &mq_bridge::models::MemoryConfig::new(&out_topic, None),
2405 );
2406 assert!(out_ch_target.is_empty(), "Message should not reach target");
2407
2408 Route::stop("test_dlq_integration").await;
2409 }
2410
2411 #[tokio::test(flavor = "multi_thread")]
2412 async fn test_large_message_handling() {
2413 let unique_id = fast_uuid_v7::gen_id().to_string();
2414 let in_topic = format!("large_in_{}", unique_id);
2415 let out_topic = format!("large_out_{}", unique_id);
2416
2417 let input = Endpoint::new_memory(&in_topic, 5); let output = Endpoint::new_memory(&out_topic, 5);
2419
2420 let route = Route::new(input.clone(), output.clone());
2421 route.deploy("test_large_msg").await.unwrap();
2422
2423 let large_payload = vec![b'x'; 5 * 1024 * 1024]; let input_ch = input.channel().unwrap();
2425
2426 input_ch
2427 .send_message(large_payload.clone().into())
2428 .await
2429 .unwrap();
2430
2431 let mut verifier = route.connect_to_output("verifier").await.unwrap();
2432 let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
2433 .await
2434 .expect("Timed out receiving large message")
2435 .unwrap();
2436
2437 assert_eq!(received.message.payload.len(), large_payload.len());
2438 assert_eq!(received.message.payload, large_payload.as_slice());
2439
2440 Route::stop("test_large_msg").await;
2441 }
2442
2443 #[test]
2444 fn test_map_responses_to_dispositions_unit() {
2445 test_map_responses_to_dispositions_logic();
2446 }
2447}