1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::ReceivedMessage;
11use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
12use crate::metrics::WorkerMetrics;
13use crate::strategies::{LoadBalancingStrategy, LeastLoadedBalancer, RandomBalancer, RoundRobinBalancer};
14use crate::worker::Worker;
15
16pub struct WorkerPool {
33 name: String,
34 workers: Vec<Arc<dyn Worker>>,
35 strategy: LoadBalancingStrategy,
36 semaphore: Arc<Semaphore>,
37 concurrency_limit: usize,
38 least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
39 round_robin_balancer: Arc<RoundRobinBalancer>,
40 random_balancer: RandomBalancer,
41 middlewares: Vec<Arc<dyn Middleware>>,
43 metrics_collector: Arc<dyn WorkerMetrics>,
45 is_running: Arc<AtomicBool>,
47 cancellation_token: CancellationToken,
49 task_completion_notify: Arc<Notify>,
51 in_flight_tasks: Arc<AtomicUsize>,
53}
54
55impl std::fmt::Debug for WorkerPool {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("WorkerPool")
58 .field("name", &self.name)
59 .field("worker_count", &self.workers.len())
60 .field("strategy", &self.strategy)
61 .field("is_running", &self.is_running.load(Ordering::SeqCst))
62 .finish()
63 }
64}
65
66impl WorkerPool {
67 pub fn new(
69 name: impl Into<String>,
70 strategy: LoadBalancingStrategy,
71 metrics_collector: Arc<dyn WorkerMetrics>,
72 ) -> Self {
73 Self::with_concurrency(name, strategy, 1000, metrics_collector)
74 }
75
76 pub fn with_concurrency(
84 name: impl Into<String>,
85 strategy: LoadBalancingStrategy,
86 concurrency_limit: usize,
87 metrics_collector: Arc<dyn WorkerMetrics>,
88 ) -> Self {
89 let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
90 Some(Arc::new(LeastLoadedBalancer::new(0)))
91 } else {
92 None
93 };
94
95 Self {
96 name: name.into(),
97 workers: Vec::new(),
98 strategy,
99 semaphore: Arc::new(Semaphore::new(concurrency_limit)),
100 concurrency_limit,
101 least_loaded_balancer,
102 round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
103 random_balancer: RandomBalancer,
104 middlewares: Vec::new(),
105 metrics_collector,
106 is_running: Arc::new(AtomicBool::new(true)),
107 cancellation_token: CancellationToken::new(),
108 task_completion_notify: Arc::new(Notify::new()),
109 in_flight_tasks: Arc::new(AtomicUsize::new(0)),
110 }
111 }
112
113 pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
115 self.workers.push(worker);
116
117 if let Some(ref balancer) = self.least_loaded_balancer {
119 balancer.add_worker();
120 }
121 self.metrics_collector.record_active_workers(self.workers.len());
122 }
123
124 pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
126 for worker in workers {
127 self.add_worker(worker);
128 }
129 }
130
131 pub fn worker_count(&self) -> usize {
133 self.workers.len()
134 }
135
136 pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
141 self.middlewares = middlewares;
142 self
143 }
144
145 pub fn name(&self) -> &str {
147 &self.name
148 }
149
150 pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
155 if !self.is_running.load(Ordering::SeqCst) {
156 return Err(WorkerError::Shutdown);
157 }
158 if self.workers.is_empty() {
159 return Err(WorkerError::PoolExhausted);
160 }
161
162 let worker_index = self.select_worker();
164 let worker = self.workers[worker_index].clone();
165 let worker_id = worker.id().to_string();
166 let queue_name = message.message.metadata.source.clone();
167
168 self.metrics_collector.record_message_received(&worker_id, &queue_name);
169 let start_time = Instant::now();
170
171 if let Some(ref balancer) = self.least_loaded_balancer {
172 balancer.increment_load(worker_index);
173 }
174
175 let permit = self.semaphore.clone().acquire_owned().await
176 .map_err(|_| WorkerError::Shutdown)?;
177
178 self.metrics_collector.record_in_flight_messages(self.semaphore.available_permits());
179
180 let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
182 let worker_handler = WorkerHandler(worker);
183 let boxed_middlewares: Vec<Box<dyn Middleware>> = self.middlewares.iter()
184 .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
185 .collect();
186
187 let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
188 Arc::new(ArcHandlerWrapper(chain.build()))
189 } else {
190 Arc::new(WorkerHandler(worker))
191 };
192
193 let metrics_collector_clone = self.metrics_collector.clone();
194 let least_loaded_balancer = self.least_loaded_balancer.clone();
195 let cancellation_token = self.cancellation_token.child_token();
196 let task_completion_notify = self.task_completion_notify.clone();
197 let in_flight_tasks = self.in_flight_tasks.clone();
198
199 in_flight_tasks.fetch_add(1, Ordering::SeqCst);
201
202 let ack_handle = message.ack_handle.clone();
204 let message_id = message.message.id.clone();
205 let attempt = message.message.metadata.attempt;
206
207 tokio::spawn(async move {
208 let result = tokio::select! {
210 result = handler.handle(message) => result, _ = cancellation_token.cancelled() => {
212 tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
213 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
215 task_completion_notify.notify_one();
216 return;
217 }
218 };
219
220 match result {
221 Ok(_) => {
222 tracing::debug!("Message {} processed successfully", message_id);
223 metrics_collector_clone.record_message_processed(&worker_id, &queue_name, start_time);
224 if let Err(e) = retry_ack(&ack_handle, &message_id).await {
226 tracing::error!("Failed to ack message {} after retries: {}. Message may be redelivered.", message_id, e);
227 }
229 }
230 Err(WorkerError::RetryableFailure { source, delay_ms }) => {
231 tracing::warn!(
232 "Message {} failed (will retry in {:?}): {}",
233 message_id,
234 delay_ms,
235 source
236 );
237 metrics_collector_clone.record_message_retried(&worker_id, &queue_name, attempt);
238 metrics_collector_clone.record_message_failed(&worker_id, &queue_name, "RetryableFailure", start_time);
239 if let Err(e) = ack_handle.nack(true).await {
240 tracing::error!("Failed to requeue message {}: {}", message_id, e);
241 }
242 sleep(delay_ms).await;
243 }
244 Err(WorkerError::RetriesExhausted { source }) => {
245 tracing::error!(
246 "Message {} exhausted all retries, sending to DLQ: {}",
247 message_id,
248 source
249 );
250 metrics_collector_clone.record_message_retries_exhausted(&worker_id, &queue_name);
251 metrics_collector_clone.record_message_failed(&worker_id, &queue_name, "RetriesExhausted", start_time);
252 if let Err(e) = ack_handle.nack(false).await {
253 tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
254 }
255 }
256 Err(e) => {
257 if matches!(e, WorkerError::AlreadyAcknowledged) {
259 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
261 task_completion_notify.notify_one();
262 return;
263 }
264
265 let error_type = format!("{:?}", e);
266 tracing::error!("Message {} failed: {}", message_id, e);
267 metrics_collector_clone.record_message_failed(&worker_id, &queue_name, &error_type, start_time);
268 if let Err(nack_err) = ack_handle.nack(false).await {
269 tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
270 }
271 }
272 }
273
274 if let Some(ref balancer) = least_loaded_balancer {
276 balancer.decrement_load(worker_index);
277 }
278
279 drop(permit);
280
281 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
283 task_completion_notify.notify_one();
284 });
285
286 Ok(())
287 }
288
289 fn select_worker(&self) -> usize {
291 match self.strategy {
292 LoadBalancingStrategy::RoundRobin => {
293 self.round_robin_balancer.next(self.workers.len())
294 }
295 LoadBalancingStrategy::Random => {
296 self.random_balancer.next(self.workers.len())
297 }
298 LoadBalancingStrategy::LeastLoaded => {
299 if let Some(ref balancer) = self.least_loaded_balancer {
300 balancer.next()
301 } else {
302 0 }
304 }
305 }
306 }
307
308 pub async fn shutdown(&self) -> WorkerResult<()> {
313 tracing::info!("Shutting down worker pool: {}", self.name);
314
315 self.is_running.store(false, Ordering::SeqCst);
316 self.metrics_collector.record_active_workers(0);
317
318 self.cancellation_token.cancel();
320 tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
321
322 self.semaphore.close();
324
325 let shutdown_timeout = Duration::from_secs(30); let start = Instant::now();
328
329 loop {
330 let available = self.semaphore.available_permits();
331 let in_flight = self.concurrency_limit.saturating_sub(available);
332
333 if in_flight == 0 {
334 break; }
336
337 if start.elapsed() >= shutdown_timeout {
338 tracing::warn!(
339 "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
340 self.name, in_flight
341 );
342 break;
343 }
344
345 tokio::select! {
347 _ = self.task_completion_notify.notified() => {
348 continue;
350 }
351 _ = tokio::time::sleep(Duration::from_millis(100)) => {
352 continue;
354 }
355 }
356 }
357
358 self.metrics_collector.record_in_flight_messages(0);
359 tracing::info!("Worker pool {} shutdown complete", self.name);
360 Ok(())
361 }
362
363 pub fn in_flight_count(&self) -> usize {
365 self.in_flight_tasks.load(Ordering::SeqCst)
366 }
367}
368
369impl HealthCheck for WorkerPool {
370 fn check_health(&self) -> HealthStatus {
371 let is_running = self.is_running.load(Ordering::SeqCst);
372 let worker_count = self.worker_count();
373 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
374
375 if !is_running {
377 return HealthStatus::Unhealthy {
378 reason: "Pool is not running".to_string()
379 };
380 }
381
382 if worker_count == 0 {
384 return HealthStatus::Degraded {
385 reason: "No workers available".to_string()
386 };
387 }
388
389 let saturation = in_flight as f64 / self.concurrency_limit as f64;
391 if saturation > 0.9 {
392 return HealthStatus::Degraded {
393 reason: format!("Pool near capacity: {} in-flight messages ({:.0}% saturation)",
394 in_flight, saturation * 100.0)
395 };
396 }
397
398 HealthStatus::Healthy
399 }
400
401 fn status_message(&self) -> String {
402 let worker_count = self.worker_count();
403 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
404 let available_permits = self.semaphore.available_permits();
405
406 match self.check_health() {
407 HealthStatus::Healthy => {
408 format!(
409 "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
410 self.name, worker_count, in_flight, available_permits
411 )
412 }
413 HealthStatus::Degraded { ref reason } => {
414 format!(
415 "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
416 self.name, reason, worker_count, in_flight
417 )
418 }
419 HealthStatus::Unhealthy { ref reason } => {
420 format!(
421 "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
422 self.name, reason, worker_count, in_flight
423 )
424 }
425 }
426 }
427}
428
429struct WorkerHandler(Arc<dyn Worker>);
431
432#[async_trait::async_trait]
433impl MessageHandler for WorkerHandler {
434 async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
435 self.0.process(message).await
436 }
437}
438
439struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
441
442#[async_trait::async_trait]
443impl Middleware for ArcMiddlewareWrapper {
444 fn name(&self) -> &str {
445 self.0.name()
446 }
447
448 async fn handle(
449 &self,
450 message: ReceivedMessage<serde_json::Value>,
451 next: Box<dyn MessageHandler>,
452 ) -> WorkerResult<()> {
453 self.0.handle(message, next).await
454 }
455}
456
457struct ArcHandlerWrapper(Box<dyn MessageHandler>);
459
460#[async_trait::async_trait]
461impl MessageHandler for ArcHandlerWrapper {
462 async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
463 self.0.handle(message).await
464 }
465}
466
467async fn retry_ack(ack_handle: &Arc<dyn crate::message::AckHandle>, message_id: &str) -> WorkerResult<()> {
471 let max_retries = 3;
472 let base_delay_ms = 100;
473
474 for attempt in 0..max_retries {
475 match ack_handle.ack().await {
476 Ok(_) => return Ok(()),
477 Err(e) => {
478 if attempt < max_retries - 1 {
479 let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
480 tracing::warn!(
481 "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
482 attempt + 1,
483 message_id,
484 e,
485 delay
486 );
487 sleep(delay).await;
488 } else {
489 return Err(e);
490 }
491 }
492 }
493 }
494
495 unreachable!()
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::message::{Message, MessageMetadata, AckHandle};
502 use async_trait::async_trait;
503 use std::sync::atomic::{AtomicUsize, Ordering};
504 use std::time::Duration;
505 use crate::metrics::NoOpMetrics; #[derive(Debug)]
508 struct MockAckHandle;
509
510 #[async_trait]
511 impl AckHandle for MockAckHandle {
512 async fn ack(&self) -> WorkerResult<()> {
513 Ok(())
514 }
515
516 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
517 Ok(())
518 }
519 }
520
521 struct TestWorker {
522 id: String,
523 process_count: Arc<AtomicUsize>,
524 }
525
526 impl TestWorker {
527 fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
528 let count = Arc::new(AtomicUsize::new(0));
529 (
530 Self {
531 id: id.to_string(),
532 process_count: count.clone(),
533 },
534 count,
535 )
536 }
537 }
538
539 #[async_trait]
540 impl Worker for TestWorker {
541 fn id(&self) -> &str {
542 &self.id
543 }
544
545 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
546 self.process_count.fetch_add(1, Ordering::SeqCst);
547 Ok(())
548 }
549 }
550
551 fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
552 let message = Message {
553 id: id.to_string(),
554 payload: serde_json::json!({"test": "data"}),
555 metadata: MessageMetadata::new("test-queue"),
556 };
557 ReceivedMessage::new(message, Arc::new(MockAckHandle))
558 }
559
560 #[tokio::test]
561 async fn test_pool_creation() {
562 let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
563 assert_eq!(pool.name(), "test-pool");
564 assert_eq!(pool.worker_count(), 0);
565 assert!(pool.is_running.load(Ordering::SeqCst));
566 }
567
568 #[tokio::test]
569 async fn test_add_worker() {
570 let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
571 let (worker, _) = TestWorker::new("worker-1");
572 pool.add_worker(Arc::new(worker));
573
574 assert_eq!(pool.worker_count(), 1);
575 }
576
577 #[tokio::test]
578 async fn test_dispatch_empty_pool() {
579 let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
580 let message = create_test_message("msg-1");
581
582 let result = pool.dispatch(message).await;
583 assert!(matches!(result, Err(WorkerError::PoolExhausted)));
584 }
585
586 #[tokio::test]
587 async fn test_round_robin_distribution() {
588 let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
589
590 let (worker1, count1) = TestWorker::new("worker-1");
591 let (worker2, count2) = TestWorker::new("worker-2");
592
593 pool.add_worker(Arc::new(worker1));
594 pool.add_worker(Arc::new(worker2));
595
596 for i in 0..4 {
598 let message = create_test_message(&format!("msg-{}", i));
599 pool.dispatch(message).await.unwrap();
600 }
601
602 tokio::time::sleep(Duration::from_millis(100)).await;
604
605 assert_eq!(count1.load(Ordering::SeqCst), 2);
607 assert_eq!(count2.load(Ordering::SeqCst), 2);
608 }
609
610 #[tokio::test]
611 async fn test_pool_health() {
612 let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
613 assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); let mut pool = pool;
616 let (worker, _) = TestWorker::new("worker-1");
617 pool.add_worker(Arc::new(worker));
618
619 assert!(matches!(pool.check_health(), HealthStatus::Healthy));
620 }
621
622 #[tokio::test]
623 async fn test_concurrency_limit_enforcement() {
624 use std::sync::atomic::{AtomicUsize, Ordering};
625
626 let concurrent_count = Arc::new(AtomicUsize::new(0));
628 let max_concurrent = Arc::new(AtomicUsize::new(0));
629
630 struct ConcurrentTestWorker {
631 id: String,
632 concurrent: Arc<AtomicUsize>,
633 max_concurrent: Arc<AtomicUsize>,
634 }
635
636 #[async_trait::async_trait]
637 impl Worker for ConcurrentTestWorker {
638 fn id(&self) -> &str {
639 &self.id
640 }
641
642 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
643 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
645
646 let mut max = self.max_concurrent.load(Ordering::SeqCst);
648 while current > max {
649 match self.max_concurrent.compare_exchange_weak(
650 max,
651 current,
652 Ordering::SeqCst,
653 Ordering::SeqCst,
654 ) {
655 Ok(_) => break,
656 Err(new_max) => max = new_max,
657 }
658 }
659
660 tokio::time::sleep(Duration::from_millis(50)).await;
662
663 self.concurrent.fetch_sub(1, Ordering::SeqCst);
665 Ok(())
666 }
667 }
668
669 let mut pool = WorkerPool::with_concurrency(
671 "test-pool",
672 LoadBalancingStrategy::RoundRobin,
673 3, Arc::new(NoOpMetrics),
675 );
676
677 let worker = ConcurrentTestWorker {
679 id: "worker-1".to_string(),
680 concurrent: concurrent_count.clone(),
681 max_concurrent: max_concurrent.clone(),
682 };
683 pool.add_worker(Arc::new(worker));
684
685 for i in 0..10 {
687 let message = create_test_message(&format!("msg-{}", i));
688 pool.dispatch(message).await.unwrap();
689 }
690
691 tokio::time::sleep(Duration::from_millis(500)).await;
693
694 let actual_max = max_concurrent.load(Ordering::SeqCst);
696 assert!(
697 actual_max <= 3,
698 "Expected max concurrency <= 3, but got {}",
699 actual_max
700 );
701 assert!(
702 actual_max >= 2,
703 "Expected some concurrency (>= 2), but got {}",
704 actual_max
705 );
706 }
707
708 #[tokio::test]
709 async fn test_concurrency_limit_with_builder() {
710 use crate::builder::WorkerPoolBuilder;
711
712 let concurrent_count = Arc::new(AtomicUsize::new(0));
713 let max_concurrent = Arc::new(AtomicUsize::new(0));
714
715 struct TrackedWorker {
716 id: String,
717 concurrent: Arc<AtomicUsize>,
718 max_concurrent: Arc<AtomicUsize>,
719 }
720
721 #[async_trait::async_trait]
722 impl Worker for TrackedWorker {
723 fn id(&self) -> &str {
724 &self.id
725 }
726
727 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
728 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
730
731 let mut max = self.max_concurrent.load(Ordering::SeqCst);
732 while current > max {
733 match self.max_concurrent.compare_exchange_weak(
734 max,
735 current,
736 Ordering::SeqCst,
737 Ordering::SeqCst,
738 ) {
739 Ok(_) => break,
740 Err(new_max) => max = new_max,
741 }
742 }
743
744 tokio::time::sleep(Duration::from_millis(100)).await;
745
746 self.concurrent.fetch_sub(1, Ordering::SeqCst);
747 Ok(())
748 }
749 }
750
751 let pool = WorkerPoolBuilder::new("test-pool")
753 .with_concurrency_limit(2)
754 .add_worker(TrackedWorker {
755 id: "worker-1".to_string(),
756 concurrent: concurrent_count.clone(),
757 max_concurrent: max_concurrent.clone(),
758 })
759 .build()
760 .unwrap();
761
762 for i in 0..6 {
764 let message = create_test_message(&format!("msg-{}", i));
765 pool.dispatch(message).await.unwrap();
766 }
767
768 tokio::time::sleep(Duration::from_millis(800)).await;
770
771 let actual_max = max_concurrent.load(Ordering::SeqCst);
773 assert!(
774 actual_max <= 2,
775 "Expected max concurrency <= 2, but got {}",
776 actual_max
777 );
778 assert!(
779 actual_max >= 1,
780 "Expected some concurrency (>= 1), but got {}",
781 actual_max
782 );
783 }
784
785 #[tokio::test]
786 async fn test_different_concurrency_limits() {
787 let pool1 = WorkerPool::with_concurrency("pool1", LoadBalancingStrategy::RoundRobin, 5, Arc::new(NoOpMetrics));
789 let pool2 = WorkerPool::with_concurrency("pool2", LoadBalancingStrategy::RoundRobin, 20, Arc::new(NoOpMetrics));
790
791 assert_eq!(pool1.name(), "pool1");
794 assert_eq!(pool2.name(), "pool2");
795 }
796
797 #[tokio::test]
798 async fn test_pool_shutdown_prevents_dispatch() {
799 let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
800 let (worker, _) = TestWorker::new("worker-1");
801 pool.add_worker(Arc::new(worker));
802
803 pool.shutdown().await.unwrap();
804
805 let message = create_test_message("msg-after-shutdown");
806 let result = pool.dispatch(message).await;
807 assert!(matches!(result, Err(WorkerError::Shutdown)));
808 assert!(matches!(pool.check_health(), HealthStatus::Unhealthy { .. }));
809 }
810}