1use crate::endpoints::{create_consumer_from_route, create_publisher_from_route};
7pub use crate::models::Route;
8use crate::models::{Endpoint, RouteOptions};
9use crate::traits::{
10 BatchCommitFunc, ConsumerError, Handler, HandlerError, MessageDisposition, PublisherError,
11 SentBatch,
12};
13use async_channel::{bounded, Sender};
14use serde::de::DeserializeOwned;
15use std::collections::{BTreeMap, HashMap};
16use std::sync::{Arc, OnceLock, RwLock};
17use tokio::{
18 select,
19 sync::Semaphore,
20 task::{JoinHandle, JoinSet},
21};
22use tracing::{debug, error, info, warn};
23
24pub use crate::extensions::{
26 get_endpoint_factory, get_middleware_factory, register_endpoint_factory,
27 register_middleware_factory,
28};
29
30#[derive(Debug)]
31pub struct RouteHandle((JoinHandle<()>, Sender<()>));
32
33impl RouteHandle {
34 pub async fn stop(&self) {
35 let _ = self.0 .1.send(()).await;
36 self.0 .1.close();
37 }
38
39 pub async fn join(self) -> Result<(), tokio::task::JoinError> {
40 self.0 .0.await
41 }
42}
43
44impl From<(JoinHandle<()>, Sender<()>)> for RouteHandle {
45 fn from(tuple: (JoinHandle<()>, Sender<()>)) -> Self {
46 RouteHandle(tuple)
47 }
48}
49
50struct ActiveRoute {
51 route: Route,
52 handle: RouteHandle,
53}
54
55static ROUTE_REGISTRY: OnceLock<RwLock<HashMap<String, ActiveRoute>>> = OnceLock::new();
56
57impl Route {
58 pub fn new(input: Endpoint, output: Endpoint) -> Self {
64 Self {
65 input,
66 output,
67 ..Default::default()
68 }
69 }
70
71 pub fn get(name: &str) -> Option<Self> {
73 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
74 let map = registry.read().expect("Route registry lock poisoned");
75 map.get(name).map(|active| active.route.clone())
76 }
77
78 pub fn list() -> Vec<String> {
80 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
81 let map = registry.read().expect("Route registry lock poisoned");
82 map.keys().cloned().collect()
83 }
84
85 pub async fn deploy(&self, name: &str) -> anyhow::Result<()> {
99 Self::stop(name).await;
100
101 let handle = self.run(name).await?;
102 let active = ActiveRoute {
103 route: self.clone(),
104 handle,
105 };
106
107 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
108 let mut map = registry.write().expect("Route registry lock poisoned");
109 map.insert(name.to_string(), active);
110 Ok(())
111 }
112
113 pub async fn stop(name: &str) -> bool {
115 let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
116 let active_opt = {
117 let mut map = registry.write().expect("Route registry lock poisoned");
118 map.remove(name)
119 };
120
121 if let Some(active) = active_opt {
122 active.handle.stop().await;
123 let _ = active.handle.join().await;
124 true
125 } else {
126 false
127 }
128 }
129
130 pub async fn create_publisher(&self) -> anyhow::Result<crate::Publisher> {
145 crate::Publisher::new(self.output.clone()).await
146 }
147
148 pub async fn connect_to_output(
151 &self,
152 name: &str,
153 ) -> anyhow::Result<Box<dyn crate::traits::MessageConsumer>> {
154 create_consumer_from_route(name, &self.output).await
155 }
156
157 pub fn check(&self, name: &str, allowed_endpoints: Option<&[&str]>) -> anyhow::Result<()> {
163 crate::endpoints::check_consumer(name, &self.input, allowed_endpoints)?;
164 crate::endpoints::check_publisher(name, &self.output, allowed_endpoints)?;
165 Ok(())
166 }
167
168 pub async fn run(&self, name_str: &str) -> anyhow::Result<RouteHandle> {
196 self.check(name_str, None)?;
197 let (shutdown_tx, shutdown_rx) = bounded(1);
198 let (ready_tx, ready_rx) = bounded(1);
199 let route = Arc::new(self.clone());
201 let name = Arc::new(name_str.to_string());
202
203 let handle = tokio::spawn(async move {
204 loop {
205 let route_arc = Arc::clone(&route);
206 let name_arc = Arc::clone(&name);
207 let (internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
211 let ready_tx_clone = ready_tx.clone();
212
213 let mut run_task = tokio::spawn(async move {
215 route_arc
216 .run_until_err(&name_arc, Some(internal_shutdown_rx), Some(ready_tx_clone))
217 .await
218 });
219
220 select! {
221 _ = shutdown_rx.recv() => {
222 info!("Shutdown signal received for route '{}'.", name);
223 let _ = internal_shutdown_tx.send(()).await;
225 let _ = run_task.await;
227 break;
228 }
229 res = &mut run_task => {
230 match res {
231 Ok(Ok(should_continue)) if !should_continue => {
232 info!("Route '{}' completed gracefully. Shutting down.", name);
233 break;
234 }
235 Ok(Err(e)) => {
236 error!("Route '{}' failed: {}. Reconnecting in 5 seconds...", name, e);
237 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
238 }
239 Err(e) => {
240 error!("Route '{}' task panicked: {}. Reconnecting in 5 seconds...", name, e);
241 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
242 }
243 _ => {} }
245 }
246 }
247 }
248 });
249
250 match tokio::time::timeout(std::time::Duration::from_secs(5), ready_rx.recv()).await {
251 Ok(Ok(_)) => Ok(RouteHandle((handle, shutdown_tx))),
252 _ => {
253 handle.abort();
254 Err(anyhow::anyhow!(
255 "Route '{}' failed to start within 5 seconds or encountered an error",
256 name_str
257 ))
258 }
259 }
260 }
261
262 pub async fn run_until_err(
264 &self,
265 name: &str,
266 shutdown_rx: Option<async_channel::Receiver<()>>,
267 ready_tx: Option<Sender<()>>,
268 ) -> anyhow::Result<bool> {
269 let (_internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
270 let shutdown_rx = shutdown_rx.unwrap_or(internal_shutdown_rx);
271 if self.options.concurrency == 1 {
272 self.run_sequentially(name, shutdown_rx, ready_tx).await
273 } else {
274 self.run_concurrently(name, shutdown_rx, ready_tx).await
275 }
276 }
277
278 async fn run_sequentially(
280 &self,
281 name: &str,
282 shutdown_rx: async_channel::Receiver<()>,
283 ready_tx: Option<Sender<()>>,
284 ) -> anyhow::Result<bool> {
285 let publisher = create_publisher_from_route(name, &self.output).await?;
286 let mut consumer = create_consumer_from_route(name, &self.input).await?;
287 let (err_tx, err_rx) = bounded(1);
288 let commit_semaphore = Arc::new(Semaphore::new(self.options.commit_concurrency_limit));
289 let mut commit_tasks = JoinSet::new();
290
291 let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
293 let mut seq_counter = 0u64;
294
295 if let Some(tx) = ready_tx {
296 let _ = tx.send(()).await;
297 }
298 let run_result = loop {
299 select! {
300 Ok(err) = err_rx.recv() => break Err(err),
301
302 _ = shutdown_rx.recv() => {
303 info!("Shutdown signal received in sequential runner for route '{}'.", name);
304 break Ok(true); }
306 res = consumer.receive_batch(self.options.batch_size) => {
307 let received_batch = match res {
308 Ok(batch) => {
309 if batch.messages.is_empty() {
310 continue; }
312 batch
313 }
314 Err(ConsumerError::EndOfStream) => {
315 info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
316 break Ok(false); }
318 Err(ConsumerError::Connection(e)) => {
319 break Err(e);
321 },
322 Err(ConsumerError::Gap { requested, base }) => {
323 break Err(anyhow::anyhow!("Consumer gap: requested offset {requested} but earliest available is {base}"));
325 }
326 };
327 debug!("Received a batch of {} messages sequentially", received_batch.messages.len());
328
329 let seq = seq_counter;
331 seq_counter += 1;
332 let commit = wrap_commit(received_batch.commit, seq, seq_tx.clone());
333 let batch_len = received_batch.messages.len();
334
335 match publisher.send_batch(received_batch.messages).await {
336 Ok(SentBatch::Ack) => {
337 let permit = commit_semaphore.clone().acquire_owned().await.map_err(|e| anyhow::anyhow!("Semaphore error: {}", e))?;
338 let err_tx = err_tx.clone();
339 commit_tasks.spawn(async move {
340 if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
341 error!("Commit failed: {}", e);
342 let _ = err_tx.send(e).await;
343 }
344 drop(permit);
346 });
347 }
348 Ok(SentBatch::Partial { responses, failed }) => {
349 let has_retryable = failed.iter().any(|(_, e)| matches!(e, PublisherError::Retryable(_)));
350 if has_retryable {
351 let failed_count = failed.len();
352 let (_, first_error) = failed
353 .into_iter()
354 .find(|(_, e)| matches!(e, PublisherError::Retryable(_)))
355 .expect("has_retryable is true");
356 break Err(anyhow::anyhow!(
357 "Failed to send {} messages in batch. First retryable error: {}",
358 failed_count,
359 first_error
360 ));
361 }
362 for (msg, e) in &failed {
363 error!("Dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
364 }
365 let permit = commit_semaphore.clone().acquire_owned().await.map_err(|e| anyhow::anyhow!("Semaphore error: {}", e))?;
366 let err_tx = err_tx.clone();
367 commit_tasks.spawn(async move {
368 let dispositions = map_responses_to_dispositions(batch_len, responses, &failed);
369 if let Err(e) = commit(dispositions).await {
370 error!("Commit failed: {}", e);
371 let _ = err_tx.send(e).await;
372 }
373 drop(permit);
374 });
375 }
376 Err(e) => break Err(e.into()), }
378 }
379 }
380 };
381
382 drop(seq_tx);
383 loop {
385 select! {
386 res = err_rx.recv() => {
387 if let Ok(err) = res {
388 error!("Error reported during shutdown: {}", err);
389 }
390 }
391 res = commit_tasks.join_next() => {
392 if res.is_none() {
393 break;
394 }
395 }
396 }
397 }
398 drop(err_rx);
399 let _ = sequencer_handle.await;
400 run_result
401 }
402
403 async fn run_concurrently(
405 &self,
406 name: &str,
407 shutdown_rx: async_channel::Receiver<()>,
408 ready_tx: Option<Sender<()>>,
409 ) -> anyhow::Result<bool> {
410 let publisher = create_publisher_from_route(name, &self.output).await?;
411 let mut consumer = create_consumer_from_route(name, &self.input).await?;
412 if let Some(tx) = ready_tx {
413 let _ = tx.send(()).await;
414 }
415 let (err_tx, err_rx) = bounded(1); let work_capacity = self
418 .options
419 .concurrency
420 .saturating_mul(self.options.batch_size);
421 let (work_tx, work_rx) =
422 bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(work_capacity);
423 let commit_semaphore = Arc::new(Semaphore::new(self.options.commit_concurrency_limit));
424
425 let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.concurrency * 2);
429
430 let mut join_set = JoinSet::new();
432 for i in 0..self.options.concurrency {
433 let work_rx_clone = work_rx.clone();
434 let publisher = Arc::clone(&publisher);
435 let err_tx = err_tx.clone();
436 let commit_semaphore = commit_semaphore.clone();
437 let mut commit_tasks = JoinSet::new();
438 join_set.spawn(async move {
439 debug!("Starting worker {}", i);
440 while let Ok((messages, commit)) = work_rx_clone.recv().await {
441 let batch_len = messages.len();
442 match publisher.send_batch(messages).await {
443 Ok(SentBatch::Ack) => {
444 let permit = match commit_semaphore.clone().acquire_owned().await {
445 Ok(p) => p,
446 Err(_) => {
447 warn!("Semaphore closed, worker exiting");
448 break;
449 }
450 };
451 let err_tx = err_tx.clone();
452 commit_tasks.spawn(async move {
453 if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
454 error!("Commit failed: {}", e);
455 let _ = err_tx.send(e).await;
456 }
457 drop(permit);
458 });
459 }
460 Ok(SentBatch::Partial { responses, failed }) => {
461 let has_retryable = failed.iter().any(|(_, e)| matches!(e, PublisherError::Retryable(_)));
462 if has_retryable {
463 let failed_count = failed.len();
464 let (_, first_error) = failed
465 .into_iter()
466 .find(|(_, e)| matches!(e, PublisherError::Retryable(_)))
467 .expect("has_retryable is true");
468 let e = anyhow::anyhow!(
469 "Failed to send {} messages in batch. First retryable error: {}",
470 failed_count,
471 first_error
472 );
473 error!("Worker failed to send message batch: {}", e);
474 if err_tx.send(e).await.is_err() {
475 warn!("Could not send error to main task, it might be down.");
476 }
477 break; }
479 for (msg, e) in &failed {
480 error!("Worker dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
481 }
482 let permit = match commit_semaphore.clone().acquire_owned().await {
483 Ok(p) => p,
484 Err(_) => {
485 warn!("Semaphore closed, worker exiting");
486 break;
487 }
488 };
489 let err_tx = err_tx.clone();
490 commit_tasks.spawn(async move {
491 let dispositions = map_responses_to_dispositions(batch_len, responses, &failed);
492 if let Err(e) = commit(dispositions).await {
493 error!("Commit failed: {}", e);
494 let _ = err_tx.send(e).await;
495 }
496 drop(permit);
497 });
498 }
499 Err(e) => {
500 error!("Worker failed to send message batch: {}", e);
501 if err_tx.send(e.into()).await.is_err() {
503 warn!("Could not send error to main task, it might be down.");
504 }
505 break;
506 }
507 }
508 }
509 while commit_tasks.join_next().await.is_some() {}
511 });
512 }
513
514 let mut seq_counter = 0u64;
515 loop {
516 select! {
517 biased; Ok(err) = err_rx.recv() => {
520 error!("A worker reported a critical error. Shutting down route.");
521 return Err(err);
522 }
523
524 Some(res) = join_set.join_next() => {
525 match res {
526 Ok(_) => {
527 error!("A worker task finished unexpectedly. Shutting down route.");
528 return Err(anyhow::anyhow!("Worker task finished unexpectedly"));
529 }
530 Err(e) => {
531 error!("A worker task panicked: {}. Shutting down route.", e);
532 return Err(e.into());
533 }
534 }
535 }
536
537 _ = shutdown_rx.recv() => {
538 info!("Shutdown signal received in concurrent runner for route '{}'.", name);
539 break;
540 }
541
542 res = consumer.receive_batch(self.options.batch_size) => {
543 let (messages, commit) = match res {
544 Ok(batch) => {
545 if batch.messages.is_empty() {
546 continue; }
548 (batch.messages, batch.commit)
549 }
550 Err(ConsumerError::EndOfStream) => {
551 info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
552 break; }
554 Err(ConsumerError::Connection(e)) => {
555 return Err(e);
557 }
558 Err(ConsumerError::Gap { requested, base }) => {
559 return Err(ConsumerError::Gap { requested, base }.into());
561 }
562 };
563 debug!("Received a batch of {} messages concurrently", messages.len());
564
565 let seq = seq_counter;
567 seq_counter += 1;
568 let wrapped_commit = wrap_commit(commit, seq, seq_tx.clone());
569
570 if work_tx.send((messages, wrapped_commit)).await.is_err() {
571 warn!("Work channel closed, cannot process more messages concurrently. Shutting down.");
572 break;
573 }
574 }
575 }
576 }
577
578 drop(work_tx);
581 while join_set.join_next().await.is_some() {}
583
584 drop(seq_tx);
586 let _ = sequencer_handle.await;
587
588 if let Ok(err) = err_rx.try_recv() {
589 return Err(err);
590 }
591
592 Ok(shutdown_rx.is_empty())
595 }
596
597 pub fn with_options(mut self, options: RouteOptions) -> Self {
598 self.options = options;
599 self
600 }
601
602 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
603 self.options.concurrency = concurrency.max(1);
604 self
605 }
606
607 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
608 self.options.batch_size = batch_size.max(1);
609 self
610 }
611
612 pub fn with_commit_concurrency_limit(mut self, limit: usize) -> Self {
613 self.options.commit_concurrency_limit = limit.max(1);
614 self
615 }
616
617 pub fn with_handler(mut self, handler: impl Handler + 'static) -> Self {
618 self.output.handler = Some(Arc::new(handler));
619 self
620 }
621
622 pub fn add_handler<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
648 where
649 T: DeserializeOwned + Send + Sync + 'static,
650 H: crate::type_handler::IntoTypedHandler<T, Args>,
651 Args: Send + Sync + 'static,
652 {
653 let handler = Arc::new(handler);
655 let wrapper = move |msg: crate::CanonicalMessage| {
656 let handler = handler.clone();
657 async move {
658 let data = msg.parse::<T>().map_err(|e| {
659 HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
660 })?;
661 let ctx = crate::MessageContext::from(msg);
662 handler.call(data, ctx).await
663 }
664 };
665 let wrapper = Arc::new(wrapper);
666
667 let prev_handler = self.output.handler.take();
668
669 let new_handler = if let Some(h) = prev_handler {
670 if let Some(extended) = h.register_handler(type_name, wrapper.clone()) {
671 extended
672 } else {
673 Arc::new(
674 crate::type_handler::TypeHandler::new()
675 .with_fallback(h)
676 .add_handler(type_name, wrapper),
677 )
678 }
679 } else {
680 Arc::new(crate::type_handler::TypeHandler::new().add_handler(type_name, wrapper))
681 };
682
683 self.output.handler = Some(new_handler);
684 self
685 }
686 pub fn add_handlers<T, H, Args>(mut self, handlers: HashMap<&str, H>) -> Self
687 where
688 T: DeserializeOwned + Send + Sync + 'static,
689 H: crate::type_handler::IntoTypedHandler<T, Args>,
690 Args: Send + Sync + 'static,
691 {
692 for (type_name, handler) in handlers {
693 self = self.add_handler(type_name, handler);
694 }
695 self
696 }
697}
698
699type SequencerItem = (
700 Vec<MessageDisposition>,
701 BatchCommitFunc,
702 tokio::sync::oneshot::Sender<anyhow::Result<()>>,
703);
704
705fn spawn_sequencer(buffer_size: usize) -> (Sender<(u64, SequencerItem)>, JoinHandle<()>) {
706 let (seq_tx, seq_rx) = bounded::<(u64, SequencerItem)>(buffer_size);
707
708 let sequencer_handle = tokio::spawn(async move {
709 let mut buffer: BTreeMap<u64, SequencerItem> = BTreeMap::new();
710 let mut next_seq = 0u64;
711 let mut deadline: Option<tokio::time::Instant> = None;
712 const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
713
714 loop {
715 while let Some((dispositions, commit_func, notify)) = buffer.remove(&next_seq) {
716 let res = commit_func(dispositions).await;
717 let _ = notify.send(res);
718 next_seq += 1;
719 }
720
721 if !buffer.is_empty() {
722 if deadline.is_none() {
723 deadline = Some(tokio::time::Instant::now() + TIMEOUT);
724 }
725 } else {
726 deadline = None;
727 }
728
729 let timeout_fut = async {
730 if let Some(d) = deadline {
731 tokio::time::sleep_until(d).await
732 } else {
733 std::future::pending().await
734 }
735 };
736
737 select! {
738 res = seq_rx.recv() => {
739 match res {
740 Ok((seq, item)) => {
741 if seq < next_seq {
742 let (_, _, notify) = item;
743 let _ = notify.send(Err(anyhow::anyhow!("Sequencer received late item (seq {} < next_seq {})", seq, next_seq)));
744 } else {
745 buffer.insert(seq, item);
746 }
747 }
748 Err(_) => {
749 for (_, (_, _, notify)) in std::mem::take(&mut buffer) {
750 let _ = notify.send(Err(anyhow::anyhow!("Sequencer shutting down")));
751 }
752 break;
753 }
754 }
755 }
756 _ = timeout_fut => {
757 if let Some(&first_seq) = buffer.keys().next() {
758 if first_seq > next_seq {
759 warn!("Sequencer timed out waiting for seq {}. Jumping to {}.", next_seq, first_seq);
760 next_seq = first_seq;
761 } else {
762 next_seq += 1;
763 }
764 } else {
765 next_seq += 1;
766 }
767 deadline = None;
768 }
769 }
770 }
771 });
772 (seq_tx, sequencer_handle)
773}
774
775fn wrap_commit(
776 commit: BatchCommitFunc,
777 seq: u64,
778 seq_tx: Sender<(u64, SequencerItem)>,
779) -> BatchCommitFunc {
780 Box::new(move |dispositions| {
781 Box::pin(async move {
782 let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
783 if seq_tx
785 .send((seq, (dispositions, commit, notify_tx)))
786 .await
787 .is_ok()
788 {
789 match notify_rx.await {
791 Ok(res) => res,
792 Err(_) => Err(anyhow::anyhow!(
793 "Sequencer dropped the commit channel unexpectedly"
794 )),
795 }
796 } else {
797 Err(anyhow::anyhow!(
798 "Failed to send commit to sequencer, route is likely shutting down"
799 ))
800 }
801 })
802 })
803}
804
805fn map_responses_to_dispositions(
806 total_count: usize,
807 responses: Option<Vec<crate::CanonicalMessage>>,
808 failed: &[(crate::CanonicalMessage, PublisherError)],
809) -> Vec<MessageDisposition> {
810 if failed.is_empty() {
811 if let Some(resps) = responses {
812 if resps.len() == total_count {
813 return resps.into_iter().map(MessageDisposition::Reply).collect();
814 }
815 } else {
816 return vec![MessageDisposition::Ack; total_count];
818 }
819 }
820
821 vec![MessageDisposition::Ack; total_count]
843}
844
845pub fn get_route(name: &str) -> Option<Route> {
846 Route::get(name)
847}
848
849pub fn list_routes() -> Vec<String> {
850 Route::list()
851}
852
853pub async fn stop_route(name: &str) -> bool {
854 Route::stop(name).await
855}
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860 use crate::models::{Endpoint, Middleware};
861 use crate::traits::{CustomMiddlewareFactory, MessageConsumer, ReceivedBatch};
862 use std::any::Any;
863 use std::sync::atomic::{AtomicBool, Ordering};
864 use std::sync::Arc;
865
866 #[derive(Debug)]
867 struct PanicMiddlewareFactory {
868 should_panic: Arc<AtomicBool>,
869 }
870
871 #[async_trait::async_trait]
872 impl CustomMiddlewareFactory for PanicMiddlewareFactory {
873 async fn apply_consumer(
874 &self,
875 consumer: Box<dyn MessageConsumer>,
876 _route_name: &str,
877 _config: &serde_json::Value,
878 ) -> anyhow::Result<Box<dyn MessageConsumer>> {
879 Ok(Box::new(PanicConsumer {
880 inner: consumer,
881 should_panic: self.should_panic.clone(),
882 }))
883 }
884 }
885
886 struct PanicConsumer {
887 inner: Box<dyn MessageConsumer>,
888 should_panic: Arc<AtomicBool>,
889 }
890
891 #[async_trait::async_trait]
892 impl MessageConsumer for PanicConsumer {
893 async fn receive_batch(
894 &mut self,
895 max_messages: usize,
896 ) -> Result<ReceivedBatch, ConsumerError> {
897 if self
899 .should_panic
900 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
901 .is_ok()
902 {
903 panic!("Simulated panic for testing recovery");
904 }
905 self.inner.receive_batch(max_messages).await
906 }
907
908 fn as_any(&self) -> &dyn Any {
909 self
910 }
911 }
912
913 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
914 #[ignore = "Takes too much time for regular tests"]
915 async fn test_route_recovery_from_panic() {
916 let unique_suffix = fast_uuid_v7::gen_id().to_string();
918 let in_topic = format!("panic_in_{}", unique_suffix);
919 let out_topic = format!("panic_out_{}", unique_suffix);
920
921 let should_panic = Arc::new(AtomicBool::new(true));
922 let factory = PanicMiddlewareFactory {
923 should_panic: should_panic.clone(),
924 };
925 register_middleware_factory("panic_factory", Arc::new(factory));
926
927 let input = Endpoint::new_memory(&in_topic, 10).add_middleware(Middleware::Custom {
928 name: "panic_factory".to_string(),
929 config: serde_json::Value::Null,
930 });
931 let output = Endpoint::new_memory(&out_topic, 10);
932
933 let route = Route::new(input.clone(), output.clone());
934
935 route
937 .deploy("panic_test")
938 .await
939 .expect("Failed to deploy route");
940 let input_ch = input.channel().unwrap();
942 input_ch
943 .send_message("persistent_msg".into())
944 .await
945 .unwrap();
946
947 let panic_wait_start = std::time::Instant::now();
950 while panic_wait_start.elapsed() < std::time::Duration::from_secs(5) {
951 if !should_panic.load(Ordering::SeqCst) {
952 break;
953 }
954 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
955 }
956 assert!(
957 !should_panic.load(Ordering::SeqCst),
958 "Route should have panicked"
959 );
960
961 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
964
965 let mut verifier = route.connect_to_output("verifier").await.unwrap();
967 let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
968 .await
969 .expect("Timed out waiting for message after recovery")
970 .expect("Stream closed");
971
972 assert_eq!(received.message.get_payload_str(), "persistent_msg");
973 (received.commit)(MessageDisposition::Ack).await.unwrap();
975
976 Route::stop("panic_test").await;
978 }
979}