Skip to main content

mq_bridge/
route.rs

1//  mq-bridge
2//  © Copyright 2025, by Marco Mengelkoch
3//  Licensed under MIT License, see License file for more details
4//  git clone https://github.com/marcomq/mq-bridge
5
6use crate::endpoints::{create_consumer_from_route, create_publisher_from_route};
7pub use crate::models::Route;
8use crate::models::{Endpoint, RouteOptions};
9use crate::traits::{
10    BatchCommitFunc, ConsumerError, Handler, HandlerError, MessageDisposition, PublisherError,
11    SentBatch,
12};
13use async_channel::{bounded, Sender};
14use serde::de::DeserializeOwned;
15use std::collections::{BTreeMap, HashMap};
16use std::sync::{Arc, OnceLock, RwLock};
17use tokio::{
18    select,
19    sync::Semaphore,
20    task::{JoinHandle, JoinSet},
21};
22use tracing::{debug, error, info, warn};
23
24// Re-export extensions for backward compatibility and internal usage
25pub use crate::extensions::{
26    get_endpoint_factory, get_middleware_factory, register_endpoint_factory,
27    register_middleware_factory,
28};
29
30#[derive(Debug)]
31pub struct RouteHandle((JoinHandle<()>, Sender<()>));
32
33impl RouteHandle {
34    pub async fn stop(&self) {
35        let _ = self.0 .1.send(()).await;
36        self.0 .1.close();
37    }
38
39    pub async fn join(self) -> Result<(), tokio::task::JoinError> {
40        self.0 .0.await
41    }
42}
43
44impl From<(JoinHandle<()>, Sender<()>)> for RouteHandle {
45    fn from(tuple: (JoinHandle<()>, Sender<()>)) -> Self {
46        RouteHandle(tuple)
47    }
48}
49
50struct ActiveRoute {
51    route: Route,
52    handle: RouteHandle,
53}
54
55static ROUTE_REGISTRY: OnceLock<RwLock<HashMap<String, ActiveRoute>>> = OnceLock::new();
56
57impl Route {
58    /// Creates a new route with default concurrency (1) and batch size (128).
59    ///
60    /// # Arguments
61    /// * `input` - The input/source endpoint for the route
62    /// * `output` - The output/sink endpoint for the route
63    pub fn new(input: Endpoint, output: Endpoint) -> Self {
64        Self {
65            input,
66            output,
67            ..Default::default()
68        }
69    }
70
71    /// Retrieves a registered (and running) route by name.
72    pub fn get(name: &str) -> Option<Self> {
73        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
74        let map = registry.read().expect("Route registry lock poisoned");
75        map.get(name).map(|active| active.route.clone())
76    }
77
78    /// Returns a list of all registered route names.
79    pub fn list() -> Vec<String> {
80        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
81        let map = registry.read().expect("Route registry lock poisoned");
82        map.keys().cloned().collect()
83    }
84
85    /// Registers the route and starts it.
86    /// If a route with the same name is already running, it will be stopped first.
87    ///    
88    /// # Examples
89    /// ```
90    /// use mq_bridge::{Route, models::Endpoint};
91    ///
92    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10));
93    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
94    /// route.deploy("global_route").await.unwrap();
95    /// assert!(Route::get("global_route").is_some());
96    /// # });
97    /// ```
98    pub async fn deploy(&self, name: &str) -> anyhow::Result<()> {
99        Self::stop(name).await;
100
101        let handle = self.run(name).await?;
102        let active = ActiveRoute {
103            route: self.clone(),
104            handle,
105        };
106
107        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
108        let mut map = registry.write().expect("Route registry lock poisoned");
109        map.insert(name.to_string(), active);
110        Ok(())
111    }
112
113    /// Stops a running route by name and removes it from the registry.
114    pub async fn stop(name: &str) -> bool {
115        let registry = ROUTE_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
116        let active_opt = {
117            let mut map = registry.write().expect("Route registry lock poisoned");
118            map.remove(name)
119        };
120
121        if let Some(active) = active_opt {
122            active.handle.stop().await;
123            let _ = active.handle.join().await;
124            true
125        } else {
126            false
127        }
128    }
129
130    /// Creates a new Publisher configured for this route's output.
131    /// This is useful if you want to send messages to the same destination as this route.
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
137    /// use mq_bridge::{Route, models::Endpoint};
138    ///
139    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10));
140    /// let publisher = route.create_publisher().await;
141    /// assert!(publisher.is_ok());
142    /// # });
143    /// ```
144    pub async fn create_publisher(&self) -> anyhow::Result<crate::Publisher> {
145        crate::Publisher::new(self.output.clone()).await
146    }
147
148    /// Creates a consumer connected to the route's output.
149    /// This is primarily useful for integration tests to verify messages reaching the destination.
150    pub async fn connect_to_output(
151        &self,
152        name: &str,
153    ) -> anyhow::Result<Box<dyn crate::traits::MessageConsumer>> {
154        create_consumer_from_route(name, &self.output).await
155    }
156
157    /// Validates the route configuration, checking if endpoints are supported and correctly configured.
158    /// Core types like file, memory, and response are always supported.
159    /// # Arguments
160    /// * `name` - The name of the route
161    /// * `allowed_endpoints` - An optional list of allowed endpoint types
162    pub fn check(&self, name: &str, allowed_endpoints: Option<&[&str]>) -> anyhow::Result<()> {
163        crate::endpoints::check_consumer(name, &self.input, allowed_endpoints)?;
164        crate::endpoints::check_publisher(name, &self.output, allowed_endpoints)?;
165        Ok(())
166    }
167
168    /// Runs the message processing route with concurrency, error handling, and graceful shutdown.
169    ///
170    /// This function spawns the necessary background tasks to process messages. It waits asynchronously
171    /// until the route is successfully initialized (i.e., connections are established) or until
172    /// a timeout occurs.
173    /// The name_str parameter is just used for logging and tracing.
174    ///
175    /// It returns a `JoinHandle` for the main route task and a `Sender` channel
176    /// that can be used to signal a graceful shutdown. The result is typically converted into a
177    /// [`RouteHandle`] for easier management.
178    ///
179    /// # Examples
180    ///
181    /// ```no_run
182    /// # use mq_bridge::{Route, route::RouteHandle, models::Endpoint};
183    /// # async fn example() -> anyhow::Result<()> {
184    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10));
185    ///
186    /// // Start the route (blocks until initialized) and convert to RouteHandle
187    /// let handle: RouteHandle = route.run("my_route").await?.into();
188    ///
189    /// // Stop the route later
190    /// handle.stop().await;
191    /// handle.join().await?;
192    /// # Ok(())
193    /// # }
194    /// ```
195    pub async fn run(&self, name_str: &str) -> anyhow::Result<RouteHandle> {
196        self.check(name_str, None)?;
197        let (shutdown_tx, shutdown_rx) = bounded(1);
198        let (ready_tx, ready_rx) = bounded(1);
199        // Use `Arc` so route/name clones are cheap (pointer copy) in the reconnect loop.
200        let route = Arc::new(self.clone());
201        let name = Arc::new(name_str.to_string());
202
203        let handle = tokio::spawn(async move {
204            loop {
205                let route_arc = Arc::clone(&route);
206                let name_arc = Arc::clone(&name);
207                // Create a new, per-iteration internal shutdown channel.
208                // This avoids a race where both this loop and the inner task
209                // try to consume the same external shutdown signal.
210                let (internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
211                let ready_tx_clone = ready_tx.clone();
212
213                // The actual route logic is in `run_until_err`.
214                let mut run_task = tokio::spawn(async move {
215                    route_arc
216                        .run_until_err(&name_arc, Some(internal_shutdown_rx), Some(ready_tx_clone))
217                        .await
218                });
219
220                select! {
221                    _ = shutdown_rx.recv() => {
222                        info!("Shutdown signal received for route '{}'.", name);
223                        // Notify the inner task to shut down.
224                        let _ = internal_shutdown_tx.send(()).await;
225                        // Wait for the inner task to finish gracefully.
226                        let _ = run_task.await;
227                        break;
228                    }
229                    res = &mut run_task => {
230                        match res {
231                            Ok(Ok(should_continue)) if !should_continue => {
232                                info!("Route '{}' completed gracefully. Shutting down.", name);
233                                break;
234                            }
235                            Ok(Err(e)) => {
236                                error!("Route '{}' failed: {}. Reconnecting in 5 seconds...", name, e);
237                                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
238                            }
239                            Err(e) => {
240                                error!("Route '{}' task panicked: {}. Reconnecting in 5 seconds...", name, e);
241                                tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
242                            }
243                            _ => {} // The route should continue running.
244                        }
245                    }
246                }
247            }
248        });
249
250        match tokio::time::timeout(std::time::Duration::from_secs(5), ready_rx.recv()).await {
251            Ok(Ok(_)) => Ok(RouteHandle((handle, shutdown_tx))),
252            _ => {
253                handle.abort();
254                Err(anyhow::anyhow!(
255                    "Route '{}' failed to start within 5 seconds or encountered an error",
256                    name_str
257                ))
258            }
259        }
260    }
261
262    /// The core logic of running the route, designed to be called within a reconnect loop.
263    pub async fn run_until_err(
264        &self,
265        name: &str,
266        shutdown_rx: Option<async_channel::Receiver<()>>,
267        ready_tx: Option<Sender<()>>,
268    ) -> anyhow::Result<bool> {
269        let (_internal_shutdown_tx, internal_shutdown_rx) = bounded(1);
270        let shutdown_rx = shutdown_rx.unwrap_or(internal_shutdown_rx);
271        if self.options.concurrency == 1 {
272            self.run_sequentially(name, shutdown_rx, ready_tx).await
273        } else {
274            self.run_concurrently(name, shutdown_rx, ready_tx).await
275        }
276    }
277
278    /// A simplified, sequential runner for when concurrency is 1.
279    async fn run_sequentially(
280        &self,
281        name: &str,
282        shutdown_rx: async_channel::Receiver<()>,
283        ready_tx: Option<Sender<()>>,
284    ) -> anyhow::Result<bool> {
285        let publisher = create_publisher_from_route(name, &self.output).await?;
286        let mut consumer = create_consumer_from_route(name, &self.input).await?;
287        let (err_tx, err_rx) = bounded(1);
288        let commit_semaphore = Arc::new(Semaphore::new(self.options.commit_concurrency_limit));
289        let mut commit_tasks = JoinSet::new();
290
291        // Sequencer setup to ensure ordered commits even with parallel commit tasks
292        let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.commit_concurrency_limit);
293        let mut seq_counter = 0u64;
294
295        if let Some(tx) = ready_tx {
296            let _ = tx.send(()).await;
297        }
298        let run_result = loop {
299            select! {
300                Ok(err) = err_rx.recv() => break Err(err),
301
302                _ = shutdown_rx.recv() => {
303                    info!("Shutdown signal received in sequential runner for route '{}'.", name);
304                    break Ok(true); // Stopped by shutdown signal
305                }
306                res = consumer.receive_batch(self.options.batch_size) => {
307                    let received_batch = match res {
308                        Ok(batch) => {
309                            if batch.messages.is_empty() {
310                                continue; // No messages, loop to select! again
311                            }
312                            batch
313                        }
314                        Err(ConsumerError::EndOfStream) => {
315                            info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
316                            break Ok(false); // Graceful exit
317                        }
318                        Err(ConsumerError::Connection(e)) => {
319                            // Propagate error to trigger reconnect by the outer loop
320                            break Err(e);
321                        },
322                        Err(ConsumerError::Gap { requested, base }) => {
323                            // Propagate gap error to trigger reconnect by the outer loop
324                            break Err(anyhow::anyhow!("Consumer gap: requested offset {requested} but earliest available is {base}"));
325                        }
326                    };
327                    debug!("Received a batch of {} messages sequentially", received_batch.messages.len());
328
329                    // Process the batch sequentially without spawning a new task
330                    let seq = seq_counter;
331                    seq_counter += 1;
332                    let commit = wrap_commit(received_batch.commit, seq, seq_tx.clone());
333                    let batch_len = received_batch.messages.len();
334
335                    match publisher.send_batch(received_batch.messages).await {
336                        Ok(SentBatch::Ack) => {
337                            let permit = commit_semaphore.clone().acquire_owned().await.map_err(|e| anyhow::anyhow!("Semaphore error: {}", e))?;
338                            let err_tx = err_tx.clone();
339                            commit_tasks.spawn(async move {
340                                if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
341                                    error!("Commit failed: {}", e);
342                                    let _ = err_tx.send(e).await;
343                                }
344                                // Permit is dropped here, releasing the slot
345                                drop(permit);
346                            });
347                        }
348                        Ok(SentBatch::Partial { responses, failed }) => {
349                            let has_retryable = failed.iter().any(|(_, e)| matches!(e, PublisherError::Retryable(_)));
350                            if has_retryable {
351                                let failed_count = failed.len();
352                                let (_, first_error) = failed
353                                    .into_iter()
354                                    .find(|(_, e)| matches!(e, PublisherError::Retryable(_)))
355                                    .expect("has_retryable is true");
356                                break Err(anyhow::anyhow!(
357                                    "Failed to send {} messages in batch. First retryable error: {}",
358                                    failed_count,
359                                    first_error
360                                ));
361                            }
362                            for (msg, e) in &failed {
363                                error!("Dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
364                            }
365                            let permit = commit_semaphore.clone().acquire_owned().await.map_err(|e| anyhow::anyhow!("Semaphore error: {}", e))?;
366                            let err_tx = err_tx.clone();
367                            commit_tasks.spawn(async move {
368                                let dispositions = map_responses_to_dispositions(batch_len, responses, &failed);
369                                if let Err(e) = commit(dispositions).await {
370                                    error!("Commit failed: {}", e);
371                                    let _ = err_tx.send(e).await;
372                                }
373                                drop(permit);
374                            });
375                        }
376                        Err(e) => break Err(e.into()), // Propagate error to trigger reconnect
377                    }
378                }
379            }
380        };
381
382        drop(seq_tx);
383        // Drain errors while waiting for tasks to finish to prevent deadlocks and lost errors
384        loop {
385            select! {
386                res = err_rx.recv() => {
387                    if let Ok(err) = res {
388                        error!("Error reported during shutdown: {}", err);
389                    }
390                }
391                res = commit_tasks.join_next() => {
392                    if res.is_none() {
393                        break;
394                    }
395                }
396            }
397        }
398        drop(err_rx);
399        let _ = sequencer_handle.await;
400        run_result
401    }
402
403    /// The main concurrent runner for when concurrency > 1.
404    async fn run_concurrently(
405        &self,
406        name: &str,
407        shutdown_rx: async_channel::Receiver<()>,
408        ready_tx: Option<Sender<()>>,
409    ) -> anyhow::Result<bool> {
410        let publisher = create_publisher_from_route(name, &self.output).await?;
411        let mut consumer = create_consumer_from_route(name, &self.input).await?;
412        if let Some(tx) = ready_tx {
413            let _ = tx.send(()).await;
414        }
415        let (err_tx, err_rx) = bounded(1); // For critical, route-stopping errors
416                                           // channel capacity: a small buffer proportional to concurrency
417        let work_capacity = self
418            .options
419            .concurrency
420            .saturating_mul(self.options.batch_size);
421        let (work_tx, work_rx) =
422            bounded::<(Vec<crate::CanonicalMessage>, BatchCommitFunc)>(work_capacity);
423        let commit_semaphore = Arc::new(Semaphore::new(self.options.commit_concurrency_limit));
424
425        // --- Ordered Commit Sequencer ---
426        // To prevent data loss with cumulative-ack brokers (Kafka/AMQP), commits must happen in order.
427        // We assign a sequence number to each batch and use a sequencer task to enforce order.
428        let (seq_tx, sequencer_handle) = spawn_sequencer(self.options.concurrency * 2);
429
430        // --- Worker Pool ---
431        let mut join_set = JoinSet::new();
432        for i in 0..self.options.concurrency {
433            let work_rx_clone = work_rx.clone();
434            let publisher = Arc::clone(&publisher);
435            let err_tx = err_tx.clone();
436            let commit_semaphore = commit_semaphore.clone();
437            let mut commit_tasks = JoinSet::new();
438            join_set.spawn(async move {
439                debug!("Starting worker {}", i);
440                while let Ok((messages, commit)) = work_rx_clone.recv().await {
441                    let batch_len = messages.len();
442                    match publisher.send_batch(messages).await {
443                        Ok(SentBatch::Ack) => {
444                            let permit = match commit_semaphore.clone().acquire_owned().await {
445                                Ok(p) => p,
446                                Err(_) => {
447                                    warn!("Semaphore closed, worker exiting");
448                                    break;
449                                }
450                            };
451                            let err_tx = err_tx.clone();
452                            commit_tasks.spawn(async move {
453                                if let Err(e) = commit(vec![MessageDisposition::Ack; batch_len]).await {
454                                    error!("Commit failed: {}", e);
455                                    let _ = err_tx.send(e).await;
456                                }
457                                drop(permit);
458                            });
459                        }
460                        Ok(SentBatch::Partial { responses, failed }) => {
461                            let has_retryable = failed.iter().any(|(_, e)| matches!(e, PublisherError::Retryable(_)));
462                            if has_retryable {
463                                let failed_count = failed.len();
464                                let (_, first_error) = failed
465                                    .into_iter()
466                                    .find(|(_, e)| matches!(e, PublisherError::Retryable(_)))
467                                    .expect("has_retryable is true");
468                                let e = anyhow::anyhow!(
469                                    "Failed to send {} messages in batch. First retryable error: {}",
470                                    failed_count,
471                                    first_error
472                                );
473                                error!("Worker failed to send message batch: {}", e);
474                                if err_tx.send(e).await.is_err() {
475                                    warn!("Could not send error to main task, it might be down.");
476                                }
477                                break; // Stop processing this batch
478                            }
479                            for (msg, e) in &failed {
480                                error!("Worker dropping message (ID: {:032x}) due to non-retryable error: {}", msg.message_id, e);
481                            }
482                            let permit = match commit_semaphore.clone().acquire_owned().await {
483                                Ok(p) => p,
484                                Err(_) => {
485                                    warn!("Semaphore closed, worker exiting");
486                                    break;
487                                }
488                            };
489                            let err_tx = err_tx.clone();
490                            commit_tasks.spawn(async move {
491                                let dispositions = map_responses_to_dispositions(batch_len, responses, &failed);
492                                if let Err(e) = commit(dispositions).await {
493                                    error!("Commit failed: {}", e);
494                                    let _ = err_tx.send(e).await;
495                                }
496                                drop(permit);
497                            });
498                        }
499                        Err(e) => {
500                            error!("Worker failed to send message batch: {}", e);
501                            // Send the error back to the main task to tear down the route.
502                            if err_tx.send(e.into()).await.is_err() {
503                                warn!("Could not send error to main task, it might be down.");
504                            }
505                            break;
506                        }
507                    }
508                }
509                // Wait for all in-flight commits to complete
510                while commit_tasks.join_next().await.is_some() {}
511            });
512        }
513
514        let mut seq_counter = 0u64;
515        loop {
516            select! {
517                biased; // Prioritize checking for errors
518
519                Ok(err) = err_rx.recv() => {
520                    error!("A worker reported a critical error. Shutting down route.");
521                    return Err(err);
522                }
523
524                Some(res) = join_set.join_next() => {
525                    match res {
526                        Ok(_) => {
527                            error!("A worker task finished unexpectedly. Shutting down route.");
528                            return Err(anyhow::anyhow!("Worker task finished unexpectedly"));
529                        }
530                        Err(e) => {
531                            error!("A worker task panicked: {}. Shutting down route.", e);
532                            return Err(e.into());
533                        }
534                    }
535                }
536
537                _ = shutdown_rx.recv() => {
538                    info!("Shutdown signal received in concurrent runner for route '{}'.", name);
539                    break;
540                }
541
542                res = consumer.receive_batch(self.options.batch_size) => {
543                    let (messages, commit) = match res {
544                        Ok(batch) => {
545                            if batch.messages.is_empty() {
546                                continue; // No messages, loop to select! again
547                            }
548                            (batch.messages, batch.commit)
549                        }
550                        Err(ConsumerError::EndOfStream) => {
551                            info!("Consumer for route '{}' reached end of stream. Shutting down.", name);
552                            break; // Graceful exit
553                        }
554                        Err(ConsumerError::Connection(e)) => {
555                            // Propagate error to trigger reconnect by the outer loop
556                            return Err(e);
557                        }
558                        Err(ConsumerError::Gap { requested, base }) => {
559                            // Propagate gap error to trigger reconnect by the outer loop
560                            return Err(ConsumerError::Gap { requested, base }.into());
561                        }
562                    };
563                    debug!("Received a batch of {} messages concurrently", messages.len());
564
565                    // Wrap the commit function to route it through the sequencer
566                    let seq = seq_counter;
567                    seq_counter += 1;
568                    let wrapped_commit = wrap_commit(commit, seq, seq_tx.clone());
569
570                    if work_tx.send((messages, wrapped_commit)).await.is_err() {
571                        warn!("Work channel closed, cannot process more messages concurrently. Shutting down.");
572                        break;
573                    }
574                }
575            }
576        }
577
578        // --- Graceful Shutdown ---
579        // Close the work channel. Workers will finish their current message and then exit the loop.
580        drop(work_tx);
581        // Wait for all worker tasks to complete.
582        while join_set.join_next().await.is_some() {}
583
584        // Close sequencer
585        drop(seq_tx);
586        let _ = sequencer_handle.await;
587
588        if let Ok(err) = err_rx.try_recv() {
589            return Err(err);
590        }
591
592        // Return true if shutdown was requested (channel is empty means it was closed/consumed),
593        // false if we reached end-of-stream naturally.
594        Ok(shutdown_rx.is_empty())
595    }
596
597    pub fn with_options(mut self, options: RouteOptions) -> Self {
598        self.options = options;
599        self
600    }
601
602    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
603        self.options.concurrency = concurrency.max(1);
604        self
605    }
606
607    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
608        self.options.batch_size = batch_size.max(1);
609        self
610    }
611
612    pub fn with_commit_concurrency_limit(mut self, limit: usize) -> Self {
613        self.options.commit_concurrency_limit = limit.max(1);
614        self
615    }
616
617    pub fn with_handler(mut self, handler: impl Handler + 'static) -> Self {
618        self.output.handler = Some(Arc::new(handler));
619        self
620    }
621
622    /// Registers a typed handler for the route.
623    ///
624    /// The handler can accept either:
625    /// - `fn(T) -> Future<Output = Result<Handled, HandlerError>>`
626    /// - `fn(T, MessageContext) -> Future<Output = Result<Handled, HandlerError>>`
627    ///
628    /// # Examples
629    ///
630    /// ```
631    /// # use mq_bridge::{Route, models::Endpoint, Handled, HandlerError};
632    /// # use serde::Deserialize;
633    /// # use std::sync::Arc;
634    ///
635    /// #[derive(Deserialize)]
636    /// struct MyData {
637    ///     id: u32,
638    /// }
639    ///
640    /// async fn my_handler(data: MyData) -> Result<Handled, HandlerError> {
641    ///     Ok(Handled::Ack)
642    /// }
643    ///
644    /// let route = Route::new(Endpoint::new_memory("in", 10), Endpoint::new_memory("out", 10))
645    ///     .add_handler("my_type", my_handler);
646    /// ```
647    pub fn add_handler<T, H, Args>(mut self, type_name: &str, handler: H) -> Self
648    where
649        T: DeserializeOwned + Send + Sync + 'static,
650        H: crate::type_handler::IntoTypedHandler<T, Args>,
651        Args: Send + Sync + 'static,
652    {
653        // Create the wrapper closure that handles deserialization and context extraction
654        let handler = Arc::new(handler);
655        let wrapper = move |msg: crate::CanonicalMessage| {
656            let handler = handler.clone();
657            async move {
658                let data = msg.parse::<T>().map_err(|e| {
659                    HandlerError::NonRetryable(anyhow::anyhow!("Deserialization failed: {}", e))
660                })?;
661                let ctx = crate::MessageContext::from(msg);
662                handler.call(data, ctx).await
663            }
664        };
665        let wrapper = Arc::new(wrapper);
666
667        let prev_handler = self.output.handler.take();
668
669        let new_handler = if let Some(h) = prev_handler {
670            if let Some(extended) = h.register_handler(type_name, wrapper.clone()) {
671                extended
672            } else {
673                Arc::new(
674                    crate::type_handler::TypeHandler::new()
675                        .with_fallback(h)
676                        .add_handler(type_name, wrapper),
677                )
678            }
679        } else {
680            Arc::new(crate::type_handler::TypeHandler::new().add_handler(type_name, wrapper))
681        };
682
683        self.output.handler = Some(new_handler);
684        self
685    }
686    pub fn add_handlers<T, H, Args>(mut self, handlers: HashMap<&str, H>) -> Self
687    where
688        T: DeserializeOwned + Send + Sync + 'static,
689        H: crate::type_handler::IntoTypedHandler<T, Args>,
690        Args: Send + Sync + 'static,
691    {
692        for (type_name, handler) in handlers {
693            self = self.add_handler(type_name, handler);
694        }
695        self
696    }
697}
698
699type SequencerItem = (
700    Vec<MessageDisposition>,
701    BatchCommitFunc,
702    tokio::sync::oneshot::Sender<anyhow::Result<()>>,
703);
704
705fn spawn_sequencer(buffer_size: usize) -> (Sender<(u64, SequencerItem)>, JoinHandle<()>) {
706    let (seq_tx, seq_rx) = bounded::<(u64, SequencerItem)>(buffer_size);
707
708    let sequencer_handle = tokio::spawn(async move {
709        let mut buffer: BTreeMap<u64, SequencerItem> = BTreeMap::new();
710        let mut next_seq = 0u64;
711        let mut deadline: Option<tokio::time::Instant> = None;
712        const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
713
714        loop {
715            while let Some((dispositions, commit_func, notify)) = buffer.remove(&next_seq) {
716                let res = commit_func(dispositions).await;
717                let _ = notify.send(res);
718                next_seq += 1;
719            }
720
721            if !buffer.is_empty() {
722                if deadline.is_none() {
723                    deadline = Some(tokio::time::Instant::now() + TIMEOUT);
724                }
725            } else {
726                deadline = None;
727            }
728
729            let timeout_fut = async {
730                if let Some(d) = deadline {
731                    tokio::time::sleep_until(d).await
732                } else {
733                    std::future::pending().await
734                }
735            };
736
737            select! {
738                res = seq_rx.recv() => {
739                    match res {
740                        Ok((seq, item)) => {
741                            if seq < next_seq {
742                                let (_, _, notify) = item;
743                                let _ = notify.send(Err(anyhow::anyhow!("Sequencer received late item (seq {} < next_seq {})", seq, next_seq)));
744                            } else {
745                                buffer.insert(seq, item);
746                            }
747                        }
748                        Err(_) => {
749                            for (_, (_, _, notify)) in std::mem::take(&mut buffer) {
750                                let _ = notify.send(Err(anyhow::anyhow!("Sequencer shutting down")));
751                            }
752                            break;
753                        }
754                    }
755                }
756                _ = timeout_fut => {
757                    if let Some(&first_seq) = buffer.keys().next() {
758                        if first_seq > next_seq {
759                            warn!("Sequencer timed out waiting for seq {}. Jumping to {}.", next_seq, first_seq);
760                            next_seq = first_seq;
761                        } else {
762                            next_seq += 1;
763                        }
764                    } else {
765                        next_seq += 1;
766                    }
767                    deadline = None;
768                }
769            }
770        }
771    });
772    (seq_tx, sequencer_handle)
773}
774
775fn wrap_commit(
776    commit: BatchCommitFunc,
777    seq: u64,
778    seq_tx: Sender<(u64, SequencerItem)>,
779) -> BatchCommitFunc {
780    Box::new(move |dispositions| {
781        Box::pin(async move {
782            let (notify_tx, notify_rx) = tokio::sync::oneshot::channel();
783            // Send to sequencer
784            if seq_tx
785                .send((seq, (dispositions, commit, notify_tx)))
786                .await
787                .is_ok()
788            {
789                // Wait for sequencer to execute the commit
790                match notify_rx.await {
791                    Ok(res) => res,
792                    Err(_) => Err(anyhow::anyhow!(
793                        "Sequencer dropped the commit channel unexpectedly"
794                    )),
795                }
796            } else {
797                Err(anyhow::anyhow!(
798                    "Failed to send commit to sequencer, route is likely shutting down"
799                ))
800            }
801        })
802    })
803}
804
805fn map_responses_to_dispositions(
806    total_count: usize,
807    responses: Option<Vec<crate::CanonicalMessage>>,
808    failed: &[(crate::CanonicalMessage, PublisherError)],
809) -> Vec<MessageDisposition> {
810    if failed.is_empty() {
811        if let Some(resps) = responses {
812            if resps.len() == total_count {
813                return resps.into_iter().map(MessageDisposition::Reply).collect();
814            }
815        } else {
816            // If there are no failures and no responses, everything is Ack.
817            return vec![MessageDisposition::Ack; total_count];
818        }
819    }
820
821    // If we have failures, we should Nack them.
822    // However, we don't have easy access to the original indices here to map 1:1 perfectly
823    // if we don't assume order.
824    // But `send_batch` usually processes in order.
825    // If `responses` is Some, it contains responses for successful messages in order.
826
827    // Simplified logic assuming order preservation for successful messages:
828    // We construct a vector of dispositions.
829    // Since we can't easily match by ID without iterating everything, and `failed` might be sparse,
830    // we'll use a heuristic:
831    // If we have explicit responses, we use them.
832    // If we have failures, we might not be able to map them back to the exact index in the batch
833    // without O(N^2) or a map, because `failed` is a subset.
834    //
835    // For F10 implementation, we will assume that if *any* message failed in the batch,
836    // and we are in a Partial state, we might want to Nack the ones that failed.
837    // But since we can't easily map back to the index in `received_batch.messages` (which we don't have here in this helper),
838    // and `commit` expects a vector of size `total_count` corresponding to the input batch...
839
840    // Current best effort: Return Ack for everything to avoid hanging, but log that we can't map precisely yet.
841    // In a real implementation of F10, `send_batch` should probably return `Vec<Result<Sent, PublisherError>>` to map 1:1.
842    vec![MessageDisposition::Ack; total_count]
843}
844
845pub fn get_route(name: &str) -> Option<Route> {
846    Route::get(name)
847}
848
849pub fn list_routes() -> Vec<String> {
850    Route::list()
851}
852
853pub async fn stop_route(name: &str) -> bool {
854    Route::stop(name).await
855}
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860    use crate::models::{Endpoint, Middleware};
861    use crate::traits::{CustomMiddlewareFactory, MessageConsumer, ReceivedBatch};
862    use std::any::Any;
863    use std::sync::atomic::{AtomicBool, Ordering};
864    use std::sync::Arc;
865
866    #[derive(Debug)]
867    struct PanicMiddlewareFactory {
868        should_panic: Arc<AtomicBool>,
869    }
870
871    #[async_trait::async_trait]
872    impl CustomMiddlewareFactory for PanicMiddlewareFactory {
873        async fn apply_consumer(
874            &self,
875            consumer: Box<dyn MessageConsumer>,
876            _route_name: &str,
877            _config: &serde_json::Value,
878        ) -> anyhow::Result<Box<dyn MessageConsumer>> {
879            Ok(Box::new(PanicConsumer {
880                inner: consumer,
881                should_panic: self.should_panic.clone(),
882            }))
883        }
884    }
885
886    struct PanicConsumer {
887        inner: Box<dyn MessageConsumer>,
888        should_panic: Arc<AtomicBool>,
889    }
890
891    #[async_trait::async_trait]
892    impl MessageConsumer for PanicConsumer {
893        async fn receive_batch(
894            &mut self,
895            max_messages: usize,
896        ) -> Result<ReceivedBatch, ConsumerError> {
897            // Panic on the first call to verify route recovery
898            if self
899                .should_panic
900                .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
901                .is_ok()
902            {
903                panic!("Simulated panic for testing recovery");
904            }
905            self.inner.receive_batch(max_messages).await
906        }
907
908        fn as_any(&self) -> &dyn Any {
909            self
910        }
911    }
912
913    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
914    #[ignore = "Takes too much time for regular tests"]
915    async fn test_route_recovery_from_panic() {
916        // Use unique topic names to avoid interference from other tests sharing the static memory channels
917        let unique_suffix = fast_uuid_v7::gen_id().to_string();
918        let in_topic = format!("panic_in_{}", unique_suffix);
919        let out_topic = format!("panic_out_{}", unique_suffix);
920
921        let should_panic = Arc::new(AtomicBool::new(true));
922        let factory = PanicMiddlewareFactory {
923            should_panic: should_panic.clone(),
924        };
925        register_middleware_factory("panic_factory", Arc::new(factory));
926
927        let input = Endpoint::new_memory(&in_topic, 10).add_middleware(Middleware::Custom {
928            name: "panic_factory".to_string(),
929            config: serde_json::Value::Null,
930        });
931        let output = Endpoint::new_memory(&out_topic, 10);
932
933        let route = Route::new(input.clone(), output.clone());
934
935        // Start the route
936        route
937            .deploy("panic_test")
938            .await
939            .expect("Failed to deploy route");
940        // 1. Send a message. The consumer will panic before picking it up.
941        let input_ch = input.channel().unwrap();
942        input_ch
943            .send_message("persistent_msg".into())
944            .await
945            .unwrap();
946
947        // 2. Wait for the panic to occur and the route to enter sleep.
948        // We loop briefly to allow the spawned task to execute and panic.
949        let panic_wait_start = std::time::Instant::now();
950        while panic_wait_start.elapsed() < std::time::Duration::from_secs(5) {
951            if !should_panic.load(Ordering::SeqCst) {
952                break;
953            }
954            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
955        }
956        assert!(
957            !should_panic.load(Ordering::SeqCst),
958            "Route should have panicked"
959        );
960
961        // 3. Wait for recovery (5s backoff + restart time).
962        // We sleep the minimum backoff, then poll with a generous timeout to handle loaded CI environments.
963        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
964
965        // 4. Verify the message is processed after recovery.
966        let mut verifier = route.connect_to_output("verifier").await.unwrap();
967        let received = tokio::time::timeout(std::time::Duration::from_secs(10), verifier.receive())
968            .await
969            .expect("Timed out waiting for message after recovery")
970            .expect("Stream closed");
971
972        assert_eq!(received.message.get_payload_str(), "persistent_msg");
973        // not necessary here, but it's a good idea to commit
974        (received.commit)(MessageDisposition::Ack).await.unwrap();
975
976        // Cleanup
977        Route::stop("panic_test").await;
978    }
979}