1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::ReceivedMessage;
11use crate::metrics::WorkerMetrics;
12use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
13use crate::strategies::{
14 LeastLoadedBalancer, LoadBalancingStrategy, RandomBalancer, RoundRobinBalancer,
15};
16use crate::worker::Worker;
17
18pub struct WorkerPool {
35 name: String,
36 workers: Vec<Arc<dyn Worker>>,
37 strategy: LoadBalancingStrategy,
38 semaphore: Arc<Semaphore>,
39 concurrency_limit: usize,
40 least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
41 round_robin_balancer: Arc<RoundRobinBalancer>,
42 random_balancer: RandomBalancer,
43 middlewares: Vec<Arc<dyn Middleware>>,
45 metrics_collector: Arc<dyn WorkerMetrics>,
47 is_running: Arc<AtomicBool>,
49 cancellation_token: CancellationToken,
51 task_completion_notify: Arc<Notify>,
53 in_flight_tasks: Arc<AtomicUsize>,
55}
56
57impl std::fmt::Debug for WorkerPool {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("WorkerPool")
60 .field("name", &self.name)
61 .field("worker_count", &self.workers.len())
62 .field("strategy", &self.strategy)
63 .field("is_running", &self.is_running.load(Ordering::SeqCst))
64 .finish()
65 }
66}
67
68impl WorkerPool {
69 pub fn new(
71 name: impl Into<String>,
72 strategy: LoadBalancingStrategy,
73 metrics_collector: Arc<dyn WorkerMetrics>,
74 ) -> Self {
75 Self::with_concurrency(name, strategy, 1000, metrics_collector)
76 }
77
78 pub fn with_concurrency(
86 name: impl Into<String>,
87 strategy: LoadBalancingStrategy,
88 concurrency_limit: usize,
89 metrics_collector: Arc<dyn WorkerMetrics>,
90 ) -> Self {
91 let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
92 Some(Arc::new(LeastLoadedBalancer::new(0)))
93 } else {
94 None
95 };
96
97 Self {
98 name: name.into(),
99 workers: Vec::new(),
100 strategy,
101 semaphore: Arc::new(Semaphore::new(concurrency_limit)),
102 concurrency_limit,
103 least_loaded_balancer,
104 round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
105 random_balancer: RandomBalancer,
106 middlewares: Vec::new(),
107 metrics_collector,
108 is_running: Arc::new(AtomicBool::new(true)),
109 cancellation_token: CancellationToken::new(),
110 task_completion_notify: Arc::new(Notify::new()),
111 in_flight_tasks: Arc::new(AtomicUsize::new(0)),
112 }
113 }
114
115 pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
117 self.workers.push(worker);
118
119 if let Some(ref balancer) = self.least_loaded_balancer {
121 balancer.add_worker();
122 }
123 self.metrics_collector
124 .record_active_workers(self.workers.len());
125 }
126
127 pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
129 for worker in workers {
130 self.add_worker(worker);
131 }
132 }
133
134 pub fn worker_count(&self) -> usize {
136 self.workers.len()
137 }
138
139 pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
144 self.middlewares = middlewares;
145 self
146 }
147
148 pub fn name(&self) -> &str {
150 &self.name
151 }
152
153 pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
158 if !self.is_running.load(Ordering::SeqCst) {
159 return Err(WorkerError::Shutdown);
160 }
161 if self.workers.is_empty() {
162 return Err(WorkerError::PoolExhausted);
163 }
164
165 let worker_index = self.select_worker();
167 let worker = self.workers[worker_index].clone();
168 let worker_id = worker.id().to_string();
169 let queue_name = message.message.metadata.source.clone();
170
171 self.metrics_collector
172 .record_message_received(&worker_id, &queue_name);
173 let start_time = Instant::now();
174
175 if let Some(ref balancer) = self.least_loaded_balancer {
176 balancer.increment_load(worker_index);
177 }
178
179 let permit = self
180 .semaphore
181 .clone()
182 .acquire_owned()
183 .await
184 .map_err(|_| WorkerError::Shutdown)?;
185
186 self.metrics_collector
187 .record_in_flight_messages(self.semaphore.available_permits());
188
189 let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
191 let worker_handler = WorkerHandler(worker);
192 let boxed_middlewares: Vec<Box<dyn Middleware>> = self
193 .middlewares
194 .iter()
195 .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
196 .collect();
197
198 let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
199 Arc::new(ArcHandlerWrapper(chain.build()))
200 } else {
201 Arc::new(WorkerHandler(worker))
202 };
203
204 let metrics_collector_clone = self.metrics_collector.clone();
205 let least_loaded_balancer = self.least_loaded_balancer.clone();
206 let cancellation_token = self.cancellation_token.child_token();
207 let task_completion_notify = self.task_completion_notify.clone();
208 let in_flight_tasks = self.in_flight_tasks.clone();
209
210 in_flight_tasks.fetch_add(1, Ordering::SeqCst);
212
213 let ack_handle = message.ack_handle.clone();
215 let message_id = message.message.id.clone();
216 let attempt = message.message.metadata.attempt;
217
218 tokio::spawn(async move {
219 let result = tokio::select! {
221 result = handler.handle(message) => result, _ = cancellation_token.cancelled() => {
223 tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
224 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
226 task_completion_notify.notify_one();
227 return;
228 }
229 };
230
231 match result {
232 Ok(middleware_result) => {
233 match middleware_result {
234 crate::middleware::MiddlewareResult::Acknowledged => {
235 tracing::debug!(
237 "Message {} already acknowledged by middleware",
238 message_id
239 );
240 metrics_collector_clone.record_message_processed(
241 &worker_id,
242 &queue_name,
243 start_time,
244 );
245 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
247 task_completion_notify.notify_one();
248 return;
249 }
250 crate::middleware::MiddlewareResult::Continue => {
251 tracing::debug!("Message {} processed successfully", message_id);
253 metrics_collector_clone.record_message_processed(
254 &worker_id,
255 &queue_name,
256 start_time,
257 );
258 if let Err(e) = retry_ack(&ack_handle, &message_id).await {
259 tracing::error!(
260 "Failed to ack message {} after retries: {}. Message may be redelivered.",
261 message_id,
262 e
263 );
264 }
265 }
266 }
267 }
268 Err(WorkerError::RetryableFailure { source, delay_ms }) => {
269 tracing::warn!(
270 "Message {} failed (will retry in {:?}): {}",
271 message_id,
272 delay_ms,
273 source
274 );
275 metrics_collector_clone.record_message_retried(
276 &worker_id,
277 &queue_name,
278 attempt,
279 );
280 metrics_collector_clone.record_message_failed(
281 &worker_id,
282 &queue_name,
283 "RetryableFailure",
284 start_time,
285 );
286 if let Err(e) = ack_handle.nack(true).await {
287 tracing::error!("Failed to requeue message {}: {}", message_id, e);
288 }
289 sleep(delay_ms).await;
290 }
291 Err(WorkerError::RetriesExhausted { source }) => {
292 tracing::error!(
293 "Message {} exhausted all retries, sending to DLQ: {}",
294 message_id,
295 source
296 );
297 metrics_collector_clone
298 .record_message_retries_exhausted(&worker_id, &queue_name);
299 metrics_collector_clone.record_message_failed(
300 &worker_id,
301 &queue_name,
302 "RetriesExhausted",
303 start_time,
304 );
305 if let Err(e) = ack_handle.nack(false).await {
306 tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
307 }
308 }
309 Err(e) => {
310 let error_type = format!("{:?}", e);
311 tracing::error!("Message {} failed: {}", message_id, e);
312 metrics_collector_clone.record_message_failed(
313 &worker_id,
314 &queue_name,
315 &error_type,
316 start_time,
317 );
318 if let Err(nack_err) = ack_handle.nack(false).await {
319 tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
320 }
321 }
322 }
323
324 if let Some(ref balancer) = least_loaded_balancer {
326 balancer.decrement_load(worker_index);
327 }
328
329 drop(permit);
330
331 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
333 task_completion_notify.notify_one();
334 });
335
336 Ok(())
337 }
338
339 fn select_worker(&self) -> usize {
341 match self.strategy {
342 LoadBalancingStrategy::RoundRobin => self.round_robin_balancer.next(self.workers.len()),
343 LoadBalancingStrategy::Random => self.random_balancer.next(self.workers.len()),
344 LoadBalancingStrategy::LeastLoaded => {
345 if let Some(ref balancer) = self.least_loaded_balancer {
346 balancer.next()
347 } else {
348 0 }
350 }
351 }
352 }
353
354 pub async fn shutdown(&self) -> WorkerResult<()> {
359 tracing::info!("Shutting down worker pool: {}", self.name);
360
361 self.is_running.store(false, Ordering::SeqCst);
362 self.metrics_collector.record_active_workers(0);
363
364 self.cancellation_token.cancel();
366 tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
367
368 self.semaphore.close();
370
371 let shutdown_timeout = Duration::from_secs(30); let start = Instant::now();
374
375 loop {
376 let available = self.semaphore.available_permits();
377 let in_flight = self.concurrency_limit.saturating_sub(available);
378
379 if in_flight == 0 {
380 break; }
382
383 if start.elapsed() >= shutdown_timeout {
384 tracing::warn!(
385 "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
386 self.name,
387 in_flight
388 );
389 break;
390 }
391
392 tokio::select! {
394 _ = self.task_completion_notify.notified() => {
395 continue;
397 }
398 _ = tokio::time::sleep(Duration::from_millis(100)) => {
399 continue;
401 }
402 }
403 }
404
405 self.metrics_collector.record_in_flight_messages(0);
406 tracing::info!("Worker pool {} shutdown complete", self.name);
407 Ok(())
408 }
409
410 pub fn in_flight_count(&self) -> usize {
412 self.in_flight_tasks.load(Ordering::SeqCst)
413 }
414}
415
416impl HealthCheck for WorkerPool {
417 fn check_health(&self) -> HealthStatus {
418 let is_running = self.is_running.load(Ordering::SeqCst);
419 let worker_count = self.worker_count();
420 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
421
422 if !is_running {
424 return HealthStatus::Unhealthy {
425 reason: "Pool is not running".to_string(),
426 };
427 }
428
429 if worker_count == 0 {
431 return HealthStatus::Degraded {
432 reason: "No workers available".to_string(),
433 };
434 }
435
436 let saturation = in_flight as f64 / self.concurrency_limit as f64;
438 if saturation > 0.9 {
439 return HealthStatus::Degraded {
440 reason: format!(
441 "Pool near capacity: {} in-flight messages ({:.0}% saturation)",
442 in_flight,
443 saturation * 100.0
444 ),
445 };
446 }
447
448 HealthStatus::Healthy
449 }
450
451 fn status_message(&self) -> String {
452 let worker_count = self.worker_count();
453 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
454 let available_permits = self.semaphore.available_permits();
455
456 match self.check_health() {
457 HealthStatus::Healthy => {
458 format!(
459 "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
460 self.name, worker_count, in_flight, available_permits
461 )
462 }
463 HealthStatus::Degraded { ref reason } => {
464 format!(
465 "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
466 self.name, reason, worker_count, in_flight
467 )
468 }
469 HealthStatus::Unhealthy { ref reason } => {
470 format!(
471 "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
472 self.name, reason, worker_count, in_flight
473 )
474 }
475 }
476 }
477}
478
479struct WorkerHandler(Arc<dyn Worker>);
481
482#[async_trait::async_trait]
483impl MessageHandler for WorkerHandler {
484 async fn handle(
485 &self,
486 message: ReceivedMessage<serde_json::Value>,
487 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
488 self.0.process(message).await?;
490 Ok(crate::middleware::MiddlewareResult::Continue)
491 }
492}
493
494struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
496
497#[async_trait::async_trait]
498impl Middleware for ArcMiddlewareWrapper {
499 fn name(&self) -> &str {
500 self.0.name()
501 }
502
503 async fn handle(
504 &self,
505 message: ReceivedMessage<serde_json::Value>,
506 next: Box<dyn MessageHandler>,
507 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
508 self.0.handle(message, next).await
509 }
510}
511
512struct ArcHandlerWrapper(Box<dyn MessageHandler>);
514
515#[async_trait::async_trait]
516impl MessageHandler for ArcHandlerWrapper {
517 async fn handle(
518 &self,
519 message: ReceivedMessage<serde_json::Value>,
520 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
521 self.0.handle(message).await
522 }
523}
524
525async fn retry_ack(
529 ack_handle: &Arc<dyn crate::message::AckHandle>,
530 message_id: &str,
531) -> WorkerResult<()> {
532 let max_retries = 3;
533 let base_delay_ms = 100;
534
535 for attempt in 0..max_retries {
536 match ack_handle.ack().await {
537 Ok(_) => return Ok(()),
538 Err(e) => {
539 if attempt < max_retries - 1 {
540 let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
541 tracing::warn!(
542 "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
543 attempt + 1,
544 message_id,
545 e,
546 delay
547 );
548 sleep(delay).await;
549 } else {
550 return Err(e);
551 }
552 }
553 }
554 }
555
556 unreachable!()
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::message::{AckHandle, Message, MessageMetadata};
563 use crate::metrics::NoOpMetrics;
564 use async_trait::async_trait;
565 use std::sync::atomic::{AtomicUsize, Ordering};
566 use std::time::Duration; #[derive(Debug)]
569 struct MockAckHandle;
570
571 #[async_trait]
572 impl AckHandle for MockAckHandle {
573 async fn ack(&self) -> WorkerResult<()> {
574 Ok(())
575 }
576
577 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
578 Ok(())
579 }
580 }
581
582 struct TestWorker {
583 id: String,
584 process_count: Arc<AtomicUsize>,
585 }
586
587 impl TestWorker {
588 fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
589 let count = Arc::new(AtomicUsize::new(0));
590 (
591 Self {
592 id: id.to_string(),
593 process_count: count.clone(),
594 },
595 count,
596 )
597 }
598 }
599
600 #[async_trait]
601 impl Worker for TestWorker {
602 fn id(&self) -> &str {
603 &self.id
604 }
605
606 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
607 self.process_count.fetch_add(1, Ordering::SeqCst);
608 Ok(())
609 }
610 }
611
612 fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
613 let message = Message {
614 id: id.to_string(),
615 payload: serde_json::json!({"test": "data"}),
616 metadata: MessageMetadata::new("test-queue"),
617 };
618 ReceivedMessage::new(message, Arc::new(MockAckHandle))
619 }
620
621 #[tokio::test]
622 async fn test_pool_creation() {
623 let pool = WorkerPool::new(
624 "test-pool",
625 LoadBalancingStrategy::RoundRobin,
626 Arc::new(NoOpMetrics),
627 );
628 assert_eq!(pool.name(), "test-pool");
629 assert_eq!(pool.worker_count(), 0);
630 assert!(pool.is_running.load(Ordering::SeqCst));
631 }
632
633 #[tokio::test]
634 async fn test_add_worker() {
635 let mut pool = WorkerPool::new(
636 "test-pool",
637 LoadBalancingStrategy::RoundRobin,
638 Arc::new(NoOpMetrics),
639 );
640 let (worker, _) = TestWorker::new("worker-1");
641 pool.add_worker(Arc::new(worker));
642
643 assert_eq!(pool.worker_count(), 1);
644 }
645
646 #[tokio::test]
647 async fn test_dispatch_empty_pool() {
648 let pool = WorkerPool::new(
649 "test-pool",
650 LoadBalancingStrategy::RoundRobin,
651 Arc::new(NoOpMetrics),
652 );
653 let message = create_test_message("msg-1");
654
655 let result = pool.dispatch(message).await;
656 assert!(matches!(result, Err(WorkerError::PoolExhausted)));
657 }
658
659 #[tokio::test]
660 async fn test_round_robin_distribution() {
661 let mut pool = WorkerPool::new(
662 "test-pool",
663 LoadBalancingStrategy::RoundRobin,
664 Arc::new(NoOpMetrics),
665 );
666
667 let (worker1, count1) = TestWorker::new("worker-1");
668 let (worker2, count2) = TestWorker::new("worker-2");
669
670 pool.add_worker(Arc::new(worker1));
671 pool.add_worker(Arc::new(worker2));
672
673 for i in 0..4 {
675 let message = create_test_message(&format!("msg-{}", i));
676 pool.dispatch(message).await.unwrap();
677 }
678
679 tokio::time::sleep(Duration::from_millis(100)).await;
681
682 assert_eq!(count1.load(Ordering::SeqCst), 2);
684 assert_eq!(count2.load(Ordering::SeqCst), 2);
685 }
686
687 #[tokio::test]
688 async fn test_pool_health() {
689 let pool = WorkerPool::new(
690 "test-pool",
691 LoadBalancingStrategy::RoundRobin,
692 Arc::new(NoOpMetrics),
693 );
694 assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); let mut pool = pool;
697 let (worker, _) = TestWorker::new("worker-1");
698 pool.add_worker(Arc::new(worker));
699
700 assert!(matches!(pool.check_health(), HealthStatus::Healthy));
701 }
702
703 #[tokio::test]
704 async fn test_concurrency_limit_enforcement() {
705 use std::sync::atomic::{AtomicUsize, Ordering};
706
707 let concurrent_count = Arc::new(AtomicUsize::new(0));
709 let max_concurrent = Arc::new(AtomicUsize::new(0));
710
711 struct ConcurrentTestWorker {
712 id: String,
713 concurrent: Arc<AtomicUsize>,
714 max_concurrent: Arc<AtomicUsize>,
715 }
716
717 #[async_trait::async_trait]
718 impl Worker for ConcurrentTestWorker {
719 fn id(&self) -> &str {
720 &self.id
721 }
722
723 async fn process(
724 &self,
725 _message: ReceivedMessage<serde_json::Value>,
726 ) -> WorkerResult<()> {
727 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
729
730 let mut max = self.max_concurrent.load(Ordering::SeqCst);
732 while current > max {
733 match self.max_concurrent.compare_exchange_weak(
734 max,
735 current,
736 Ordering::SeqCst,
737 Ordering::SeqCst,
738 ) {
739 Ok(_) => break,
740 Err(new_max) => max = new_max,
741 }
742 }
743
744 tokio::time::sleep(Duration::from_millis(50)).await;
746
747 self.concurrent.fetch_sub(1, Ordering::SeqCst);
749 Ok(())
750 }
751 }
752
753 let mut pool = WorkerPool::with_concurrency(
755 "test-pool",
756 LoadBalancingStrategy::RoundRobin,
757 3, Arc::new(NoOpMetrics),
759 );
760
761 let worker = ConcurrentTestWorker {
763 id: "worker-1".to_string(),
764 concurrent: concurrent_count.clone(),
765 max_concurrent: max_concurrent.clone(),
766 };
767 pool.add_worker(Arc::new(worker));
768
769 for i in 0..10 {
771 let message = create_test_message(&format!("msg-{}", i));
772 pool.dispatch(message).await.unwrap();
773 }
774
775 tokio::time::sleep(Duration::from_millis(500)).await;
777
778 let actual_max = max_concurrent.load(Ordering::SeqCst);
780 assert!(
781 actual_max <= 3,
782 "Expected max concurrency <= 3, but got {}",
783 actual_max
784 );
785 assert!(
786 actual_max >= 2,
787 "Expected some concurrency (>= 2), but got {}",
788 actual_max
789 );
790 }
791
792 #[tokio::test]
793 async fn test_concurrency_limit_with_builder() {
794 use crate::builder::WorkerPoolBuilder;
795
796 let concurrent_count = Arc::new(AtomicUsize::new(0));
797 let max_concurrent = Arc::new(AtomicUsize::new(0));
798
799 struct TrackedWorker {
800 id: String,
801 concurrent: Arc<AtomicUsize>,
802 max_concurrent: Arc<AtomicUsize>,
803 }
804
805 #[async_trait::async_trait]
806 impl Worker for TrackedWorker {
807 fn id(&self) -> &str {
808 &self.id
809 }
810
811 async fn process(
812 &self,
813 _message: ReceivedMessage<serde_json::Value>,
814 ) -> WorkerResult<()> {
815 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
817
818 let mut max = self.max_concurrent.load(Ordering::SeqCst);
819 while current > max {
820 match self.max_concurrent.compare_exchange_weak(
821 max,
822 current,
823 Ordering::SeqCst,
824 Ordering::SeqCst,
825 ) {
826 Ok(_) => break,
827 Err(new_max) => max = new_max,
828 }
829 }
830
831 tokio::time::sleep(Duration::from_millis(100)).await;
832
833 self.concurrent.fetch_sub(1, Ordering::SeqCst);
834 Ok(())
835 }
836 }
837
838 let pool = WorkerPoolBuilder::new("test-pool")
840 .with_concurrency_limit(2)
841 .add_worker(TrackedWorker {
842 id: "worker-1".to_string(),
843 concurrent: concurrent_count.clone(),
844 max_concurrent: max_concurrent.clone(),
845 })
846 .build()
847 .unwrap();
848
849 for i in 0..6 {
851 let message = create_test_message(&format!("msg-{}", i));
852 pool.dispatch(message).await.unwrap();
853 }
854
855 tokio::time::sleep(Duration::from_millis(800)).await;
857
858 let actual_max = max_concurrent.load(Ordering::SeqCst);
860 assert!(
861 actual_max <= 2,
862 "Expected max concurrency <= 2, but got {}",
863 actual_max
864 );
865 assert!(
866 actual_max >= 1,
867 "Expected some concurrency (>= 1), but got {}",
868 actual_max
869 );
870 }
871
872 #[tokio::test]
873 async fn test_different_concurrency_limits() {
874 let pool1 = WorkerPool::with_concurrency(
876 "pool1",
877 LoadBalancingStrategy::RoundRobin,
878 5,
879 Arc::new(NoOpMetrics),
880 );
881 let pool2 = WorkerPool::with_concurrency(
882 "pool2",
883 LoadBalancingStrategy::RoundRobin,
884 20,
885 Arc::new(NoOpMetrics),
886 );
887
888 assert_eq!(pool1.name(), "pool1");
891 assert_eq!(pool2.name(), "pool2");
892 }
893
894 #[tokio::test]
895 async fn test_pool_shutdown_prevents_dispatch() {
896 let mut pool = WorkerPool::new(
897 "test-pool",
898 LoadBalancingStrategy::RoundRobin,
899 Arc::new(NoOpMetrics),
900 );
901 let (worker, _) = TestWorker::new("worker-1");
902 pool.add_worker(Arc::new(worker));
903
904 pool.shutdown().await.unwrap();
905
906 let message = create_test_message("msg-after-shutdown");
907 let result = pool.dispatch(message).await;
908 assert!(matches!(result, Err(WorkerError::Shutdown)));
909 assert!(matches!(
910 pool.check_health(),
911 HealthStatus::Unhealthy { .. }
912 ));
913 }
914
915 #[tokio::test]
918 async fn test_no_double_ack_with_middleware() {
919 use crate::AckNackMiddleware;
920 use std::sync::atomic::AtomicUsize;
921
922 let ack_count = Arc::new(AtomicUsize::new(0));
924 let nack_count = Arc::new(AtomicUsize::new(0));
925
926 #[derive(Debug)]
927 struct TrackingAckHandle {
928 ack_count: Arc<AtomicUsize>,
929 nack_count: Arc<AtomicUsize>,
930 }
931
932 #[async_trait]
933 impl AckHandle for TrackingAckHandle {
934 async fn ack(&self) -> WorkerResult<()> {
935 self.ack_count.fetch_add(1, Ordering::SeqCst);
936 Ok(())
937 }
938
939 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
940 self.nack_count.fetch_add(1, Ordering::SeqCst);
941 Ok(())
942 }
943 }
944
945 struct SuccessWorker;
947
948 #[async_trait]
949 impl Worker for SuccessWorker {
950 fn id(&self) -> &str {
951 "success-worker"
952 }
953
954 async fn process(
955 &self,
956 _message: ReceivedMessage<serde_json::Value>,
957 ) -> WorkerResult<()> {
958 Ok(())
959 }
960 }
961
962 let mut pool = WorkerPool::new(
963 "test-pool",
964 LoadBalancingStrategy::RoundRobin,
965 Arc::new(NoOpMetrics),
966 );
967
968 pool.middlewares
970 .push(Arc::new(AckNackMiddleware::default()));
971 pool.add_worker(Arc::new(SuccessWorker));
972
973 let message = Message {
975 id: "test-msg".to_string(),
976 payload: serde_json::json!({"test": "data"}),
977 metadata: MessageMetadata::new("test-queue"),
978 };
979 let received = ReceivedMessage::new(
980 message,
981 Arc::new(TrackingAckHandle {
982 ack_count: ack_count.clone(),
983 nack_count: nack_count.clone(),
984 }),
985 );
986
987 pool.dispatch(received).await.unwrap();
989
990 tokio::time::sleep(Duration::from_millis(100)).await;
992
993 assert_eq!(
995 ack_count.load(Ordering::SeqCst),
996 1,
997 "Message should have been acked exactly once by middleware"
998 );
999 assert_eq!(
1000 nack_count.load(Ordering::SeqCst),
1001 0,
1002 "Message should not have been nacked"
1003 );
1004 }
1005
1006 #[tokio::test]
1008 async fn test_pool_handles_ack_without_middleware() {
1009 use std::sync::atomic::AtomicUsize;
1010
1011 let ack_count = Arc::new(AtomicUsize::new(0));
1012 let nack_count = Arc::new(AtomicUsize::new(0));
1013
1014 #[derive(Debug)]
1015 struct CountingAckHandle {
1016 ack_count: Arc<AtomicUsize>,
1017 nack_count: Arc<AtomicUsize>,
1018 }
1019
1020 #[async_trait]
1021 impl AckHandle for CountingAckHandle {
1022 async fn ack(&self) -> WorkerResult<()> {
1023 self.ack_count.fetch_add(1, Ordering::SeqCst);
1024 Ok(())
1025 }
1026
1027 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1028 self.nack_count.fetch_add(1, Ordering::SeqCst);
1029 Ok(())
1030 }
1031 }
1032
1033 struct SuccessWorker;
1035
1036 #[async_trait]
1037 impl Worker for SuccessWorker {
1038 fn id(&self) -> &str {
1039 "success-worker"
1040 }
1041
1042 async fn process(
1043 &self,
1044 _message: ReceivedMessage<serde_json::Value>,
1045 ) -> WorkerResult<()> {
1046 Ok(())
1047 }
1048 }
1049
1050 let mut pool = WorkerPool::new(
1051 "test-pool",
1052 LoadBalancingStrategy::RoundRobin,
1053 Arc::new(NoOpMetrics),
1054 );
1055 pool.add_worker(Arc::new(SuccessWorker));
1057
1058 let message = Message {
1059 id: "test-msg".to_string(),
1060 payload: serde_json::json!({"test": "data"}),
1061 metadata: MessageMetadata::new("test-queue"),
1062 };
1063 let received = ReceivedMessage::new(
1064 message,
1065 Arc::new(CountingAckHandle {
1066 ack_count: ack_count.clone(),
1067 nack_count: nack_count.clone(),
1068 }),
1069 );
1070
1071 pool.dispatch(received).await.unwrap();
1072
1073 tokio::time::sleep(Duration::from_millis(100)).await;
1075
1076 assert_eq!(
1078 ack_count.load(Ordering::SeqCst),
1079 1,
1080 "Pool should have acked the message when no middleware is used"
1081 );
1082 assert_eq!(
1083 nack_count.load(Ordering::SeqCst),
1084 0,
1085 "Nack should not be called for successful message"
1086 );
1087 }
1088
1089 #[tokio::test]
1091 async fn test_pool_handles_nack_without_middleware() {
1092 use std::sync::atomic::AtomicUsize;
1093
1094 let ack_count = Arc::new(AtomicUsize::new(0));
1095 let nack_count = Arc::new(AtomicUsize::new(0));
1096
1097 #[derive(Debug)]
1098 struct CountingAckHandle {
1099 ack_count: Arc<AtomicUsize>,
1100 nack_count: Arc<AtomicUsize>,
1101 }
1102
1103 #[async_trait]
1104 impl AckHandle for CountingAckHandle {
1105 async fn ack(&self) -> WorkerResult<()> {
1106 self.ack_count.fetch_add(1, Ordering::SeqCst);
1107 Ok(())
1108 }
1109
1110 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1111 self.nack_count.fetch_add(1, Ordering::SeqCst);
1112 Ok(())
1113 }
1114 }
1115
1116 struct FailingWorker;
1118
1119 #[async_trait]
1120 impl Worker for FailingWorker {
1121 fn id(&self) -> &str {
1122 "failing-worker"
1123 }
1124
1125 async fn process(
1126 &self,
1127 _message: ReceivedMessage<serde_json::Value>,
1128 ) -> WorkerResult<()> {
1129 Err(WorkerError::ProcessingFailed(
1130 "Simulated failure".to_string(),
1131 ))
1132 }
1133 }
1134
1135 let mut pool = WorkerPool::new(
1136 "test-pool",
1137 LoadBalancingStrategy::RoundRobin,
1138 Arc::new(NoOpMetrics),
1139 );
1140 pool.add_worker(Arc::new(FailingWorker));
1142
1143 let message = Message {
1144 id: "test-msg".to_string(),
1145 payload: serde_json::json!({"test": "data"}),
1146 metadata: MessageMetadata::new("test-queue"),
1147 };
1148 let received = ReceivedMessage::new(
1149 message,
1150 Arc::new(CountingAckHandle {
1151 ack_count: ack_count.clone(),
1152 nack_count: nack_count.clone(),
1153 }),
1154 );
1155
1156 pool.dispatch(received).await.unwrap();
1157
1158 tokio::time::sleep(Duration::from_millis(100)).await;
1160
1161 assert_eq!(
1163 nack_count.load(Ordering::SeqCst),
1164 1,
1165 "Pool should have nacked the message when no middleware is used"
1166 );
1167 assert_eq!(
1168 ack_count.load(Ordering::SeqCst),
1169 0,
1170 "Ack should not be called for failed message"
1171 );
1172 }
1173
1174 #[tokio::test]
1176 async fn test_integration_ack_nack_middleware_with_failure() {
1177 use crate::AckNackMiddleware;
1178 use std::sync::atomic::AtomicUsize;
1179
1180 let ack_count = Arc::new(AtomicUsize::new(0));
1181 let nack_count = Arc::new(AtomicUsize::new(0));
1182
1183 #[derive(Debug)]
1184 struct CountingAckHandle {
1185 ack_count: Arc<AtomicUsize>,
1186 nack_count: Arc<AtomicUsize>,
1187 }
1188
1189 #[async_trait]
1190 impl AckHandle for CountingAckHandle {
1191 async fn ack(&self) -> WorkerResult<()> {
1192 self.ack_count.fetch_add(1, Ordering::SeqCst);
1193 Ok(())
1194 }
1195
1196 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1197 self.nack_count.fetch_add(1, Ordering::SeqCst);
1198 Ok(())
1199 }
1200 }
1201
1202 struct FailingWorker;
1204
1205 #[async_trait]
1206 impl Worker for FailingWorker {
1207 fn id(&self) -> &str {
1208 "failing-worker"
1209 }
1210
1211 async fn process(
1212 &self,
1213 _message: ReceivedMessage<serde_json::Value>,
1214 ) -> WorkerResult<()> {
1215 Err(WorkerError::ProcessingFailed(
1216 "Simulated failure".to_string(),
1217 ))
1218 }
1219 }
1220
1221 let mut pool = WorkerPool::new(
1222 "test-pool",
1223 LoadBalancingStrategy::RoundRobin,
1224 Arc::new(NoOpMetrics),
1225 );
1226
1227 pool.middlewares
1229 .push(Arc::new(AckNackMiddleware::default()));
1230 pool.add_worker(Arc::new(FailingWorker));
1231
1232 let message = Message {
1233 id: "test-msg".to_string(),
1234 payload: serde_json::json!({"test": "data"}),
1235 metadata: MessageMetadata::new("test-queue"),
1236 };
1237 let received = ReceivedMessage::new(
1238 message,
1239 Arc::new(CountingAckHandle {
1240 ack_count: ack_count.clone(),
1241 nack_count: nack_count.clone(),
1242 }),
1243 );
1244
1245 pool.dispatch(received).await.unwrap();
1246
1247 tokio::time::sleep(Duration::from_millis(100)).await;
1249
1250 assert_eq!(
1252 nack_count.load(Ordering::SeqCst),
1253 1,
1254 "Nack should be called exactly once by middleware"
1255 );
1256 assert_eq!(
1257 ack_count.load(Ordering::SeqCst),
1258 0,
1259 "Ack should not be called for failed message"
1260 );
1261 }
1262}