Skip to main content

foxtive_worker/
pool.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::ReceivedMessage;
11use crate::metrics::WorkerMetrics;
12use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
13use crate::strategies::{
14    LeastLoadedBalancer, LoadBalancingStrategy, RandomBalancer, RoundRobinBalancer,
15};
16use crate::worker::Worker;
17
18/// A pool of workers with load balancing for message distribution.
19///
20/// The pool manages multiple worker instances and distributes incoming messages
21/// based on the configured load balancing strategy.
22///
23/// # Example
24/// ```rust,no_run
25/// use foxtive_worker::pool::WorkerPool;
26/// use foxtive_worker::strategies::LoadBalancingStrategy;
27/// use foxtive_worker::metrics::NoOpMetrics;
28/// use std::sync::Arc;
29///
30/// let pool = WorkerPool::new("my-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
31/// // Add workers...
32/// // Dispatch messages...
33/// ```
34pub struct WorkerPool {
35    name: String,
36    workers: Vec<Arc<dyn Worker>>,
37    strategy: LoadBalancingStrategy,
38    semaphore: Arc<Semaphore>,
39    concurrency_limit: usize,
40    least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
41    round_robin_balancer: Arc<RoundRobinBalancer>,
42    random_balancer: RandomBalancer,
43    /// Middleware list (Arc-wrapped for cloning)
44    middlewares: Vec<Arc<dyn Middleware>>,
45    /// Metrics collector for this pool
46    metrics_collector: Arc<dyn WorkerMetrics>,
47    /// Indicates if the worker pool is currently running.
48    is_running: Arc<AtomicBool>,
49    /// Cancellation token for graceful shutdown of all spawned tasks
50    cancellation_token: CancellationToken,
51    /// Notify for task completion signaling during shutdown
52    task_completion_notify: Arc<Notify>,
53    /// Track number of in-flight tasks for monitoring
54    in_flight_tasks: Arc<AtomicUsize>,
55}
56
57impl std::fmt::Debug for WorkerPool {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("WorkerPool")
60            .field("name", &self.name)
61            .field("worker_count", &self.workers.len())
62            .field("strategy", &self.strategy)
63            .field("is_running", &self.is_running.load(Ordering::SeqCst))
64            .finish()
65    }
66}
67
68impl WorkerPool {
69    /// Create a new worker pool with the given name, load balancing strategy, and metrics collector.
70    pub fn new(
71        name: impl Into<String>,
72        strategy: LoadBalancingStrategy,
73        metrics_collector: Arc<dyn WorkerMetrics>,
74    ) -> Self {
75        Self::with_concurrency(name, strategy, 1000, metrics_collector)
76    }
77
78    /// Create a new worker pool with custom concurrency limit.
79    ///
80    /// # Arguments
81    /// * `name` - Pool name
82    /// * `strategy` - Load balancing strategy
83    /// * `concurrency_limit` - Maximum concurrent messages across all workers
84    /// * `metrics_collector` - Metrics collector implementation
85    pub fn with_concurrency(
86        name: impl Into<String>,
87        strategy: LoadBalancingStrategy,
88        concurrency_limit: usize,
89        metrics_collector: Arc<dyn WorkerMetrics>,
90    ) -> Self {
91        let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
92            Some(Arc::new(LeastLoadedBalancer::new(0)))
93        } else {
94            None
95        };
96
97        Self {
98            name: name.into(),
99            workers: Vec::new(),
100            strategy,
101            semaphore: Arc::new(Semaphore::new(concurrency_limit)),
102            concurrency_limit,
103            least_loaded_balancer,
104            round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
105            random_balancer: RandomBalancer,
106            middlewares: Vec::new(),
107            metrics_collector,
108            is_running: Arc::new(AtomicBool::new(true)),
109            cancellation_token: CancellationToken::new(),
110            task_completion_notify: Arc::new(Notify::new()),
111            in_flight_tasks: Arc::new(AtomicUsize::new(0)),
112        }
113    }
114
115    /// Add a worker to the pool.
116    pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
117        self.workers.push(worker);
118
119        // Update least-loaded balancer if needed - now O(1) with atomic swap
120        if let Some(ref balancer) = self.least_loaded_balancer {
121            balancer.add_worker();
122        }
123        self.metrics_collector
124            .record_active_workers(self.workers.len());
125    }
126
127    /// Add multiple workers to the pool.
128    pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
129        for worker in workers {
130            self.add_worker(worker);
131        }
132    }
133
134    /// Get the number of workers in the pool.
135    pub fn worker_count(&self) -> usize {
136        self.workers.len()
137    }
138
139    /// Set middleware for the pool.
140    ///
141    /// This allows you to inject middleware processing into the message flow.
142    /// Middleware will be executed in order before messages reaches workers.
143    pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
144        self.middlewares = middlewares;
145        self
146    }
147
148    /// Get the pool name.
149    pub fn name(&self) -> &str {
150        &self.name
151    }
152
153    /// Dispatch a message to a worker based on the load balancing strategy.
154    ///
155    /// Spawns an async task to process the message, respecting concurrency limits.
156    /// If middleware is configured, the message flows through the chain before reaching the worker.
157    pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
158        if !self.is_running.load(Ordering::SeqCst) {
159            return Err(WorkerError::Shutdown);
160        }
161        if self.workers.is_empty() {
162            return Err(WorkerError::PoolExhausted);
163        }
164
165        // Pick a worker and track metrics
166        let worker_index = self.select_worker();
167        let worker = self.workers[worker_index].clone();
168        let worker_id = worker.id().to_string();
169        let queue_name = message.message.metadata.source.clone();
170
171        self.metrics_collector
172            .record_message_received(&worker_id, &queue_name);
173        let start_time = Instant::now();
174
175        if let Some(ref balancer) = self.least_loaded_balancer {
176            balancer.increment_load(worker_index);
177        }
178
179        let permit = self
180            .semaphore
181            .clone()
182            .acquire_owned()
183            .await
184            .map_err(|_| WorkerError::Shutdown)?;
185
186        self.metrics_collector
187            .record_in_flight_messages(self.semaphore.available_permits());
188
189        // Build the handler chain - wrap worker with middleware if configured
190        let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
191            let worker_handler = WorkerHandler(worker);
192            let boxed_middlewares: Vec<Box<dyn Middleware>> = self
193                .middlewares
194                .iter()
195                .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
196                .collect();
197
198            let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
199            Arc::new(ArcHandlerWrapper(chain.build()))
200        } else {
201            Arc::new(WorkerHandler(worker))
202        };
203
204        let metrics_collector_clone = self.metrics_collector.clone();
205        let least_loaded_balancer = self.least_loaded_balancer.clone();
206        let cancellation_token = self.cancellation_token.child_token();
207        let task_completion_notify = self.task_completion_notify.clone();
208        let in_flight_tasks = self.in_flight_tasks.clone();
209
210        // Track in-flight task count
211        in_flight_tasks.fetch_add(1, Ordering::SeqCst);
212
213        // Extract ack_handle and metadata before moving message into task
214        let ack_handle = message.ack_handle.clone();
215        let message_id = message.message.id.clone();
216        let attempt = message.message.metadata.attempt;
217
218        tokio::spawn(async move {
219            // Use tokio::select! to handle cancellation
220            let result = tokio::select! {
221                result = handler.handle(message) => result,  // No clone - move message into task
222                _ = cancellation_token.cancelled() => {
223                    tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
224                    // Decrement in-flight counter on cancellation
225                    in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
226                    task_completion_notify.notify_one();
227                    return;
228                }
229            };
230
231            match result {
232                Ok(middleware_result) => {
233                    match middleware_result {
234                        crate::middleware::MiddlewareResult::Acknowledged => {
235                            // Middleware (e.g., AckNackMiddleware) already handled acknowledgment
236                            tracing::debug!("Message {} already acknowledged by middleware", message_id);
237                            metrics_collector_clone.record_message_processed(
238                                &worker_id,
239                                &queue_name,
240                                start_time,
241                            );
242                            // Decrement in-flight counter
243                            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
244                            task_completion_notify.notify_one();
245                            return;
246                        }
247                        crate::middleware::MiddlewareResult::Continue => {
248                            // No middleware handled ack - pool should ack on success
249                            tracing::debug!("Message {} processed successfully", message_id);
250                            metrics_collector_clone.record_message_processed(
251                                &worker_id,
252                                &queue_name,
253                                start_time,
254                            );
255                            if let Err(e) = retry_ack(&ack_handle, &message_id).await {
256                                tracing::error!(
257                                    "Failed to ack message {} after retries: {}. Message may be redelivered.",
258                                    message_id,
259                                    e
260                                );
261                            }
262                        }
263                    }
264                }
265                Err(WorkerError::RetryableFailure { source, delay_ms }) => {
266                    tracing::warn!(
267                        "Message {} failed (will retry in {:?}): {}",
268                        message_id,
269                        delay_ms,
270                        source
271                    );
272                    metrics_collector_clone.record_message_retried(
273                        &worker_id,
274                        &queue_name,
275                        attempt,
276                    );
277                    metrics_collector_clone.record_message_failed(
278                        &worker_id,
279                        &queue_name,
280                        "RetryableFailure",
281                        start_time,
282                    );
283                    if let Err(e) = ack_handle.nack(true).await {
284                        tracing::error!("Failed to requeue message {}: {}", message_id, e);
285                    }
286                    sleep(delay_ms).await;
287                }
288                Err(WorkerError::RetriesExhausted { source }) => {
289                    tracing::error!(
290                        "Message {} exhausted all retries, sending to DLQ: {}",
291                        message_id,
292                        source
293                    );
294                    metrics_collector_clone
295                        .record_message_retries_exhausted(&worker_id, &queue_name);
296                    metrics_collector_clone.record_message_failed(
297                        &worker_id,
298                        &queue_name,
299                        "RetriesExhausted",
300                        start_time,
301                    );
302                    if let Err(e) = ack_handle.nack(false).await {
303                        tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
304                    }
305                }
306                Err(e) => {
307                    let error_type = format!("{:?}", e);
308                    tracing::error!("Message {} failed: {}", message_id, e);
309                    metrics_collector_clone.record_message_failed(
310                        &worker_id,
311                        &queue_name,
312                        &error_type,
313                        start_time,
314                    );
315                    if let Err(nack_err) = ack_handle.nack(false).await {
316                        tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
317                    }
318                }
319            }
320
321            // Decrement load for least-loaded balancer after processing completes
322            if let Some(ref balancer) = least_loaded_balancer {
323                balancer.decrement_load(worker_index);
324            }
325
326            drop(permit);
327
328            // Signal task completion and decrement counter
329            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
330            task_completion_notify.notify_one();
331        });
332
333        Ok(())
334    }
335
336    /// Select a worker based on the configured load balancing strategy.
337    fn select_worker(&self) -> usize {
338        match self.strategy {
339            LoadBalancingStrategy::RoundRobin => self.round_robin_balancer.next(self.workers.len()),
340            LoadBalancingStrategy::Random => self.random_balancer.next(self.workers.len()),
341            LoadBalancingStrategy::LeastLoaded => {
342                if let Some(ref balancer) = self.least_loaded_balancer {
343                    balancer.next()
344                } else {
345                    0 // Fallback
346                }
347            }
348        }
349    }
350
351    /// Shutdown the pool gracefully.
352    ///
353    /// This prevents new messages from being dispatched, cancels in-flight tasks,
354    /// and waits for completion with a timeout using efficient notification.
355    pub async fn shutdown(&self) -> WorkerResult<()> {
356        tracing::info!("Shutting down worker pool: {}", self.name);
357
358        self.is_running.store(false, Ordering::SeqCst);
359        self.metrics_collector.record_active_workers(0);
360
361        // Cancel all in-flight tasks
362        self.cancellation_token.cancel();
363        tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
364
365        // Close the semaphore to prevent new acquisitions
366        self.semaphore.close();
367
368        // Wait for permits to be returned with efficient notification
369        let shutdown_timeout = Duration::from_secs(30); // 30 second timeout
370        let start = Instant::now();
371
372        loop {
373            let available = self.semaphore.available_permits();
374            let in_flight = self.concurrency_limit.saturating_sub(available);
375
376            if in_flight == 0 {
377                break; // All tasks completed
378            }
379
380            if start.elapsed() >= shutdown_timeout {
381                tracing::warn!(
382                    "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
383                    self.name,
384                    in_flight
385                );
386                break;
387            }
388
389            // Wait efficiently for task completion notification instead of busy-wait
390            tokio::select! {
391                _ = self.task_completion_notify.notified() => {
392                    // A task completed, check again
393                    continue;
394                }
395                _ = tokio::time::sleep(Duration::from_millis(100)) => {
396                    // Periodic check in case notifications are missed
397                    continue;
398                }
399            }
400        }
401
402        self.metrics_collector.record_in_flight_messages(0);
403        tracing::info!("Worker pool {} shutdown complete", self.name);
404        Ok(())
405    }
406
407    /// Get the current number of in-flight tasks.
408    pub fn in_flight_count(&self) -> usize {
409        self.in_flight_tasks.load(Ordering::SeqCst)
410    }
411}
412
413impl HealthCheck for WorkerPool {
414    fn check_health(&self) -> HealthStatus {
415        let is_running = self.is_running.load(Ordering::SeqCst);
416        let worker_count = self.worker_count();
417        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
418
419        // Check if pool is running
420        if !is_running {
421            return HealthStatus::Unhealthy {
422                reason: "Pool is not running".to_string(),
423            };
424        }
425
426        // Check if workers are available
427        if worker_count == 0 {
428            return HealthStatus::Degraded {
429                reason: "No workers available".to_string(),
430            };
431        }
432
433        // Check if pool is saturated (90%+ capacity)
434        let saturation = in_flight as f64 / self.concurrency_limit as f64;
435        if saturation > 0.9 {
436            return HealthStatus::Degraded {
437                reason: format!(
438                    "Pool near capacity: {} in-flight messages ({:.0}% saturation)",
439                    in_flight,
440                    saturation * 100.0
441                ),
442            };
443        }
444
445        HealthStatus::Healthy
446    }
447
448    fn status_message(&self) -> String {
449        let worker_count = self.worker_count();
450        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
451        let available_permits = self.semaphore.available_permits();
452
453        match self.check_health() {
454            HealthStatus::Healthy => {
455                format!(
456                    "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
457                    self.name, worker_count, in_flight, available_permits
458                )
459            }
460            HealthStatus::Degraded { ref reason } => {
461                format!(
462                    "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
463                    self.name, reason, worker_count, in_flight
464                )
465            }
466            HealthStatus::Unhealthy { ref reason } => {
467                format!(
468                    "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
469                    self.name, reason, worker_count, in_flight
470                )
471            }
472        }
473    }
474}
475
476/// Wrapper that converts a Worker into a MessageHandler.
477struct WorkerHandler(Arc<dyn Worker>);
478
479#[async_trait::async_trait]
480impl MessageHandler for WorkerHandler {
481    async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
482        // Workers always return Continue - they don't handle acknowledgment directly
483        self.0.process(message).await?;
484        Ok(crate::middleware::MiddlewareResult::Continue)
485    }
486}
487
488/// Wrapper to convert Arc<dyn Middleware> to Box<dyn Middleware>
489struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
490
491#[async_trait::async_trait]
492impl Middleware for ArcMiddlewareWrapper {
493    fn name(&self) -> &str {
494        self.0.name()
495    }
496
497    async fn handle(
498        &self,
499        message: ReceivedMessage<serde_json::Value>,
500        next: Box<dyn MessageHandler>,
501    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
502        self.0.handle(message, next).await
503    }
504}
505
506/// Wrapper to convert Box<dyn MessageHandler> to Arc
507struct ArcHandlerWrapper(Box<dyn MessageHandler>);
508
509#[async_trait::async_trait]
510impl MessageHandler for ArcHandlerWrapper {
511    async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
512        self.0.handle(message).await
513    }
514}
515
516/// Retry ack with exponential backoff on failure.
517///
518/// This helps handle transient network issues or broker unavailability.
519async fn retry_ack(
520    ack_handle: &Arc<dyn crate::message::AckHandle>,
521    message_id: &str,
522) -> WorkerResult<()> {
523    let max_retries = 3;
524    let base_delay_ms = 100;
525
526    for attempt in 0..max_retries {
527        match ack_handle.ack().await {
528            Ok(_) => return Ok(()),
529            Err(e) => {
530                if attempt < max_retries - 1 {
531                    let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
532                    tracing::warn!(
533                        "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
534                        attempt + 1,
535                        message_id,
536                        e,
537                        delay
538                    );
539                    sleep(delay).await;
540                } else {
541                    return Err(e);
542                }
543            }
544        }
545    }
546
547    unreachable!()
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553    use crate::message::{AckHandle, Message, MessageMetadata};
554    use crate::metrics::NoOpMetrics;
555    use async_trait::async_trait;
556    use std::sync::atomic::{AtomicUsize, Ordering};
557    use std::time::Duration; // Use NoOpMetrics for tests
558
559    #[derive(Debug)]
560    struct MockAckHandle;
561
562    #[async_trait]
563    impl AckHandle for MockAckHandle {
564        async fn ack(&self) -> WorkerResult<()> {
565            Ok(())
566        }
567
568        async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
569            Ok(())
570        }
571    }
572
573    struct TestWorker {
574        id: String,
575        process_count: Arc<AtomicUsize>,
576    }
577
578    impl TestWorker {
579        fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
580            let count = Arc::new(AtomicUsize::new(0));
581            (
582                Self {
583                    id: id.to_string(),
584                    process_count: count.clone(),
585                },
586                count,
587            )
588        }
589    }
590
591    #[async_trait]
592    impl Worker for TestWorker {
593        fn id(&self) -> &str {
594            &self.id
595        }
596
597        async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
598            self.process_count.fetch_add(1, Ordering::SeqCst);
599            Ok(())
600        }
601    }
602
603    fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
604        let message = Message {
605            id: id.to_string(),
606            payload: serde_json::json!({"test": "data"}),
607            metadata: MessageMetadata::new("test-queue"),
608        };
609        ReceivedMessage::new(message, Arc::new(MockAckHandle))
610    }
611
612    #[tokio::test]
613    async fn test_pool_creation() {
614        let pool = WorkerPool::new(
615            "test-pool",
616            LoadBalancingStrategy::RoundRobin,
617            Arc::new(NoOpMetrics),
618        );
619        assert_eq!(pool.name(), "test-pool");
620        assert_eq!(pool.worker_count(), 0);
621        assert!(pool.is_running.load(Ordering::SeqCst));
622    }
623
624    #[tokio::test]
625    async fn test_add_worker() {
626        let mut pool = WorkerPool::new(
627            "test-pool",
628            LoadBalancingStrategy::RoundRobin,
629            Arc::new(NoOpMetrics),
630        );
631        let (worker, _) = TestWorker::new("worker-1");
632        pool.add_worker(Arc::new(worker));
633
634        assert_eq!(pool.worker_count(), 1);
635    }
636
637    #[tokio::test]
638    async fn test_dispatch_empty_pool() {
639        let pool = WorkerPool::new(
640            "test-pool",
641            LoadBalancingStrategy::RoundRobin,
642            Arc::new(NoOpMetrics),
643        );
644        let message = create_test_message("msg-1");
645
646        let result = pool.dispatch(message).await;
647        assert!(matches!(result, Err(WorkerError::PoolExhausted)));
648    }
649
650    #[tokio::test]
651    async fn test_round_robin_distribution() {
652        let mut pool = WorkerPool::new(
653            "test-pool",
654            LoadBalancingStrategy::RoundRobin,
655            Arc::new(NoOpMetrics),
656        );
657
658        let (worker1, count1) = TestWorker::new("worker-1");
659        let (worker2, count2) = TestWorker::new("worker-2");
660
661        pool.add_worker(Arc::new(worker1));
662        pool.add_worker(Arc::new(worker2));
663
664        // Dispatch 4 messages
665        for i in 0..4 {
666            let message = create_test_message(&format!("msg-{}", i));
667            pool.dispatch(message).await.unwrap();
668        }
669
670        // Give tasks time to complete
671        tokio::time::sleep(Duration::from_millis(100)).await;
672
673        // Each worker should have processed 2 messages
674        assert_eq!(count1.load(Ordering::SeqCst), 2);
675        assert_eq!(count2.load(Ordering::SeqCst), 2);
676    }
677
678    #[tokio::test]
679    async fn test_pool_health() {
680        let pool = WorkerPool::new(
681            "test-pool",
682            LoadBalancingStrategy::RoundRobin,
683            Arc::new(NoOpMetrics),
684        );
685        assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); // Degraded because 0 workers
686
687        let mut pool = pool;
688        let (worker, _) = TestWorker::new("worker-1");
689        pool.add_worker(Arc::new(worker));
690
691        assert!(matches!(pool.check_health(), HealthStatus::Healthy));
692    }
693
694    #[tokio::test]
695    async fn test_concurrency_limit_enforcement() {
696        use std::sync::atomic::{AtomicUsize, Ordering};
697
698        // Create a worker that tracks concurrent executions
699        let concurrent_count = Arc::new(AtomicUsize::new(0));
700        let max_concurrent = Arc::new(AtomicUsize::new(0));
701
702        struct ConcurrentTestWorker {
703            id: String,
704            concurrent: Arc<AtomicUsize>,
705            max_concurrent: Arc<AtomicUsize>,
706        }
707
708        #[async_trait::async_trait]
709        impl Worker for ConcurrentTestWorker {
710            fn id(&self) -> &str {
711                &self.id
712            }
713
714            async fn process(
715                &self,
716                _message: ReceivedMessage<serde_json::Value>,
717            ) -> WorkerResult<()> {
718                // Increment concurrent count
719                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
720
721                // Track maximum
722                let mut max = self.max_concurrent.load(Ordering::SeqCst);
723                while current > max {
724                    match self.max_concurrent.compare_exchange_weak(
725                        max,
726                        current,
727                        Ordering::SeqCst,
728                        Ordering::SeqCst,
729                    ) {
730                        Ok(_) => break,
731                        Err(new_max) => max = new_max,
732                    }
733                }
734
735                // Simulate processing time
736                tokio::time::sleep(Duration::from_millis(50)).await;
737
738                // Decrement concurrent count
739                self.concurrent.fetch_sub(1, Ordering::SeqCst);
740                Ok(())
741            }
742        }
743
744        // Create pool with concurrency limit of 3
745        let mut pool = WorkerPool::with_concurrency(
746            "test-pool",
747            LoadBalancingStrategy::RoundRobin,
748            3, // Limit to 3 concurrent
749            Arc::new(NoOpMetrics),
750        );
751
752        // Add 1 worker (will handle all messages)
753        let worker = ConcurrentTestWorker {
754            id: "worker-1".to_string(),
755            concurrent: concurrent_count.clone(),
756            max_concurrent: max_concurrent.clone(),
757        };
758        pool.add_worker(Arc::new(worker));
759
760        // Dispatch 10 messages rapidly
761        for i in 0..10 {
762            let message = create_test_message(&format!("msg-{}", i));
763            pool.dispatch(message).await.unwrap();
764        }
765
766        // Wait for all to complete
767        tokio::time::sleep(Duration::from_millis(500)).await;
768
769        // Verify that concurrency never exceeded the limit
770        let actual_max = max_concurrent.load(Ordering::SeqCst);
771        assert!(
772            actual_max <= 3,
773            "Expected max concurrency <= 3, but got {}",
774            actual_max
775        );
776        assert!(
777            actual_max >= 2,
778            "Expected some concurrency (>= 2), but got {}",
779            actual_max
780        );
781    }
782
783    #[tokio::test]
784    async fn test_concurrency_limit_with_builder() {
785        use crate::builder::WorkerPoolBuilder;
786
787        let concurrent_count = Arc::new(AtomicUsize::new(0));
788        let max_concurrent = Arc::new(AtomicUsize::new(0));
789
790        struct TrackedWorker {
791            id: String,
792            concurrent: Arc<AtomicUsize>,
793            max_concurrent: Arc<AtomicUsize>,
794        }
795
796        #[async_trait::async_trait]
797        impl Worker for TrackedWorker {
798            fn id(&self) -> &str {
799                &self.id
800            }
801
802            async fn process(
803                &self,
804                _message: ReceivedMessage<serde_json::Value>,
805            ) -> WorkerResult<()> {
806                // Track concurrency
807                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
808
809                let mut max = self.max_concurrent.load(Ordering::SeqCst);
810                while current > max {
811                    match self.max_concurrent.compare_exchange_weak(
812                        max,
813                        current,
814                        Ordering::SeqCst,
815                        Ordering::SeqCst,
816                    ) {
817                        Ok(_) => break,
818                        Err(new_max) => max = new_max,
819                    }
820                }
821
822                tokio::time::sleep(Duration::from_millis(100)).await;
823
824                self.concurrent.fetch_sub(1, Ordering::SeqCst);
825                Ok(())
826            }
827        }
828
829        // Build pool with concurrency limit of 2
830        let pool = WorkerPoolBuilder::new("test-pool")
831            .with_concurrency_limit(2)
832            .add_worker(TrackedWorker {
833                id: "worker-1".to_string(),
834                concurrent: concurrent_count.clone(),
835                max_concurrent: max_concurrent.clone(),
836            })
837            .build()
838            .unwrap();
839
840        // Dispatch 6 messages rapidly
841        for i in 0..6 {
842            let message = create_test_message(&format!("msg-{}", i));
843            pool.dispatch(message).await.unwrap();
844        }
845
846        // Wait for all to complete
847        tokio::time::sleep(Duration::from_millis(800)).await;
848
849        // Verify that concurrency never exceeded the limit of 2
850        let actual_max = max_concurrent.load(Ordering::SeqCst);
851        assert!(
852            actual_max <= 2,
853            "Expected max concurrency <= 2, but got {}",
854            actual_max
855        );
856        assert!(
857            actual_max >= 1,
858            "Expected some concurrency (>= 1), but got {}",
859            actual_max
860        );
861    }
862
863    #[tokio::test]
864    async fn test_different_concurrency_limits() {
865        // Test that different pools can have different limits
866        let pool1 = WorkerPool::with_concurrency(
867            "pool1",
868            LoadBalancingStrategy::RoundRobin,
869            5,
870            Arc::new(NoOpMetrics),
871        );
872        let pool2 = WorkerPool::with_concurrency(
873            "pool2",
874            LoadBalancingStrategy::RoundRobin,
875            20,
876            Arc::new(NoOpMetrics),
877        );
878
879        // Verify they have different semaphore capacities
880        // We can't directly check semaphore capacity, but we can verify the pools work independently
881        assert_eq!(pool1.name(), "pool1");
882        assert_eq!(pool2.name(), "pool2");
883    }
884
885    #[tokio::test]
886    async fn test_pool_shutdown_prevents_dispatch() {
887        let mut pool = WorkerPool::new(
888            "test-pool",
889            LoadBalancingStrategy::RoundRobin,
890            Arc::new(NoOpMetrics),
891        );
892        let (worker, _) = TestWorker::new("worker-1");
893        pool.add_worker(Arc::new(worker));
894
895        pool.shutdown().await.unwrap();
896
897        let message = create_test_message("msg-after-shutdown");
898        let result = pool.dispatch(message).await;
899        assert!(matches!(result, Err(WorkerError::Shutdown)));
900        assert!(matches!(
901            pool.check_health(),
902            HealthStatus::Unhealthy { .. }
903        ));
904    }
905
906    /// Test that messages are not double-acknowledged when using AckNackMiddleware.
907    /// This verifies the fix for the "PRECONDITION_FAILED - unknown delivery tag" issue.
908    #[tokio::test]
909    async fn test_no_double_ack_with_middleware() {
910        use crate::AckNackMiddleware;
911        use std::sync::atomic::AtomicUsize;
912
913        // Track if ack was called
914        let ack_count = Arc::new(AtomicUsize::new(0));
915        let nack_count = Arc::new(AtomicUsize::new(0));
916
917        #[derive(Debug)]
918        struct TrackingAckHandle {
919            ack_count: Arc<AtomicUsize>,
920            nack_count: Arc<AtomicUsize>,
921        }
922
923        #[async_trait]
924        impl AckHandle for TrackingAckHandle {
925            async fn ack(&self) -> WorkerResult<()> {
926                self.ack_count.fetch_add(1, Ordering::SeqCst);
927                Ok(())
928            }
929
930            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
931                self.nack_count.fetch_add(1, Ordering::SeqCst);
932                Ok(())
933            }
934        }
935
936        // Create a successful worker
937        struct SuccessWorker;
938
939        #[async_trait]
940        impl Worker for SuccessWorker {
941            fn id(&self) -> &str {
942                "success-worker"
943            }
944
945            async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
946                Ok(())
947            }
948        }
949
950        let mut pool = WorkerPool::new(
951            "test-pool",
952            LoadBalancingStrategy::RoundRobin,
953            Arc::new(NoOpMetrics),
954        );
955
956        // Add AckNackMiddleware to auto-ack on success
957        pool.middlewares.push(Arc::new(AckNackMiddleware::default()));
958        pool.add_worker(Arc::new(SuccessWorker));
959
960        // Create message with tracking ack handle
961        let message = Message {
962            id: "test-msg".to_string(),
963            payload: serde_json::json!({"test": "data"}),
964            metadata: MessageMetadata::new("test-queue"),
965        };
966        let received = ReceivedMessage::new(
967            message,
968            Arc::new(TrackingAckHandle {
969                ack_count: ack_count.clone(),
970                nack_count: nack_count.clone(),
971            }),
972        );
973
974        // Dispatch the message
975        pool.dispatch(received).await.unwrap();
976
977        // Wait for processing to complete
978        tokio::time::sleep(Duration::from_millis(100)).await;
979
980        // Verify ack was called exactly once (by middleware, not by pool)
981        assert_eq!(
982            ack_count.load(Ordering::SeqCst),
983            1,
984            "Message should have been acked exactly once by middleware"
985        );
986        assert_eq!(
987            nack_count.load(Ordering::SeqCst),
988            0,
989            "Message should not have been nacked"
990        );
991    }
992
993    /// Test that pool handles ack/nack correctly WITHOUT AckNackMiddleware.
994    #[tokio::test]
995    async fn test_pool_handles_ack_without_middleware() {
996        use std::sync::atomic::AtomicUsize;
997
998        let ack_count = Arc::new(AtomicUsize::new(0));
999        let nack_count = Arc::new(AtomicUsize::new(0));
1000
1001        #[derive(Debug)]
1002        struct CountingAckHandle {
1003            ack_count: Arc<AtomicUsize>,
1004            nack_count: Arc<AtomicUsize>,
1005        }
1006
1007        #[async_trait]
1008        impl AckHandle for CountingAckHandle {
1009            async fn ack(&self) -> WorkerResult<()> {
1010                self.ack_count.fetch_add(1, Ordering::SeqCst);
1011                Ok(())
1012            }
1013
1014            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1015                self.nack_count.fetch_add(1, Ordering::SeqCst);
1016                Ok(())
1017            }
1018        }
1019
1020        // Create a successful worker
1021        struct SuccessWorker;
1022
1023        #[async_trait]
1024        impl Worker for SuccessWorker {
1025            fn id(&self) -> &str {
1026                "success-worker"
1027            }
1028
1029            async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
1030                Ok(())
1031            }
1032        }
1033
1034        let mut pool = WorkerPool::new(
1035            "test-pool",
1036            LoadBalancingStrategy::RoundRobin,
1037            Arc::new(NoOpMetrics),
1038        );
1039        // NO middleware - pool should handle ack
1040        pool.add_worker(Arc::new(SuccessWorker));
1041
1042        let message = Message {
1043            id: "test-msg".to_string(),
1044            payload: serde_json::json!({"test": "data"}),
1045            metadata: MessageMetadata::new("test-queue"),
1046        };
1047        let received = ReceivedMessage::new(
1048            message,
1049            Arc::new(CountingAckHandle {
1050                ack_count: ack_count.clone(),
1051                nack_count: nack_count.clone(),
1052            }),
1053        );
1054
1055        pool.dispatch(received).await.unwrap();
1056
1057        // Wait for processing to complete
1058        tokio::time::sleep(Duration::from_millis(100)).await;
1059
1060        // Verify pool acked the message
1061        assert_eq!(
1062            ack_count.load(Ordering::SeqCst),
1063            1,
1064            "Pool should have acked the message when no middleware is used"
1065        );
1066        assert_eq!(
1067            nack_count.load(Ordering::SeqCst),
1068            0,
1069            "Nack should not be called for successful message"
1070        );
1071    }
1072
1073    /// Test that pool handles nack correctly WITHOUT AckNackMiddleware.
1074    #[tokio::test]
1075    async fn test_pool_handles_nack_without_middleware() {
1076        use std::sync::atomic::AtomicUsize;
1077
1078        let ack_count = Arc::new(AtomicUsize::new(0));
1079        let nack_count = Arc::new(AtomicUsize::new(0));
1080
1081        #[derive(Debug)]
1082        struct CountingAckHandle {
1083            ack_count: Arc<AtomicUsize>,
1084            nack_count: Arc<AtomicUsize>,
1085        }
1086
1087        #[async_trait]
1088        impl AckHandle for CountingAckHandle {
1089            async fn ack(&self) -> WorkerResult<()> {
1090                self.ack_count.fetch_add(1, Ordering::SeqCst);
1091                Ok(())
1092            }
1093
1094            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1095                self.nack_count.fetch_add(1, Ordering::SeqCst);
1096                Ok(())
1097            }
1098        }
1099
1100        // Create a failing worker
1101        struct FailingWorker;
1102
1103        #[async_trait]
1104        impl Worker for FailingWorker {
1105            fn id(&self) -> &str {
1106                "failing-worker"
1107            }
1108
1109            async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
1110                Err(WorkerError::ProcessingFailed("Simulated failure".to_string()))
1111            }
1112        }
1113
1114        let mut pool = WorkerPool::new(
1115            "test-pool",
1116            LoadBalancingStrategy::RoundRobin,
1117            Arc::new(NoOpMetrics),
1118        );
1119        // NO middleware - pool should handle nack
1120        pool.add_worker(Arc::new(FailingWorker));
1121
1122        let message = Message {
1123            id: "test-msg".to_string(),
1124            payload: serde_json::json!({"test": "data"}),
1125            metadata: MessageMetadata::new("test-queue"),
1126        };
1127        let received = ReceivedMessage::new(
1128            message,
1129            Arc::new(CountingAckHandle {
1130                ack_count: ack_count.clone(),
1131                nack_count: nack_count.clone(),
1132            }),
1133        );
1134
1135        pool.dispatch(received).await.unwrap();
1136
1137        // Wait for processing to complete
1138        tokio::time::sleep(Duration::from_millis(100)).await;
1139
1140        // Verify pool nacked the message
1141        assert_eq!(
1142            nack_count.load(Ordering::SeqCst),
1143            1,
1144            "Pool should have nacked the message when no middleware is used"
1145        );
1146        assert_eq!(
1147            ack_count.load(Ordering::SeqCst),
1148            0,
1149            "Ack should not be called for failed message"
1150        );
1151    }
1152
1153    /// Integration test: Verify end-to-end flow with AckNackMiddleware and failing worker.
1154    #[tokio::test]
1155    async fn test_integration_ack_nack_middleware_with_failure() {
1156        use crate::AckNackMiddleware;
1157        use std::sync::atomic::AtomicUsize;
1158
1159        let ack_count = Arc::new(AtomicUsize::new(0));
1160        let nack_count = Arc::new(AtomicUsize::new(0));
1161
1162        #[derive(Debug)]
1163        struct CountingAckHandle {
1164            ack_count: Arc<AtomicUsize>,
1165            nack_count: Arc<AtomicUsize>,
1166        }
1167
1168        #[async_trait]
1169        impl AckHandle for CountingAckHandle {
1170            async fn ack(&self) -> WorkerResult<()> {
1171                self.ack_count.fetch_add(1, Ordering::SeqCst);
1172                Ok(())
1173            }
1174
1175            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1176                self.nack_count.fetch_add(1, Ordering::SeqCst);
1177                Ok(())
1178            }
1179        }
1180
1181        // Create a failing worker
1182        struct FailingWorker;
1183
1184        #[async_trait]
1185        impl Worker for FailingWorker {
1186            fn id(&self) -> &str {
1187                "failing-worker"
1188            }
1189
1190            async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
1191                Err(WorkerError::ProcessingFailed("Simulated failure".to_string()))
1192            }
1193        }
1194
1195        let mut pool = WorkerPool::new(
1196            "test-pool",
1197            LoadBalancingStrategy::RoundRobin,
1198            Arc::new(NoOpMetrics),
1199        );
1200
1201        // Add AckNackMiddleware to auto-nack on failure
1202        pool.middlewares.push(Arc::new(AckNackMiddleware::default()));
1203        pool.add_worker(Arc::new(FailingWorker));
1204
1205        let message = Message {
1206            id: "test-msg".to_string(),
1207            payload: serde_json::json!({"test": "data"}),
1208            metadata: MessageMetadata::new("test-queue"),
1209        };
1210        let received = ReceivedMessage::new(
1211            message,
1212            Arc::new(CountingAckHandle {
1213                ack_count: ack_count.clone(),
1214                nack_count: nack_count.clone(),
1215            }),
1216        );
1217
1218        pool.dispatch(received).await.unwrap();
1219
1220        // Wait for processing
1221        tokio::time::sleep(Duration::from_millis(100)).await;
1222
1223        // Verify nack was called exactly once (by middleware)
1224        assert_eq!(
1225            nack_count.load(Ordering::SeqCst),
1226            1,
1227            "Nack should be called exactly once by middleware"
1228        );
1229        assert_eq!(
1230            ack_count.load(Ordering::SeqCst),
1231            0,
1232            "Ack should not be called for failed message"
1233        );
1234    }
1235}