Skip to main content

foxtive_worker/
pool.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::ReceivedMessage;
11use crate::metrics::WorkerMetrics;
12use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
13use crate::strategies::{
14    LeastLoadedBalancer, LoadBalancingStrategy, RandomBalancer, RoundRobinBalancer,
15};
16use crate::worker::Worker;
17
18/// A pool of workers with load balancing for message distribution.
19///
20/// The pool manages multiple worker instances and distributes incoming messages
21/// based on the configured load balancing strategy.
22///
23/// # Example
24/// ```rust,no_run
25/// use foxtive_worker::pool::WorkerPool;
26/// use foxtive_worker::strategies::LoadBalancingStrategy;
27/// use foxtive_worker::metrics::NoOpMetrics;
28/// use std::sync::Arc;
29///
30/// let pool = WorkerPool::new("my-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
31/// // Add workers...
32/// // Dispatch messages...
33/// ```
34pub struct WorkerPool {
35    name: String,
36    workers: Vec<Arc<dyn Worker>>,
37    strategy: LoadBalancingStrategy,
38    semaphore: Arc<Semaphore>,
39    concurrency_limit: usize,
40    least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
41    round_robin_balancer: Arc<RoundRobinBalancer>,
42    random_balancer: RandomBalancer,
43    /// Middleware list (Arc-wrapped for cloning)
44    middlewares: Vec<Arc<dyn Middleware>>,
45    /// Metrics collector for this pool
46    metrics_collector: Arc<dyn WorkerMetrics>,
47    /// Indicates if the worker pool is currently running.
48    is_running: Arc<AtomicBool>,
49    /// Cancellation token for graceful shutdown of all spawned tasks
50    cancellation_token: CancellationToken,
51    /// Notify for task completion signaling during shutdown
52    task_completion_notify: Arc<Notify>,
53    /// Track number of in-flight tasks for monitoring
54    in_flight_tasks: Arc<AtomicUsize>,
55}
56
57impl std::fmt::Debug for WorkerPool {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("WorkerPool")
60            .field("name", &self.name)
61            .field("worker_count", &self.workers.len())
62            .field("strategy", &self.strategy)
63            .field("is_running", &self.is_running.load(Ordering::SeqCst))
64            .finish()
65    }
66}
67
68impl WorkerPool {
69    /// Create a new worker pool with the given name, load balancing strategy, and metrics collector.
70    pub fn new(
71        name: impl Into<String>,
72        strategy: LoadBalancingStrategy,
73        metrics_collector: Arc<dyn WorkerMetrics>,
74    ) -> Self {
75        Self::with_concurrency(name, strategy, 1000, metrics_collector)
76    }
77
78    /// Create a new worker pool with custom concurrency limit.
79    ///
80    /// # Arguments
81    /// * `name` - Pool name
82    /// * `strategy` - Load balancing strategy
83    /// * `concurrency_limit` - Maximum concurrent messages across all workers
84    /// * `metrics_collector` - Metrics collector implementation
85    pub fn with_concurrency(
86        name: impl Into<String>,
87        strategy: LoadBalancingStrategy,
88        concurrency_limit: usize,
89        metrics_collector: Arc<dyn WorkerMetrics>,
90    ) -> Self {
91        let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
92            Some(Arc::new(LeastLoadedBalancer::new(0)))
93        } else {
94            None
95        };
96
97        Self {
98            name: name.into(),
99            workers: Vec::new(),
100            strategy,
101            semaphore: Arc::new(Semaphore::new(concurrency_limit)),
102            concurrency_limit,
103            least_loaded_balancer,
104            round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
105            random_balancer: RandomBalancer,
106            middlewares: Vec::new(),
107            metrics_collector,
108            is_running: Arc::new(AtomicBool::new(true)),
109            cancellation_token: CancellationToken::new(),
110            task_completion_notify: Arc::new(Notify::new()),
111            in_flight_tasks: Arc::new(AtomicUsize::new(0)),
112        }
113    }
114
115    /// Add a worker to the pool.
116    pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
117        self.workers.push(worker);
118
119        // Update least-loaded balancer if needed - now O(1) with atomic swap
120        if let Some(ref balancer) = self.least_loaded_balancer {
121            balancer.add_worker();
122        }
123        self.metrics_collector
124            .record_active_workers(self.workers.len());
125    }
126
127    /// Add multiple workers to the pool.
128    pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
129        for worker in workers {
130            self.add_worker(worker);
131        }
132    }
133
134    /// Get the number of workers in the pool.
135    pub fn worker_count(&self) -> usize {
136        self.workers.len()
137    }
138
139    /// Set middleware for the pool.
140    ///
141    /// This allows you to inject middleware processing into the message flow.
142    /// Middleware will be executed in order before messages reaches workers.
143    pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
144        self.middlewares = middlewares;
145        self
146    }
147
148    /// Get the pool name.
149    pub fn name(&self) -> &str {
150        &self.name
151    }
152
153    /// Dispatch a message to a worker based on the load balancing strategy.
154    ///
155    /// Spawns an async task to process the message, respecting concurrency limits.
156    /// If middleware is configured, the message flows through the chain before reaching the worker.
157    pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
158        if !self.is_running.load(Ordering::SeqCst) {
159            return Err(WorkerError::Shutdown);
160        }
161        if self.workers.is_empty() {
162            return Err(WorkerError::PoolExhausted);
163        }
164
165        // Pick a worker and track metrics
166        let worker_index = self.select_worker();
167        let worker = self.workers[worker_index].clone();
168        let worker_id = worker.id().to_string();
169        let queue_name = message.message.metadata.source.clone();
170
171        self.metrics_collector
172            .record_message_received(&worker_id, &queue_name);
173        let start_time = Instant::now();
174
175        if let Some(ref balancer) = self.least_loaded_balancer {
176            balancer.increment_load(worker_index);
177        }
178
179        let permit = self
180            .semaphore
181            .clone()
182            .acquire_owned()
183            .await
184            .map_err(|_| WorkerError::Shutdown)?;
185
186        self.metrics_collector
187            .record_in_flight_messages(self.semaphore.available_permits());
188
189        // Build the handler chain - wrap worker with middleware if configured
190        let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
191            let worker_handler = WorkerHandler(worker);
192            let boxed_middlewares: Vec<Box<dyn Middleware>> = self
193                .middlewares
194                .iter()
195                .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
196                .collect();
197
198            let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
199            Arc::new(ArcHandlerWrapper(chain.build()))
200        } else {
201            Arc::new(WorkerHandler(worker))
202        };
203
204        let metrics_collector_clone = self.metrics_collector.clone();
205        let least_loaded_balancer = self.least_loaded_balancer.clone();
206        let cancellation_token = self.cancellation_token.child_token();
207        let task_completion_notify = self.task_completion_notify.clone();
208        let in_flight_tasks = self.in_flight_tasks.clone();
209
210        // Track in-flight task count
211        in_flight_tasks.fetch_add(1, Ordering::SeqCst);
212
213        // Extract ack_handle and metadata before moving message into task
214        let ack_handle = message.ack_handle.clone();
215        let message_id = message.message.id.clone();
216        let attempt = message.message.metadata.attempt;
217
218        tokio::spawn(async move {
219            // Use tokio::select! to handle cancellation
220            let result = tokio::select! {
221                result = handler.handle(message) => result,  // No clone - move message into task
222                _ = cancellation_token.cancelled() => {
223                    tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
224                    // Decrement in-flight counter on cancellation
225                    in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
226                    task_completion_notify.notify_one();
227                    return;
228                }
229            };
230
231            match result {
232                Ok(middleware_result) => {
233                    match middleware_result {
234                        crate::middleware::MiddlewareResult::Acknowledged => {
235                            // Middleware (e.g., AckNackMiddleware) already handled acknowledgment
236                            tracing::debug!(
237                                "Message {} already acknowledged by middleware",
238                                message_id
239                            );
240                            metrics_collector_clone.record_message_processed(
241                                &worker_id,
242                                &queue_name,
243                                start_time,
244                            );
245                            // Decrement in-flight counter
246                            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
247                            task_completion_notify.notify_one();
248                            return;
249                        }
250                        crate::middleware::MiddlewareResult::Continue => {
251                            // No middleware handled ack - pool should ack on success
252                            tracing::debug!("Message {} processed successfully", message_id);
253                            metrics_collector_clone.record_message_processed(
254                                &worker_id,
255                                &queue_name,
256                                start_time,
257                            );
258                            if let Err(e) = retry_ack(&ack_handle, &message_id).await {
259                                tracing::error!(
260                                    "Failed to ack message {} after retries: {}. Message may be redelivered.",
261                                    message_id,
262                                    e
263                                );
264                            }
265                        }
266                    }
267                }
268                Err(WorkerError::RetryableFailure { source, delay_ms }) => {
269                    tracing::warn!(
270                        "Message {} failed (will retry in {:?}): {}",
271                        message_id,
272                        delay_ms,
273                        source
274                    );
275                    metrics_collector_clone.record_message_retried(
276                        &worker_id,
277                        &queue_name,
278                        attempt,
279                    );
280                    metrics_collector_clone.record_message_failed(
281                        &worker_id,
282                        &queue_name,
283                        "RetryableFailure",
284                        start_time,
285                    );
286                    if let Err(e) = ack_handle.nack(true).await {
287                        tracing::error!("Failed to requeue message {}: {}", message_id, e);
288                    }
289                    sleep(delay_ms).await;
290                }
291                Err(WorkerError::RetriesExhausted { source }) => {
292                    tracing::error!(
293                        "Message {} exhausted all retries, sending to DLQ: {}",
294                        message_id,
295                        source
296                    );
297                    metrics_collector_clone
298                        .record_message_retries_exhausted(&worker_id, &queue_name);
299                    metrics_collector_clone.record_message_failed(
300                        &worker_id,
301                        &queue_name,
302                        "RetriesExhausted",
303                        start_time,
304                    );
305                    if let Err(e) = ack_handle.nack(false).await {
306                        tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
307                    }
308                }
309                Err(e) => {
310                    let error_type = format!("{:?}", e);
311                    tracing::error!("Message {} failed: {}", message_id, e);
312                    metrics_collector_clone.record_message_failed(
313                        &worker_id,
314                        &queue_name,
315                        &error_type,
316                        start_time,
317                    );
318                    if let Err(nack_err) = ack_handle.nack(false).await {
319                        tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
320                    }
321                }
322            }
323
324            // Decrement load for least-loaded balancer after processing completes
325            if let Some(ref balancer) = least_loaded_balancer {
326                balancer.decrement_load(worker_index);
327            }
328
329            drop(permit);
330
331            // Signal task completion and decrement counter
332            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
333            task_completion_notify.notify_one();
334        });
335
336        Ok(())
337    }
338
339    /// Select a worker based on the configured load balancing strategy.
340    fn select_worker(&self) -> usize {
341        match self.strategy {
342            LoadBalancingStrategy::RoundRobin => self.round_robin_balancer.next(self.workers.len()),
343            LoadBalancingStrategy::Random => self.random_balancer.next(self.workers.len()),
344            LoadBalancingStrategy::LeastLoaded => {
345                if let Some(ref balancer) = self.least_loaded_balancer {
346                    balancer.next()
347                } else {
348                    0 // Fallback
349                }
350            }
351        }
352    }
353
354    /// Shutdown the pool gracefully.
355    ///
356    /// This prevents new messages from being dispatched, cancels in-flight tasks,
357    /// and waits for completion with a timeout using efficient notification.
358    pub async fn shutdown(&self) -> WorkerResult<()> {
359        tracing::info!("Shutting down worker pool: {}", self.name);
360
361        self.is_running.store(false, Ordering::SeqCst);
362        self.metrics_collector.record_active_workers(0);
363
364        // Cancel all in-flight tasks
365        self.cancellation_token.cancel();
366        tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
367
368        // Close the semaphore to prevent new acquisitions
369        self.semaphore.close();
370
371        // Wait for permits to be returned with efficient notification
372        let shutdown_timeout = Duration::from_secs(30); // 30 second timeout
373        let start = Instant::now();
374
375        loop {
376            let available = self.semaphore.available_permits();
377            let in_flight = self.concurrency_limit.saturating_sub(available);
378
379            if in_flight == 0 {
380                break; // All tasks completed
381            }
382
383            if start.elapsed() >= shutdown_timeout {
384                tracing::warn!(
385                    "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
386                    self.name,
387                    in_flight
388                );
389                break;
390            }
391
392            // Wait efficiently for task completion notification instead of busy-wait
393            tokio::select! {
394                _ = self.task_completion_notify.notified() => {
395                    // A task completed, check again
396                    continue;
397                }
398                _ = tokio::time::sleep(Duration::from_millis(100)) => {
399                    // Periodic check in case notifications are missed
400                    continue;
401                }
402            }
403        }
404
405        self.metrics_collector.record_in_flight_messages(0);
406        tracing::info!("Worker pool {} shutdown complete", self.name);
407        Ok(())
408    }
409
410    /// Get the current number of in-flight tasks.
411    pub fn in_flight_count(&self) -> usize {
412        self.in_flight_tasks.load(Ordering::SeqCst)
413    }
414}
415
416impl HealthCheck for WorkerPool {
417    fn check_health(&self) -> HealthStatus {
418        let is_running = self.is_running.load(Ordering::SeqCst);
419        let worker_count = self.worker_count();
420        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
421
422        // Check if pool is running
423        if !is_running {
424            return HealthStatus::Unhealthy {
425                reason: "Pool is not running".to_string(),
426            };
427        }
428
429        // Check if workers are available
430        if worker_count == 0 {
431            return HealthStatus::Degraded {
432                reason: "No workers available".to_string(),
433            };
434        }
435
436        // Check if pool is saturated (90%+ capacity)
437        let saturation = in_flight as f64 / self.concurrency_limit as f64;
438        if saturation > 0.9 {
439            return HealthStatus::Degraded {
440                reason: format!(
441                    "Pool near capacity: {} in-flight messages ({:.0}% saturation)",
442                    in_flight,
443                    saturation * 100.0
444                ),
445            };
446        }
447
448        HealthStatus::Healthy
449    }
450
451    fn status_message(&self) -> String {
452        let worker_count = self.worker_count();
453        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
454        let available_permits = self.semaphore.available_permits();
455
456        match self.check_health() {
457            HealthStatus::Healthy => {
458                format!(
459                    "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
460                    self.name, worker_count, in_flight, available_permits
461                )
462            }
463            HealthStatus::Degraded { ref reason } => {
464                format!(
465                    "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
466                    self.name, reason, worker_count, in_flight
467                )
468            }
469            HealthStatus::Unhealthy { ref reason } => {
470                format!(
471                    "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
472                    self.name, reason, worker_count, in_flight
473                )
474            }
475        }
476    }
477}
478
479/// Wrapper that converts a Worker into a MessageHandler.
480struct WorkerHandler(Arc<dyn Worker>);
481
482#[async_trait::async_trait]
483impl MessageHandler for WorkerHandler {
484    async fn handle(
485        &self,
486        message: ReceivedMessage<serde_json::Value>,
487    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
488        // Workers always return Continue - they don't handle acknowledgment directly
489        self.0.process(message).await?;
490        Ok(crate::middleware::MiddlewareResult::Continue)
491    }
492}
493
494/// Wrapper to convert Arc<dyn Middleware> to Box<dyn Middleware>
495struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
496
497#[async_trait::async_trait]
498impl Middleware for ArcMiddlewareWrapper {
499    fn name(&self) -> &str {
500        self.0.name()
501    }
502
503    async fn handle(
504        &self,
505        message: ReceivedMessage<serde_json::Value>,
506        next: Box<dyn MessageHandler>,
507    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
508        self.0.handle(message, next).await
509    }
510}
511
512/// Wrapper to convert Box<dyn MessageHandler> to Arc
513struct ArcHandlerWrapper(Box<dyn MessageHandler>);
514
515#[async_trait::async_trait]
516impl MessageHandler for ArcHandlerWrapper {
517    async fn handle(
518        &self,
519        message: ReceivedMessage<serde_json::Value>,
520    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
521        self.0.handle(message).await
522    }
523}
524
525/// Retry ack with exponential backoff on failure.
526///
527/// This helps handle transient network issues or broker unavailability.
528async fn retry_ack(
529    ack_handle: &Arc<dyn crate::message::AckHandle>,
530    message_id: &str,
531) -> WorkerResult<()> {
532    let max_retries = 3;
533    let base_delay_ms = 100;
534
535    for attempt in 0..max_retries {
536        match ack_handle.ack().await {
537            Ok(_) => return Ok(()),
538            Err(e) => {
539                if attempt < max_retries - 1 {
540                    let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
541                    tracing::warn!(
542                        "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
543                        attempt + 1,
544                        message_id,
545                        e,
546                        delay
547                    );
548                    sleep(delay).await;
549                } else {
550                    return Err(e);
551                }
552            }
553        }
554    }
555
556    unreachable!()
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use crate::message::{AckHandle, Message, MessageMetadata};
563    use crate::metrics::NoOpMetrics;
564    use async_trait::async_trait;
565    use std::sync::atomic::{AtomicUsize, Ordering};
566    use std::time::Duration; // Use NoOpMetrics for tests
567
568    #[derive(Debug)]
569    struct MockAckHandle;
570
571    #[async_trait]
572    impl AckHandle for MockAckHandle {
573        async fn ack(&self) -> WorkerResult<()> {
574            Ok(())
575        }
576
577        async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
578            Ok(())
579        }
580    }
581
582    struct TestWorker {
583        id: String,
584        process_count: Arc<AtomicUsize>,
585    }
586
587    impl TestWorker {
588        fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
589            let count = Arc::new(AtomicUsize::new(0));
590            (
591                Self {
592                    id: id.to_string(),
593                    process_count: count.clone(),
594                },
595                count,
596            )
597        }
598    }
599
600    #[async_trait]
601    impl Worker for TestWorker {
602        fn id(&self) -> &str {
603            &self.id
604        }
605
606        async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
607            self.process_count.fetch_add(1, Ordering::SeqCst);
608            Ok(())
609        }
610    }
611
612    fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
613        let message = Message {
614            id: id.to_string(),
615            payload: serde_json::json!({"test": "data"}),
616            metadata: MessageMetadata::new("test-queue"),
617        };
618        ReceivedMessage::new(message, Arc::new(MockAckHandle))
619    }
620
621    #[tokio::test]
622    async fn test_pool_creation() {
623        let pool = WorkerPool::new(
624            "test-pool",
625            LoadBalancingStrategy::RoundRobin,
626            Arc::new(NoOpMetrics),
627        );
628        assert_eq!(pool.name(), "test-pool");
629        assert_eq!(pool.worker_count(), 0);
630        assert!(pool.is_running.load(Ordering::SeqCst));
631    }
632
633    #[tokio::test]
634    async fn test_add_worker() {
635        let mut pool = WorkerPool::new(
636            "test-pool",
637            LoadBalancingStrategy::RoundRobin,
638            Arc::new(NoOpMetrics),
639        );
640        let (worker, _) = TestWorker::new("worker-1");
641        pool.add_worker(Arc::new(worker));
642
643        assert_eq!(pool.worker_count(), 1);
644    }
645
646    #[tokio::test]
647    async fn test_dispatch_empty_pool() {
648        let pool = WorkerPool::new(
649            "test-pool",
650            LoadBalancingStrategy::RoundRobin,
651            Arc::new(NoOpMetrics),
652        );
653        let message = create_test_message("msg-1");
654
655        let result = pool.dispatch(message).await;
656        assert!(matches!(result, Err(WorkerError::PoolExhausted)));
657    }
658
659    #[tokio::test]
660    async fn test_round_robin_distribution() {
661        let mut pool = WorkerPool::new(
662            "test-pool",
663            LoadBalancingStrategy::RoundRobin,
664            Arc::new(NoOpMetrics),
665        );
666
667        let (worker1, count1) = TestWorker::new("worker-1");
668        let (worker2, count2) = TestWorker::new("worker-2");
669
670        pool.add_worker(Arc::new(worker1));
671        pool.add_worker(Arc::new(worker2));
672
673        // Dispatch 4 messages
674        for i in 0..4 {
675            let message = create_test_message(&format!("msg-{}", i));
676            pool.dispatch(message).await.unwrap();
677        }
678
679        // Give tasks time to complete
680        tokio::time::sleep(Duration::from_millis(100)).await;
681
682        // Each worker should have processed 2 messages
683        assert_eq!(count1.load(Ordering::SeqCst), 2);
684        assert_eq!(count2.load(Ordering::SeqCst), 2);
685    }
686
687    #[tokio::test]
688    async fn test_pool_health() {
689        let pool = WorkerPool::new(
690            "test-pool",
691            LoadBalancingStrategy::RoundRobin,
692            Arc::new(NoOpMetrics),
693        );
694        assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); // Degraded because 0 workers
695
696        let mut pool = pool;
697        let (worker, _) = TestWorker::new("worker-1");
698        pool.add_worker(Arc::new(worker));
699
700        assert!(matches!(pool.check_health(), HealthStatus::Healthy));
701    }
702
703    #[tokio::test]
704    async fn test_concurrency_limit_enforcement() {
705        use std::sync::atomic::{AtomicUsize, Ordering};
706
707        // Create a worker that tracks concurrent executions
708        let concurrent_count = Arc::new(AtomicUsize::new(0));
709        let max_concurrent = Arc::new(AtomicUsize::new(0));
710
711        struct ConcurrentTestWorker {
712            id: String,
713            concurrent: Arc<AtomicUsize>,
714            max_concurrent: Arc<AtomicUsize>,
715        }
716
717        #[async_trait::async_trait]
718        impl Worker for ConcurrentTestWorker {
719            fn id(&self) -> &str {
720                &self.id
721            }
722
723            async fn process(
724                &self,
725                _message: ReceivedMessage<serde_json::Value>,
726            ) -> WorkerResult<()> {
727                // Increment concurrent count
728                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
729
730                // Track maximum
731                let mut max = self.max_concurrent.load(Ordering::SeqCst);
732                while current > max {
733                    match self.max_concurrent.compare_exchange_weak(
734                        max,
735                        current,
736                        Ordering::SeqCst,
737                        Ordering::SeqCst,
738                    ) {
739                        Ok(_) => break,
740                        Err(new_max) => max = new_max,
741                    }
742                }
743
744                // Simulate processing time
745                tokio::time::sleep(Duration::from_millis(50)).await;
746
747                // Decrement concurrent count
748                self.concurrent.fetch_sub(1, Ordering::SeqCst);
749                Ok(())
750            }
751        }
752
753        // Create pool with concurrency limit of 3
754        let mut pool = WorkerPool::with_concurrency(
755            "test-pool",
756            LoadBalancingStrategy::RoundRobin,
757            3, // Limit to 3 concurrent
758            Arc::new(NoOpMetrics),
759        );
760
761        // Add 1 worker (will handle all messages)
762        let worker = ConcurrentTestWorker {
763            id: "worker-1".to_string(),
764            concurrent: concurrent_count.clone(),
765            max_concurrent: max_concurrent.clone(),
766        };
767        pool.add_worker(Arc::new(worker));
768
769        // Dispatch 10 messages rapidly
770        for i in 0..10 {
771            let message = create_test_message(&format!("msg-{}", i));
772            pool.dispatch(message).await.unwrap();
773        }
774
775        // Wait for all to complete
776        tokio::time::sleep(Duration::from_millis(500)).await;
777
778        // Verify that concurrency never exceeded the limit
779        let actual_max = max_concurrent.load(Ordering::SeqCst);
780        assert!(
781            actual_max <= 3,
782            "Expected max concurrency <= 3, but got {}",
783            actual_max
784        );
785        assert!(
786            actual_max >= 2,
787            "Expected some concurrency (>= 2), but got {}",
788            actual_max
789        );
790    }
791
792    #[tokio::test]
793    async fn test_concurrency_limit_with_builder() {
794        use crate::builder::WorkerPoolBuilder;
795
796        let concurrent_count = Arc::new(AtomicUsize::new(0));
797        let max_concurrent = Arc::new(AtomicUsize::new(0));
798
799        struct TrackedWorker {
800            id: String,
801            concurrent: Arc<AtomicUsize>,
802            max_concurrent: Arc<AtomicUsize>,
803        }
804
805        #[async_trait::async_trait]
806        impl Worker for TrackedWorker {
807            fn id(&self) -> &str {
808                &self.id
809            }
810
811            async fn process(
812                &self,
813                _message: ReceivedMessage<serde_json::Value>,
814            ) -> WorkerResult<()> {
815                // Track concurrency
816                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
817
818                let mut max = self.max_concurrent.load(Ordering::SeqCst);
819                while current > max {
820                    match self.max_concurrent.compare_exchange_weak(
821                        max,
822                        current,
823                        Ordering::SeqCst,
824                        Ordering::SeqCst,
825                    ) {
826                        Ok(_) => break,
827                        Err(new_max) => max = new_max,
828                    }
829                }
830
831                tokio::time::sleep(Duration::from_millis(100)).await;
832
833                self.concurrent.fetch_sub(1, Ordering::SeqCst);
834                Ok(())
835            }
836        }
837
838        // Build pool with concurrency limit of 2
839        let pool = WorkerPoolBuilder::new("test-pool")
840            .with_concurrency_limit(2)
841            .add_worker(TrackedWorker {
842                id: "worker-1".to_string(),
843                concurrent: concurrent_count.clone(),
844                max_concurrent: max_concurrent.clone(),
845            })
846            .build()
847            .unwrap();
848
849        // Dispatch 6 messages rapidly
850        for i in 0..6 {
851            let message = create_test_message(&format!("msg-{}", i));
852            pool.dispatch(message).await.unwrap();
853        }
854
855        // Wait for all to complete
856        tokio::time::sleep(Duration::from_millis(800)).await;
857
858        // Verify that concurrency never exceeded the limit of 2
859        let actual_max = max_concurrent.load(Ordering::SeqCst);
860        assert!(
861            actual_max <= 2,
862            "Expected max concurrency <= 2, but got {}",
863            actual_max
864        );
865        assert!(
866            actual_max >= 1,
867            "Expected some concurrency (>= 1), but got {}",
868            actual_max
869        );
870    }
871
872    #[tokio::test]
873    async fn test_different_concurrency_limits() {
874        // Test that different pools can have different limits
875        let pool1 = WorkerPool::with_concurrency(
876            "pool1",
877            LoadBalancingStrategy::RoundRobin,
878            5,
879            Arc::new(NoOpMetrics),
880        );
881        let pool2 = WorkerPool::with_concurrency(
882            "pool2",
883            LoadBalancingStrategy::RoundRobin,
884            20,
885            Arc::new(NoOpMetrics),
886        );
887
888        // Verify they have different semaphore capacities
889        // We can't directly check semaphore capacity, but we can verify the pools work independently
890        assert_eq!(pool1.name(), "pool1");
891        assert_eq!(pool2.name(), "pool2");
892    }
893
894    #[tokio::test]
895    async fn test_pool_shutdown_prevents_dispatch() {
896        let mut pool = WorkerPool::new(
897            "test-pool",
898            LoadBalancingStrategy::RoundRobin,
899            Arc::new(NoOpMetrics),
900        );
901        let (worker, _) = TestWorker::new("worker-1");
902        pool.add_worker(Arc::new(worker));
903
904        pool.shutdown().await.unwrap();
905
906        let message = create_test_message("msg-after-shutdown");
907        let result = pool.dispatch(message).await;
908        assert!(matches!(result, Err(WorkerError::Shutdown)));
909        assert!(matches!(
910            pool.check_health(),
911            HealthStatus::Unhealthy { .. }
912        ));
913    }
914
915    /// Test that messages are not double-acknowledged when using AckNackMiddleware.
916    /// This verifies the fix for the "PRECONDITION_FAILED - unknown delivery tag" issue.
917    #[tokio::test]
918    async fn test_no_double_ack_with_middleware() {
919        use crate::AckNackMiddleware;
920        use std::sync::atomic::AtomicUsize;
921
922        // Track if ack was called
923        let ack_count = Arc::new(AtomicUsize::new(0));
924        let nack_count = Arc::new(AtomicUsize::new(0));
925
926        #[derive(Debug)]
927        struct TrackingAckHandle {
928            ack_count: Arc<AtomicUsize>,
929            nack_count: Arc<AtomicUsize>,
930        }
931
932        #[async_trait]
933        impl AckHandle for TrackingAckHandle {
934            async fn ack(&self) -> WorkerResult<()> {
935                self.ack_count.fetch_add(1, Ordering::SeqCst);
936                Ok(())
937            }
938
939            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
940                self.nack_count.fetch_add(1, Ordering::SeqCst);
941                Ok(())
942            }
943        }
944
945        // Create a successful worker
946        struct SuccessWorker;
947
948        #[async_trait]
949        impl Worker for SuccessWorker {
950            fn id(&self) -> &str {
951                "success-worker"
952            }
953
954            async fn process(
955                &self,
956                _message: ReceivedMessage<serde_json::Value>,
957            ) -> WorkerResult<()> {
958                Ok(())
959            }
960        }
961
962        let mut pool = WorkerPool::new(
963            "test-pool",
964            LoadBalancingStrategy::RoundRobin,
965            Arc::new(NoOpMetrics),
966        );
967
968        // Add AckNackMiddleware to auto-ack on success
969        pool.middlewares
970            .push(Arc::new(AckNackMiddleware::default()));
971        pool.add_worker(Arc::new(SuccessWorker));
972
973        // Create message with tracking ack handle
974        let message = Message {
975            id: "test-msg".to_string(),
976            payload: serde_json::json!({"test": "data"}),
977            metadata: MessageMetadata::new("test-queue"),
978        };
979        let received = ReceivedMessage::new(
980            message,
981            Arc::new(TrackingAckHandle {
982                ack_count: ack_count.clone(),
983                nack_count: nack_count.clone(),
984            }),
985        );
986
987        // Dispatch the message
988        pool.dispatch(received).await.unwrap();
989
990        // Wait for processing to complete
991        tokio::time::sleep(Duration::from_millis(100)).await;
992
993        // Verify ack was called exactly once (by middleware, not by pool)
994        assert_eq!(
995            ack_count.load(Ordering::SeqCst),
996            1,
997            "Message should have been acked exactly once by middleware"
998        );
999        assert_eq!(
1000            nack_count.load(Ordering::SeqCst),
1001            0,
1002            "Message should not have been nacked"
1003        );
1004    }
1005
1006    /// Test that pool handles ack/nack correctly WITHOUT AckNackMiddleware.
1007    #[tokio::test]
1008    async fn test_pool_handles_ack_without_middleware() {
1009        use std::sync::atomic::AtomicUsize;
1010
1011        let ack_count = Arc::new(AtomicUsize::new(0));
1012        let nack_count = Arc::new(AtomicUsize::new(0));
1013
1014        #[derive(Debug)]
1015        struct CountingAckHandle {
1016            ack_count: Arc<AtomicUsize>,
1017            nack_count: Arc<AtomicUsize>,
1018        }
1019
1020        #[async_trait]
1021        impl AckHandle for CountingAckHandle {
1022            async fn ack(&self) -> WorkerResult<()> {
1023                self.ack_count.fetch_add(1, Ordering::SeqCst);
1024                Ok(())
1025            }
1026
1027            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1028                self.nack_count.fetch_add(1, Ordering::SeqCst);
1029                Ok(())
1030            }
1031        }
1032
1033        // Create a successful worker
1034        struct SuccessWorker;
1035
1036        #[async_trait]
1037        impl Worker for SuccessWorker {
1038            fn id(&self) -> &str {
1039                "success-worker"
1040            }
1041
1042            async fn process(
1043                &self,
1044                _message: ReceivedMessage<serde_json::Value>,
1045            ) -> WorkerResult<()> {
1046                Ok(())
1047            }
1048        }
1049
1050        let mut pool = WorkerPool::new(
1051            "test-pool",
1052            LoadBalancingStrategy::RoundRobin,
1053            Arc::new(NoOpMetrics),
1054        );
1055        // NO middleware - pool should handle ack
1056        pool.add_worker(Arc::new(SuccessWorker));
1057
1058        let message = Message {
1059            id: "test-msg".to_string(),
1060            payload: serde_json::json!({"test": "data"}),
1061            metadata: MessageMetadata::new("test-queue"),
1062        };
1063        let received = ReceivedMessage::new(
1064            message,
1065            Arc::new(CountingAckHandle {
1066                ack_count: ack_count.clone(),
1067                nack_count: nack_count.clone(),
1068            }),
1069        );
1070
1071        pool.dispatch(received).await.unwrap();
1072
1073        // Wait for processing to complete
1074        tokio::time::sleep(Duration::from_millis(100)).await;
1075
1076        // Verify pool acked the message
1077        assert_eq!(
1078            ack_count.load(Ordering::SeqCst),
1079            1,
1080            "Pool should have acked the message when no middleware is used"
1081        );
1082        assert_eq!(
1083            nack_count.load(Ordering::SeqCst),
1084            0,
1085            "Nack should not be called for successful message"
1086        );
1087    }
1088
1089    /// Test that pool handles nack correctly WITHOUT AckNackMiddleware.
1090    #[tokio::test]
1091    async fn test_pool_handles_nack_without_middleware() {
1092        use std::sync::atomic::AtomicUsize;
1093
1094        let ack_count = Arc::new(AtomicUsize::new(0));
1095        let nack_count = Arc::new(AtomicUsize::new(0));
1096
1097        #[derive(Debug)]
1098        struct CountingAckHandle {
1099            ack_count: Arc<AtomicUsize>,
1100            nack_count: Arc<AtomicUsize>,
1101        }
1102
1103        #[async_trait]
1104        impl AckHandle for CountingAckHandle {
1105            async fn ack(&self) -> WorkerResult<()> {
1106                self.ack_count.fetch_add(1, Ordering::SeqCst);
1107                Ok(())
1108            }
1109
1110            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1111                self.nack_count.fetch_add(1, Ordering::SeqCst);
1112                Ok(())
1113            }
1114        }
1115
1116        // Create a failing worker
1117        struct FailingWorker;
1118
1119        #[async_trait]
1120        impl Worker for FailingWorker {
1121            fn id(&self) -> &str {
1122                "failing-worker"
1123            }
1124
1125            async fn process(
1126                &self,
1127                _message: ReceivedMessage<serde_json::Value>,
1128            ) -> WorkerResult<()> {
1129                Err(WorkerError::ProcessingFailed(
1130                    "Simulated failure".to_string(),
1131                ))
1132            }
1133        }
1134
1135        let mut pool = WorkerPool::new(
1136            "test-pool",
1137            LoadBalancingStrategy::RoundRobin,
1138            Arc::new(NoOpMetrics),
1139        );
1140        // NO middleware - pool should handle nack
1141        pool.add_worker(Arc::new(FailingWorker));
1142
1143        let message = Message {
1144            id: "test-msg".to_string(),
1145            payload: serde_json::json!({"test": "data"}),
1146            metadata: MessageMetadata::new("test-queue"),
1147        };
1148        let received = ReceivedMessage::new(
1149            message,
1150            Arc::new(CountingAckHandle {
1151                ack_count: ack_count.clone(),
1152                nack_count: nack_count.clone(),
1153            }),
1154        );
1155
1156        pool.dispatch(received).await.unwrap();
1157
1158        // Wait for processing to complete
1159        tokio::time::sleep(Duration::from_millis(100)).await;
1160
1161        // Verify pool nacked the message
1162        assert_eq!(
1163            nack_count.load(Ordering::SeqCst),
1164            1,
1165            "Pool should have nacked the message when no middleware is used"
1166        );
1167        assert_eq!(
1168            ack_count.load(Ordering::SeqCst),
1169            0,
1170            "Ack should not be called for failed message"
1171        );
1172    }
1173
1174    /// Integration test: Verify end-to-end flow with AckNackMiddleware and failing worker.
1175    #[tokio::test]
1176    async fn test_integration_ack_nack_middleware_with_failure() {
1177        use crate::AckNackMiddleware;
1178        use std::sync::atomic::AtomicUsize;
1179
1180        let ack_count = Arc::new(AtomicUsize::new(0));
1181        let nack_count = Arc::new(AtomicUsize::new(0));
1182
1183        #[derive(Debug)]
1184        struct CountingAckHandle {
1185            ack_count: Arc<AtomicUsize>,
1186            nack_count: Arc<AtomicUsize>,
1187        }
1188
1189        #[async_trait]
1190        impl AckHandle for CountingAckHandle {
1191            async fn ack(&self) -> WorkerResult<()> {
1192                self.ack_count.fetch_add(1, Ordering::SeqCst);
1193                Ok(())
1194            }
1195
1196            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1197                self.nack_count.fetch_add(1, Ordering::SeqCst);
1198                Ok(())
1199            }
1200        }
1201
1202        // Create a failing worker
1203        struct FailingWorker;
1204
1205        #[async_trait]
1206        impl Worker for FailingWorker {
1207            fn id(&self) -> &str {
1208                "failing-worker"
1209            }
1210
1211            async fn process(
1212                &self,
1213                _message: ReceivedMessage<serde_json::Value>,
1214            ) -> WorkerResult<()> {
1215                Err(WorkerError::ProcessingFailed(
1216                    "Simulated failure".to_string(),
1217                ))
1218            }
1219        }
1220
1221        let mut pool = WorkerPool::new(
1222            "test-pool",
1223            LoadBalancingStrategy::RoundRobin,
1224            Arc::new(NoOpMetrics),
1225        );
1226
1227        // Add AckNackMiddleware to auto-nack on failure
1228        pool.middlewares
1229            .push(Arc::new(AckNackMiddleware::default()));
1230        pool.add_worker(Arc::new(FailingWorker));
1231
1232        let message = Message {
1233            id: "test-msg".to_string(),
1234            payload: serde_json::json!({"test": "data"}),
1235            metadata: MessageMetadata::new("test-queue"),
1236        };
1237        let received = ReceivedMessage::new(
1238            message,
1239            Arc::new(CountingAckHandle {
1240                ack_count: ack_count.clone(),
1241                nack_count: nack_count.clone(),
1242            }),
1243        );
1244
1245        pool.dispatch(received).await.unwrap();
1246
1247        // Wait for processing
1248        tokio::time::sleep(Duration::from_millis(100)).await;
1249
1250        // Verify nack was called exactly once (by middleware)
1251        assert_eq!(
1252            nack_count.load(Ordering::SeqCst),
1253            1,
1254            "Nack should be called exactly once by middleware"
1255        );
1256        assert_eq!(
1257            ack_count.load(Ordering::SeqCst),
1258            0,
1259            "Ack should not be called for failed message"
1260        );
1261    }
1262}