1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::ReceivedMessage;
11use crate::metrics::WorkerMetrics;
12use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
13use crate::strategies::{
14 LeastLoadedBalancer, LoadBalancingStrategy, RandomBalancer, RoundRobinBalancer,
15};
16use crate::worker::Worker;
17
18pub struct WorkerPool {
35 name: String,
36 workers: Vec<Arc<dyn Worker>>,
37 strategy: LoadBalancingStrategy,
38 semaphore: Arc<Semaphore>,
39 concurrency_limit: usize,
40 least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
41 round_robin_balancer: Arc<RoundRobinBalancer>,
42 random_balancer: RandomBalancer,
43 middlewares: Vec<Arc<dyn Middleware>>,
45 metrics_collector: Arc<dyn WorkerMetrics>,
47 is_running: Arc<AtomicBool>,
49 cancellation_token: CancellationToken,
51 task_completion_notify: Arc<Notify>,
53 in_flight_tasks: Arc<AtomicUsize>,
55}
56
57impl std::fmt::Debug for WorkerPool {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("WorkerPool")
60 .field("name", &self.name)
61 .field("worker_count", &self.workers.len())
62 .field("strategy", &self.strategy)
63 .field("is_running", &self.is_running.load(Ordering::SeqCst))
64 .finish()
65 }
66}
67
68impl WorkerPool {
69 pub fn new(
71 name: impl Into<String>,
72 strategy: LoadBalancingStrategy,
73 metrics_collector: Arc<dyn WorkerMetrics>,
74 ) -> Self {
75 Self::with_concurrency(name, strategy, 1000, metrics_collector)
76 }
77
78 pub fn with_concurrency(
86 name: impl Into<String>,
87 strategy: LoadBalancingStrategy,
88 concurrency_limit: usize,
89 metrics_collector: Arc<dyn WorkerMetrics>,
90 ) -> Self {
91 let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
92 Some(Arc::new(LeastLoadedBalancer::new(0)))
93 } else {
94 None
95 };
96
97 Self {
98 name: name.into(),
99 workers: Vec::new(),
100 strategy,
101 semaphore: Arc::new(Semaphore::new(concurrency_limit)),
102 concurrency_limit,
103 least_loaded_balancer,
104 round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
105 random_balancer: RandomBalancer,
106 middlewares: Vec::new(),
107 metrics_collector,
108 is_running: Arc::new(AtomicBool::new(true)),
109 cancellation_token: CancellationToken::new(),
110 task_completion_notify: Arc::new(Notify::new()),
111 in_flight_tasks: Arc::new(AtomicUsize::new(0)),
112 }
113 }
114
115 pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
117 self.workers.push(worker);
118
119 if let Some(ref balancer) = self.least_loaded_balancer {
121 balancer.add_worker();
122 }
123 self.metrics_collector
124 .record_active_workers(self.workers.len());
125 }
126
127 pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
129 for worker in workers {
130 self.add_worker(worker);
131 }
132 }
133
134 pub fn worker_count(&self) -> usize {
136 self.workers.len()
137 }
138
139 pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
144 self.middlewares = middlewares;
145 self
146 }
147
148 pub fn name(&self) -> &str {
150 &self.name
151 }
152
153 pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
158 if !self.is_running.load(Ordering::SeqCst) {
159 return Err(WorkerError::Shutdown);
160 }
161 if self.workers.is_empty() {
162 return Err(WorkerError::PoolExhausted);
163 }
164
165 let worker_index = self.select_worker();
167 let worker = self.workers[worker_index].clone();
168 let worker_id = worker.id().to_string();
169 let queue_name = message.message.metadata.source.clone();
170
171 self.metrics_collector
172 .record_message_received(&worker_id, &queue_name);
173 let start_time = Instant::now();
174
175 if let Some(ref balancer) = self.least_loaded_balancer {
176 balancer.increment_load(worker_index);
177 }
178
179 let permit = self
180 .semaphore
181 .clone()
182 .acquire_owned()
183 .await
184 .map_err(|_| WorkerError::Shutdown)?;
185
186 self.metrics_collector
187 .record_in_flight_messages(self.semaphore.available_permits());
188
189 let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
191 let worker_handler = WorkerHandler(worker);
192 let boxed_middlewares: Vec<Box<dyn Middleware>> = self
193 .middlewares
194 .iter()
195 .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
196 .collect();
197
198 let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
199 Arc::new(ArcHandlerWrapper(chain.build()))
200 } else {
201 Arc::new(WorkerHandler(worker))
202 };
203
204 let metrics_collector_clone = self.metrics_collector.clone();
205 let least_loaded_balancer = self.least_loaded_balancer.clone();
206 let cancellation_token = self.cancellation_token.child_token();
207 let task_completion_notify = self.task_completion_notify.clone();
208 let in_flight_tasks = self.in_flight_tasks.clone();
209
210 in_flight_tasks.fetch_add(1, Ordering::SeqCst);
212
213 let ack_handle = message.ack_handle.clone();
215 let message_id = message.message.id.clone();
216 let attempt = message.message.metadata.attempt;
217
218 tokio::spawn(async move {
219 let result = tokio::select! {
221 result = handler.handle(message) => result, _ = cancellation_token.cancelled() => {
223 tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
224 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
226 task_completion_notify.notify_one();
227 return;
228 }
229 };
230
231 match result {
232 Ok(middleware_result) => {
233 match middleware_result {
234 crate::middleware::MiddlewareResult::Acknowledged => {
235 tracing::debug!("Message {} already acknowledged by middleware", message_id);
237 metrics_collector_clone.record_message_processed(
238 &worker_id,
239 &queue_name,
240 start_time,
241 );
242 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
244 task_completion_notify.notify_one();
245 return;
246 }
247 crate::middleware::MiddlewareResult::Continue => {
248 tracing::debug!("Message {} processed successfully", message_id);
250 metrics_collector_clone.record_message_processed(
251 &worker_id,
252 &queue_name,
253 start_time,
254 );
255 if let Err(e) = retry_ack(&ack_handle, &message_id).await {
256 tracing::error!(
257 "Failed to ack message {} after retries: {}. Message may be redelivered.",
258 message_id,
259 e
260 );
261 }
262 }
263 }
264 }
265 Err(WorkerError::RetryableFailure { source, delay_ms }) => {
266 tracing::warn!(
267 "Message {} failed (will retry in {:?}): {}",
268 message_id,
269 delay_ms,
270 source
271 );
272 metrics_collector_clone.record_message_retried(
273 &worker_id,
274 &queue_name,
275 attempt,
276 );
277 metrics_collector_clone.record_message_failed(
278 &worker_id,
279 &queue_name,
280 "RetryableFailure",
281 start_time,
282 );
283 if let Err(e) = ack_handle.nack(true).await {
284 tracing::error!("Failed to requeue message {}: {}", message_id, e);
285 }
286 sleep(delay_ms).await;
287 }
288 Err(WorkerError::RetriesExhausted { source }) => {
289 tracing::error!(
290 "Message {} exhausted all retries, sending to DLQ: {}",
291 message_id,
292 source
293 );
294 metrics_collector_clone
295 .record_message_retries_exhausted(&worker_id, &queue_name);
296 metrics_collector_clone.record_message_failed(
297 &worker_id,
298 &queue_name,
299 "RetriesExhausted",
300 start_time,
301 );
302 if let Err(e) = ack_handle.nack(false).await {
303 tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
304 }
305 }
306 Err(e) => {
307 let error_type = format!("{:?}", e);
308 tracing::error!("Message {} failed: {}", message_id, e);
309 metrics_collector_clone.record_message_failed(
310 &worker_id,
311 &queue_name,
312 &error_type,
313 start_time,
314 );
315 if let Err(nack_err) = ack_handle.nack(false).await {
316 tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
317 }
318 }
319 }
320
321 if let Some(ref balancer) = least_loaded_balancer {
323 balancer.decrement_load(worker_index);
324 }
325
326 drop(permit);
327
328 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
330 task_completion_notify.notify_one();
331 });
332
333 Ok(())
334 }
335
336 fn select_worker(&self) -> usize {
338 match self.strategy {
339 LoadBalancingStrategy::RoundRobin => self.round_robin_balancer.next(self.workers.len()),
340 LoadBalancingStrategy::Random => self.random_balancer.next(self.workers.len()),
341 LoadBalancingStrategy::LeastLoaded => {
342 if let Some(ref balancer) = self.least_loaded_balancer {
343 balancer.next()
344 } else {
345 0 }
347 }
348 }
349 }
350
351 pub async fn shutdown(&self) -> WorkerResult<()> {
356 tracing::info!("Shutting down worker pool: {}", self.name);
357
358 self.is_running.store(false, Ordering::SeqCst);
359 self.metrics_collector.record_active_workers(0);
360
361 self.cancellation_token.cancel();
363 tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
364
365 self.semaphore.close();
367
368 let shutdown_timeout = Duration::from_secs(30); let start = Instant::now();
371
372 loop {
373 let available = self.semaphore.available_permits();
374 let in_flight = self.concurrency_limit.saturating_sub(available);
375
376 if in_flight == 0 {
377 break; }
379
380 if start.elapsed() >= shutdown_timeout {
381 tracing::warn!(
382 "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
383 self.name,
384 in_flight
385 );
386 break;
387 }
388
389 tokio::select! {
391 _ = self.task_completion_notify.notified() => {
392 continue;
394 }
395 _ = tokio::time::sleep(Duration::from_millis(100)) => {
396 continue;
398 }
399 }
400 }
401
402 self.metrics_collector.record_in_flight_messages(0);
403 tracing::info!("Worker pool {} shutdown complete", self.name);
404 Ok(())
405 }
406
407 pub fn in_flight_count(&self) -> usize {
409 self.in_flight_tasks.load(Ordering::SeqCst)
410 }
411}
412
413impl HealthCheck for WorkerPool {
414 fn check_health(&self) -> HealthStatus {
415 let is_running = self.is_running.load(Ordering::SeqCst);
416 let worker_count = self.worker_count();
417 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
418
419 if !is_running {
421 return HealthStatus::Unhealthy {
422 reason: "Pool is not running".to_string(),
423 };
424 }
425
426 if worker_count == 0 {
428 return HealthStatus::Degraded {
429 reason: "No workers available".to_string(),
430 };
431 }
432
433 let saturation = in_flight as f64 / self.concurrency_limit as f64;
435 if saturation > 0.9 {
436 return HealthStatus::Degraded {
437 reason: format!(
438 "Pool near capacity: {} in-flight messages ({:.0}% saturation)",
439 in_flight,
440 saturation * 100.0
441 ),
442 };
443 }
444
445 HealthStatus::Healthy
446 }
447
448 fn status_message(&self) -> String {
449 let worker_count = self.worker_count();
450 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
451 let available_permits = self.semaphore.available_permits();
452
453 match self.check_health() {
454 HealthStatus::Healthy => {
455 format!(
456 "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
457 self.name, worker_count, in_flight, available_permits
458 )
459 }
460 HealthStatus::Degraded { ref reason } => {
461 format!(
462 "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
463 self.name, reason, worker_count, in_flight
464 )
465 }
466 HealthStatus::Unhealthy { ref reason } => {
467 format!(
468 "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
469 self.name, reason, worker_count, in_flight
470 )
471 }
472 }
473 }
474}
475
476struct WorkerHandler(Arc<dyn Worker>);
478
479#[async_trait::async_trait]
480impl MessageHandler for WorkerHandler {
481 async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
482 self.0.process(message).await?;
484 Ok(crate::middleware::MiddlewareResult::Continue)
485 }
486}
487
488struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
490
491#[async_trait::async_trait]
492impl Middleware for ArcMiddlewareWrapper {
493 fn name(&self) -> &str {
494 self.0.name()
495 }
496
497 async fn handle(
498 &self,
499 message: ReceivedMessage<serde_json::Value>,
500 next: Box<dyn MessageHandler>,
501 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
502 self.0.handle(message, next).await
503 }
504}
505
506struct ArcHandlerWrapper(Box<dyn MessageHandler>);
508
509#[async_trait::async_trait]
510impl MessageHandler for ArcHandlerWrapper {
511 async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
512 self.0.handle(message).await
513 }
514}
515
516async fn retry_ack(
520 ack_handle: &Arc<dyn crate::message::AckHandle>,
521 message_id: &str,
522) -> WorkerResult<()> {
523 let max_retries = 3;
524 let base_delay_ms = 100;
525
526 for attempt in 0..max_retries {
527 match ack_handle.ack().await {
528 Ok(_) => return Ok(()),
529 Err(e) => {
530 if attempt < max_retries - 1 {
531 let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
532 tracing::warn!(
533 "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
534 attempt + 1,
535 message_id,
536 e,
537 delay
538 );
539 sleep(delay).await;
540 } else {
541 return Err(e);
542 }
543 }
544 }
545 }
546
547 unreachable!()
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use crate::message::{AckHandle, Message, MessageMetadata};
554 use crate::metrics::NoOpMetrics;
555 use async_trait::async_trait;
556 use std::sync::atomic::{AtomicUsize, Ordering};
557 use std::time::Duration; #[derive(Debug)]
560 struct MockAckHandle;
561
562 #[async_trait]
563 impl AckHandle for MockAckHandle {
564 async fn ack(&self) -> WorkerResult<()> {
565 Ok(())
566 }
567
568 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
569 Ok(())
570 }
571 }
572
573 struct TestWorker {
574 id: String,
575 process_count: Arc<AtomicUsize>,
576 }
577
578 impl TestWorker {
579 fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
580 let count = Arc::new(AtomicUsize::new(0));
581 (
582 Self {
583 id: id.to_string(),
584 process_count: count.clone(),
585 },
586 count,
587 )
588 }
589 }
590
591 #[async_trait]
592 impl Worker for TestWorker {
593 fn id(&self) -> &str {
594 &self.id
595 }
596
597 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
598 self.process_count.fetch_add(1, Ordering::SeqCst);
599 Ok(())
600 }
601 }
602
603 fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
604 let message = Message {
605 id: id.to_string(),
606 payload: serde_json::json!({"test": "data"}),
607 metadata: MessageMetadata::new("test-queue"),
608 };
609 ReceivedMessage::new(message, Arc::new(MockAckHandle))
610 }
611
612 #[tokio::test]
613 async fn test_pool_creation() {
614 let pool = WorkerPool::new(
615 "test-pool",
616 LoadBalancingStrategy::RoundRobin,
617 Arc::new(NoOpMetrics),
618 );
619 assert_eq!(pool.name(), "test-pool");
620 assert_eq!(pool.worker_count(), 0);
621 assert!(pool.is_running.load(Ordering::SeqCst));
622 }
623
624 #[tokio::test]
625 async fn test_add_worker() {
626 let mut pool = WorkerPool::new(
627 "test-pool",
628 LoadBalancingStrategy::RoundRobin,
629 Arc::new(NoOpMetrics),
630 );
631 let (worker, _) = TestWorker::new("worker-1");
632 pool.add_worker(Arc::new(worker));
633
634 assert_eq!(pool.worker_count(), 1);
635 }
636
637 #[tokio::test]
638 async fn test_dispatch_empty_pool() {
639 let pool = WorkerPool::new(
640 "test-pool",
641 LoadBalancingStrategy::RoundRobin,
642 Arc::new(NoOpMetrics),
643 );
644 let message = create_test_message("msg-1");
645
646 let result = pool.dispatch(message).await;
647 assert!(matches!(result, Err(WorkerError::PoolExhausted)));
648 }
649
650 #[tokio::test]
651 async fn test_round_robin_distribution() {
652 let mut pool = WorkerPool::new(
653 "test-pool",
654 LoadBalancingStrategy::RoundRobin,
655 Arc::new(NoOpMetrics),
656 );
657
658 let (worker1, count1) = TestWorker::new("worker-1");
659 let (worker2, count2) = TestWorker::new("worker-2");
660
661 pool.add_worker(Arc::new(worker1));
662 pool.add_worker(Arc::new(worker2));
663
664 for i in 0..4 {
666 let message = create_test_message(&format!("msg-{}", i));
667 pool.dispatch(message).await.unwrap();
668 }
669
670 tokio::time::sleep(Duration::from_millis(100)).await;
672
673 assert_eq!(count1.load(Ordering::SeqCst), 2);
675 assert_eq!(count2.load(Ordering::SeqCst), 2);
676 }
677
678 #[tokio::test]
679 async fn test_pool_health() {
680 let pool = WorkerPool::new(
681 "test-pool",
682 LoadBalancingStrategy::RoundRobin,
683 Arc::new(NoOpMetrics),
684 );
685 assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); let mut pool = pool;
688 let (worker, _) = TestWorker::new("worker-1");
689 pool.add_worker(Arc::new(worker));
690
691 assert!(matches!(pool.check_health(), HealthStatus::Healthy));
692 }
693
694 #[tokio::test]
695 async fn test_concurrency_limit_enforcement() {
696 use std::sync::atomic::{AtomicUsize, Ordering};
697
698 let concurrent_count = Arc::new(AtomicUsize::new(0));
700 let max_concurrent = Arc::new(AtomicUsize::new(0));
701
702 struct ConcurrentTestWorker {
703 id: String,
704 concurrent: Arc<AtomicUsize>,
705 max_concurrent: Arc<AtomicUsize>,
706 }
707
708 #[async_trait::async_trait]
709 impl Worker for ConcurrentTestWorker {
710 fn id(&self) -> &str {
711 &self.id
712 }
713
714 async fn process(
715 &self,
716 _message: ReceivedMessage<serde_json::Value>,
717 ) -> WorkerResult<()> {
718 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
720
721 let mut max = self.max_concurrent.load(Ordering::SeqCst);
723 while current > max {
724 match self.max_concurrent.compare_exchange_weak(
725 max,
726 current,
727 Ordering::SeqCst,
728 Ordering::SeqCst,
729 ) {
730 Ok(_) => break,
731 Err(new_max) => max = new_max,
732 }
733 }
734
735 tokio::time::sleep(Duration::from_millis(50)).await;
737
738 self.concurrent.fetch_sub(1, Ordering::SeqCst);
740 Ok(())
741 }
742 }
743
744 let mut pool = WorkerPool::with_concurrency(
746 "test-pool",
747 LoadBalancingStrategy::RoundRobin,
748 3, Arc::new(NoOpMetrics),
750 );
751
752 let worker = ConcurrentTestWorker {
754 id: "worker-1".to_string(),
755 concurrent: concurrent_count.clone(),
756 max_concurrent: max_concurrent.clone(),
757 };
758 pool.add_worker(Arc::new(worker));
759
760 for i in 0..10 {
762 let message = create_test_message(&format!("msg-{}", i));
763 pool.dispatch(message).await.unwrap();
764 }
765
766 tokio::time::sleep(Duration::from_millis(500)).await;
768
769 let actual_max = max_concurrent.load(Ordering::SeqCst);
771 assert!(
772 actual_max <= 3,
773 "Expected max concurrency <= 3, but got {}",
774 actual_max
775 );
776 assert!(
777 actual_max >= 2,
778 "Expected some concurrency (>= 2), but got {}",
779 actual_max
780 );
781 }
782
783 #[tokio::test]
784 async fn test_concurrency_limit_with_builder() {
785 use crate::builder::WorkerPoolBuilder;
786
787 let concurrent_count = Arc::new(AtomicUsize::new(0));
788 let max_concurrent = Arc::new(AtomicUsize::new(0));
789
790 struct TrackedWorker {
791 id: String,
792 concurrent: Arc<AtomicUsize>,
793 max_concurrent: Arc<AtomicUsize>,
794 }
795
796 #[async_trait::async_trait]
797 impl Worker for TrackedWorker {
798 fn id(&self) -> &str {
799 &self.id
800 }
801
802 async fn process(
803 &self,
804 _message: ReceivedMessage<serde_json::Value>,
805 ) -> WorkerResult<()> {
806 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
808
809 let mut max = self.max_concurrent.load(Ordering::SeqCst);
810 while current > max {
811 match self.max_concurrent.compare_exchange_weak(
812 max,
813 current,
814 Ordering::SeqCst,
815 Ordering::SeqCst,
816 ) {
817 Ok(_) => break,
818 Err(new_max) => max = new_max,
819 }
820 }
821
822 tokio::time::sleep(Duration::from_millis(100)).await;
823
824 self.concurrent.fetch_sub(1, Ordering::SeqCst);
825 Ok(())
826 }
827 }
828
829 let pool = WorkerPoolBuilder::new("test-pool")
831 .with_concurrency_limit(2)
832 .add_worker(TrackedWorker {
833 id: "worker-1".to_string(),
834 concurrent: concurrent_count.clone(),
835 max_concurrent: max_concurrent.clone(),
836 })
837 .build()
838 .unwrap();
839
840 for i in 0..6 {
842 let message = create_test_message(&format!("msg-{}", i));
843 pool.dispatch(message).await.unwrap();
844 }
845
846 tokio::time::sleep(Duration::from_millis(800)).await;
848
849 let actual_max = max_concurrent.load(Ordering::SeqCst);
851 assert!(
852 actual_max <= 2,
853 "Expected max concurrency <= 2, but got {}",
854 actual_max
855 );
856 assert!(
857 actual_max >= 1,
858 "Expected some concurrency (>= 1), but got {}",
859 actual_max
860 );
861 }
862
863 #[tokio::test]
864 async fn test_different_concurrency_limits() {
865 let pool1 = WorkerPool::with_concurrency(
867 "pool1",
868 LoadBalancingStrategy::RoundRobin,
869 5,
870 Arc::new(NoOpMetrics),
871 );
872 let pool2 = WorkerPool::with_concurrency(
873 "pool2",
874 LoadBalancingStrategy::RoundRobin,
875 20,
876 Arc::new(NoOpMetrics),
877 );
878
879 assert_eq!(pool1.name(), "pool1");
882 assert_eq!(pool2.name(), "pool2");
883 }
884
885 #[tokio::test]
886 async fn test_pool_shutdown_prevents_dispatch() {
887 let mut pool = WorkerPool::new(
888 "test-pool",
889 LoadBalancingStrategy::RoundRobin,
890 Arc::new(NoOpMetrics),
891 );
892 let (worker, _) = TestWorker::new("worker-1");
893 pool.add_worker(Arc::new(worker));
894
895 pool.shutdown().await.unwrap();
896
897 let message = create_test_message("msg-after-shutdown");
898 let result = pool.dispatch(message).await;
899 assert!(matches!(result, Err(WorkerError::Shutdown)));
900 assert!(matches!(
901 pool.check_health(),
902 HealthStatus::Unhealthy { .. }
903 ));
904 }
905
906 #[tokio::test]
909 async fn test_no_double_ack_with_middleware() {
910 use crate::AckNackMiddleware;
911 use std::sync::atomic::AtomicUsize;
912
913 let ack_count = Arc::new(AtomicUsize::new(0));
915 let nack_count = Arc::new(AtomicUsize::new(0));
916
917 #[derive(Debug)]
918 struct TrackingAckHandle {
919 ack_count: Arc<AtomicUsize>,
920 nack_count: Arc<AtomicUsize>,
921 }
922
923 #[async_trait]
924 impl AckHandle for TrackingAckHandle {
925 async fn ack(&self) -> WorkerResult<()> {
926 self.ack_count.fetch_add(1, Ordering::SeqCst);
927 Ok(())
928 }
929
930 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
931 self.nack_count.fetch_add(1, Ordering::SeqCst);
932 Ok(())
933 }
934 }
935
936 struct SuccessWorker;
938
939 #[async_trait]
940 impl Worker for SuccessWorker {
941 fn id(&self) -> &str {
942 "success-worker"
943 }
944
945 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
946 Ok(())
947 }
948 }
949
950 let mut pool = WorkerPool::new(
951 "test-pool",
952 LoadBalancingStrategy::RoundRobin,
953 Arc::new(NoOpMetrics),
954 );
955
956 pool.middlewares.push(Arc::new(AckNackMiddleware::default()));
958 pool.add_worker(Arc::new(SuccessWorker));
959
960 let message = Message {
962 id: "test-msg".to_string(),
963 payload: serde_json::json!({"test": "data"}),
964 metadata: MessageMetadata::new("test-queue"),
965 };
966 let received = ReceivedMessage::new(
967 message,
968 Arc::new(TrackingAckHandle {
969 ack_count: ack_count.clone(),
970 nack_count: nack_count.clone(),
971 }),
972 );
973
974 pool.dispatch(received).await.unwrap();
976
977 tokio::time::sleep(Duration::from_millis(100)).await;
979
980 assert_eq!(
982 ack_count.load(Ordering::SeqCst),
983 1,
984 "Message should have been acked exactly once by middleware"
985 );
986 assert_eq!(
987 nack_count.load(Ordering::SeqCst),
988 0,
989 "Message should not have been nacked"
990 );
991 }
992
993 #[tokio::test]
995 async fn test_pool_handles_ack_without_middleware() {
996 use std::sync::atomic::AtomicUsize;
997
998 let ack_count = Arc::new(AtomicUsize::new(0));
999 let nack_count = Arc::new(AtomicUsize::new(0));
1000
1001 #[derive(Debug)]
1002 struct CountingAckHandle {
1003 ack_count: Arc<AtomicUsize>,
1004 nack_count: Arc<AtomicUsize>,
1005 }
1006
1007 #[async_trait]
1008 impl AckHandle for CountingAckHandle {
1009 async fn ack(&self) -> WorkerResult<()> {
1010 self.ack_count.fetch_add(1, Ordering::SeqCst);
1011 Ok(())
1012 }
1013
1014 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1015 self.nack_count.fetch_add(1, Ordering::SeqCst);
1016 Ok(())
1017 }
1018 }
1019
1020 struct SuccessWorker;
1022
1023 #[async_trait]
1024 impl Worker for SuccessWorker {
1025 fn id(&self) -> &str {
1026 "success-worker"
1027 }
1028
1029 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
1030 Ok(())
1031 }
1032 }
1033
1034 let mut pool = WorkerPool::new(
1035 "test-pool",
1036 LoadBalancingStrategy::RoundRobin,
1037 Arc::new(NoOpMetrics),
1038 );
1039 pool.add_worker(Arc::new(SuccessWorker));
1041
1042 let message = Message {
1043 id: "test-msg".to_string(),
1044 payload: serde_json::json!({"test": "data"}),
1045 metadata: MessageMetadata::new("test-queue"),
1046 };
1047 let received = ReceivedMessage::new(
1048 message,
1049 Arc::new(CountingAckHandle {
1050 ack_count: ack_count.clone(),
1051 nack_count: nack_count.clone(),
1052 }),
1053 );
1054
1055 pool.dispatch(received).await.unwrap();
1056
1057 tokio::time::sleep(Duration::from_millis(100)).await;
1059
1060 assert_eq!(
1062 ack_count.load(Ordering::SeqCst),
1063 1,
1064 "Pool should have acked the message when no middleware is used"
1065 );
1066 assert_eq!(
1067 nack_count.load(Ordering::SeqCst),
1068 0,
1069 "Nack should not be called for successful message"
1070 );
1071 }
1072
1073 #[tokio::test]
1075 async fn test_pool_handles_nack_without_middleware() {
1076 use std::sync::atomic::AtomicUsize;
1077
1078 let ack_count = Arc::new(AtomicUsize::new(0));
1079 let nack_count = Arc::new(AtomicUsize::new(0));
1080
1081 #[derive(Debug)]
1082 struct CountingAckHandle {
1083 ack_count: Arc<AtomicUsize>,
1084 nack_count: Arc<AtomicUsize>,
1085 }
1086
1087 #[async_trait]
1088 impl AckHandle for CountingAckHandle {
1089 async fn ack(&self) -> WorkerResult<()> {
1090 self.ack_count.fetch_add(1, Ordering::SeqCst);
1091 Ok(())
1092 }
1093
1094 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1095 self.nack_count.fetch_add(1, Ordering::SeqCst);
1096 Ok(())
1097 }
1098 }
1099
1100 struct FailingWorker;
1102
1103 #[async_trait]
1104 impl Worker for FailingWorker {
1105 fn id(&self) -> &str {
1106 "failing-worker"
1107 }
1108
1109 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
1110 Err(WorkerError::ProcessingFailed("Simulated failure".to_string()))
1111 }
1112 }
1113
1114 let mut pool = WorkerPool::new(
1115 "test-pool",
1116 LoadBalancingStrategy::RoundRobin,
1117 Arc::new(NoOpMetrics),
1118 );
1119 pool.add_worker(Arc::new(FailingWorker));
1121
1122 let message = Message {
1123 id: "test-msg".to_string(),
1124 payload: serde_json::json!({"test": "data"}),
1125 metadata: MessageMetadata::new("test-queue"),
1126 };
1127 let received = ReceivedMessage::new(
1128 message,
1129 Arc::new(CountingAckHandle {
1130 ack_count: ack_count.clone(),
1131 nack_count: nack_count.clone(),
1132 }),
1133 );
1134
1135 pool.dispatch(received).await.unwrap();
1136
1137 tokio::time::sleep(Duration::from_millis(100)).await;
1139
1140 assert_eq!(
1142 nack_count.load(Ordering::SeqCst),
1143 1,
1144 "Pool should have nacked the message when no middleware is used"
1145 );
1146 assert_eq!(
1147 ack_count.load(Ordering::SeqCst),
1148 0,
1149 "Ack should not be called for failed message"
1150 );
1151 }
1152
1153 #[tokio::test]
1155 async fn test_integration_ack_nack_middleware_with_failure() {
1156 use crate::AckNackMiddleware;
1157 use std::sync::atomic::AtomicUsize;
1158
1159 let ack_count = Arc::new(AtomicUsize::new(0));
1160 let nack_count = Arc::new(AtomicUsize::new(0));
1161
1162 #[derive(Debug)]
1163 struct CountingAckHandle {
1164 ack_count: Arc<AtomicUsize>,
1165 nack_count: Arc<AtomicUsize>,
1166 }
1167
1168 #[async_trait]
1169 impl AckHandle for CountingAckHandle {
1170 async fn ack(&self) -> WorkerResult<()> {
1171 self.ack_count.fetch_add(1, Ordering::SeqCst);
1172 Ok(())
1173 }
1174
1175 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1176 self.nack_count.fetch_add(1, Ordering::SeqCst);
1177 Ok(())
1178 }
1179 }
1180
1181 struct FailingWorker;
1183
1184 #[async_trait]
1185 impl Worker for FailingWorker {
1186 fn id(&self) -> &str {
1187 "failing-worker"
1188 }
1189
1190 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
1191 Err(WorkerError::ProcessingFailed("Simulated failure".to_string()))
1192 }
1193 }
1194
1195 let mut pool = WorkerPool::new(
1196 "test-pool",
1197 LoadBalancingStrategy::RoundRobin,
1198 Arc::new(NoOpMetrics),
1199 );
1200
1201 pool.middlewares.push(Arc::new(AckNackMiddleware::default()));
1203 pool.add_worker(Arc::new(FailingWorker));
1204
1205 let message = Message {
1206 id: "test-msg".to_string(),
1207 payload: serde_json::json!({"test": "data"}),
1208 metadata: MessageMetadata::new("test-queue"),
1209 };
1210 let received = ReceivedMessage::new(
1211 message,
1212 Arc::new(CountingAckHandle {
1213 ack_count: ack_count.clone(),
1214 nack_count: nack_count.clone(),
1215 }),
1216 );
1217
1218 pool.dispatch(received).await.unwrap();
1219
1220 tokio::time::sleep(Duration::from_millis(100)).await;
1222
1223 assert_eq!(
1225 nack_count.load(Ordering::SeqCst),
1226 1,
1227 "Nack should be called exactly once by middleware"
1228 );
1229 assert_eq!(
1230 ack_count.load(Ordering::SeqCst),
1231 0,
1232 "Ack should not be called for failed message"
1233 );
1234 }
1235}