Skip to main content

mq_bridge/
route.rs

1//  mq-bridge
2//  © Copyright 2025, by Marco Mengelkoch
3//  Licensed under MIT License, see License file for more details
4//  git clone https://github.com/marcomq/mq-bridge
5
6use crate::endpoints::{create_consumer_from_route, create_publisher_from_route};
7use crate::errors::ProcessingError;
8pub use crate::models::Route;
9use crate::models::{Endpoint, EndpointType, RouteOptions};
10use crate::traits::{
11    BatchCommitFunc, ConsumerError, Handler, HandlerError, MessageConsumer, MessageDisposition,
12    MessagePublisher, PublisherError, SentBatch,
13};
14use async_channel::{bounded, Sender};
15use serde::de::DeserializeOwned;
16use std::collections::{BTreeMap, HashMap};
17use std::sync::{Arc, OnceLock, RwLock};
18use tokio::{
19    select,
20    task::{JoinHandle, JoinSet},
21};
22use tracing::{debug, error, info, trace, warn};
23
24// Re-export extensions for backward compatibility and internal usage
25pub use crate::extensions::{
26    get_endpoint_factory, get_middleware_factory, register_endpoint_factory,
27    register_middleware_factory,
28};
29
30#[derive(Debug)]
31pub struct RouteHandle((JoinHandle<()>, Sender<()>));
32
33impl RouteHandle {
34    pub async fn stop(&self) {
35        let _ = self.0 .1.send(()).await;
36        self.0 .1.close();
37    }
38
39    pub async fn join(self) -> Result<(), tokio::task::JoinError> {
40        self.0 .0.await
41    }
42}
43
44async fn run_publisher_connect_hook(
45    route_name: &str,
46    publisher: &Arc<dyn MessagePublisher>,
47) -> anyhow::Result<()> {
48    if let Some(hook) = publisher.on_connect_hook() {
49        hook.await.map_err(|err| {
50            anyhow::anyhow!(
51                "Publisher on_connect hook failed for route '{}': {}",
52                route_name,
53                err
54            )
55        })?;
56    }
57    Ok(())
58}
59
60async fn run_consumer_connect_hook(
61    route_name: &str,
62    consumer: &dyn MessageConsumer,
63) -> anyhow::Result<()> {
64    if let Some(hook) = consumer.on_connect_hook() {
65        hook.await.map_err(|err| {
66            anyhow::anyhow!(
67                "Consumer on_connect hook failed for route '{}': {}",
68                route_name,
69                err
70            )
71        })?;
72    }
73    Ok(())
74}
75
76async fn run_publisher_disconnect_hook(route_name: &str, publisher: &Arc<dyn MessagePublisher>) {
77    if let Some(hook) = publisher.on_disconnect_hook() {
78        if let Err(err) = hook.await {
79            warn!(
80                "Publisher on_disconnect hook failed for route '{}': {}",
81                route_name, err
82            );
83        }
84    }
85}
86
87async fn run_consumer_disconnect_hook(route_name: &str, consumer: &dyn MessageConsumer) {
88    if let Some(hook) = consumer.on_disconnect_hook() {
89        if let Err(err) = hook.await {
90            warn!(
91                "Consumer on_disconnect hook failed for route '{}': {}",
92                route_name, err
93            );
94        }
95    }
96}
97
98impl From<(JoinHandle<()>, Sender<()>)> for RouteHandle {
99    fn from(tuple: (JoinHandle<()>, Sender<()>)) -> Self {
100        RouteHandle(tuple)
101    }
102}
103
104struct ActiveRoute {
105    route: Route,
106    handle: RouteHandle,
107}
108
109static ROUTE_REGISTRY: OnceLock<RwLock<HashMap<String, ActiveRoute>>> = OnceLock::new();
110static ENDPOINT_REF_REGISTRY: OnceLock<RwLock<HashMap<String, Endpoint>>> = OnceLock::new();
111
112/// Registers a named endpoint that can be referenced by other endpoints using `ref: "name"`.
113/// This will overwrite any existing endpoint with the same name.
114pub fn register_endpoint(name: &str, endpoint: Endpoint) {
115    let registry = ENDPOINT_REF_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
116    let mut writer = registry
117        .write()
118        .expect("Named endpoint registry lock poisoned");
119    if writer.insert(name.to_string(), endpoint).is_some() {
120        debug!("Overwriting a registered endpoint named '{}'", name);
121    }
122}
123
124/// Retrieves a registered endpoint by name.
125pub fn get_endpoint(name: &str) -> Option<Endpoint> {
126    let registry = ENDPOINT_REF_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
127    let reader = registry
128        .read()
129        .expect("Named endpoint registry lock poisoned");
130    reader.get(name).cloned()
131}
132
133impl Route {
134    /// Creates a new route with default concurrency (1) and batch size (128).
135    ///
136    /// # Arguments
137    /// * `input` - The input/source endpoint for the route
138    /// * `output` - The output/sink endpoint for the route
139    pub fn new(input: Endpoint, output: Endpoint) -> Self {
140        Self {
141            input,
142            output,
143            ..Default::default()
144        }
145    }
146
147    /// Retrieves a registered (and running) route by name.
148    pub fn get(name: &str) -> Option<Self> {
149        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
150        let map = registry.read().expect("Route registry lock poisoned");
151        map.get(name).map(|active| active.route.clone())
152    }
153
154    /// Returns a list of all registered route names.
155    pub fn list() -> Vec<String> {
156        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
157        let map = registry.read().expect("Route registry lock poisoned");
158        map.keys().cloned().collect()
159    }
160
161    /// Returns true if the input is of type ref (and the output isn't)
162    pub fn is_ref(&self) -> bool {
163        matches!(self.input.endpoint_type, EndpointType::Ref(_))
164            && !matches!(self.output.endpoint_type, EndpointType::Ref(_))
165    }
166
167    /// Registers the route's output endpoint under the given name.
168    /// This allows other routes to reference this output using `ref: "name"`.
169    pub fn register_output_endpoint(&self, name: Option<&str>) -> Result<(), anyhow::Error> {
170        match name {
171            Some(name) => {
172                register_endpoint(name, self.output.clone());
173            }
174            None => {
175                if let EndpointType::Ref(name) = &self.input.endpoint_type {
176                    register_endpoint(name, self.output.clone());
177                } else {
178                    return Err(anyhow::anyhow!(
179                        "No name and input is not a reference endpoint"
180                    ));
181                }
182            }
183        };
184        Ok(())
185    }
186
187    /// Registers the route and starts it.
188    /// If a route with the same name is already running, it will be stopped first.
189    ///    
190    /// # Examples
191    /// ```
192    /// use mq_bridge::{Route, models::Endpoint};
193    ///
194    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10));
195    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
196    /// route.deploy("global_route").await.unwrap();
197    /// assert!(Route::get("global_route").is_some());
198    /// # });
199    /// ```
200    pub async fn deploy(&self, name: &str) -> anyhow::Result<()> {
201        Self::stop(name).await;
202
203        let handle = self.run(name).await?;
204        let active = ActiveRoute {
205            route: self.clone(),
206            handle,
207        };
208
209        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
210        let mut map = registry.write().expect("Route registry lock poisoned");
211        map.insert(name.to_string(), active);
212        Ok(())
213    }
214
215    /// Stops a running route by name and removes it from the registry.
216    /// Waits up to 5 seconds for the route task to join; if the timeout elapses
217    /// the task is aborted and the implementation awaits the aborted handle to
218    /// ensure the background task has fully terminated before returning.
219    pub async fn stop(name: &str) -> bool {
220        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
221        let active_opt = {
222            let mut map = registry.write().expect("Route registry lock poisoned");
223            map.remove(name)
224        };
225
226        if let Some(active) = active_opt {
227            // Move the handle out so we can operate on its internals.
228            let handle = active.handle;
229
230            // Signal the route to stop and close the shutdown channel.
231            let _ = handle.0 .1.send(()).await;
232            handle.0 .1.close();
233
234            // Extract the JoinHandle so we can monitor and, if needed, abort it.
235            let mut join_handle = handle.0 .0;
236            tokio::select! {
237                res = &mut join_handle => {
238                    // The task finished naturally within the 5s window
239                    let _ = res;
240                }
241                _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
242                    // The 5s timer finished first - abort the task to ensure it doesn't linger.
243                    join_handle.abort();
244                    // Await the handle one last time to ensure the task has fully shut down.
245                    let _ = join_handle.await;
246                }
247            }
248
249            true
250        } else {
251            false
252        }
253    }
254
255    /// Creates a new Publisher configured for this route's output.
256    /// This is useful if you want to send messages to the same destination as this route.
257    ///
258    /// # Examples
259    ///
260    /// ```
261    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
262    /// use mq_bridge::{Route, models::Endpoint};
263    ///
264    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10));
265    /// let publisher = route.create_publisher().await;
266    /// assert!(publisher.is_ok());
267    /// # });
268    /// ```
269    pub async fn create_publisher(&self) -> anyhow::Result<crate::Publisher> {
270        crate::Publisher::new(self.output.clone()).await
271    }
272
273    /// Creates a consumer connected to the route's output.
274    /// This is primarily useful for integration tests to verify messages reaching the destination.
275    pub async fn connect_to_output(
276        &self,
277        name: &str,
278    ) -> anyhow::Result<Box<dyn crate::traits::MessageConsumer>> {
279        create_consumer_from_route(name, &self.output).await
280    }
281
282    /// Validates the route configuration, checking if endpoints are supported and correctly configured.
283    /// Core types like file, memory, and response are always supported.
284    /// # Arguments
285    /// * `name` - The name of the route
286    /// * `allowed_endpoints` - An optional list of allowed endpoint types
287    pub fn check(
288        &self,
289        name: &str,
290        allowed_endpoints: Option<&[&str]>,
291    ) -> anyhow::Result<Vec<String>> {
292        let mut warnings = Vec::new();
293        warnings.extend(crate::endpoints::check_consumer(
294            name,
295            &self.input,
296            allowed_endpoints,
297        )?);
298        warnings.extend(crate::endpoints::check_publisher(
299            name,
300            &self.output,
301            allowed_endpoints,
302        )?);
303        Ok(warnings)
304    }
305
306    /// Runs the message processing route with concurrency, error handling, and graceful shutdown.
307    ///
308    /// This function spawns the necessary background tasks to process messages. It waits asynchronously
309    /// until the route is successfully initialized (i.e., connections are established) or until
310    /// a timeout occurs.
311    /// The name_str parameter is just used for logging and tracing.
312    ///
313    /// It returns a `JoinHandle` for the main route task and a `Sender` channel
314    /// that can be used to signal a graceful shutdown. The result is typically converted into a
315    /// [`RouteHandle`] for easier management.
316    ///
317    /// # Examples
318    ///
319    /// ```no_run
320    /// # use mq_bridge::{Route, route::RouteHandle, models::Endpoint};
321    /// # async fn example() -> anyhow::Result<()> {
322    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10));
323    ///
324    /// // Start the route (blocks until initialized) and convert to RouteHandle
325    /// let handle: RouteHandle = route.run("my_route").await?.into();
326    ///
327    /// // Stop the route later
328    /// handle.stop().await;
329    /// handle.join().await?;
330    /// # Ok(())
331    /// # }
332    /// ```
333    pub async fn run(&self, name_str: &str) -> anyhow::Result<RouteHandle> {
334        let warnings = self.check(name_str, None)?;
335        for warning in warnings {
336            tracing::warn!(route = name_str, "Configuration warning: {}", warning);
337        }
338        let (shutdown_tx, shutdown_rx) = bounded(1);
339        let (ready_tx, ready_rx) = bounded(1);
340        // Use `Arc` so route/name clones are cheap (pointer copy) in the reconnect loop.
341        let route = Arc::new(self.clone());
342        let name = Arc::new(name_str.to_string());
343
344        let handle = tokio::spawn(async move {
345            loop {
346                let route_arc = Arc::clone(&route);
347                let name_arc = Arc::clone(&name);
348                // Create a new, per-iteration internal shutdown channel.
349                // This avoids a race where both this loop and the inner task
350                // try to consume the same external shutdown signal.
351                let (internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
352                let ready_tx_clone = ready_tx.clone();
353
354                // The actual route logic is in `run_until_err`.
355                let mut run_task = tokio::spawn(async move {
356                    route_arc
357                        .run_until_err(&name_arc, Some(internal_shutdown_rx), Some(ready_tx_clone))
358                        .await
359                });
360
361                select! {
362                    _ = shutdown_rx.recv() => {
363                        info!("Shutdown signal received for route '{}'.", name);
364                        // Notify the inner task to shut down.
365                        let _ = internal_shutdown_tx.send(()).await;
366                        // Wait for the inner task to finish gracefully.
367                        let _ = run_task.await;
368                        break;
369                    }
370                    res = &mut run_task => {
371                        match res {
372                            Ok(Ok(should_continue)) if !should_continue => {
373                                info!("Route '{}' completed gracefully. Shutting down.", name);
374                                break;
375                            }
376                            Ok(Err(e)) => {
377                                let is_permanent =
378                                    e.downcast_ref::<ProcessingError>().is_some_and(|pe| matches!(pe, ProcessingError::NonRetryable(_)))
379                                    || e.downcast_ref::<ConsumerError>().is_some_and(|ce| matches!(ce, ConsumerError::EndOfStream));
380
381                                if is_permanent {
382                                    error!("Route '{}' failed with a permanent error: {}. Shutting down.", name, e);
383                                    break;
384                                }
385
386                                warn!("Route '{}' failed: {}. Reconnecting in 5 seconds...", name, e);
387                                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
388                            }
389                            Err(e) => {
390                                error!("Route '{}' task panicked: {}. Reconnecting in 5 seconds...", name, e);
391                                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
392                            }
393                            _ => {} // The route should continue running.
394                        }
395                    }
396                }
397            }
398        });
399
400        match tokio::time::timeout(std::time::Duration::from_secs(5), ready_rx.recv()).await {
401            Ok(Ok(_)) => Ok(RouteHandle((handle, shutdown_tx))),
402            _ => {
403                handle.abort();
404                Err(anyhow::anyhow!(
405                    "Route '{}' failed to start within 5 seconds or encountered an error",
406                    name_str
407                ))
408            }
409        }
410    }
411
412    /// The core logic of running the route, designed to be called within a reconnect loop.
413    pub async fn run_until_err(
414        &self,
415        name: &str,
416        shutdown_rx: Option<async_channel::Receiver<()>>,
417        ready_tx: Option<Sender<()>>,
418    ) -> anyhow::Result<bool> {
419        let (_internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
420        let shutdown_rx = shutdown_rx.unwrap_or(internal_shutdown_rx);
421        if self.options.concurrency == 1 {
422            self.run_sequentially(name, shutdown_rx, ready_tx).await
423        } else {
424            self.run_concurrently(name, shutdown_rx, ready_tx).await
425        }
426    }
427
428    /// A simplified, sequential runner for when concurrency is 1.
429    async fn run_sequentially(
430        &self,
431        name: &str,
432        shutdown_rx: async_channel::Receiver<()>,
433        ready_tx: Option<Sender<()>>,
434    ) -> anyhow::Result<bool> {
435        let publisher = create_publisher_from_route(name, &self.output).await?;
436        let mut consumer = create_consumer_from_route(name, &self.input).await?;
437        if let Err(err) = run_publisher_connect_hook(name, &publisher).await {
438            run_publisher_disconnect_hook(name, &publisher).await;
439            return Err(err);
440        }
441        if let Err(err) = run_consumer_connect_hook(name, consumer.as_ref()).await {
442            run_consumer_disconnect_hook(name, consumer.as_ref()).await;
443            run_publisher_disconnect_hook(name, &publisher).await;
444            return Err(err);
445        }
446        let (err_tx, err_rx) = bounded(1);
447        let mut commit_tasks = JoinSet::new();
448
449        // Sequencer setup to ensure ordered commits even with parallel commit tasks
450        let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
451        let mut seq_counter = 0u64;
452
453        if let Some(tx) = ready_tx {
454            let _ = tx.send(()).await;
455        }
456        let mut message_ids = Vec::with_capacity(self.options.batch_size);
457        // Check if retry middleware is present on output
458        let has_retry_middleware = self.output.has_retry_middleware();
459        let run_result = loop {
460            select! {
461                Ok(err) = err_rx.recv() => break Err(err),
462
463                _ = shutdown_rx.recv() => {
464                    info!("Shutdown signal received in sequential runner for route '{}'.", name);
465                    break Ok(true); // Stopped by shutdown signal
466                }
467                res = consumer.receive_batch(self.options.batch_size) => {
468                    let received_batch = match res {
469                        Ok(batch) => {
470                            if batch.messages.is_empty() {
471                                continue; // No messages, loop to select! again
472                            }
473                            batch
474                        }
475                        Err(ConsumerError::EndOfStream) => {
476                            info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
477                            break Ok(false); // Graceful exit
478                        }
479                        Err(ConsumerError::Connection(e)) => {
480                            // Propagate error to trigger reconnect by the outer loop
481                            break Err(e);
482                        },
483                        Err(ConsumerError::Gap { requested, base }) => {
484                            // Propagate gap error to trigger reconnect by the outer loop
485                            break Err(anyhow::anyhow!("Consumer gap: requested offset {requested} but earliest available is {base}"));
486                        }
487                    };
488                    debug!("Received a batch of {} messages sequentially", received_batch.messages.len());
489
490                    // Process the batch sequentially without spawning a new task
491                    let seq = seq_counter;
492                    seq_counter += 1;
493                    let mut commit_opt = Some(wrap_commit(received_batch.commit, seq, seq_tx.clone()));
494                    let batch_len = received_batch.messages.len();
495                    message_ids.clear();
496                    message_ids.extend(received_batch.messages.iter().map(|m| m.message_id));
497                    let request_ids: std::collections::HashSet<u128> = received_batch
498                        .messages
499                        .iter()
500                        .filter(|m| m.metadata.contains_key("reply_to"))
501                        .map(|m| m.message_id)
502                        .collect();
503
504                    match publisher.send_batch(received_batch.messages).await {
505                        Ok(SentBatch::Ack) => {
506                            for id in &message_ids {
507                                if request_ids.contains(id) {
508                                    warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop broken.", id);
509                                }
510                            }
511                            let commit = commit_opt.take().expect("Commit already used");
512                            let err_tx = err_tx.clone();
513                            commit_tasks.spawn(async move {
514                                if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
515                                    error!("Commit failed: {}", e);
516                                    match err_tx.try_send(e) {
517                                        Ok(_) => trace!("Reported commit error to main task"),
518                                        Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
519                                    }
520                                }
521                            });
522                        }
523                        Ok(SentBatch::Partial { responses, failed }) => {
524                            // Connection and Retryable are both "transient" errors - treat them the same.
525                            // Connection errors from handlers or publishers both indicate a temporary
526                            // failure that should either be retried (with middleware) or crash the route.
527                            let has_transient = failed.iter().any(|(_, e)| {
528                                matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_))
529                            });
530                            if has_transient {
531                                let (_, first_err) = failed
532                                    .iter()
533                                    .find(|(_, e)| matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_)))
534                                    .expect("has_transient is true");
535                                let err = anyhow::anyhow!(
536                                    "Transient error in batch send ({} messages failed). First error: {}",
537                                    failed.len(),
538                                    first_err
539                                );
540                                // Ack/nack the batch to fill the sequencer slot.
541                                let commit = commit_opt.take().expect("Commit already used");
542                                let dispositions =
543                                    map_responses_to_dispositions(&message_ids, responses, &failed, &request_ids);
544                                if let Err(commit_err) = commit(dispositions).await {
545                                    warn!("Commit after transient failure also failed: {}", commit_err);
546                                }
547                                if !has_retry_middleware {
548                                    break Err(err);
549                                }
550                                warn!("Transient error in batch, message(s) Nack'ed for re-delivery: {}", err);
551                                tokio::task::yield_now().await;
552                                continue;
553                            }
554                            // Only non-retryable errors remain - drop them with a log.
555                            for (msg, e) in &failed {
556                                error!("Dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
557                            }
558                            let commit = commit_opt.take().expect("Commit already used");
559                            let err_tx = err_tx.clone();
560                            let ids = std::mem::take(&mut message_ids);
561                            let req_ids = request_ids;
562                            commit_tasks.spawn(async move {
563                                let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &req_ids);
564                                if let Err(e) = commit(dispositions).await {
565                                    error!("Commit failed: {}", e);
566                                    match err_tx.try_send(e) {
567                                        Ok(_) => trace!("Reported commit error to main task"),
568                                        Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
569                                    }
570                                }
571                            });
572                        }
573                        Err(e) => {
574                            // Any direct error from send_batch crashes the route.
575                            warn!("Publisher error, sending {} Nacks to commit", batch_len);
576                            let commit = commit_opt.take().expect("Commit already used");
577                            let nack_result = commit(vec![MessageDisposition::Nack; batch_len]).await;
578                            debug!("Nack commit result: {:?}", nack_result);
579                            break Err(e.into());
580                        }
581                    }
582
583                    tokio::task::yield_now().await;
584                }
585            }
586        };
587
588        drop(seq_tx);
589        // Drain errors while waiting for tasks to finish to prevent deadlocks and lost errors
590        loop {
591            select! {
592                res = err_rx.recv() => {
593                    if let Ok(err) = res {
594                        error!("Error reported during shutdown: {}", err);
595                    }
596                }
597                res = commit_tasks.join_next() => {
598                    if res.is_none() {
599                        break;
600                    }
601                }
602            }
603        }
604        drop(err_rx);
605        let _ = sequencer_handle.await;
606        run_consumer_disconnect_hook(name, consumer.as_ref()).await;
607        run_publisher_disconnect_hook(name, &publisher).await;
608        run_result
609    }
610
611    /// The main concurrent runner for when concurrency > 1.
612    async fn run_concurrently(
613        &self,
614        name: &str,
615        shutdown_rx: async_channel::Receiver<()>,
616        ready_tx: Option<Sender<()>>,
617    ) -> anyhow::Result<bool> {
618        let publisher = create_publisher_from_route(name, &self.output).await?;
619        let mut consumer = create_consumer_from_route(name, &self.input).await?;
620        if let Err(err) = run_publisher_connect_hook(name, &publisher).await {
621            run_publisher_disconnect_hook(name, &publisher).await;
622            return Err(err);
623        }
624        if let Err(err) = run_consumer_connect_hook(name, consumer.as_ref()).await {
625            run_consumer_disconnect_hook(name, consumer.as_ref()).await;
626            run_publisher_disconnect_hook(name, &publisher).await;
627            return Err(err);
628        }
629        if let Some(tx) = ready_tx {
630            let _ = tx.send(()).await;
631        }
632        let (err_tx, err_rx) = bounded(1); // For critical, route-stopping errors
633                                           // channel capacity: a small buffer proportional to concurrency
634        let work_capacity = self
635            .options
636            .concurrency
637            .saturating_mul(self.options.batch_size);
638        let (work_tx, work_rx) =
639            bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(work_capacity);
640        // --- Ordered Commit Sequencer ---
641        // To prevent data loss with cumulative-ack brokers (Kafka/AMQP), commits must happen in order.
642        // We assign a sequence number to each batch and use a sequencer task to enforce order.
643        let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
644
645        // --- Worker Pool ---
646        let batch_size = self.options.batch_size;
647        let mut join_set = JoinSet::new();
648        for i in 0..self.options.concurrency {
649            let work_rx_clone = work_rx.clone();
650            let publisher = Arc::clone(&publisher);
651            let err_tx = err_tx.clone();
652            let mut commit_tasks = JoinSet::new();
653            let has_retry_middleware = self.output.has_retry_middleware();
654            join_set.spawn(async move {
655                debug!("Starting worker {}", i);
656                let mut message_ids = Vec::with_capacity(batch_size);
657                while let Ok((messages, commit_func)) = work_rx_clone.recv().await {
658                    let mut commit_opt = Some(commit_func);
659                    let batch_len = messages.len();
660                    message_ids.clear();
661                    message_ids.extend(messages.iter().map(|m| m.message_id));
662                    let request_ids: std::collections::HashSet<u128> = messages
663                        .iter()
664                        .filter(|m| m.metadata.contains_key("reply_to"))
665                        .map(|m| m.message_id)
666                        .collect();
667
668                    match publisher.send_batch(messages).await {
669                        Ok(SentBatch::Ack) => {
670                            for id in &message_ids {
671                                if request_ids.contains(id) {
672                                    warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop broken.", id);
673                                }
674                            }
675                            let commit = commit_opt.take().expect("Commit already used");
676                            let err_tx = err_tx.clone();
677                            commit_tasks.spawn(async move {
678                                if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
679                                    error!("Commit failed: {}", e);
680                                    match err_tx.try_send(e) {
681                                        Ok(_) => trace!("Reported commit error to main task"),
682                                        Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
683                                    }
684                                }
685                            });
686                        }
687                        Ok(SentBatch::Partial { responses, failed }) => {
688                            // Connection and Retryable are both "transient" errors.
689                            let has_transient = failed.iter().any(|(_, e)| {
690                                matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_))
691                            });
692                            if has_transient {
693                                let (_, first_err) = failed
694                                    .iter()
695                                    .find(|(_, e)| matches!(e, PublisherError::Retryable(_) | PublisherError::Connection(_)))
696                                    .expect("has_transient is true");
697                                let e = anyhow::anyhow!(
698                                    "Transient error in batch send ({} messages failed). First error: {}",
699                                    failed.len(),
700                                    first_err
701                                );
702                                let commit = commit_opt.take().expect("Commit already used");
703                                // Ack/nack the batch to fill the sequencer slot.
704                                let dispositions =
705                                    map_responses_to_dispositions(&message_ids, responses, &failed, &request_ids);
706                                if let Err(commit_err) = commit(dispositions).await {
707                                    warn!("Commit after transient failure also failed: {}", commit_err);
708                                }
709                                if !has_retry_middleware {
710                                    match err_tx.try_send(e) {
711                                        Ok(_) => trace!("Reported error to main task"),
712                                        Err(err_send) => warn!(error=?err_send, "Could not send error to main task, it might be down or busy."),
713                                    }
714                                    break;
715                                }
716                                warn!("Transient error in batch, message(s) Nack'ed for re-delivery: {}", e);
717                                tokio::task::yield_now().await;
718                                continue;
719                            }
720                            // Only non-retryable errors remain - drop them with a log.
721                            for (msg, e) in &failed {
722                                error!("Worker dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
723                            }
724                            let commit = commit_opt.take().expect("Commit already used");
725                            let err_tx = err_tx.clone();
726                            let ids = std::mem::take(&mut message_ids);
727                            let req_ids = request_ids;
728                            commit_tasks.spawn(async move {
729                                let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &req_ids);
730                                if let Err(e) = commit(dispositions).await {
731                                    error!("Commit failed: {}", e);
732                                    match err_tx.try_send(e) {
733                                        Ok(_) => trace!("Reported commit error to main task"),
734                                        Err(err_send) => warn!(error=?err_send, "Could not send commit error to main task, it might be down or busy."),
735                                    }
736                                }
737                            });
738                        }
739                        Err(e) => {
740                            error!("Worker failed to send message batch: {}", e);
741                            let commit = commit_opt.take().expect("Commit already used");
742                            // Nack the commit to fill the sequencer slot and prevent a deadlock.
743                            let nack_result = commit(vec![MessageDisposition::Nack; batch_len]).await;
744                            debug!("Nack commit result: {:?}", nack_result);
745                            // Send the error back to the main task to tear down the route.
746                            match err_tx.try_send(e.into()) {
747                                Ok(_) => trace!("Reported error to main task"),
748                                Err(err_send) => warn!(error=?err_send, "Could not send error to main task, it might be down or busy."),
749                            }
750                            break;
751                        }
752                    }
753                }
754                // Wait for all in-flight commits to complete
755                while commit_tasks.join_next().await.is_some() {}
756            });
757        }
758
759        let mut seq_counter = 0u64;
760        // Holds an error that caused the loop to break, to be returned after graceful shutdown.
761        let mut loop_error: Option<anyhow::Error> = None;
762        loop {
763            select! {
764                biased; // Prioritize checking for errors
765
766                Ok(err) = err_rx.recv() => {
767                    error!("A worker reported a critical error. Shutting down route.");
768                    loop_error = Some(err);
769                    break;
770                }
771
772                Some(res) = join_set.join_next() => {
773                    match res {
774                        Ok(_) => {
775                            error!("A worker task finished unexpectedly. Shutting down route.");
776                            loop_error = Some(anyhow::anyhow!("Worker task finished unexpectedly"));
777                        }
778                        Err(e) => {
779                            error!("A worker task panicked: {}. Shutting down route.", e);
780                            loop_error = Some(e.into());
781                        }
782                    }
783                    break;
784                }
785
786                _ = shutdown_rx.recv() => {
787                    info!("Shutdown signal received in concurrent runner for route '{}'.", name);
788                    break;
789                }
790
791                res = consumer.receive_batch(self.options.batch_size) => {
792                    let (messages, commit) = match res {
793                        Ok(batch) => {
794                            if batch.messages.is_empty() {
795                                continue; // No messages, loop to select! again
796                            }
797                            (batch.messages, batch.commit)
798                        }
799                        Err(ConsumerError::EndOfStream) => {
800                            info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
801                            break; // Graceful exit
802                        }
803                        Err(ConsumerError::Connection(e)) => {
804                            // Propagate error to trigger reconnect by the outer loop
805                            loop_error = Some(e);
806                            break;
807                        }
808                        Err(ConsumerError::Gap { requested, base }) => {
809                            // Propagate gap error to trigger reconnect by the outer loop
810                            loop_error = Some(ConsumerError::Gap { requested, base }.into());
811                            break;
812                        }
813                    };
814                    debug!("Received a batch of {} messages concurrently", messages.len());
815
816                    // Wrap the commit function to route it through the sequencer.
817                    // Only advance the sequence counter after we've successfully enqueued
818                    // the work item to avoid creating sequence gaps if the work channel
819                    // is closed while producing batches.
820                    let seq = seq_counter;
821                    let wrapped_commit = wrap_commit(commit, seq, seq_tx.clone());
822
823                    match work_tx.send((messages, wrapped_commit)).await {
824                        Ok(()) => {
825                            seq_counter += 1;
826                        }
827                        Err(e) => {
828                            warn!("Work channel closed, cannot process more messages concurrently. Shutting down.");
829                            // Recover the moved tuple so we can invoke the wrapped commit
830                            // and resolve the batch with a NACK.
831                            let (msgs_back, wrapped_commit_back) = e.into_inner();
832                            let _ = (wrapped_commit_back)(vec![crate::traits::MessageDisposition::Nack; msgs_back.len()]).await;
833                            break;
834                        }
835                    }
836
837                    tokio::task::yield_now().await;
838                }
839            }
840        }
841
842        // --- Graceful Shutdown ---
843        // Close the work channel so workers drain their current messages and exit the loop.
844        // This applies on both normal shutdown AND error paths, ensuring in-flight commits
845        // are not aborted mid-sequence.
846        drop(work_tx);
847        // Wait for all worker tasks to complete.
848        while join_set.join_next().await.is_some() {}
849
850        // Close sequencer
851        drop(seq_tx);
852        let _ = sequencer_handle.await;
853        run_consumer_disconnect_hook(name, consumer.as_ref()).await;
854        run_publisher_disconnect_hook(name, &publisher).await;
855
856        if let Some(err) = loop_error {
857            return Err(err);
858        }
859
860        if let Ok(err) = err_rx.try_recv() {
861            return Err(err);
862        }
863
864        // Return true if shutdown was requested (channel is empty means it was closed/consumed),
865        // false if we reached end-of-stream naturally.
866        Ok(shutdown_rx.is_empty())
867    }
868
869    pub fn with_options(mut self, options: RouteOptions) -> Self {
870        self.options = options;
871        self
872    }
873    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
874        self.options.concurrency = concurrency.max(1);
875        self
876    }
877
878    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
879        self.options.batch_size = batch_size.max(1);
880        self
881    }
882    pub fn with_commit_concurrency_limit(mut self, limit: usize) -> Self {
883        self.options.commit_concurrency_limit = limit.max(1);
884        self
885    }
886
887    pub fn with_handler(mut self, handler: impl Handler + 'static) -> Self {
888        self.output.handler = Some(Arc::new(handler));
889        self
890    }
891
892    /// Registers a typed handler for the route.
893    ///
894    /// The handler can accept either:
895    /// - `fn(T) -> Future<Output = Result<Handled, HandlerError>>`
896    /// - `fn(T, MessageContext) -> Future<Output = Result<Handled, HandlerError>>`
897    ///
898    /// # Examples
899    ///
900    /// ```
901    /// # use mq_bridge::{Route, models::Endpoint};
902    /// # use serde::Deserialize;
903    ///
904    /// #[derive(Deserialize)]
905    /// struct MyData { id: u32 }
906    ///
907    /// async fn my_handler(data: MyData) -> anyhow::Result<()> {
908    ///     Ok(())
909    /// }
910    ///
911    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10))
912    ///     .add_handler("my_type", my_handler);
913    /// ```
914    pub fn add_handler<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
915    where
916        T: DeserializeOwned + Send + Sync + 'static,
917        H: crate::type_handler::IntoTypedHandler<T, Args>,
918        Args: Send + Sync + 'static,
919    {
920        // Create the wrapper closure that handles deserialization and context extraction
921        let handler = Arc::new(handler);
922        let wrapper = move |msg: crate::CanonicalMessage| {
923            let handler = handler.clone();
924            async move {
925                let data = msg.parse::<T>().map_err(|e| {
926                    HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
927                })?;
928                let ctx = crate::MessageContext::from(msg);
929                handler.call(data, ctx).await
930            }
931        };
932        let wrapper = Arc::new(wrapper);
933
934        let prev_handler = self.output.handler.take();
935
936        let new_handler = if let Some(h) = prev_handler {
937            if let Some(extended) = h.register_handler(type_name, wrapper.clone()) {
938                extended
939            } else {
940                Arc::new(
941                    crate::type_handler::TypeHandler::new()
942                        .with_fallback(h)
943                        .add_handler(type_name, wrapper),
944                )
945            }
946        } else {
947            Arc::new(crate::type_handler::TypeHandler::new().add_handler(type_name, wrapper))
948        };
949
950        self.output.handler = Some(new_handler);
951        self
952    }
953    pub fn add_handlers<T, H, Args>(mut self, handlers: HashMap<&str, H>) -> Self
954    where
955        T: DeserializeOwned + Send + Sync + 'static,
956        H: crate::type_handler::IntoTypedHandler<T, Args>,
957        Args: Send + Sync + 'static,
958    {
959        for (type_name, handler) in handlers {
960            self = self.add_handler(type_name, handler);
961        }
962        self
963    }
964}
965
966type SequencerItem = (
967    Vec<MessageDisposition>,
968    BatchCommitFunc,
969    tokio::sync::oneshot::Sender<anyhow::Result<()>>,
970);
971
972fn spawn_sequencer(buffer_size: usize) -> (Sender<(u64, SequencerItem)>, JoinHandle<()>) {
973    let (seq_tx, seq_rx) = bounded::<(u64, SequencerItem)>(buffer_size);
974    let sequencer_handle = tokio::spawn(async move {
975        let mut buffer: BTreeMap<u64, SequencerItem> = BTreeMap::new();
976        let mut next_seq = 0u64;
977
978        loop {
979            // If we have the next item in sequence, execute its commit directly.
980            // Using a plain await (no select!) here is essential: if we raced a recv
981            // against the commit future, a recv win would drop the commit future and
982            // the notify sender, leaving the caller permanently blocked while next_seq
983            // stays unadvanced — a deadlock.
984            if let Some((dispositions, commit_func, notify)) = buffer.remove(&next_seq) {
985                let result = commit_func(dispositions).await;
986                let _ = notify.send(result);
987                next_seq += 1;
988                // Yield to allow other tasks to run, preventing busy-loop when buffer has many messages
989                tokio::task::yield_now().await;
990                continue;
991            }
992
993            // Wait for the next item from any worker.
994            match seq_rx.recv().await {
995                Ok((seq, item)) => {
996                    if seq < next_seq {
997                        let (_, _, notify) = item;
998                        trace!(
999                            seq,
1000                            next_seq,
1001                            "Sequencer received late item (seq < next_seq)"
1002                        );
1003                        let _ = notify.send(Err(anyhow::anyhow!(
1004                            "Sequencer received late item (seq {} < next_seq {})",
1005                            seq,
1006                            next_seq
1007                        )));
1008                    } else {
1009                        buffer.insert(seq, item);
1010                    }
1011                }
1012                Err(_) => {
1013                    // seq_tx was dropped — drain and notify any remaining buffered commits.
1014                    for (_, (_, _, notify)) in buffer {
1015                        let _ = notify.send(Err(anyhow::anyhow!("Sequencer is shutting down")));
1016                    }
1017                    break;
1018                }
1019            }
1020        }
1021    });
1022    (seq_tx, sequencer_handle)
1023}
1024
1025fn wrap_commit(
1026    commit: BatchCommitFunc,
1027    seq: u64,
1028    seq_tx: Sender<(u64, SequencerItem)>,
1029) -> BatchCommitFunc {
1030    Box::new(move |dispositions| {
1031        Box::pin(async move {
1032            let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
1033            if seq_tx
1034                .send((seq, (dispositions, commit, notify_tx)))
1035                .await
1036                .is_ok()
1037            {
1038                match notify_rx.await {
1039                    Ok(res) => res,
1040                    Err(_) => Err(anyhow::anyhow!(
1041                        "Sequencer dropped the commit channel unexpectedly"
1042                    )),
1043                }
1044            } else {
1045                Err(anyhow::anyhow!(
1046                    "Failed to send commit to sequencer, route is likely shutting down"
1047                ))
1048            }
1049        })
1050    })
1051}
1052
1053fn map_responses_to_dispositions(
1054    message_ids: &[u128],
1055    responses: Option<Vec<crate::CanonicalMessage>>,
1056    failed: &[(crate::CanonicalMessage, PublisherError)],
1057    request_ids: &std::collections::HashSet<u128>,
1058) -> Vec<MessageDisposition> {
1059    let mut dispositions = Vec::with_capacity(message_ids.len());
1060    let failed_ids: std::collections::HashSet<u128> =
1061        failed.iter().map(|(m, _)| m.message_id).collect();
1062
1063    // Create a map from message_id to response message for efficient lookup.
1064    let mut response_map: std::collections::HashMap<u128, crate::CanonicalMessage> = responses
1065        .unwrap_or_default()
1066        .into_iter()
1067        .map(|r| (r.message_id, r))
1068        .collect();
1069
1070    for id in message_ids {
1071        if failed_ids.contains(id) {
1072            dispositions.push(MessageDisposition::Nack);
1073        } else if let Some(resp) = response_map.remove(id) {
1074            // If a response exists for this specific ID, use it.
1075            dispositions.push(MessageDisposition::Reply(resp));
1076        } else if request_ids.contains(id) {
1077            error!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Nacking to avoid committing a lost response.", id);
1078            dispositions.push(MessageDisposition::Nack);
1079        } else {
1080            // Otherwise, it was a successful send that did not produce a response.
1081            dispositions.push(MessageDisposition::Ack);
1082        }
1083    }
1084    dispositions
1085}
1086
1087#[cfg(test)]
1088fn test_map_responses_to_dispositions_logic() {
1089    use crate::{traits::PublisherError, CanonicalMessage};
1090    use anyhow::anyhow;
1091
1092    let ids = vec![1, 2, 3, 4];
1093
1094    let mut resp1 = CanonicalMessage::from("resp1");
1095    resp1.message_id = 1;
1096    let mut resp4 = CanonicalMessage::from("resp4");
1097    resp4.message_id = 4;
1098
1099    let responses = Some(vec![
1100        resp1, // Corresponds to id 1
1101        resp4, // Corresponds to id 4
1102    ]);
1103
1104    let mut msg2 = CanonicalMessage::from("msg2");
1105    msg2.message_id = 2;
1106    let failed = vec![(msg2, PublisherError::NonRetryable(anyhow!("failed")))];
1107
1108    let mut request_ids = std::collections::HashSet::new();
1109    request_ids.insert(3); // id 3 expects a reply but won't get one
1110    let dispositions = map_responses_to_dispositions(&ids, responses, &failed, &request_ids);
1111
1112    assert_eq!(dispositions.len(), 4);
1113    assert!(matches!(dispositions[0], MessageDisposition::Reply(_))); // from responses
1114    assert!(matches!(dispositions[1], MessageDisposition::Nack)); // from failed
1115    assert!(matches!(dispositions[2], MessageDisposition::Nack)); // missing reply
1116    assert!(matches!(dispositions[3], MessageDisposition::Reply(_))); // from responses
1117}
1118
1119pub fn get_route(name: &str) -> Option<Route> {
1120    Route::get(name)
1121}
1122
1123pub fn list_routes() -> Vec<String> {
1124    Route::list()
1125}
1126
1127pub async fn stop_route(name: &str) -> bool {
1128    Route::stop(name).await
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133    use super::*;
1134    use crate::models::{Endpoint, EndpointType, FaultMode, Middleware, RandomPanicMiddleware};
1135    use crate::traits::{
1136        CustomMiddlewareFactory, MessageConsumer, MessagePublisher, ReceivedBatch,
1137    };
1138    use crate::CanonicalMessage;
1139    use std::any::Any;
1140    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
1141    use std::sync::Arc;
1142    use std::time::Duration;
1143
1144    #[derive(Debug, Default)]
1145    struct CommitObservation {
1146        completed: Mutex<Vec<u64>>,
1147        active: std::sync::atomic::AtomicUsize,
1148        max_active: std::sync::atomic::AtomicUsize,
1149    }
1150
1151    #[derive(Debug)]
1152    struct CommitTrackingMiddlewareFactory {
1153        observation: Arc<CommitObservation>,
1154    }
1155
1156    #[derive(Debug)]
1157    struct ReorderingPublisherMiddlewareFactory;
1158
1159    struct CommitTrackingConsumer {
1160        inner: Box<dyn MessageConsumer>,
1161        observation: Arc<CommitObservation>,
1162    }
1163
1164    struct ReorderingPublisher {
1165        inner: Box<dyn MessagePublisher>,
1166    }
1167
1168    #[async_trait::async_trait]
1169    impl CustomMiddlewareFactory for CommitTrackingMiddlewareFactory {
1170        async fn apply_consumer(
1171            &self,
1172            consumer: Box<dyn MessageConsumer>,
1173            _route_name: &str,
1174            _config: &serde_json::Value,
1175        ) -> anyhow::Result<Box<dyn MessageConsumer>> {
1176            Ok(Box::new(CommitTrackingConsumer {
1177                inner: consumer,
1178                observation: Arc::clone(&self.observation),
1179            }))
1180        }
1181    }
1182
1183    #[async_trait::async_trait]
1184    impl CustomMiddlewareFactory for ReorderingPublisherMiddlewareFactory {
1185        async fn apply_publisher(
1186            &self,
1187            publisher: Box<dyn MessagePublisher>,
1188            _route_name: &str,
1189            _config: &serde_json::Value,
1190        ) -> anyhow::Result<Box<dyn MessagePublisher>> {
1191            Ok(Box::new(ReorderingPublisher { inner: publisher }))
1192        }
1193    }
1194
1195    #[async_trait::async_trait]
1196    impl MessageConsumer for CommitTrackingConsumer {
1197        async fn receive_batch(
1198            &mut self,
1199            max_messages: usize,
1200        ) -> Result<ReceivedBatch, ConsumerError> {
1201            let mut batch = self.inner.receive_batch(max_messages).await?;
1202            let seq = batch
1203                .messages
1204                .first()
1205                .and_then(|message| message.get_payload_str().parse::<u64>().ok())
1206                .expect("tracking test expects numeric payloads");
1207            let original_commit = batch.commit;
1208            let observation = Arc::clone(&self.observation);
1209            batch.commit = Box::new(move |dispositions| {
1210                let observation = Arc::clone(&observation);
1211                Box::pin(async move {
1212                    let active_now = observation.active.fetch_add(1, Ordering::SeqCst) + 1;
1213                    let _ = observation.max_active.fetch_update(
1214                        Ordering::SeqCst,
1215                        Ordering::SeqCst,
1216                        |current| (active_now > current).then_some(active_now),
1217                    );
1218
1219                    tokio::time::sleep(Duration::from_millis(20)).await;
1220                    let result = original_commit(dispositions).await;
1221                    observation.completed.lock().unwrap().push(seq);
1222                    observation.active.fetch_sub(1, Ordering::SeqCst);
1223                    result
1224                })
1225            });
1226            Ok(batch)
1227        }
1228
1229        fn as_any(&self) -> &dyn Any {
1230            self
1231        }
1232    }
1233
1234    #[async_trait::async_trait]
1235    impl MessagePublisher for ReorderingPublisher {
1236        async fn send_batch(
1237            &self,
1238            messages: Vec<crate::CanonicalMessage>,
1239        ) -> Result<SentBatch, PublisherError> {
1240            let seq = messages
1241                .first()
1242                .and_then(|message| message.get_payload_str().parse::<u64>().ok())
1243                .expect("tracking test expects numeric payloads");
1244            let delay_ms = 10 * (6u64.saturating_sub(seq.min(6)));
1245            tokio::time::sleep(Duration::from_millis(delay_ms)).await;
1246            self.inner.send_batch(messages).await
1247        }
1248
1249        async fn send(&self, msg: crate::CanonicalMessage) -> Result<Sent, PublisherError> {
1250            self.inner.send(msg).await
1251        }
1252
1253        async fn flush(&self) -> anyhow::Result<()> {
1254            self.inner.flush().await
1255        }
1256
1257        fn as_any(&self) -> &dyn Any {
1258            self
1259        }
1260    }
1261
1262    async fn assert_route_commits_are_ordered_and_non_overlapping(concurrency: usize) {
1263        let unique_id = fast_uuid_v7::gen_id().to_string();
1264        let tracking_name = format!("track_commit_{}", unique_id);
1265        let reorder_name = format!("reorder_publish_{}", unique_id);
1266        let in_topic = format!("ordered_commit_in_{}", unique_id);
1267        let observation = Arc::new(CommitObservation::default());
1268
1269        register_middleware_factory(
1270            &tracking_name,
1271            Arc::new(CommitTrackingMiddlewareFactory {
1272                observation: Arc::clone(&observation),
1273            }),
1274        );
1275        register_middleware_factory(
1276            &reorder_name,
1277            Arc::new(ReorderingPublisherMiddlewareFactory),
1278        );
1279
1280        let input = Endpoint::new_memory(&in_topic, 32).add_middleware(Middleware::Custom {
1281            name: tracking_name,
1282            config: serde_json::Value::Null,
1283        });
1284        let output = Endpoint::new(EndpointType::Null).add_middleware(Middleware::Custom {
1285            name: reorder_name,
1286            config: serde_json::Value::Null,
1287        });
1288
1289        let route = Route::new(input.clone(), output)
1290            .with_concurrency(concurrency)
1291            .with_batch_size(1)
1292            .with_commit_concurrency_limit(1);
1293
1294        let input_channel = input.channel().unwrap();
1295        let messages = (0..6)
1296            .map(|seq| crate::CanonicalMessage::from(seq.to_string()))
1297            .collect();
1298        input_channel.fill_messages(messages).await.unwrap();
1299        input_channel.close();
1300
1301        tokio::time::timeout(
1302            std::time::Duration::from_secs(5),
1303            route.run_until_err("ordered_commit_regression", None, None),
1304        )
1305        .await
1306        .expect("Route should not hang while draining finite input")
1307        .expect("Route should complete without commit errors");
1308        assert_eq!(
1309            *observation.completed.lock().unwrap(),
1310            vec![0, 1, 2, 3, 4, 5],
1311            "Commit execution must follow receive order",
1312        );
1313        assert_eq!(
1314            observation.max_active.load(Ordering::SeqCst),
1315            1,
1316            "Broker-facing commit functions must never overlap",
1317        );
1318    }
1319
1320    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1321    async fn test_sequential_route_commits_are_ordered_and_non_overlapping() {
1322        assert_route_commits_are_ordered_and_non_overlapping(1).await;
1323    }
1324
1325    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1326    async fn test_concurrent_route_commits_are_ordered_and_non_overlapping() {
1327        assert_route_commits_are_ordered_and_non_overlapping(4).await;
1328    }
1329
1330    // Helper function to run a fault injection test on the consumer side.
1331    async fn run_consumer_fault_test(
1332        mode: FaultMode,
1333        expected_payload: &str,
1334        route_should_restart: bool,
1335        concurrency: usize,
1336    ) {
1337        let unique_suffix = fast_uuid_v7::gen_id().to_string();
1338        let in_topic = format!("fault_in_{}_{}_{}", mode, concurrency, unique_suffix);
1339        let out_topic = format!("fault_out_{}_{}_{}", mode, concurrency, unique_suffix);
1340
1341        let fault_config = RandomPanicMiddleware {
1342            mode,
1343            trigger_on_message: Some(1), // Panic on the first message
1344            enabled: true,
1345            ..Default::default()
1346        };
1347
1348        let input = Endpoint::new_memory(&in_topic, 10)
1349            .add_middleware(Middleware::RandomPanic(fault_config));
1350        let output = Endpoint::new_memory(&out_topic, 10);
1351
1352        let route_name = format!("fault_test_{}_{}", mode, concurrency);
1353        let route = Route::new(input.clone(), output.clone()).with_concurrency(concurrency);
1354
1355        // Start the route
1356        route
1357            .deploy(&route_name)
1358            .await
1359            .expect("Failed to deploy route");
1360        // Send a message. The consumer will inject a fault when it tries to receive it.
1361        let input_ch = input.channel().unwrap();
1362        input_ch
1363            .send_message("persistent_msg".into())
1364            .await
1365            .unwrap();
1366
1367        if route_should_restart {
1368            // The route's worker will fail, and the supervisor will wait 5 seconds before restarting.
1369            // We wait for a bit longer than that to ensure recovery has happened.
1370            tokio::time::sleep(std::time::Duration::from_secs(6)).await;
1371        } else {
1372            // Route doesn't restart, just wait a bit for the (faulty) message to pass through.
1373            tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1374        }
1375
1376        // Verify the outcome.
1377        let mut verifier = route.connect_to_output("verifier").await.unwrap();
1378        let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
1379            .await
1380            .expect("Timed out waiting for message after fault")
1381            .expect("Stream closed while waiting for message");
1382
1383        assert_eq!(received.message.get_payload_str(), expected_payload);
1384        (received.commit)(MessageDisposition::Ack).await.unwrap();
1385
1386        // Cleanup
1387        Route::stop(&route_name).await;
1388    }
1389
1390    // Helper function to run a fault injection test on the publisher side.
1391    async fn run_publisher_fault_test(
1392        mode: FaultMode,
1393        expected_payload: &str,
1394        route_should_restart: bool,
1395    ) {
1396        let unique_suffix = fast_uuid_v7::gen_id().to_string();
1397        let in_topic = format!("pub_fault_in_{}_{}", mode, unique_suffix);
1398        let out_topic = format!("pub_fault_out_{}_{}", mode, unique_suffix);
1399
1400        let fault_config = RandomPanicMiddleware {
1401            mode,
1402            trigger_on_message: Some(1), // Trigger on the first message
1403            enabled: true,
1404            ..Default::default()
1405        };
1406
1407        let mut input = Endpoint::new_memory(&in_topic, 10);
1408        // Enable NACK on input so messages aren't lost when publisher crashes
1409        if let EndpointType::Memory(ref mut cfg) = input.endpoint_type {
1410            cfg.enable_nack = true;
1411        }
1412        // Apply fault middleware to output
1413        let output = Endpoint::new_memory(&out_topic, 10)
1414            .add_middleware(Middleware::RandomPanic(fault_config));
1415
1416        let route_name = format!("pub_fault_test_{}", mode);
1417        let route = Route::new(input.clone(), output.clone());
1418
1419        route
1420            .deploy(&route_name)
1421            .await
1422            .expect("Failed to deploy route");
1423
1424        let input_ch = input.channel().unwrap();
1425        input_ch
1426            .send_message(expected_payload.into())
1427            .await
1428            .unwrap();
1429
1430        if route_should_restart {
1431            tokio::time::sleep(std::time::Duration::from_secs(6)).await;
1432        } else {
1433            tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1434        }
1435
1436        let mut verifier = route.connect_to_output("verifier").await.unwrap();
1437        let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
1438            .await
1439            .expect("Timed out waiting for message after publisher fault")
1440            .expect("Stream closed");
1441
1442        assert_eq!(received.message.get_payload_str(), expected_payload);
1443        (received.commit)(MessageDisposition::Ack).await.unwrap();
1444
1445        Route::stop(&route_name).await;
1446    }
1447
1448    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1449    #[ignore = "Takes too much time for regular tests"]
1450    async fn test_route_recovery_from_faults() {
1451        let original_payload = "persistent_msg";
1452
1453        // Test with concurrency > 1
1454        run_consumer_fault_test(FaultMode::Panic, original_payload, true, 2).await;
1455        run_consumer_fault_test(FaultMode::Disconnect, original_payload, true, 2).await;
1456        run_consumer_fault_test(FaultMode::Timeout, original_payload, true, 2).await;
1457        run_consumer_fault_test(FaultMode::Nack, original_payload, true, 2).await;
1458
1459        // This fault replaces the message but does not restart the route.
1460        run_consumer_fault_test(FaultMode::JsonFormatError, "{invalid json}", false, 2).await;
1461    }
1462
1463    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1464    #[ignore = "Takes too much time for regular tests"]
1465    async fn test_route_recovery_from_faults_sequential() {
1466        let original_payload = "persistent_msg";
1467
1468        // Test with concurrency = 1
1469        run_consumer_fault_test(FaultMode::Panic, original_payload, true, 1).await;
1470        run_consumer_fault_test(FaultMode::Disconnect, original_payload, true, 1).await;
1471    }
1472
1473    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1474    #[ignore = "Takes too much time for regular tests"]
1475    async fn test_publisher_recovery_from_faults() {
1476        let original_payload = "persistent_msg";
1477        // Test publisher-side faults causing restart/retry.
1478        // `FaultMode::Panic` is not tested here because the `MemoryConsumer` used for input
1479        // does not support crash-safe at-least-once delivery. A panic in the publisher
1480        // worker would cause the in-flight message to be lost.
1481        run_publisher_fault_test(FaultMode::Disconnect, original_payload, true).await;
1482        run_publisher_fault_test(FaultMode::Timeout, original_payload, true).await;
1483    }
1484
1485    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1486    async fn test_route_sequencer_deadlock_fix() {
1487        // This test ensures that when a worker fails to send a batch (and thus drops the commit handle),
1488        // the sequencer doesn't deadlock waiting for that sequence number.
1489        // The fix ensures that even on failure, the commit function is called (with Nack) to fill the sequence gap.
1490
1491        let unique_id = fast_uuid_v7::gen_id().to_string();
1492        let factory_name = format!("fail_factory_{}", unique_id);
1493        let in_topic = format!("deadlock_in_{}", unique_id);
1494        let out_topic = format!("deadlock_out_{}", unique_id);
1495
1496        #[derive(Debug)]
1497        struct FailingMiddlewareFactory {
1498            fail_flag: Arc<AtomicBool>,
1499        }
1500
1501        #[async_trait::async_trait]
1502        impl CustomMiddlewareFactory for FailingMiddlewareFactory {
1503            async fn apply_publisher(
1504                &self,
1505                publisher: Box<dyn MessagePublisher>,
1506                _route_name: &str,
1507                _config: &serde_json::Value,
1508            ) -> anyhow::Result<Box<dyn MessagePublisher>> {
1509                Ok(Box::new(FailingPublisher {
1510                    inner: publisher,
1511                    fail_flag: self.fail_flag.clone(),
1512                }))
1513            }
1514            async fn apply_consumer(
1515                &self,
1516                consumer: Box<dyn MessageConsumer>,
1517                _route_name: &str,
1518                _config: &serde_json::Value,
1519            ) -> anyhow::Result<Box<dyn MessageConsumer>> {
1520                Ok(consumer)
1521            }
1522        }
1523
1524        struct FailingPublisher {
1525            inner: Box<dyn MessagePublisher>,
1526            fail_flag: Arc<AtomicBool>,
1527        }
1528
1529        #[async_trait::async_trait]
1530        impl MessagePublisher for FailingPublisher {
1531            async fn send_batch(
1532                &self,
1533                messages: Vec<crate::CanonicalMessage>,
1534            ) -> Result<SentBatch, PublisherError> {
1535                // We want to fail one batch to trigger the error path in the worker.
1536                // We use compare_exchange to ensure only one failure happens.
1537                if self
1538                    .fail_flag
1539                    .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
1540                    .is_ok()
1541                {
1542                    return Err(PublisherError::Retryable(anyhow::anyhow!(
1543                        "Simulated failure"
1544                    )));
1545                }
1546                // Add a small delay for successful batches to ensure the failed one (if it created a gap)
1547                // would block the sequencer if the gap wasn't filled.
1548                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1549                self.inner.send_batch(messages).await
1550            }
1551            async fn send(
1552                &self,
1553                msg: crate::CanonicalMessage,
1554            ) -> Result<crate::traits::Sent, PublisherError> {
1555                self.inner.send(msg).await
1556            }
1557            async fn flush(&self) -> anyhow::Result<()> {
1558                self.inner.flush().await
1559            }
1560            fn as_any(&self) -> &dyn Any {
1561                self
1562            }
1563        }
1564
1565        let fail_flag = Arc::new(AtomicBool::new(true));
1566        register_middleware_factory(
1567            &factory_name,
1568            Arc::new(FailingMiddlewareFactory {
1569                fail_flag: fail_flag.clone(),
1570            }),
1571        );
1572
1573        let input = Endpoint::new_memory(&in_topic, 100);
1574        let output = Endpoint::new_memory(&out_topic, 100).add_middleware(Middleware::Custom {
1575            name: factory_name,
1576            config: serde_json::Value::Null,
1577        });
1578
1579        // Concurrency > 1 is required to have multiple workers and potential out-of-order completion
1580        let route = Route::new(input.clone(), output.clone())
1581            .with_concurrency(2)
1582            .with_batch_size(1);
1583
1584        // Send messages
1585        let input_ch = input.channel().unwrap();
1586        input_ch.send_message("msg1".into()).await.unwrap();
1587        input_ch.send_message("msg2".into()).await.unwrap();
1588        input_ch.send_message("msg3".into()).await.unwrap();
1589
1590        // Run the route. It should fail eventually due to the simulated error,
1591        // but it MUST NOT deadlock.
1592        let run_fut = async {
1593            let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1);
1594            route
1595                .run_until_err("deadlock_test", Some(shutdown_rx), None)
1596                .await
1597        };
1598
1599        // If deadlock exists, this timeout will trigger.
1600        let result = tokio::time::timeout(std::time::Duration::from_secs(5), run_fut).await;
1601
1602        match result {
1603            Ok(res) => {
1604                // We expect an error because the publisher returns Err.
1605                assert!(
1606                    res.is_err(),
1607                    "Route should have failed with simulated error"
1608                );
1609            }
1610            Err(_) => {
1611                panic!("Route deadlocked! The sequencer likely didn't receive the Nack for the failed batch.");
1612            }
1613        }
1614    }
1615
1616    #[tokio::test]
1617    async fn test_sequencer_ordered_commits() {
1618        use std::time::Duration;
1619        use tokio::time::timeout;
1620
1621        let (seq_tx, sequencer_handle) = spawn_sequencer(16);
1622        let processed: Arc<Mutex<Vec<u64>>> = Arc::new(Mutex::new(Vec::new()));
1623
1624        // Send sequences out of order to ensure the sequencer enforces ordering.
1625        let seqs = [2u64, 0u64, 1u64, 3u64];
1626        let mut receivers = Vec::new();
1627
1628        for seq in seqs.iter().cloned() {
1629            let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
1630            let processed_clone = processed.clone();
1631            let commit: BatchCommitFunc = Box::new(move |_dispositions| {
1632                let processed = processed_clone.clone();
1633                Box::pin(async move {
1634                    // Simulate variable work durations
1635                    tokio::time::sleep(Duration::from_millis(10 * seq)).await;
1636                    processed.lock().unwrap().push(seq);
1637                    Ok(())
1638                })
1639            });
1640            seq_tx
1641                .send((seq, (Vec::new(), commit, notify_tx)))
1642                .await
1643                .unwrap();
1644            receivers.push(notify_rx);
1645        }
1646
1647        // Wait for all commits to complete (with timeout to catch deadlocks)
1648        for rx in receivers {
1649            let res = timeout(Duration::from_secs(2), rx)
1650                .await
1651                .expect("Sequencer notify timed out");
1652            assert!(res.is_ok(), "Sequencer reported an error on commit");
1653            assert!(res.unwrap().is_ok(), "Commit returned an error");
1654        }
1655
1656        // Close sender to allow sequencer task to exit and await it.
1657        drop(seq_tx);
1658        let _ = sequencer_handle.await;
1659
1660        let result = processed.lock().unwrap().clone();
1661        assert_eq!(
1662            result,
1663            vec![0u64, 1u64, 2u64, 3u64],
1664            "Sequencer must process commits in order"
1665        );
1666    }
1667
1668    #[tokio::test]
1669    async fn test_sequencer_shutdown_notifies_pending() {
1670        use std::time::Duration;
1671        use tokio::time::timeout;
1672
1673        let (seq_tx, sequencer_handle) = spawn_sequencer(8);
1674
1675        // Prepare two pending items for sequences 1 and 2 while sequence 0 is missing.
1676        let (notify_tx1, notify_rx1) = tokio::sync::oneshot::channel();
1677        let (notify_tx2, notify_rx2) = tokio::sync::oneshot::channel();
1678
1679        let commit1: BatchCommitFunc = Box::new(|_dispositions| {
1680            Box::pin(async move {
1681                // Should not be executed because next_seq is missing (0)
1682                panic!("Commit should not be executed during shutdown drain");
1683                #[allow(unreachable_code)]
1684                Ok(())
1685            })
1686        });
1687
1688        let commit2: BatchCommitFunc = Box::new(|_dispositions| {
1689            Box::pin(async move {
1690                panic!("Commit should not be executed during shutdown drain");
1691                #[allow(unreachable_code)]
1692                Ok(())
1693            })
1694        });
1695
1696        seq_tx
1697            .send((1u64, (Vec::new(), commit1, notify_tx1)))
1698            .await
1699            .unwrap();
1700        seq_tx
1701            .send((2u64, (Vec::new(), commit2, notify_tx2)))
1702            .await
1703            .unwrap();
1704
1705        // Trigger shutdown of the sequencer by dropping the sender.
1706        drop(seq_tx);
1707
1708        // Sequencer should drain buffered items and reply with an error to the notifiers.
1709        let r1 = timeout(Duration::from_secs(1), notify_rx1)
1710            .await
1711            .expect("Timeout waiting for notify_rx1")
1712            .expect("Sequencer closed notify channel");
1713        assert!(
1714            r1.is_err(),
1715            "Pending commit should receive Err on sequencer shutdown"
1716        );
1717
1718        let r2 = timeout(Duration::from_secs(1), notify_rx2)
1719            .await
1720            .expect("Timeout waiting for notify_rx2")
1721            .expect("Sequencer closed notify channel");
1722        assert!(
1723            r2.is_err(),
1724            "Pending commit should receive Err on sequencer shutdown"
1725        );
1726
1727        let _ = sequencer_handle.await;
1728    }
1729
1730    use crate::traits::{BoxFuture, CustomEndpointFactory, Sent};
1731    use std::sync::Mutex;
1732
1733    type ConsumerBehavior =
1734        Arc<Mutex<dyn FnMut() -> Result<Box<dyn MessageConsumer>, anyhow::Error> + Send + Sync>>;
1735    type PublisherBehavior =
1736        Arc<Mutex<dyn FnMut() -> Result<Box<dyn MessagePublisher>, anyhow::Error> + Send + Sync>>;
1737
1738    struct MockEndpointFactory {
1739        create_consumer_fail: bool,
1740        consumer_behavior: ConsumerBehavior,
1741        publisher_behavior: PublisherBehavior,
1742    }
1743
1744    impl std::fmt::Debug for MockEndpointFactory {
1745        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1746            f.debug_struct("MockEndpointFactory")
1747                .field("create_consumer_fail", &self.create_consumer_fail)
1748                .finish()
1749        }
1750    }
1751
1752    impl MockEndpointFactory {
1753        fn new() -> Self {
1754            Self {
1755                create_consumer_fail: false,
1756                consumer_behavior: Arc::new(Mutex::new(|| Err(anyhow::anyhow!("Not implemented")))),
1757                publisher_behavior: Arc::new(Mutex::new(|| {
1758                    Ok(Box::new(NoOpPublisher) as Box<dyn MessagePublisher>)
1759                })),
1760            }
1761        }
1762    }
1763
1764    #[derive(Clone)]
1765    struct NoOpPublisher;
1766    #[async_trait::async_trait]
1767    impl MessagePublisher for NoOpPublisher {
1768        async fn send_batch(
1769            &self,
1770            _: Vec<crate::CanonicalMessage>,
1771        ) -> Result<SentBatch, PublisherError> {
1772            Ok(SentBatch::Ack)
1773        }
1774        async fn send(&self, _: crate::CanonicalMessage) -> Result<Sent, PublisherError> {
1775            Ok(Sent::Ack)
1776        }
1777        fn as_any(&self) -> &dyn Any {
1778            self
1779        }
1780    }
1781
1782    #[async_trait::async_trait]
1783    impl CustomEndpointFactory for MockEndpointFactory {
1784        async fn create_consumer(
1785            &self,
1786            _: &str,
1787            _: &serde_json::Value,
1788        ) -> anyhow::Result<Box<dyn MessageConsumer>> {
1789            if self.create_consumer_fail {
1790                return Err(anyhow::anyhow!("Endpoint unavailable"));
1791            }
1792            (self.consumer_behavior.lock().unwrap())()
1793        }
1794        async fn create_publisher(
1795            &self,
1796            _: &str,
1797            _: &serde_json::Value,
1798        ) -> anyhow::Result<Box<dyn MessagePublisher>> {
1799            (self.publisher_behavior.lock().unwrap())()
1800        }
1801    }
1802
1803    #[derive(Clone, Default)]
1804    struct HookState {
1805        consumer_connects: Arc<AtomicUsize>,
1806        consumer_disconnects: Arc<AtomicUsize>,
1807        publisher_connects: Arc<AtomicUsize>,
1808        publisher_disconnects: Arc<AtomicUsize>,
1809        shared_mutations: Arc<AtomicUsize>,
1810        fail_consumer_connect: Arc<AtomicBool>,
1811        fail_consumer_disconnect: Arc<AtomicBool>,
1812        fail_publisher_disconnect: Arc<AtomicBool>,
1813    }
1814
1815    struct HookConsumer {
1816        state: HookState,
1817    }
1818
1819    struct HookPublisher {
1820        state: HookState,
1821    }
1822
1823    #[async_trait::async_trait]
1824    impl MessageConsumer for HookConsumer {
1825        fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1826            Some(Box::pin(async move {
1827                self.state.consumer_connects.fetch_add(1, Ordering::SeqCst);
1828                self.state.shared_mutations.fetch_add(1, Ordering::SeqCst);
1829                if self.state.fail_consumer_connect.load(Ordering::SeqCst) {
1830                    return Err(anyhow::anyhow!("consumer hook failed"));
1831                }
1832                Ok(())
1833            }))
1834        }
1835
1836        fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1837            Some(Box::pin(async move {
1838                self.state
1839                    .consumer_disconnects
1840                    .fetch_add(1, Ordering::SeqCst);
1841                if self.state.fail_consumer_disconnect.load(Ordering::SeqCst) {
1842                    return Err(anyhow::anyhow!("consumer disconnect hook failed"));
1843                }
1844                Ok(())
1845            }))
1846        }
1847
1848        async fn receive_batch(&mut self, _max: usize) -> Result<ReceivedBatch, ConsumerError> {
1849            Err(ConsumerError::EndOfStream)
1850        }
1851
1852        fn as_any(&self) -> &dyn Any {
1853            self
1854        }
1855    }
1856
1857    #[async_trait::async_trait]
1858    impl MessagePublisher for HookPublisher {
1859        fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1860            Some(Box::pin(async move {
1861                self.state.publisher_connects.fetch_add(1, Ordering::SeqCst);
1862                self.state.shared_mutations.fetch_add(1, Ordering::SeqCst);
1863                Ok(())
1864            }))
1865        }
1866
1867        fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
1868            Some(Box::pin(async move {
1869                self.state
1870                    .publisher_disconnects
1871                    .fetch_add(1, Ordering::SeqCst);
1872                if self.state.fail_publisher_disconnect.load(Ordering::SeqCst) {
1873                    return Err(anyhow::anyhow!("publisher disconnect hook failed"));
1874                }
1875                Ok(())
1876            }))
1877        }
1878
1879        async fn send_batch(
1880            &self,
1881            _: Vec<crate::CanonicalMessage>,
1882        ) -> Result<SentBatch, PublisherError> {
1883            Ok(SentBatch::Ack)
1884        }
1885
1886        fn as_any(&self) -> &dyn Any {
1887            self
1888        }
1889    }
1890
1891    fn hook_route(state: HookState, concurrency: usize) -> Route {
1892        let unique_id = fast_uuid_v7::gen_id().to_string();
1893        let factory_name = format!("hooks_{}", unique_id);
1894        let mut factory = MockEndpointFactory::new();
1895
1896        let consumer_state = state.clone();
1897        factory.consumer_behavior = Arc::new(Mutex::new(move || {
1898            Ok(Box::new(HookConsumer {
1899                state: consumer_state.clone(),
1900            }) as Box<dyn MessageConsumer>)
1901        }));
1902
1903        let publisher_state = state;
1904        factory.publisher_behavior = Arc::new(Mutex::new(move || {
1905            Ok(Box::new(HookPublisher {
1906                state: publisher_state.clone(),
1907            }) as Box<dyn MessagePublisher>)
1908        }));
1909
1910        register_endpoint_factory(&factory_name, Arc::new(factory));
1911
1912        let input = Endpoint {
1913            endpoint_type: EndpointType::Custom {
1914                name: factory_name.clone(),
1915                config: serde_json::Value::Null,
1916            },
1917            middlewares: vec![],
1918            handler: None,
1919        };
1920        let output = Endpoint {
1921            endpoint_type: EndpointType::Custom {
1922                name: factory_name,
1923                config: serde_json::Value::Null,
1924            },
1925            middlewares: vec![],
1926            handler: None,
1927        };
1928        Route::new(input, output).with_concurrency(concurrency)
1929    }
1930
1931    #[tokio::test]
1932    async fn test_lifecycle_hooks_called_once_sequentially() {
1933        let state = HookState::default();
1934        let route = hook_route(state.clone(), 1);
1935
1936        let stopped_by_shutdown = route
1937            .run_until_err("test_lifecycle_sequential", None, None)
1938            .await
1939            .unwrap();
1940
1941        assert!(!stopped_by_shutdown);
1942        assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
1943        assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
1944        assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
1945        assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
1946        assert_eq!(state.shared_mutations.load(Ordering::SeqCst), 2);
1947    }
1948
1949    #[tokio::test]
1950    async fn test_lifecycle_hooks_called_once_concurrently() {
1951        let state = HookState::default();
1952        let route = hook_route(state.clone(), 4);
1953
1954        route
1955            .run_until_err("test_lifecycle_concurrent", None, None)
1956            .await
1957            .unwrap();
1958
1959        assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
1960        assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
1961        assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
1962        assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
1963    }
1964
1965    #[tokio::test]
1966    async fn test_lifecycle_on_connect_failure_stops_route() {
1967        let state = HookState::default();
1968        state.fail_consumer_connect.store(true, Ordering::SeqCst);
1969        let route = hook_route(state.clone(), 1);
1970
1971        let err = route
1972            .run_until_err("test_lifecycle_connect_failure", None, None)
1973            .await
1974            .unwrap_err();
1975
1976        assert!(err.to_string().contains("on_connect hook failed"));
1977        assert_eq!(state.publisher_connects.load(Ordering::SeqCst), 1);
1978        assert_eq!(state.consumer_connects.load(Ordering::SeqCst), 1);
1979    }
1980
1981    #[tokio::test]
1982    async fn test_lifecycle_on_disconnect_failure_does_not_stop_route() {
1983        let state = HookState::default();
1984        state.fail_consumer_disconnect.store(true, Ordering::SeqCst);
1985        state
1986            .fail_publisher_disconnect
1987            .store(true, Ordering::SeqCst);
1988        let route = hook_route(state.clone(), 1);
1989
1990        let stopped_by_shutdown = route
1991            .run_until_err("test_lifecycle_disconnect_failure", None, None)
1992            .await
1993            .unwrap();
1994
1995        assert!(!stopped_by_shutdown);
1996        assert_eq!(state.consumer_disconnects.load(Ordering::SeqCst), 1);
1997        assert_eq!(state.publisher_disconnects.load(Ordering::SeqCst), 1);
1998    }
1999
2000    #[tokio::test]
2001    async fn test_start_fails_on_unavailable_endpoint() {
2002        // tokio::time::pause();
2003        let unique_id = fast_uuid_v7::gen_id().to_string();
2004        let factory_name = format!("unavailable_{}", unique_id);
2005
2006        let factory = Arc::new(MockEndpointFactory {
2007            create_consumer_fail: true,
2008            ..MockEndpointFactory::new()
2009        });
2010        register_endpoint_factory(&factory_name, factory);
2011
2012        let input = Endpoint {
2013            endpoint_type: EndpointType::Custom {
2014                name: factory_name,
2015                config: serde_json::Value::Null,
2016            },
2017            middlewares: vec![],
2018            handler: None,
2019        };
2020        let output = Endpoint::new_memory("out", 10);
2021        let route = Route::new(input, output);
2022
2023        // The route should fail to start because the input endpoint fails to create.
2024        // The run() method waits for a ready signal which never comes.
2025        let result = route.run("test_start_fail").await;
2026        assert!(result.is_err());
2027        assert!(result.unwrap_err().to_string().contains("failed to start"));
2028    }
2029
2030    #[tokio::test]
2031    async fn test_reconnect_on_consumer_error() {
2032        // tokio::time::pause();
2033        let unique_id = fast_uuid_v7::gen_id().to_string();
2034        let factory_name = format!("reconnect_{}", unique_id);
2035
2036        // Shared state to track connection attempts
2037        let connection_attempts = Arc::new(AtomicUsize::new(0));
2038        let attempts_clone = connection_attempts.clone();
2039
2040        let consumer_logic = move || -> Result<Box<dyn MessageConsumer>, anyhow::Error> {
2041            let attempt = attempts_clone.fetch_add(1, Ordering::SeqCst);
2042
2043            struct FlakyConsumer {
2044                attempt: usize,
2045            }
2046            #[async_trait::async_trait]
2047            impl MessageConsumer for FlakyConsumer {
2048                async fn receive_batch(
2049                    &mut self,
2050                    _max: usize,
2051                ) -> Result<ReceivedBatch, ConsumerError> {
2052                    if self.attempt == 0 {
2053                        // First connection works for one batch, then fails
2054                        self.attempt = 999; // prevent infinite loop in this instance
2055                        Ok(ReceivedBatch {
2056                            messages: vec![crate::CanonicalMessage::from("msg1")],
2057                            commit: Box::new(|_| Box::pin(async { Ok(()) })),
2058                        })
2059                    } else if self.attempt == 999 {
2060                        // Simulate connection drop
2061                        Err(ConsumerError::Connection(anyhow::anyhow!(
2062                            "Connection dropped"
2063                        )))
2064                    } else {
2065                        // Subsequent connections work
2066                        // Sleep a bit to prevent busy loop in test
2067                        tokio::time::sleep(Duration::from_millis(100)).await;
2068                        Ok(ReceivedBatch {
2069                            messages: vec![crate::CanonicalMessage::from("msg2")],
2070                            commit: Box::new(|_| Box::pin(async { Ok(()) })),
2071                        })
2072                    }
2073                }
2074                fn as_any(&self) -> &dyn Any {
2075                    self
2076                }
2077            }
2078            Ok(Box::new(FlakyConsumer { attempt }))
2079        };
2080
2081        let mut factory = MockEndpointFactory::new();
2082        factory.consumer_behavior = Arc::new(Mutex::new(consumer_logic));
2083        register_endpoint_factory(&factory_name, Arc::new(factory));
2084
2085        let input = Endpoint {
2086            endpoint_type: EndpointType::Custom {
2087                name: factory_name,
2088                config: serde_json::Value::Null,
2089            },
2090            middlewares: vec![],
2091            handler: None,
2092        };
2093        let output = Endpoint::new_memory(&format!("out_{}", unique_id), 10);
2094        let route = Route::new(input, output.clone());
2095
2096        route.deploy("test_reconnect").await.unwrap();
2097
2098        // Wait for reconnection and messages
2099        let mut verifier = create_consumer_from_route("verifier", &output)
2100            .await
2101            .unwrap();
2102
2103        // Should receive msg1
2104        let msg1 = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
2105            .await
2106            .expect("Timed out waiting for msg1")
2107            .unwrap();
2108        assert_eq!(msg1.message.get_payload_str(), "msg1");
2109
2110        // Route encounters error, sleeps 5s (skipped by pause), reconnects.
2111        // Should receive msg2
2112        let msg2 = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
2113            .await
2114            .expect("Timed out waiting for msg2")
2115            .unwrap();
2116        assert_eq!(msg2.message.get_payload_str(), "msg2");
2117
2118        assert!(connection_attempts.load(Ordering::SeqCst) >= 2);
2119        Route::stop("test_reconnect").await;
2120    }
2121
2122    #[tokio::test]
2123    async fn test_non_retryable_handler_error_does_not_crash_route() {
2124        let unique_id = fast_uuid_v7::gen_id().to_string();
2125        let in_topic = format!("bad_input_in_{}", unique_id);
2126        let out_topic = format!("bad_input_out_{}", unique_id); // Not used, but good practice
2127
2128        let input = Endpoint::new_memory(&in_topic, 10);
2129        let output = Endpoint::new_memory(&out_topic, 10);
2130
2131        // A handler that fails on specific input
2132        let handler = |msg: crate::CanonicalMessage| async move {
2133            if msg.get_payload_str() == "poison" {
2134                Err(HandlerError::NonRetryable(anyhow::anyhow!("Invalid input")))
2135            } else {
2136                Ok(crate::Handled::Publish(msg))
2137            }
2138        };
2139
2140        let route = Route::new(input.clone(), output).with_handler(handler);
2141        route.deploy("test_invalid_input").await.unwrap();
2142
2143        let input_ch = input.channel().unwrap();
2144        let out_channel = route.output.channel().unwrap();
2145
2146        // 1. Send poison message
2147        input_ch.send_message("poison".into()).await.unwrap();
2148
2149        // 2. Send valid message
2150        input_ch.send_message("valid".into()).await.unwrap();
2151
2152        // 3. Verify the valid message was processed and published
2153        let received = tokio::time::timeout(std::time::Duration::from_secs(5), async {
2154            loop {
2155                if let Some(msg) = out_channel.drain_messages().pop() {
2156                    return msg;
2157                }
2158                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2159            }
2160        })
2161        .await
2162        .expect("Timed out waiting for valid message to be processed");
2163        assert_eq!(received.get_payload_str(), "valid");
2164        Route::stop("test_invalid_input").await;
2165    }
2166
2167    #[tokio::test(flavor = "multi_thread")]
2168    async fn test_dlq_and_retry_batch_integration() {
2169        use crate::models::{DeadLetterQueueMiddleware, Middleware, RetryMiddleware};
2170        use crate::traits::{MessagePublisher, PublisherError, SentBatch};
2171        use std::collections::HashMap;
2172        use std::sync::Mutex;
2173
2174        // Mock publisher that fails messages with even-numbered IDs
2175        #[derive(Clone)]
2176        struct PartialFailPublisher {
2177            attempts: Arc<Mutex<HashMap<u128, usize>>>,
2178        }
2179
2180        #[async_trait::async_trait]
2181        impl MessagePublisher for PartialFailPublisher {
2182            async fn send_batch(
2183                &self,
2184                messages: Vec<CanonicalMessage>,
2185            ) -> Result<SentBatch, PublisherError> {
2186                let mut failed = Vec::new();
2187                let mut attempts = self.attempts.lock().unwrap();
2188
2189                for msg in messages {
2190                    let msg_num: u32 = serde_json::from_slice::<serde_json::Value>(&msg.payload)
2191                        .unwrap()["id"]
2192                        .as_u64()
2193                        .unwrap() as u32;
2194
2195                    let attempt_count = attempts.entry(msg.message_id).or_insert(0);
2196                    *attempt_count += 1;
2197
2198                    if msg_num % 2 == 0 {
2199                        // Fail even numbers
2200                        failed.push((
2201                            msg,
2202                            PublisherError::Retryable(anyhow::anyhow!("simulated failure")),
2203                        ));
2204                    }
2205                    // Odd numbers succeed implicitly by not being in `failed`
2206                }
2207
2208                if failed.is_empty() {
2209                    Ok(SentBatch::Ack)
2210                } else {
2211                    Ok(SentBatch::Partial {
2212                        responses: None,
2213                        failed,
2214                    })
2215                }
2216            }
2217            async fn send(
2218                &self,
2219                _msg: CanonicalMessage,
2220            ) -> Result<crate::traits::Sent, PublisherError> {
2221                unimplemented!()
2222            }
2223            fn as_any(&self) -> &dyn Any {
2224                self
2225            }
2226        }
2227
2228        // 1. Setup
2229        let in_topic = "batch_retry_dlq_in";
2230        let out_topic = "batch_retry_dlq_out";
2231        let dlq_topic = "batch_retry_dlq_dlq";
2232
2233        let input = Endpoint::new_memory(in_topic, 10);
2234        let dlq_endpoint = Endpoint::new_memory(dlq_topic, 10);
2235
2236        let mock_publisher = PartialFailPublisher {
2237            attempts: Arc::new(Mutex::new(HashMap::new())),
2238        };
2239
2240        let mut output_with_middlewares = Endpoint::new_memory(out_topic, 10);
2241        output_with_middlewares.middlewares = vec![
2242            Middleware::Retry(RetryMiddleware {
2243                max_attempts: 2,
2244                initial_interval_ms: 1,
2245                ..Default::default()
2246            }),
2247            Middleware::Dlq(Box::new(DeadLetterQueueMiddleware {
2248                endpoint: dlq_endpoint.clone(),
2249            })),
2250        ];
2251
2252        let route = Route::new(input.clone(), output_with_middlewares).with_batch_size(4);
2253        // Inject the mock publisher into the route's output
2254        let final_publisher = crate::middleware::apply_middlewares_to_publisher(
2255            Box::new(mock_publisher.clone()),
2256            &route.output,
2257            "test_route",
2258        )
2259        .await
2260        .unwrap();
2261
2262        // We need a way to run the route with our mocked publisher.
2263        // The simplest way is to manually drive the core logic.
2264        let (work_tx, work_rx) =
2265            async_channel::bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(1);
2266        let (seq_tx, _sequencer_handle) = spawn_sequencer(1);
2267
2268        // Spawn a worker to process one batch
2269        tokio::spawn(async move {
2270            if let Ok((messages, commit)) = work_rx.recv().await {
2271                let batch_len = messages.len();
2272                match final_publisher.send_batch(messages).await {
2273                    Ok(SentBatch::Ack) => {
2274                        let _ = commit(vec![MessageDisposition::Ack; batch_len]).await;
2275                    }
2276                    Ok(SentBatch::Partial { failed, .. }) => {
2277                        // In a real route, we'd map responses, but here we just care about failure.
2278                        let dispositions = if failed.is_empty() {
2279                            vec![MessageDisposition::Ack; batch_len]
2280                        } else {
2281                            // This is a simplification for the test. A real implementation
2282                            // would map dispositions based on message IDs.
2283                            vec![MessageDisposition::Nack; batch_len]
2284                        };
2285                        let _ = commit(dispositions).await;
2286                    }
2287                    Err(_) => {
2288                        let _ = commit(vec![MessageDisposition::Nack; batch_len]).await;
2289                    }
2290                }
2291            }
2292        });
2293
2294        // 2. Send a batch of messages
2295        let mut messages = Vec::new();
2296        for i in 1..=4 {
2297            // 1 (ok), 2 (fail), 3 (ok), 4 (fail)
2298            messages.push(CanonicalMessage::from_json(serde_json::json!({"id": i})).unwrap());
2299        }
2300        let commit = wrap_commit(Box::new(|_| Box::pin(async { Ok(()) })), 0, seq_tx.clone());
2301        work_tx.send((messages, commit)).await.unwrap();
2302
2303        // 3. Verify
2304        let dlq_channel = dlq_endpoint.channel().unwrap();
2305
2306        let start = std::time::Instant::now();
2307        while dlq_channel.len() < 2 {
2308            if start.elapsed() > std::time::Duration::from_secs(5) {
2309                break;
2310            }
2311            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2312        }
2313
2314        let dlq_msgs = dlq_channel.drain_messages();
2315
2316        assert_eq!(dlq_msgs.len(), 2, "Expected 2 messages to go to DLQ");
2317
2318        let dlq_ids: std::collections::HashSet<u32> = dlq_msgs
2319            .iter()
2320            .map(|m| {
2321                serde_json::from_slice::<serde_json::Value>(&m.payload).unwrap()["id"]
2322                    .as_u64()
2323                    .unwrap() as u32
2324            })
2325            .collect();
2326
2327        assert!(dlq_ids.contains(&2));
2328        assert!(dlq_ids.contains(&4));
2329
2330        // Verify retry attempts
2331        let attempts = mock_publisher.attempts.lock().unwrap();
2332        // Messages 2 and 4 should be tried `max_attempts` times.
2333        assert_eq!(attempts.values().filter(|&&c| c == 2).count(), 2);
2334        // Messages 1 and 3 should be tried once.
2335        assert_eq!(attempts.values().filter(|&&c| c == 1).count(), 2);
2336    }
2337
2338    #[tokio::test(flavor = "multi_thread")]
2339    async fn test_route_dlq_integration() {
2340        // Setup: Input -> [Panic(Disconnect) -> Retry -> DLQ] -> Output
2341        // Panic(Disconnect) simulates transient failure.
2342        // Retry handles it up to N times.
2343        // If max attempts reached, DLQ catches it.
2344        // Note: Middleware application order is [Panic, Retry, DLQ] in list to wrap as DLQ(Retry(Panic(Endpoint))).
2345
2346        let unique_id = fast_uuid_v7::gen_id().to_string();
2347        let in_topic = format!("dlq_in_{}", unique_id);
2348        let out_topic = format!("dlq_out_{}", unique_id);
2349        let dlq_topic = format!("dlq_target_{}", unique_id);
2350        let input = Endpoint::new_memory(&in_topic, 10);
2351        let dlq_endpoint = Endpoint::new_memory(&dlq_topic, 10);
2352
2353        let mut output = Endpoint::new_memory(&out_topic, 10);
2354        output.middlewares = vec![
2355            // Inner-most: Fail always
2356            Middleware::RandomPanic(RandomPanicMiddleware {
2357                mode: FaultMode::Timeout, // Returns Retryable error, does NOT cause route restart
2358                trigger_on_message: None, // Fail always
2359                enabled: true,
2360                ..Default::default()
2361            }),
2362            // Middle: Retry
2363            Middleware::Retry(crate::models::RetryMiddleware {
2364                max_attempts: 2,
2365                initial_interval_ms: 10,
2366                max_interval_ms: 100,
2367                multiplier: 1.0,
2368            }),
2369            // Outer-most: DLQ
2370            Middleware::Dlq(Box::new(crate::models::DeadLetterQueueMiddleware {
2371                endpoint: dlq_endpoint.clone(),
2372            })),
2373        ];
2374
2375        let route = Route::new(input.clone(), output);
2376        route.deploy("test_dlq_integration").await.unwrap();
2377
2378        // Send message
2379        let input_ch = input.channel().unwrap();
2380        input_ch.send_message("fail_msg".into()).await.unwrap();
2381
2382        // Verify:
2383        // 1. Output channel is empty (msg failed to go there)
2384        // 2. DLQ channel has message
2385
2386        let dlq_ch = dlq_endpoint.channel().unwrap();
2387
2388        // Wait for DLQ
2389        let received = tokio::time::timeout(std::time::Duration::from_secs(5), async {
2390            loop {
2391                let batch = dlq_ch.drain_messages();
2392                if !batch.is_empty() {
2393                    return batch[0].clone();
2394                }
2395                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
2396            }
2397        })
2398        .await
2399        .expect("Timed out waiting for DLQ");
2400
2401        assert_eq!(received.get_payload_str(), "fail_msg");
2402
2403        let out_ch_target = mq_bridge::endpoints::memory::get_or_create_channel(
2404            &mq_bridge::models::MemoryConfig::new(&out_topic, None),
2405        );
2406        assert!(out_ch_target.is_empty(), "Message should not reach target");
2407
2408        Route::stop("test_dlq_integration").await;
2409    }
2410
2411    #[tokio::test(flavor = "multi_thread")]
2412    async fn test_large_message_handling() {
2413        let unique_id = fast_uuid_v7::gen_id().to_string();
2414        let in_topic = format!("large_in_{}", unique_id);
2415        let out_topic = format!("large_out_{}", unique_id);
2416
2417        let input = Endpoint::new_memory(&in_topic, 5); // Small capacity
2418        let output = Endpoint::new_memory(&out_topic, 5);
2419
2420        let route = Route::new(input.clone(), output.clone());
2421        route.deploy("test_large_msg").await.unwrap();
2422
2423        let large_payload = vec![b'x'; 5 * 1024 * 1024]; // 5MB
2424        let input_ch = input.channel().unwrap();
2425
2426        input_ch
2427            .send_message(large_payload.clone().into())
2428            .await
2429            .unwrap();
2430
2431        let mut verifier = route.connect_to_output("verifier").await.unwrap();
2432        let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
2433            .await
2434            .expect("Timed out receiving large message")
2435            .unwrap();
2436
2437        assert_eq!(received.message.payload.len(), large_payload.len());
2438        assert_eq!(received.message.payload, large_payload.as_slice());
2439
2440        Route::stop("test_large_msg").await;
2441    }
2442
2443    #[test]
2444    fn test_map_responses_to_dispositions_unit() {
2445        test_map_responses_to_dispositions_logic();
2446    }
2447}