Skip to main content

foxtive_worker/
pool.rs

1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::ReceivedMessage;
11use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
12use crate::metrics::WorkerMetrics;
13use crate::strategies::{LoadBalancingStrategy, LeastLoadedBalancer, RandomBalancer, RoundRobinBalancer};
14use crate::worker::Worker;
15
16/// A pool of workers with load balancing for message distribution.
17///
18/// The pool manages multiple worker instances and distributes incoming messages
19/// based on the configured load balancing strategy.
20///
21/// # Example
22/// ```rust,no_run
23/// use foxtive_worker::pool::WorkerPool;
24/// use foxtive_worker::strategies::LoadBalancingStrategy;
25/// use foxtive_worker::metrics::NoOpMetrics;
26/// use std::sync::Arc;
27///
28/// let pool = WorkerPool::new("my-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
29/// // Add workers...
30/// // Dispatch messages...
31/// ```
32pub struct WorkerPool {
33    name: String,
34    workers: Vec<Arc<dyn Worker>>,
35    strategy: LoadBalancingStrategy,
36    semaphore: Arc<Semaphore>,
37    concurrency_limit: usize,
38    least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
39    round_robin_balancer: Arc<RoundRobinBalancer>,
40    random_balancer: RandomBalancer,
41    /// Middleware list (Arc-wrapped for cloning)
42    middlewares: Vec<Arc<dyn Middleware>>,
43    /// Metrics collector for this pool
44    metrics_collector: Arc<dyn WorkerMetrics>,
45    /// Indicates if the worker pool is currently running.
46    is_running: Arc<AtomicBool>,
47    /// Cancellation token for graceful shutdown of all spawned tasks
48    cancellation_token: CancellationToken,
49    /// Notify for task completion signaling during shutdown
50    task_completion_notify: Arc<Notify>,
51    /// Track number of in-flight tasks for monitoring
52    in_flight_tasks: Arc<AtomicUsize>,
53}
54
55impl std::fmt::Debug for WorkerPool {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("WorkerPool")
58            .field("name", &self.name)
59            .field("worker_count", &self.workers.len())
60            .field("strategy", &self.strategy)
61            .field("is_running", &self.is_running.load(Ordering::SeqCst))
62            .finish()
63    }
64}
65
66impl WorkerPool {
67    /// Create a new worker pool with the given name, load balancing strategy, and metrics collector.
68    pub fn new(
69        name: impl Into<String>,
70        strategy: LoadBalancingStrategy,
71        metrics_collector: Arc<dyn WorkerMetrics>,
72    ) -> Self {
73        Self::with_concurrency(name, strategy, 1000, metrics_collector)
74    }
75
76    /// Create a new worker pool with custom concurrency limit.
77    ///
78    /// # Arguments
79    /// * `name` - Pool name
80    /// * `strategy` - Load balancing strategy
81    /// * `concurrency_limit` - Maximum concurrent messages across all workers
82    /// * `metrics_collector` - Metrics collector implementation
83    pub fn with_concurrency(
84        name: impl Into<String>,
85        strategy: LoadBalancingStrategy,
86        concurrency_limit: usize,
87        metrics_collector: Arc<dyn WorkerMetrics>,
88    ) -> Self {
89        let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
90            Some(Arc::new(LeastLoadedBalancer::new(0)))
91        } else {
92            None
93        };
94
95        Self {
96            name: name.into(),
97            workers: Vec::new(),
98            strategy,
99            semaphore: Arc::new(Semaphore::new(concurrency_limit)),
100            concurrency_limit,
101            least_loaded_balancer,
102            round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
103            random_balancer: RandomBalancer,
104            middlewares: Vec::new(),
105            metrics_collector,
106            is_running: Arc::new(AtomicBool::new(true)),
107            cancellation_token: CancellationToken::new(),
108            task_completion_notify: Arc::new(Notify::new()),
109            in_flight_tasks: Arc::new(AtomicUsize::new(0)),
110        }
111    }
112
113    /// Add a worker to the pool.
114    pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
115        self.workers.push(worker);
116
117        // Update least-loaded balancer if needed - now O(1) with atomic swap
118        if let Some(ref balancer) = self.least_loaded_balancer {
119            balancer.add_worker();
120        }
121        self.metrics_collector.record_active_workers(self.workers.len());
122    }
123
124    /// Add multiple workers to the pool.
125    pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
126        for worker in workers {
127            self.add_worker(worker);
128        }
129    }
130
131    /// Get the number of workers in the pool.
132    pub fn worker_count(&self) -> usize {
133        self.workers.len()
134    }
135
136    /// Set middleware for the pool.
137    ///
138    /// This allows you to inject middleware processing into the message flow.
139    /// Middleware will be executed in order before messages reaches workers.
140    pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
141        self.middlewares = middlewares;
142        self
143    }
144
145    /// Get the pool name.
146    pub fn name(&self) -> &str {
147        &self.name
148    }
149
150    /// Dispatch a message to a worker based on the load balancing strategy.
151    ///
152    /// Spawns an async task to process the message, respecting concurrency limits.
153    /// If middleware is configured, the message flows through the chain before reaching the worker.
154    pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
155        if !self.is_running.load(Ordering::SeqCst) {
156            return Err(WorkerError::Shutdown);
157        }
158        if self.workers.is_empty() {
159            return Err(WorkerError::PoolExhausted);
160        }
161
162        // Pick a worker and track metrics
163        let worker_index = self.select_worker();
164        let worker = self.workers[worker_index].clone();
165        let worker_id = worker.id().to_string();
166        let queue_name = message.message.metadata.source.clone();
167
168        self.metrics_collector.record_message_received(&worker_id, &queue_name);
169        let start_time = Instant::now();
170
171        if let Some(ref balancer) = self.least_loaded_balancer {
172            balancer.increment_load(worker_index);
173        }
174
175        let permit = self.semaphore.clone().acquire_owned().await
176            .map_err(|_| WorkerError::Shutdown)?;
177
178        self.metrics_collector.record_in_flight_messages(self.semaphore.available_permits());
179
180        // Build the handler chain - wrap worker with middleware if configured
181        let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
182            let worker_handler = WorkerHandler(worker);
183            let boxed_middlewares: Vec<Box<dyn Middleware>> = self.middlewares.iter()
184                .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
185                .collect();
186
187            let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
188            Arc::new(ArcHandlerWrapper(chain.build()))
189        } else {
190            Arc::new(WorkerHandler(worker))
191        };
192
193        let metrics_collector_clone = self.metrics_collector.clone();
194        let least_loaded_balancer = self.least_loaded_balancer.clone();
195        let cancellation_token = self.cancellation_token.child_token();
196        let task_completion_notify = self.task_completion_notify.clone();
197        let in_flight_tasks = self.in_flight_tasks.clone();
198
199        // Track in-flight task count
200        in_flight_tasks.fetch_add(1, Ordering::SeqCst);
201
202        // Extract ack_handle and metadata before moving message into task
203        let ack_handle = message.ack_handle.clone();
204        let message_id = message.message.id.clone();
205        let attempt = message.message.metadata.attempt;
206
207        tokio::spawn(async move {
208            // Use tokio::select! to handle cancellation
209            let result = tokio::select! {
210                result = handler.handle(message) => result,  // No clone - move message into task
211                _ = cancellation_token.cancelled() => {
212                    tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
213                    // Decrement in-flight counter on cancellation
214                    in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
215                    task_completion_notify.notify_one();
216                    return;
217                }
218            };
219            
220            match result {
221                Ok(_) => {
222                    tracing::debug!("Message {} processed successfully", message_id);
223                    metrics_collector_clone.record_message_processed(&worker_id, &queue_name, start_time);
224                    // Retry ack with exponential backoff on failure
225                    if let Err(e) = retry_ack(&ack_handle, &message_id).await {
226                        tracing::error!("Failed to ack message {} after retries: {}. Message may be redelivered.", message_id, e);
227                        // Consider sending to DLQ or implementing idempotency at application level
228                    }
229                }
230                Err(WorkerError::RetryableFailure { source, delay_ms }) => {
231                    tracing::warn!(
232                        "Message {} failed (will retry in {:?}): {}",
233                        message_id,
234                        delay_ms,
235                        source
236                    );
237                    metrics_collector_clone.record_message_retried(&worker_id, &queue_name, attempt);
238                    metrics_collector_clone.record_message_failed(&worker_id, &queue_name, "RetryableFailure", start_time);
239                    if let Err(e) = ack_handle.nack(true).await {
240                        tracing::error!("Failed to requeue message {}: {}", message_id, e);
241                    }
242                    sleep(delay_ms).await;
243                }
244                Err(WorkerError::RetriesExhausted { source }) => {
245                    tracing::error!(
246                        "Message {} exhausted all retries, sending to DLQ: {}",
247                        message_id,
248                        source
249                    );
250                    metrics_collector_clone.record_message_retries_exhausted(&worker_id, &queue_name);
251                    metrics_collector_clone.record_message_failed(&worker_id, &queue_name, "RetriesExhausted", start_time);
252                    if let Err(e) = ack_handle.nack(false).await {
253                        tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
254                    }
255                }
256                Err(e) => {
257                    // Skip nack if middleware already handled it
258                    if matches!(e, WorkerError::AlreadyAcknowledged) {
259                        // Decrement in-flight counter
260                        in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
261                        task_completion_notify.notify_one();
262                        return;
263                    }
264                    
265                    let error_type = format!("{:?}", e);
266                    tracing::error!("Message {} failed: {}", message_id, e);
267                    metrics_collector_clone.record_message_failed(&worker_id, &queue_name, &error_type, start_time);
268                    if let Err(nack_err) = ack_handle.nack(false).await {
269                        tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
270                    }
271                }
272            }
273            
274            // Decrement load for least-loaded balancer after processing completes
275            if let Some(ref balancer) = least_loaded_balancer {
276                balancer.decrement_load(worker_index);
277            }
278            
279            drop(permit);
280            
281            // Signal task completion and decrement counter
282            in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
283            task_completion_notify.notify_one();
284        });
285
286        Ok(())
287    }
288
289    /// Select a worker based on the configured load balancing strategy.
290    fn select_worker(&self) -> usize {
291        match self.strategy {
292            LoadBalancingStrategy::RoundRobin => {
293                self.round_robin_balancer.next(self.workers.len())
294            }
295            LoadBalancingStrategy::Random => {
296                self.random_balancer.next(self.workers.len())
297            }
298            LoadBalancingStrategy::LeastLoaded => {
299                if let Some(ref balancer) = self.least_loaded_balancer {
300                    balancer.next()
301                } else {
302                    0 // Fallback
303                }
304            }
305        }
306    }
307
308    /// Shutdown the pool gracefully.
309    ///
310    /// This prevents new messages from being dispatched, cancels in-flight tasks,
311    /// and waits for completion with a timeout using efficient notification.
312    pub async fn shutdown(&self) -> WorkerResult<()> {
313        tracing::info!("Shutting down worker pool: {}", self.name);
314
315        self.is_running.store(false, Ordering::SeqCst);
316        self.metrics_collector.record_active_workers(0);
317
318        // Cancel all in-flight tasks
319        self.cancellation_token.cancel();
320        tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
321
322        // Close the semaphore to prevent new acquisitions
323        self.semaphore.close();
324        
325        // Wait for permits to be returned with efficient notification
326        let shutdown_timeout = Duration::from_secs(30); // 30 second timeout
327        let start = Instant::now();
328        
329        loop {
330            let available = self.semaphore.available_permits();
331            let in_flight = self.concurrency_limit.saturating_sub(available);
332            
333            if in_flight == 0 {
334                break; // All tasks completed
335            }
336            
337            if start.elapsed() >= shutdown_timeout {
338                tracing::warn!(
339                    "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
340                    self.name, in_flight
341                );
342                break;
343            }
344            
345            // Wait efficiently for task completion notification instead of busy-wait
346            tokio::select! {
347                _ = self.task_completion_notify.notified() => {
348                    // A task completed, check again
349                    continue;
350                }
351                _ = tokio::time::sleep(Duration::from_millis(100)) => {
352                    // Periodic check in case notifications are missed
353                    continue;
354                }
355            }
356        }
357        
358        self.metrics_collector.record_in_flight_messages(0);
359        tracing::info!("Worker pool {} shutdown complete", self.name);
360        Ok(())
361    }
362    
363    /// Get the current number of in-flight tasks.
364    pub fn in_flight_count(&self) -> usize {
365        self.in_flight_tasks.load(Ordering::SeqCst)
366    }
367}
368
369impl HealthCheck for WorkerPool {
370    fn check_health(&self) -> HealthStatus {
371        let is_running = self.is_running.load(Ordering::SeqCst);
372        let worker_count = self.worker_count();
373        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
374        
375        // Check if pool is running
376        if !is_running {
377            return HealthStatus::Unhealthy { 
378                reason: "Pool is not running".to_string() 
379            };
380        }
381        
382        // Check if workers are available
383        if worker_count == 0 {
384            return HealthStatus::Degraded { 
385                reason: "No workers available".to_string() 
386            };
387        }
388        
389        // Check if pool is saturated (90%+ capacity)
390        let saturation = in_flight as f64 / self.concurrency_limit as f64;
391        if saturation > 0.9 {
392            return HealthStatus::Degraded { 
393                reason: format!("Pool near capacity: {} in-flight messages ({:.0}% saturation)", 
394                    in_flight, saturation * 100.0)
395            };
396        }
397        
398        HealthStatus::Healthy
399    }
400
401    fn status_message(&self) -> String {
402        let worker_count = self.worker_count();
403        let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
404        let available_permits = self.semaphore.available_permits();
405        
406        match self.check_health() {
407            HealthStatus::Healthy => {
408                format!(
409                    "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
410                    self.name, worker_count, in_flight, available_permits
411                )
412            }
413            HealthStatus::Degraded { ref reason } => {
414                format!(
415                    "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
416                    self.name, reason, worker_count, in_flight
417                )
418            }
419            HealthStatus::Unhealthy { ref reason } => {
420                format!(
421                    "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
422                    self.name, reason, worker_count, in_flight
423                )
424            }
425        }
426    }
427}
428
429/// Wrapper that converts a Worker into a MessageHandler.
430struct WorkerHandler(Arc<dyn Worker>);
431
432#[async_trait::async_trait]
433impl MessageHandler for WorkerHandler {
434    async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
435        self.0.process(message).await
436    }
437}
438
439/// Wrapper to convert Arc<dyn Middleware> to Box<dyn Middleware>
440struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
441
442#[async_trait::async_trait]
443impl Middleware for ArcMiddlewareWrapper {
444    fn name(&self) -> &str {
445        self.0.name()
446    }
447
448    async fn handle(
449        &self,
450        message: ReceivedMessage<serde_json::Value>,
451        next: Box<dyn MessageHandler>,
452    ) -> WorkerResult<()> {
453        self.0.handle(message, next).await
454    }
455}
456
457/// Wrapper to convert Box<dyn MessageHandler> to Arc
458struct ArcHandlerWrapper(Box<dyn MessageHandler>);
459
460#[async_trait::async_trait]
461impl MessageHandler for ArcHandlerWrapper {
462    async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
463        self.0.handle(message).await
464    }
465}
466
467/// Retry ack with exponential backoff on failure.
468///
469/// This helps handle transient network issues or broker unavailability.
470async fn retry_ack(ack_handle: &Arc<dyn crate::message::AckHandle>, message_id: &str) -> WorkerResult<()> {
471    let max_retries = 3;
472    let base_delay_ms = 100;
473    
474    for attempt in 0..max_retries {
475        match ack_handle.ack().await {
476            Ok(_) => return Ok(()),
477            Err(e) => {
478                if attempt < max_retries - 1 {
479                    let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
480                    tracing::warn!(
481                        "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
482                        attempt + 1,
483                        message_id,
484                        e,
485                        delay
486                    );
487                    sleep(delay).await;
488                } else {
489                    return Err(e);
490                }
491            }
492        }
493    }
494    
495    unreachable!()
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use crate::message::{Message, MessageMetadata, AckHandle};
502    use async_trait::async_trait;
503    use std::sync::atomic::{AtomicUsize, Ordering};
504    use std::time::Duration;
505    use crate::metrics::NoOpMetrics; // Use NoOpMetrics for tests
506
507    #[derive(Debug)]
508    struct MockAckHandle;
509
510    #[async_trait]
511    impl AckHandle for MockAckHandle {
512        async fn ack(&self) -> WorkerResult<()> {
513            Ok(())
514        }
515
516        async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
517            Ok(())
518        }
519    }
520
521    struct TestWorker {
522        id: String,
523        process_count: Arc<AtomicUsize>,
524    }
525
526    impl TestWorker {
527        fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
528            let count = Arc::new(AtomicUsize::new(0));
529            (
530                Self {
531                    id: id.to_string(),
532                    process_count: count.clone(),
533                },
534                count,
535            )
536        }
537    }
538
539    #[async_trait]
540    impl Worker for TestWorker {
541        fn id(&self) -> &str {
542            &self.id
543        }
544
545        async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
546            self.process_count.fetch_add(1, Ordering::SeqCst);
547            Ok(())
548        }
549    }
550
551    fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
552        let message = Message {
553            id: id.to_string(),
554            payload: serde_json::json!({"test": "data"}),
555            metadata: MessageMetadata::new("test-queue"),
556        };
557        ReceivedMessage::new(message, Arc::new(MockAckHandle))
558    }
559
560    #[tokio::test]
561    async fn test_pool_creation() {
562        let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
563        assert_eq!(pool.name(), "test-pool");
564        assert_eq!(pool.worker_count(), 0);
565        assert!(pool.is_running.load(Ordering::SeqCst));
566    }
567
568    #[tokio::test]
569    async fn test_add_worker() {
570        let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
571        let (worker, _) = TestWorker::new("worker-1");
572        pool.add_worker(Arc::new(worker));
573        
574        assert_eq!(pool.worker_count(), 1);
575    }
576
577    #[tokio::test]
578    async fn test_dispatch_empty_pool() {
579        let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
580        let message = create_test_message("msg-1");
581        
582        let result = pool.dispatch(message).await;
583        assert!(matches!(result, Err(WorkerError::PoolExhausted)));
584    }
585
586    #[tokio::test]
587    async fn test_round_robin_distribution() {
588        let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
589        
590        let (worker1, count1) = TestWorker::new("worker-1");
591        let (worker2, count2) = TestWorker::new("worker-2");
592        
593        pool.add_worker(Arc::new(worker1));
594        pool.add_worker(Arc::new(worker2));
595        
596        // Dispatch 4 messages
597        for i in 0..4 {
598            let message = create_test_message(&format!("msg-{}", i));
599            pool.dispatch(message).await.unwrap();
600        }
601        
602        // Give tasks time to complete
603        tokio::time::sleep(Duration::from_millis(100)).await;
604        
605        // Each worker should have processed 2 messages
606        assert_eq!(count1.load(Ordering::SeqCst), 2);
607        assert_eq!(count2.load(Ordering::SeqCst), 2);
608    }
609
610    #[tokio::test]
611    async fn test_pool_health() {
612        let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
613        assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); // Degraded because 0 workers
614        
615        let mut pool = pool;
616        let (worker, _) = TestWorker::new("worker-1");
617        pool.add_worker(Arc::new(worker));
618        
619        assert!(matches!(pool.check_health(), HealthStatus::Healthy));
620    }
621
622    #[tokio::test]
623    async fn test_concurrency_limit_enforcement() {
624        use std::sync::atomic::{AtomicUsize, Ordering};
625        
626        // Create a worker that tracks concurrent executions
627        let concurrent_count = Arc::new(AtomicUsize::new(0));
628        let max_concurrent = Arc::new(AtomicUsize::new(0));
629        
630        struct ConcurrentTestWorker {
631            id: String,
632            concurrent: Arc<AtomicUsize>,
633            max_concurrent: Arc<AtomicUsize>,
634        }
635
636        #[async_trait::async_trait]
637        impl Worker for ConcurrentTestWorker {
638            fn id(&self) -> &str {
639                &self.id
640            }
641
642            async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
643                // Increment concurrent count
644                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
645                
646                // Track maximum
647                let mut max = self.max_concurrent.load(Ordering::SeqCst);
648                while current > max {
649                    match self.max_concurrent.compare_exchange_weak(
650                        max,
651                        current,
652                        Ordering::SeqCst,
653                        Ordering::SeqCst,
654                    ) {
655                        Ok(_) => break,
656                        Err(new_max) => max = new_max,
657                    }
658                }
659                
660                // Simulate processing time
661                tokio::time::sleep(Duration::from_millis(50)).await;
662                
663                // Decrement concurrent count
664                self.concurrent.fetch_sub(1, Ordering::SeqCst);
665                Ok(())
666            }
667        }
668
669        // Create pool with concurrency limit of 3
670        let mut pool = WorkerPool::with_concurrency(
671            "test-pool",
672            LoadBalancingStrategy::RoundRobin,
673            3, // Limit to 3 concurrent
674            Arc::new(NoOpMetrics),
675        );
676        
677        // Add 1 worker (will handle all messages)
678        let worker = ConcurrentTestWorker {
679            id: "worker-1".to_string(),
680            concurrent: concurrent_count.clone(),
681            max_concurrent: max_concurrent.clone(),
682        };
683        pool.add_worker(Arc::new(worker));
684        
685        // Dispatch 10 messages rapidly
686        for i in 0..10 {
687            let message = create_test_message(&format!("msg-{}", i));
688            pool.dispatch(message).await.unwrap();
689        }
690        
691        // Wait for all to complete
692        tokio::time::sleep(Duration::from_millis(500)).await;
693        
694        // Verify that concurrency never exceeded the limit
695        let actual_max = max_concurrent.load(Ordering::SeqCst);
696        assert!(
697            actual_max <= 3,
698            "Expected max concurrency <= 3, but got {}",
699            actual_max
700        );
701        assert!(
702            actual_max >= 2,
703            "Expected some concurrency (>= 2), but got {}",
704            actual_max
705        );
706    }
707
708    #[tokio::test]
709    async fn test_concurrency_limit_with_builder() {
710        use crate::builder::WorkerPoolBuilder;
711        
712        let concurrent_count = Arc::new(AtomicUsize::new(0));
713        let max_concurrent = Arc::new(AtomicUsize::new(0));
714        
715        struct TrackedWorker {
716            id: String,
717            concurrent: Arc<AtomicUsize>,
718            max_concurrent: Arc<AtomicUsize>,
719        }
720
721        #[async_trait::async_trait]
722        impl Worker for TrackedWorker {
723            fn id(&self) -> &str {
724                &self.id
725            }
726
727            async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
728                // Track concurrency
729                let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
730                
731                let mut max = self.max_concurrent.load(Ordering::SeqCst);
732                while current > max {
733                    match self.max_concurrent.compare_exchange_weak(
734                        max,
735                        current,
736                        Ordering::SeqCst,
737                        Ordering::SeqCst,
738                    ) {
739                        Ok(_) => break,
740                        Err(new_max) => max = new_max,
741                    }
742                }
743                
744                tokio::time::sleep(Duration::from_millis(100)).await;
745                
746                self.concurrent.fetch_sub(1, Ordering::SeqCst);
747                Ok(())
748            }
749        }
750        
751        // Build pool with concurrency limit of 2
752        let pool = WorkerPoolBuilder::new("test-pool")
753            .with_concurrency_limit(2)
754            .add_worker(TrackedWorker {
755                id: "worker-1".to_string(),
756                concurrent: concurrent_count.clone(),
757                max_concurrent: max_concurrent.clone(),
758            })
759            .build()
760            .unwrap();
761        
762        // Dispatch 6 messages rapidly
763        for i in 0..6 {
764            let message = create_test_message(&format!("msg-{}", i));
765            pool.dispatch(message).await.unwrap();
766        }
767        
768        // Wait for all to complete
769        tokio::time::sleep(Duration::from_millis(800)).await;
770        
771        // Verify that concurrency never exceeded the limit of 2
772        let actual_max = max_concurrent.load(Ordering::SeqCst);
773        assert!(
774            actual_max <= 2,
775            "Expected max concurrency <= 2, but got {}",
776            actual_max
777        );
778        assert!(
779            actual_max >= 1,
780            "Expected some concurrency (>= 1), but got {}",
781            actual_max
782        );
783    }
784
785    #[tokio::test]
786    async fn test_different_concurrency_limits() {
787        // Test that different pools can have different limits
788        let pool1 = WorkerPool::with_concurrency("pool1", LoadBalancingStrategy::RoundRobin, 5, Arc::new(NoOpMetrics));
789        let pool2 = WorkerPool::with_concurrency("pool2", LoadBalancingStrategy::RoundRobin, 20, Arc::new(NoOpMetrics));
790        
791        // Verify they have different semaphore capacities
792        // We can't directly check semaphore capacity, but we can verify the pools work independently
793        assert_eq!(pool1.name(), "pool1");
794        assert_eq!(pool2.name(), "pool2");
795    }
796
797    #[tokio::test]
798    async fn test_pool_shutdown_prevents_dispatch() {
799        let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
800        let (worker, _) = TestWorker::new("worker-1");
801        pool.add_worker(Arc::new(worker));
802
803        pool.shutdown().await.unwrap();
804
805        let message = create_test_message("msg-after-shutdown");
806        let result = pool.dispatch(message).await;
807        assert!(matches!(result, Err(WorkerError::Shutdown)));
808        assert!(matches!(pool.check_health(), HealthStatus::Unhealthy { .. }));
809    }
810}