Skip to main content

foxtive_worker/
pool.rs

1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::{AckHandle, ReceivedMessage};
11use crate::metrics::WorkerMetrics;
12use crate::middleware::{MessageHandler, Middleware, MiddlewareChain, MiddlewareResult};
13use crate::strategies::{
14    LeastLoadedBalancer, LoadBalancingStrategy, RandomBalancer, RoundRobinBalancer,
15};
16use crate::worker::Worker;
17
18/// A pool of workers with load balancing for message distribution.
19///
20/// The pool manages multiple worker instances and distributes incoming messages
21/// based on the configured load balancing strategy.
22///
23/// # Example
24/// ```rust,no_run
25/// use foxtive_worker::pool::WorkerPool;
26/// use foxtive_worker::strategies::LoadBalancingStrategy;
27/// use foxtive_worker::metrics::NoOpMetrics;
28/// use std::sync::Arc;
29///
30/// let pool = WorkerPool::new("my-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
31/// // Add workers...
32/// // Dispatch messages...
33/// ```
34pub struct WorkerPool {
35    name: String,
36    workers: Vec<Arc<dyn Worker>>,
37    strategy: LoadBalancingStrategy,
38    semaphore: Arc<Semaphore>,
39    concurrency_limit: usize,
40    least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
41    round_robin_balancer: Arc<RoundRobinBalancer>,
42    random_balancer: RandomBalancer,
43    /// Middleware list (Arc-wrapped for cloning)
44    middlewares: Vec<Arc<dyn Middleware>>,
45    /// Metrics collector for this pool
46    metrics_collector: Arc<dyn WorkerMetrics>,
47    /// Indicates if the worker pool is currently running.
48    is_running: Arc<AtomicBool>,
49    /// Cancellation token for graceful shutdown of all spawned tasks
50    cancellation_token: CancellationToken,
51    /// Notify for task completion signaling during shutdown
52    task_completion_notify: Arc<Notify>,
53    /// Track number of in-flight tasks for monitoring
54    in_flight_tasks: Arc<AtomicUsize>,
55}
56
57impl std::fmt::Debug for WorkerPool {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("WorkerPool")
60            .field("name", &self.name)
61            .field("worker_count", &self.workers.len())
62            .field("strategy", &self.strategy)
63            .field("is_running", &self.is_running.load(Ordering::SeqCst))
64            .finish()
65    }
66}
67
68impl WorkerPool {
69    /// Create a new worker pool with the given name, load balancing strategy, and metrics collector.
70    pub fn new(
71        name: impl Into<String>,
72        strategy: LoadBalancingStrategy,
73        metrics_collector: Arc<dyn WorkerMetrics>,
74    ) -> Self {
75        Self::with_concurrency(name, strategy, 1000, metrics_collector)
76    }
77
78    /// Create a new worker pool with custom concurrency limit.
79    ///
80    /// # Arguments
81    /// * `name` - Pool name
82    /// * `strategy` - Load balancing strategy
83    /// * `concurrency_limit` - Maximum concurrent messages across all workers
84    /// * `metrics_collector` - Metrics collector implementation
85    pub fn with_concurrency(
86        name: impl Into<String>,
87        strategy: LoadBalancingStrategy,
88        concurrency_limit: usize,
89        metrics_collector: Arc<dyn WorkerMetrics>,
90    ) -> Self {
91        let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
92            Some(Arc::new(LeastLoadedBalancer::new(0)))
93        } else {
94            None
95        };
96
97        Self {
98            name: name.into(),
99            workers: Vec::new(),
100            strategy,
101            semaphore: Arc::new(Semaphore::new(concurrency_limit)),
102            concurrency_limit,
103            least_loaded_balancer,
104            round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
105            random_balancer: RandomBalancer,
106            middlewares: Vec::new(),
107            metrics_collector,
108            is_running: Arc::new(AtomicBool::new(true)),
109            cancellation_token: CancellationToken::new(),
110            task_completion_notify: Arc::new(Notify::new()),
111            in_flight_tasks: Arc::new(AtomicUsize::new(0)),
112        }
113    }
114
115    /// Add a worker to the pool.
116    pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
117        self.workers.push(worker);
118
119        // Update least-loaded balancer if needed - now O(1) with atomic swap
120        if let Some(ref balancer) = self.least_loaded_balancer {
121            balancer.add_worker();
122        }
123        self.metrics_collector
124            .record_active_workers(self.workers.len());
125    }
126
127    /// Add multiple workers to the pool.
128    pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
129        for worker in workers {
130            self.add_worker(worker);
131        }
132    }
133
134    /// Get the number of workers in the pool.
135    pub fn worker_count(&self) -> usize {
136        self.workers.len()
137    }
138
139    /// Set middleware for the pool.
140    ///
141    /// This allows you to inject middleware processing into the message flow.
142    /// Middleware will be executed in order before messages reaches workers.
143    pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
144        self.middlewares = middlewares;
145        self
146    }
147
148    /// Get the pool name.
149    pub fn name(&self) -> &str {
150        &self.name
151    }
152
153    /// Dispatch a message to a worker based on the load balancing strategy.
154    ///
155    /// Spawns an async task to process the message, respecting concurrency limits.
156    /// If middleware is configured, the message flows through the chain before reaching the worker.
157    pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
158        if !self.is_running.load(Ordering::SeqCst) {
159            return Err(WorkerError::Shutdown);
160        }
161        if self.workers.is_empty() {
162            return Err(WorkerError::PoolExhausted);
163        }
164
165        // Pick a worker and track metrics
166        let worker_index = self.select_worker();
167        let worker = self.workers[worker_index].clone();
168        let worker_id = worker.id().to_string();
169        let queue_name = message.message.metadata.source.clone();
170
171        self.metrics_collector
172            .record_message_received(&worker_id, &queue_name);
173        let start_time = Instant::now();
174
175        if let Some(ref balancer) = self.least_loaded_balancer {
176            balancer.increment_load(worker_index);
177        }
178
179        let permit = self
180            .semaphore
181            .clone()
182            .acquire_owned()
183            .await
184            .map_err(|_| WorkerError::Shutdown)?;
185
186        self.metrics_collector
187            .record_in_flight_messages(self.semaphore.available_permits());
188
189        // Build the handler chain - wrap worker with middleware if configured
190        let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
191            let worker_handler = WorkerHandler(worker);
192            let boxed_middlewares: Vec<Box<dyn Middleware>> = self
193                .middlewares
194                .iter()
195                .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
196                .collect();
197
198            let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
199            Arc::new(ArcHandlerWrapper(chain.build()))
200        } else {
201            Arc::new(WorkerHandler(worker))
202        };
203
204        let metrics_collector_clone = self.metrics_collector.clone();
205        let least_loaded_balancer = self.least_loaded_balancer.clone();
206        let cancellation_token = self.cancellation_token.child_token();
207        let task_completion_notify = self.task_completion_notify.clone();
208        let in_flight_tasks = self.in_flight_tasks.clone();
209
210        // Track in-flight task count
211        in_flight_tasks.fetch_add(1, Ordering::SeqCst);
212
213        // Extract ack_handle and message data before moving message into task
214        let ack_handle = message.ack_handle.clone();
215        let message_id = message.message.id.clone();
216        let attempt = message.message.metadata.attempt;
217        // Clone the full message for retry (preserves routing_key and all metadata)
218        let retry_message = message.message.clone();
219
220        tokio::spawn(async move {
221            // Use tokio::select! to handle cancellation
222            let result = tokio::select! {
223                result = handler.handle(message) => result,  // No clone - move message into task
224                _ = cancellation_token.cancelled() => {
225                    tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
226                    // Decrement in-flight counter on cancellation
227                    in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
228                    task_completion_notify.notify_one();
229                    return;
230                }
231            };
232
233            match result {
234                Ok(middleware_result) => {
235                    match middleware_result {
236                        MiddlewareResult::Acknowledged => {
237                            // Middleware (e.g., AckNackMiddleware) already handled acknowledgment
238                            tracing::debug!(
239                                "Message {} already acknowledged by middleware",
240                                message_id
241                            );
242                            metrics_collector_clone.record_message_processed(
243                                &worker_id,
244                                &queue_name,
245                                start_time,
246                            );
247                            // Decrement in-flight counter
248                            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
249                            task_completion_notify.notify_one();
250                            return;
251                        }
252                        MiddlewareResult::Continue => {
253                            // No middleware handled ack - pool should ack on success
254                            tracing::debug!("Message {} processed successfully", message_id);
255                            metrics_collector_clone.record_message_processed(
256                                &worker_id,
257                                &queue_name,
258                                start_time,
259                            );
260                            if let Err(e) = retry_ack(&ack_handle, &message_id).await {
261                                tracing::error!(
262                                    "Failed to ack message {} after retries: {}. Message may be redelivered.",
263                                    message_id,
264                                    e
265                                );
266                            }
267                        }
268                    }
269                }
270                Err(WorkerError::RetryableFailure { source, delay_ms }) => {
271                    tracing::warn!(
272                        "Message {} failed (will retry in {:?}): {}",
273                        message_id,
274                        delay_ms,
275                        source
276                    );
277                    metrics_collector_clone.record_message_retried(
278                        &worker_id,
279                        &queue_name,
280                        attempt,
281                    );
282                    metrics_collector_clone.record_message_failed(
283                        &worker_id,
284                        &queue_name,
285                        "RetryableFailure",
286                        start_time,
287                    );
288
289                    // Use delayed retry if supported by backend, otherwise nack with requeue
290                    // The retry_with_delay method will handle backend-specific retry logic
291                    // Pass the original message to preserve all metadata including routing_key
292                    if let Err(e) = ack_handle
293                        .retry_with_delay(&retry_message, delay_ms.as_millis() as u64)
294                        .await
295                    {
296                        tracing::error!(
297                            "Failed to schedule retry for message {}: {}",
298                            message_id,
299                            e
300                        );
301                        // Fallback to immediate nack with requeue if retry fails
302                        if let Err(nack_err) = ack_handle.nack(true).await {
303                            tracing::error!(
304                                "Fallback nack also failed for message {}: {}",
305                                message_id,
306                                nack_err
307                            );
308                        }
309                    }
310                }
311                Err(WorkerError::RetriesExhausted { source }) => {
312                    tracing::error!(
313                        "Message {} exhausted all retries, sending to DLQ: {}",
314                        message_id,
315                        source
316                    );
317                    metrics_collector_clone
318                        .record_message_retries_exhausted(&worker_id, &queue_name);
319                    metrics_collector_clone.record_message_failed(
320                        &worker_id,
321                        &queue_name,
322                        "RetriesExhausted",
323                        start_time,
324                    );
325                    // Send to DLQ using the ack handle's send_to_dlq method
326                    if let Err(e) = ack_handle
327                        .send_to_dlq(&retry_message, &source.to_string())
328                        .await
329                    {
330                        tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
331                        // Fallback: nack without requeue (message will be discarded)
332                        if let Err(nack_err) = ack_handle.nack(false).await {
333                            tracing::error!(
334                                "Fallback nack also failed for message {}: {}",
335                                message_id,
336                                nack_err
337                            );
338                        }
339                    }
340                }
341                Err(e) => {
342                    let error_type = format!("{:?}", e);
343                    tracing::error!("Message {} failed: {}", message_id, e);
344                    metrics_collector_clone.record_message_failed(
345                        &worker_id,
346                        &queue_name,
347                        &error_type,
348                        start_time,
349                    );
350                    if let Err(nack_err) = ack_handle.nack(false).await {
351                        tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
352                    }
353                }
354            }
355
356            // Decrement load for least-loaded balancer after processing completes
357            if let Some(ref balancer) = least_loaded_balancer {
358                balancer.decrement_load(worker_index);
359            }
360
361            drop(permit);
362
363            // Signal task completion and decrement counter
364            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
365            task_completion_notify.notify_one();
366        });
367
368        Ok(())
369    }
370
371    /// Select a worker based on the configured load balancing strategy.
372    fn select_worker(&self) -> usize {
373        match self.strategy {
374            LoadBalancingStrategy::RoundRobin => self.round_robin_balancer.next(self.workers.len()),
375            LoadBalancingStrategy::Random => self.random_balancer.next(self.workers.len()),
376            LoadBalancingStrategy::LeastLoaded => {
377                if let Some(ref balancer) = self.least_loaded_balancer {
378                    balancer.next()
379                } else {
380                    0 // Fallback
381                }
382            }
383        }
384    }
385
386    /// Shutdown the pool gracefully.
387    ///
388    /// This prevents new messages from being dispatched, cancels in-flight tasks,
389    /// and waits for completion with a timeout using efficient notification.
390    pub async fn shutdown(&self) -> WorkerResult<()> {
391        tracing::info!("Shutting down worker pool: {}", self.name);
392
393        self.is_running.store(false, Ordering::SeqCst);
394        self.metrics_collector.record_active_workers(0);
395
396        // Cancel all in-flight tasks
397        self.cancellation_token.cancel();
398        tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
399
400        // Close the semaphore to prevent new acquisitions
401        self.semaphore.close();
402
403        // Wait for permits to be returned with efficient notification
404        let shutdown_timeout = Duration::from_secs(30); // 30 second timeout
405        let start = Instant::now();
406
407        loop {
408            let available = self.semaphore.available_permits();
409            let in_flight = self.concurrency_limit.saturating_sub(available);
410
411            if in_flight == 0 {
412                break; // All tasks completed
413            }
414
415            if start.elapsed() >= shutdown_timeout {
416                tracing::warn!(
417                    "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
418                    self.name,
419                    in_flight
420                );
421                break;
422            }
423
424            // Wait efficiently for task completion notification instead of busy-wait
425            tokio::select! {
426                _ = self.task_completion_notify.notified() => {
427                    // A task completed, check again
428                    continue;
429                }
430                _ = tokio::time::sleep(Duration::from_millis(100)) => {
431                    // Periodic check in case notifications are missed
432                    continue;
433                }
434            }
435        }
436
437        self.metrics_collector.record_in_flight_messages(0);
438        tracing::info!("Worker pool {} shutdown complete", self.name);
439        Ok(())
440    }
441
442    /// Get the current number of in-flight tasks.
443    pub fn in_flight_count(&self) -> usize {
444        self.in_flight_tasks.load(Ordering::SeqCst)
445    }
446}
447
448impl HealthCheck for WorkerPool {
449    fn check_health(&self) -> HealthStatus {
450        let is_running = self.is_running.load(Ordering::SeqCst);
451        let worker_count = self.worker_count();
452        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
453
454        // Check if pool is running
455        if !is_running {
456            return HealthStatus::Unhealthy {
457                reason: "Pool is not running".to_string(),
458            };
459        }
460
461        // Check if workers are available
462        if worker_count == 0 {
463            return HealthStatus::Degraded {
464                reason: "No workers available".to_string(),
465            };
466        }
467
468        // Check if pool is saturated (90%+ capacity)
469        let saturation = in_flight as f64 / self.concurrency_limit as f64;
470        if saturation > 0.9 {
471            return HealthStatus::Degraded {
472                reason: format!(
473                    "Pool near capacity: {} in-flight messages ({:.0}% saturation)",
474                    in_flight,
475                    saturation * 100.0
476                ),
477            };
478        }
479
480        HealthStatus::Healthy
481    }
482
483    fn status_message(&self) -> String {
484        let worker_count = self.worker_count();
485        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
486        let available_permits = self.semaphore.available_permits();
487
488        match self.check_health() {
489            HealthStatus::Healthy => {
490                format!(
491                    "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
492                    self.name, worker_count, in_flight, available_permits
493                )
494            }
495            HealthStatus::Degraded { ref reason } => {
496                format!(
497                    "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
498                    self.name, reason, worker_count, in_flight
499                )
500            }
501            HealthStatus::Unhealthy { ref reason } => {
502                format!(
503                    "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
504                    self.name, reason, worker_count, in_flight
505                )
506            }
507        }
508    }
509}
510
511/// Wrapper that converts a Worker into a MessageHandler.
512struct WorkerHandler(Arc<dyn Worker>);
513
514#[async_trait::async_trait]
515impl MessageHandler for WorkerHandler {
516    async fn handle(
517        &self,
518        message: ReceivedMessage<serde_json::Value>,
519    ) -> Result<MiddlewareResult, WorkerError> {
520        // Workers always return Continue - they don't handle acknowledgment directly
521        self.0.process(message).await?;
522        Ok(MiddlewareResult::Continue)
523    }
524}
525
526/// Wrapper to convert Arc<dyn Middleware> to Box<dyn Middleware>
527struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
528
529#[async_trait::async_trait]
530impl Middleware for ArcMiddlewareWrapper {
531    fn name(&self) -> &str {
532        self.0.name()
533    }
534
535    async fn handle(
536        &self,
537        message: ReceivedMessage<serde_json::Value>,
538        next: Box<dyn MessageHandler>,
539    ) -> Result<MiddlewareResult, WorkerError> {
540        self.0.handle(message, next).await
541    }
542}
543
544/// Wrapper to convert Box<dyn MessageHandler> to Arc
545struct ArcHandlerWrapper(Box<dyn MessageHandler>);
546
547#[async_trait::async_trait]
548impl MessageHandler for ArcHandlerWrapper {
549    async fn handle(
550        &self,
551        message: ReceivedMessage<serde_json::Value>,
552    ) -> Result<MiddlewareResult, WorkerError> {
553        self.0.handle(message).await
554    }
555}
556
557/// Retry ack with exponential backoff on failure.
558///
559/// This helps handle transient network issues or broker unavailability.
560async fn retry_ack(
561    ack_handle: &Arc<dyn AckHandle>,
562    message_id: &str,
563) -> WorkerResult<()> {
564    let max_retries = 3;
565    let base_delay_ms = 100;
566
567    for attempt in 0..max_retries {
568        match ack_handle.ack().await {
569            Ok(_) => return Ok(()),
570            Err(e) => {
571                if attempt < max_retries - 1 {
572                    let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
573                    tracing::warn!(
574                        "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
575                        attempt + 1,
576                        message_id,
577                        e,
578                        delay
579                    );
580                    sleep(delay).await;
581                } else {
582                    return Err(e);
583                }
584            }
585        }
586    }
587
588    unreachable!()
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::message::{AckHandle, Message, MessageMetadata};
595    use crate::metrics::NoOpMetrics;
596    use async_trait::async_trait;
597    use std::sync::atomic::{AtomicUsize, Ordering};
598    use std::time::Duration;
599    // Use NoOpMetrics for tests
600
601    #[derive(Debug)]
602    struct MockAckHandle;
603
604    #[async_trait]
605    impl AckHandle for MockAckHandle {
606        async fn ack(&self) -> WorkerResult<()> {
607            Ok(())
608        }
609
610        async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
611            Ok(())
612        }
613    }
614
615    struct TestWorker {
616        id: String,
617        process_count: Arc<AtomicUsize>,
618    }
619
620    impl TestWorker {
621        fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
622            let count = Arc::new(AtomicUsize::new(0));
623            (
624                Self {
625                    id: id.to_string(),
626                    process_count: count.clone(),
627                },
628                count,
629            )
630        }
631    }
632
633    #[async_trait]
634    impl Worker for TestWorker {
635        fn id(&self) -> &str {
636            &self.id
637        }
638
639        async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
640            self.process_count.fetch_add(1, Ordering::SeqCst);
641            Ok(())
642        }
643    }
644
645    fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
646        let message = Message {
647            id: id.to_string(),
648            payload: serde_json::json!({"test": "data"}),
649            metadata: MessageMetadata::new("test-queue"),
650        };
651        ReceivedMessage::new(message, Arc::new(MockAckHandle))
652    }
653
654    #[tokio::test]
655    async fn test_pool_creation() {
656        let pool = WorkerPool::new(
657            "test-pool",
658            LoadBalancingStrategy::RoundRobin,
659            Arc::new(NoOpMetrics),
660        );
661        assert_eq!(pool.name(), "test-pool");
662        assert_eq!(pool.worker_count(), 0);
663        assert!(pool.is_running.load(Ordering::SeqCst));
664    }
665
666    #[tokio::test]
667    async fn test_add_worker() {
668        let mut pool = WorkerPool::new(
669            "test-pool",
670            LoadBalancingStrategy::RoundRobin,
671            Arc::new(NoOpMetrics),
672        );
673        let (worker, _) = TestWorker::new("worker-1");
674        pool.add_worker(Arc::new(worker));
675
676        assert_eq!(pool.worker_count(), 1);
677    }
678
679    #[tokio::test]
680    async fn test_dispatch_empty_pool() {
681        let pool = WorkerPool::new(
682            "test-pool",
683            LoadBalancingStrategy::RoundRobin,
684            Arc::new(NoOpMetrics),
685        );
686        let message = create_test_message("msg-1");
687
688        let result = pool.dispatch(message).await;
689        assert!(matches!(result, Err(WorkerError::PoolExhausted)));
690    }
691
692    #[tokio::test]
693    async fn test_round_robin_distribution() {
694        let mut pool = WorkerPool::new(
695            "test-pool",
696            LoadBalancingStrategy::RoundRobin,
697            Arc::new(NoOpMetrics),
698        );
699
700        let (worker1, count1) = TestWorker::new("worker-1");
701        let (worker2, count2) = TestWorker::new("worker-2");
702
703        pool.add_worker(Arc::new(worker1));
704        pool.add_worker(Arc::new(worker2));
705
706        // Dispatch 4 messages
707        for i in 0..4 {
708            let message = create_test_message(&format!("msg-{}", i));
709            pool.dispatch(message).await.unwrap();
710        }
711
712        // Give tasks time to complete
713        tokio::time::sleep(Duration::from_millis(100)).await;
714
715        // Each worker should have processed 2 messages
716        assert_eq!(count1.load(Ordering::SeqCst), 2);
717        assert_eq!(count2.load(Ordering::SeqCst), 2);
718    }
719
720    #[tokio::test]
721    async fn test_pool_health() {
722        let pool = WorkerPool::new(
723            "test-pool",
724            LoadBalancingStrategy::RoundRobin,
725            Arc::new(NoOpMetrics),
726        );
727        assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); // Degraded because 0 workers
728
729        let mut pool = pool;
730        let (worker, _) = TestWorker::new("worker-1");
731        pool.add_worker(Arc::new(worker));
732
733        assert!(matches!(pool.check_health(), HealthStatus::Healthy));
734    }
735
736    #[tokio::test]
737    async fn test_concurrency_limit_enforcement() {
738        use std::sync::atomic::{AtomicUsize, Ordering};
739
740        // Create a worker that tracks concurrent executions
741        let concurrent_count = Arc::new(AtomicUsize::new(0));
742        let max_concurrent = Arc::new(AtomicUsize::new(0));
743
744        struct ConcurrentTestWorker {
745            id: String,
746            concurrent: Arc<AtomicUsize>,
747            max_concurrent: Arc<AtomicUsize>,
748        }
749
750        #[async_trait::async_trait]
751        impl Worker for ConcurrentTestWorker {
752            fn id(&self) -> &str {
753                &self.id
754            }
755
756            async fn process(
757                &self,
758                _message: ReceivedMessage<serde_json::Value>,
759            ) -> WorkerResult<()> {
760                // Increment concurrent count
761                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
762
763                // Track maximum
764                let mut max = self.max_concurrent.load(Ordering::SeqCst);
765                while current > max {
766                    match self.max_concurrent.compare_exchange_weak(
767                        max,
768                        current,
769                        Ordering::SeqCst,
770                        Ordering::SeqCst,
771                    ) {
772                        Ok(_) => break,
773                        Err(new_max) => max = new_max,
774                    }
775                }
776
777                // Simulate processing time
778                tokio::time::sleep(Duration::from_millis(50)).await;
779
780                // Decrement concurrent count
781                self.concurrent.fetch_sub(1, Ordering::SeqCst);
782                Ok(())
783            }
784        }
785
786        // Create pool with concurrency limit of 3
787        let mut pool = WorkerPool::with_concurrency(
788            "test-pool",
789            LoadBalancingStrategy::RoundRobin,
790            3, // Limit to 3 concurrent
791            Arc::new(NoOpMetrics),
792        );
793
794        // Add 1 worker (will handle all messages)
795        let worker = ConcurrentTestWorker {
796            id: "worker-1".to_string(),
797            concurrent: concurrent_count.clone(),
798            max_concurrent: max_concurrent.clone(),
799        };
800        pool.add_worker(Arc::new(worker));
801
802        // Dispatch 10 messages rapidly
803        for i in 0..10 {
804            let message = create_test_message(&format!("msg-{}", i));
805            pool.dispatch(message).await.unwrap();
806        }
807
808        // Wait for all to complete
809        tokio::time::sleep(Duration::from_millis(500)).await;
810
811        // Verify that concurrency never exceeded the limit
812        let actual_max = max_concurrent.load(Ordering::SeqCst);
813        assert!(
814            actual_max <= 3,
815            "Expected max concurrency <= 3, but got {}",
816            actual_max
817        );
818        assert!(
819            actual_max >= 2,
820            "Expected some concurrency (>= 2), but got {}",
821            actual_max
822        );
823    }
824
825    #[tokio::test]
826    async fn test_concurrency_limit_with_builder() {
827        use crate::builder::WorkerPoolBuilder;
828
829        let concurrent_count = Arc::new(AtomicUsize::new(0));
830        let max_concurrent = Arc::new(AtomicUsize::new(0));
831
832        struct TrackedWorker {
833            id: String,
834            concurrent: Arc<AtomicUsize>,
835            max_concurrent: Arc<AtomicUsize>,
836        }
837
838        #[async_trait::async_trait]
839        impl Worker for TrackedWorker {
840            fn id(&self) -> &str {
841                &self.id
842            }
843
844            async fn process(
845                &self,
846                _message: ReceivedMessage<serde_json::Value>,
847            ) -> WorkerResult<()> {
848                // Track concurrency
849                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
850
851                let mut max = self.max_concurrent.load(Ordering::SeqCst);
852                while current > max {
853                    match self.max_concurrent.compare_exchange_weak(
854                        max,
855                        current,
856                        Ordering::SeqCst,
857                        Ordering::SeqCst,
858                    ) {
859                        Ok(_) => break,
860                        Err(new_max) => max = new_max,
861                    }
862                }
863
864                tokio::time::sleep(Duration::from_millis(100)).await;
865
866                self.concurrent.fetch_sub(1, Ordering::SeqCst);
867                Ok(())
868            }
869        }
870
871        // Build pool with concurrency limit of 2
872        let pool = WorkerPoolBuilder::new("test-pool")
873            .with_concurrency_limit(2)
874            .add_worker(TrackedWorker {
875                id: "worker-1".to_string(),
876                concurrent: concurrent_count.clone(),
877                max_concurrent: max_concurrent.clone(),
878            })
879            .build()
880            .unwrap();
881
882        // Dispatch 6 messages rapidly
883        for i in 0..6 {
884            let message = create_test_message(&format!("msg-{}", i));
885            pool.dispatch(message).await.unwrap();
886        }
887
888        // Wait for all to complete
889        tokio::time::sleep(Duration::from_millis(800)).await;
890
891        // Verify that concurrency never exceeded the limit of 2
892        let actual_max = max_concurrent.load(Ordering::SeqCst);
893        assert!(
894            actual_max <= 2,
895            "Expected max concurrency <= 2, but got {}",
896            actual_max
897        );
898        assert!(
899            actual_max >= 1,
900            "Expected some concurrency (>= 1), but got {}",
901            actual_max
902        );
903    }
904
905    #[tokio::test]
906    async fn test_different_concurrency_limits() {
907        // Test that different pools can have different limits
908        let pool1 = WorkerPool::with_concurrency(
909            "pool1",
910            LoadBalancingStrategy::RoundRobin,
911            5,
912            Arc::new(NoOpMetrics),
913        );
914        let pool2 = WorkerPool::with_concurrency(
915            "pool2",
916            LoadBalancingStrategy::RoundRobin,
917            20,
918            Arc::new(NoOpMetrics),
919        );
920
921        // Verify they have different semaphore capacities
922        // We can't directly check semaphore capacity, but we can verify the pools work independently
923        assert_eq!(pool1.name(), "pool1");
924        assert_eq!(pool2.name(), "pool2");
925    }
926
927    #[tokio::test]
928    async fn test_pool_shutdown_prevents_dispatch() {
929        let mut pool = WorkerPool::new(
930            "test-pool",
931            LoadBalancingStrategy::RoundRobin,
932            Arc::new(NoOpMetrics),
933        );
934        let (worker, _) = TestWorker::new("worker-1");
935        pool.add_worker(Arc::new(worker));
936
937        pool.shutdown().await.unwrap();
938
939        let message = create_test_message("msg-after-shutdown");
940        let result = pool.dispatch(message).await;
941        assert!(matches!(result, Err(WorkerError::Shutdown)));
942        assert!(matches!(
943            pool.check_health(),
944            HealthStatus::Unhealthy { .. }
945        ));
946    }
947
948    /// Test that messages are not double-acknowledged when using AckNackMiddleware.
949    /// This verifies the fix for the "PRECONDITION_FAILED - unknown delivery tag" issue.
950    #[tokio::test]
951    async fn test_no_double_ack_with_middleware() {
952        use crate::AckNackMiddleware;
953        use std::sync::atomic::AtomicUsize;
954
955        // Track if ack was called
956        let ack_count = Arc::new(AtomicUsize::new(0));
957        let nack_count = Arc::new(AtomicUsize::new(0));
958
959        #[derive(Debug)]
960        struct TrackingAckHandle {
961            ack_count: Arc<AtomicUsize>,
962            nack_count: Arc<AtomicUsize>,
963        }
964
965        #[async_trait]
966        impl AckHandle for TrackingAckHandle {
967            async fn ack(&self) -> WorkerResult<()> {
968                self.ack_count.fetch_add(1, Ordering::SeqCst);
969                Ok(())
970            }
971
972            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
973                self.nack_count.fetch_add(1, Ordering::SeqCst);
974                Ok(())
975            }
976        }
977
978        // Create a successful worker
979        struct SuccessWorker;
980
981        #[async_trait]
982        impl Worker for SuccessWorker {
983            fn id(&self) -> &str {
984                "success-worker"
985            }
986
987            async fn process(
988                &self,
989                _message: ReceivedMessage<serde_json::Value>,
990            ) -> WorkerResult<()> {
991                Ok(())
992            }
993        }
994
995        let mut pool = WorkerPool::new(
996            "test-pool",
997            LoadBalancingStrategy::RoundRobin,
998            Arc::new(NoOpMetrics),
999        );
1000
1001        // Add AckNackMiddleware to auto-ack on success
1002        pool.middlewares
1003            .push(Arc::new(AckNackMiddleware::default()));
1004        pool.add_worker(Arc::new(SuccessWorker));
1005
1006        // Create message with tracking ack handle
1007        let message = Message {
1008            id: "test-msg".to_string(),
1009            payload: serde_json::json!({"test": "data"}),
1010            metadata: MessageMetadata::new("test-queue"),
1011        };
1012        let received = ReceivedMessage::new(
1013            message,
1014            Arc::new(TrackingAckHandle {
1015                ack_count: ack_count.clone(),
1016                nack_count: nack_count.clone(),
1017            }),
1018        );
1019
1020        // Dispatch the message
1021        pool.dispatch(received).await.unwrap();
1022
1023        // Wait for processing to complete
1024        tokio::time::sleep(Duration::from_millis(100)).await;
1025
1026        // Verify ack was called exactly once (by middleware, not by pool)
1027        assert_eq!(
1028            ack_count.load(Ordering::SeqCst),
1029            1,
1030            "Message should have been acked exactly once by middleware"
1031        );
1032        assert_eq!(
1033            nack_count.load(Ordering::SeqCst),
1034            0,
1035            "Message should not have been nacked"
1036        );
1037    }
1038
1039    /// Test that pool handles ack/nack correctly WITHOUT AckNackMiddleware.
1040    #[tokio::test]
1041    async fn test_pool_handles_ack_without_middleware() {
1042        use std::sync::atomic::AtomicUsize;
1043
1044        let ack_count = Arc::new(AtomicUsize::new(0));
1045        let nack_count = Arc::new(AtomicUsize::new(0));
1046
1047        #[derive(Debug)]
1048        struct CountingAckHandle {
1049            ack_count: Arc<AtomicUsize>,
1050            nack_count: Arc<AtomicUsize>,
1051        }
1052
1053        #[async_trait]
1054        impl AckHandle for CountingAckHandle {
1055            async fn ack(&self) -> WorkerResult<()> {
1056                self.ack_count.fetch_add(1, Ordering::SeqCst);
1057                Ok(())
1058            }
1059
1060            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1061                self.nack_count.fetch_add(1, Ordering::SeqCst);
1062                Ok(())
1063            }
1064        }
1065
1066        // Create a successful worker
1067        struct SuccessWorker;
1068
1069        #[async_trait]
1070        impl Worker for SuccessWorker {
1071            fn id(&self) -> &str {
1072                "success-worker"
1073            }
1074
1075            async fn process(
1076                &self,
1077                _message: ReceivedMessage<serde_json::Value>,
1078            ) -> WorkerResult<()> {
1079                Ok(())
1080            }
1081        }
1082
1083        let mut pool = WorkerPool::new(
1084            "test-pool",
1085            LoadBalancingStrategy::RoundRobin,
1086            Arc::new(NoOpMetrics),
1087        );
1088        // NO middleware - pool should handle ack
1089        pool.add_worker(Arc::new(SuccessWorker));
1090
1091        let message = Message {
1092            id: "test-msg".to_string(),
1093            payload: serde_json::json!({"test": "data"}),
1094            metadata: MessageMetadata::new("test-queue"),
1095        };
1096        let received = ReceivedMessage::new(
1097            message,
1098            Arc::new(CountingAckHandle {
1099                ack_count: ack_count.clone(),
1100                nack_count: nack_count.clone(),
1101            }),
1102        );
1103
1104        pool.dispatch(received).await.unwrap();
1105
1106        // Wait for processing to complete
1107        tokio::time::sleep(Duration::from_millis(100)).await;
1108
1109        // Verify pool acked the message
1110        assert_eq!(
1111            ack_count.load(Ordering::SeqCst),
1112            1,
1113            "Pool should have acked the message when no middleware is used"
1114        );
1115        assert_eq!(
1116            nack_count.load(Ordering::SeqCst),
1117            0,
1118            "Nack should not be called for successful message"
1119        );
1120    }
1121
1122    /// Test that pool handles nack correctly WITHOUT AckNackMiddleware.
1123    #[tokio::test]
1124    async fn test_pool_handles_nack_without_middleware() {
1125        use std::sync::atomic::AtomicUsize;
1126
1127        let ack_count = Arc::new(AtomicUsize::new(0));
1128        let nack_count = Arc::new(AtomicUsize::new(0));
1129
1130        #[derive(Debug)]
1131        struct CountingAckHandle {
1132            ack_count: Arc<AtomicUsize>,
1133            nack_count: Arc<AtomicUsize>,
1134        }
1135
1136        #[async_trait]
1137        impl AckHandle for CountingAckHandle {
1138            async fn ack(&self) -> WorkerResult<()> {
1139                self.ack_count.fetch_add(1, Ordering::SeqCst);
1140                Ok(())
1141            }
1142
1143            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1144                self.nack_count.fetch_add(1, Ordering::SeqCst);
1145                Ok(())
1146            }
1147        }
1148
1149        // Create a failing worker
1150        struct FailingWorker;
1151
1152        #[async_trait]
1153        impl Worker for FailingWorker {
1154            fn id(&self) -> &str {
1155                "failing-worker"
1156            }
1157
1158            async fn process(
1159                &self,
1160                _message: ReceivedMessage<serde_json::Value>,
1161            ) -> WorkerResult<()> {
1162                Err(WorkerError::ProcessingFailed(
1163                    "Simulated failure".to_string(),
1164                ))
1165            }
1166        }
1167
1168        let mut pool = WorkerPool::new(
1169            "test-pool",
1170            LoadBalancingStrategy::RoundRobin,
1171            Arc::new(NoOpMetrics),
1172        );
1173        // NO middleware - pool should handle nack
1174        pool.add_worker(Arc::new(FailingWorker));
1175
1176        let message = Message {
1177            id: "test-msg".to_string(),
1178            payload: serde_json::json!({"test": "data"}),
1179            metadata: MessageMetadata::new("test-queue"),
1180        };
1181        let received = ReceivedMessage::new(
1182            message,
1183            Arc::new(CountingAckHandle {
1184                ack_count: ack_count.clone(),
1185                nack_count: nack_count.clone(),
1186            }),
1187        );
1188
1189        pool.dispatch(received).await.unwrap();
1190
1191        // Wait for processing to complete
1192        tokio::time::sleep(Duration::from_millis(100)).await;
1193
1194        // Verify pool nacked the message
1195        assert_eq!(
1196            nack_count.load(Ordering::SeqCst),
1197            1,
1198            "Pool should have nacked the message when no middleware is used"
1199        );
1200        assert_eq!(
1201            ack_count.load(Ordering::SeqCst),
1202            0,
1203            "Ack should not be called for failed message"
1204        );
1205    }
1206
1207    /// Integration test: Verify end-to-end flow with AckNackMiddleware and failing worker.
1208    #[tokio::test]
1209    async fn test_integration_ack_nack_middleware_with_failure() {
1210        use crate::AckNackMiddleware;
1211        use std::sync::atomic::AtomicUsize;
1212
1213        let ack_count = Arc::new(AtomicUsize::new(0));
1214        let nack_count = Arc::new(AtomicUsize::new(0));
1215
1216        #[derive(Debug)]
1217        struct CountingAckHandle {
1218            ack_count: Arc<AtomicUsize>,
1219            nack_count: Arc<AtomicUsize>,
1220        }
1221
1222        #[async_trait]
1223        impl AckHandle for CountingAckHandle {
1224            async fn ack(&self) -> WorkerResult<()> {
1225                self.ack_count.fetch_add(1, Ordering::SeqCst);
1226                Ok(())
1227            }
1228
1229            async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1230                self.nack_count.fetch_add(1, Ordering::SeqCst);
1231                Ok(())
1232            }
1233        }
1234
1235        // Create a failing worker
1236        struct FailingWorker;
1237
1238        #[async_trait]
1239        impl Worker for FailingWorker {
1240            fn id(&self) -> &str {
1241                "failing-worker"
1242            }
1243
1244            async fn process(
1245                &self,
1246                _message: ReceivedMessage<serde_json::Value>,
1247            ) -> WorkerResult<()> {
1248                Err(WorkerError::ProcessingFailed(
1249                    "Simulated failure".to_string(),
1250                ))
1251            }
1252        }
1253
1254        let mut pool = WorkerPool::new(
1255            "test-pool",
1256            LoadBalancingStrategy::RoundRobin,
1257            Arc::new(NoOpMetrics),
1258        );
1259
1260        // Add AckNackMiddleware to auto-nack on failure
1261        pool.middlewares
1262            .push(Arc::new(AckNackMiddleware::default()));
1263        pool.add_worker(Arc::new(FailingWorker));
1264
1265        let message = Message {
1266            id: "test-msg".to_string(),
1267            payload: serde_json::json!({"test": "data"}),
1268            metadata: MessageMetadata::new("test-queue"),
1269        };
1270        let received = ReceivedMessage::new(
1271            message,
1272            Arc::new(CountingAckHandle {
1273                ack_count: ack_count.clone(),
1274                nack_count: nack_count.clone(),
1275            }),
1276        );
1277
1278        pool.dispatch(received).await.unwrap();
1279
1280        // Wait for processing
1281        tokio::time::sleep(Duration::from_millis(100)).await;
1282
1283        // Verify nack was called exactly once (by middleware)
1284        assert_eq!(
1285            nack_count.load(Ordering::SeqCst),
1286            1,
1287            "Nack should be called exactly once by middleware"
1288        );
1289        assert_eq!(
1290            ack_count.load(Ordering::SeqCst),
1291            0,
1292            "Ack should not be called for failed message"
1293        );
1294    }
1295}