1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::{Notify, Semaphore};
5use tokio::time::sleep;
6use tokio_util::sync::CancellationToken;
7
8use crate::error::{WorkerError, WorkerResult};
9use crate::health::{HealthCheck, HealthStatus};
10use crate::message::{AckHandle, ReceivedMessage};
11use crate::metrics::WorkerMetrics;
12use crate::middleware::{MessageHandler, Middleware, MiddlewareChain, MiddlewareResult};
13use crate::strategies::{
14 LeastLoadedBalancer, LoadBalancingStrategy, RandomBalancer, RoundRobinBalancer,
15};
16use crate::worker::Worker;
17
18pub struct WorkerPool {
35 name: String,
36 workers: Vec<Arc<dyn Worker>>,
37 strategy: LoadBalancingStrategy,
38 semaphore: Arc<Semaphore>,
39 concurrency_limit: usize,
40 least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
41 round_robin_balancer: Arc<RoundRobinBalancer>,
42 random_balancer: RandomBalancer,
43 middlewares: Vec<Arc<dyn Middleware>>,
45 metrics_collector: Arc<dyn WorkerMetrics>,
47 is_running: Arc<AtomicBool>,
49 cancellation_token: CancellationToken,
51 task_completion_notify: Arc<Notify>,
53 in_flight_tasks: Arc<AtomicUsize>,
55}
56
57impl std::fmt::Debug for WorkerPool {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("WorkerPool")
60 .field("name", &self.name)
61 .field("worker_count", &self.workers.len())
62 .field("strategy", &self.strategy)
63 .field("is_running", &self.is_running.load(Ordering::SeqCst))
64 .finish()
65 }
66}
67
68impl WorkerPool {
69 pub fn new(
71 name: impl Into<String>,
72 strategy: LoadBalancingStrategy,
73 metrics_collector: Arc<dyn WorkerMetrics>,
74 ) -> Self {
75 Self::with_concurrency(name, strategy, 1000, metrics_collector)
76 }
77
78 pub fn with_concurrency(
86 name: impl Into<String>,
87 strategy: LoadBalancingStrategy,
88 concurrency_limit: usize,
89 metrics_collector: Arc<dyn WorkerMetrics>,
90 ) -> Self {
91 let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
92 Some(Arc::new(LeastLoadedBalancer::new(0)))
93 } else {
94 None
95 };
96
97 Self {
98 name: name.into(),
99 workers: Vec::new(),
100 strategy,
101 semaphore: Arc::new(Semaphore::new(concurrency_limit)),
102 concurrency_limit,
103 least_loaded_balancer,
104 round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
105 random_balancer: RandomBalancer,
106 middlewares: Vec::new(),
107 metrics_collector,
108 is_running: Arc::new(AtomicBool::new(true)),
109 cancellation_token: CancellationToken::new(),
110 task_completion_notify: Arc::new(Notify::new()),
111 in_flight_tasks: Arc::new(AtomicUsize::new(0)),
112 }
113 }
114
115 pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
117 self.workers.push(worker);
118
119 if let Some(ref balancer) = self.least_loaded_balancer {
121 balancer.add_worker();
122 }
123 self.metrics_collector
124 .record_active_workers(self.workers.len());
125 }
126
127 pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
129 for worker in workers {
130 self.add_worker(worker);
131 }
132 }
133
134 pub fn worker_count(&self) -> usize {
136 self.workers.len()
137 }
138
139 pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
144 self.middlewares = middlewares;
145 self
146 }
147
148 pub fn name(&self) -> &str {
150 &self.name
151 }
152
153 pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
158 if !self.is_running.load(Ordering::SeqCst) {
159 return Err(WorkerError::Shutdown);
160 }
161 if self.workers.is_empty() {
162 return Err(WorkerError::PoolExhausted);
163 }
164
165 let worker_index = self.select_worker();
167 let worker = self.workers[worker_index].clone();
168 let worker_id = worker.id().to_string();
169 let queue_name = message.message.metadata.source.clone();
170
171 self.metrics_collector
172 .record_message_received(&worker_id, &queue_name);
173 let start_time = Instant::now();
174
175 if let Some(ref balancer) = self.least_loaded_balancer {
176 balancer.increment_load(worker_index);
177 }
178
179 let permit = self
180 .semaphore
181 .clone()
182 .acquire_owned()
183 .await
184 .map_err(|_| WorkerError::Shutdown)?;
185
186 self.metrics_collector
187 .record_in_flight_messages(self.semaphore.available_permits());
188
189 let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
191 let worker_handler = WorkerHandler(worker);
192 let boxed_middlewares: Vec<Box<dyn Middleware>> = self
193 .middlewares
194 .iter()
195 .map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
196 .collect();
197
198 let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
199 Arc::new(ArcHandlerWrapper(chain.build()))
200 } else {
201 Arc::new(WorkerHandler(worker))
202 };
203
204 let metrics_collector_clone = self.metrics_collector.clone();
205 let least_loaded_balancer = self.least_loaded_balancer.clone();
206 let cancellation_token = self.cancellation_token.child_token();
207 let task_completion_notify = self.task_completion_notify.clone();
208 let in_flight_tasks = self.in_flight_tasks.clone();
209
210 in_flight_tasks.fetch_add(1, Ordering::SeqCst);
212
213 let ack_handle = message.ack_handle.clone();
215 let message_id = message.message.id.clone();
216 let attempt = message.message.metadata.attempt;
217 let retry_message = message.message.clone();
219
220 tokio::spawn(async move {
221 let result = tokio::select! {
223 result = handler.handle(message) => result, _ = cancellation_token.cancelled() => {
225 tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
226 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
228 task_completion_notify.notify_one();
229 return;
230 }
231 };
232
233 match result {
234 Ok(middleware_result) => {
235 match middleware_result {
236 MiddlewareResult::Acknowledged => {
237 tracing::debug!(
239 "Message {} already acknowledged by middleware",
240 message_id
241 );
242 metrics_collector_clone.record_message_processed(
243 &worker_id,
244 &queue_name,
245 start_time,
246 );
247 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
249 task_completion_notify.notify_one();
250 return;
251 }
252 MiddlewareResult::Continue => {
253 tracing::debug!("Message {} processed successfully", message_id);
255 metrics_collector_clone.record_message_processed(
256 &worker_id,
257 &queue_name,
258 start_time,
259 );
260 if let Err(e) = retry_ack(&ack_handle, &message_id).await {
261 tracing::error!(
262 "Failed to ack message {} after retries: {}. Message may be redelivered.",
263 message_id,
264 e
265 );
266 }
267 }
268 }
269 }
270 Err(WorkerError::RetryableFailure { source, delay_ms }) => {
271 tracing::warn!(
272 "Message {} failed (will retry in {:?}): {}",
273 message_id,
274 delay_ms,
275 source
276 );
277 metrics_collector_clone.record_message_retried(
278 &worker_id,
279 &queue_name,
280 attempt,
281 );
282 metrics_collector_clone.record_message_failed(
283 &worker_id,
284 &queue_name,
285 "RetryableFailure",
286 start_time,
287 );
288
289 if let Err(e) = ack_handle
293 .retry_with_delay(&retry_message, delay_ms.as_millis() as u64)
294 .await
295 {
296 tracing::error!(
297 "Failed to schedule retry for message {}: {}",
298 message_id,
299 e
300 );
301 if let Err(nack_err) = ack_handle.nack(true).await {
303 tracing::error!(
304 "Fallback nack also failed for message {}: {}",
305 message_id,
306 nack_err
307 );
308 }
309 }
310 }
311 Err(WorkerError::RetriesExhausted { source }) => {
312 tracing::error!(
313 "Message {} exhausted all retries, sending to DLQ: {}",
314 message_id,
315 source
316 );
317 metrics_collector_clone
318 .record_message_retries_exhausted(&worker_id, &queue_name);
319 metrics_collector_clone.record_message_failed(
320 &worker_id,
321 &queue_name,
322 "RetriesExhausted",
323 start_time,
324 );
325 if let Err(e) = ack_handle
327 .send_to_dlq(&retry_message, &source.to_string())
328 .await
329 {
330 tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
331 if let Err(nack_err) = ack_handle.nack(false).await {
333 tracing::error!(
334 "Fallback nack also failed for message {}: {}",
335 message_id,
336 nack_err
337 );
338 }
339 }
340 }
341 Err(e) => {
342 let error_type = format!("{:?}", e);
343 tracing::error!("Message {} failed: {}", message_id, e);
344 metrics_collector_clone.record_message_failed(
345 &worker_id,
346 &queue_name,
347 &error_type,
348 start_time,
349 );
350 if let Err(nack_err) = ack_handle.nack(false).await {
351 tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
352 }
353 }
354 }
355
356 if let Some(ref balancer) = least_loaded_balancer {
358 balancer.decrement_load(worker_index);
359 }
360
361 drop(permit);
362
363 in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
365 task_completion_notify.notify_one();
366 });
367
368 Ok(())
369 }
370
371 fn select_worker(&self) -> usize {
373 match self.strategy {
374 LoadBalancingStrategy::RoundRobin => self.round_robin_balancer.next(self.workers.len()),
375 LoadBalancingStrategy::Random => self.random_balancer.next(self.workers.len()),
376 LoadBalancingStrategy::LeastLoaded => {
377 if let Some(ref balancer) = self.least_loaded_balancer {
378 balancer.next()
379 } else {
380 0 }
382 }
383 }
384 }
385
386 pub async fn shutdown(&self) -> WorkerResult<()> {
391 tracing::info!("Shutting down worker pool: {}", self.name);
392
393 self.is_running.store(false, Ordering::SeqCst);
394 self.metrics_collector.record_active_workers(0);
395
396 self.cancellation_token.cancel();
398 tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
399
400 self.semaphore.close();
402
403 let shutdown_timeout = Duration::from_secs(30); let start = Instant::now();
406
407 loop {
408 let available = self.semaphore.available_permits();
409 let in_flight = self.concurrency_limit.saturating_sub(available);
410
411 if in_flight == 0 {
412 break; }
414
415 if start.elapsed() >= shutdown_timeout {
416 tracing::warn!(
417 "Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
418 self.name,
419 in_flight
420 );
421 break;
422 }
423
424 tokio::select! {
426 _ = self.task_completion_notify.notified() => {
427 continue;
429 }
430 _ = tokio::time::sleep(Duration::from_millis(100)) => {
431 continue;
433 }
434 }
435 }
436
437 self.metrics_collector.record_in_flight_messages(0);
438 tracing::info!("Worker pool {} shutdown complete", self.name);
439 Ok(())
440 }
441
442 pub fn in_flight_count(&self) -> usize {
444 self.in_flight_tasks.load(Ordering::SeqCst)
445 }
446}
447
448impl HealthCheck for WorkerPool {
449 fn check_health(&self) -> HealthStatus {
450 let is_running = self.is_running.load(Ordering::SeqCst);
451 let worker_count = self.worker_count();
452 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
453
454 if !is_running {
456 return HealthStatus::Unhealthy {
457 reason: "Pool is not running".to_string(),
458 };
459 }
460
461 if worker_count == 0 {
463 return HealthStatus::Degraded {
464 reason: "No workers available".to_string(),
465 };
466 }
467
468 let saturation = in_flight as f64 / self.concurrency_limit as f64;
470 if saturation > 0.9 {
471 return HealthStatus::Degraded {
472 reason: format!(
473 "Pool near capacity: {} in-flight messages ({:.0}% saturation)",
474 in_flight,
475 saturation * 100.0
476 ),
477 };
478 }
479
480 HealthStatus::Healthy
481 }
482
483 fn status_message(&self) -> String {
484 let worker_count = self.worker_count();
485 let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
486 let available_permits = self.semaphore.available_permits();
487
488 match self.check_health() {
489 HealthStatus::Healthy => {
490 format!(
491 "WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
492 self.name, worker_count, in_flight, available_permits
493 )
494 }
495 HealthStatus::Degraded { ref reason } => {
496 format!(
497 "WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
498 self.name, reason, worker_count, in_flight
499 )
500 }
501 HealthStatus::Unhealthy { ref reason } => {
502 format!(
503 "WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
504 self.name, reason, worker_count, in_flight
505 )
506 }
507 }
508 }
509}
510
511struct WorkerHandler(Arc<dyn Worker>);
513
514#[async_trait::async_trait]
515impl MessageHandler for WorkerHandler {
516 async fn handle(
517 &self,
518 message: ReceivedMessage<serde_json::Value>,
519 ) -> Result<MiddlewareResult, WorkerError> {
520 self.0.process(message).await?;
522 Ok(MiddlewareResult::Continue)
523 }
524}
525
526struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
528
529#[async_trait::async_trait]
530impl Middleware for ArcMiddlewareWrapper {
531 fn name(&self) -> &str {
532 self.0.name()
533 }
534
535 async fn handle(
536 &self,
537 message: ReceivedMessage<serde_json::Value>,
538 next: Box<dyn MessageHandler>,
539 ) -> Result<MiddlewareResult, WorkerError> {
540 self.0.handle(message, next).await
541 }
542}
543
544struct ArcHandlerWrapper(Box<dyn MessageHandler>);
546
547#[async_trait::async_trait]
548impl MessageHandler for ArcHandlerWrapper {
549 async fn handle(
550 &self,
551 message: ReceivedMessage<serde_json::Value>,
552 ) -> Result<MiddlewareResult, WorkerError> {
553 self.0.handle(message).await
554 }
555}
556
557async fn retry_ack(
561 ack_handle: &Arc<dyn AckHandle>,
562 message_id: &str,
563) -> WorkerResult<()> {
564 let max_retries = 3;
565 let base_delay_ms = 100;
566
567 for attempt in 0..max_retries {
568 match ack_handle.ack().await {
569 Ok(_) => return Ok(()),
570 Err(e) => {
571 if attempt < max_retries - 1 {
572 let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
573 tracing::warn!(
574 "Attempt {} failed to ack message {}: {}. Retrying in {:?}",
575 attempt + 1,
576 message_id,
577 e,
578 delay
579 );
580 sleep(delay).await;
581 } else {
582 return Err(e);
583 }
584 }
585 }
586 }
587
588 unreachable!()
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use crate::message::{AckHandle, Message, MessageMetadata};
595 use crate::metrics::NoOpMetrics;
596 use async_trait::async_trait;
597 use std::sync::atomic::{AtomicUsize, Ordering};
598 use std::time::Duration;
599 #[derive(Debug)]
602 struct MockAckHandle;
603
604 #[async_trait]
605 impl AckHandle for MockAckHandle {
606 async fn ack(&self) -> WorkerResult<()> {
607 Ok(())
608 }
609
610 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
611 Ok(())
612 }
613 }
614
615 struct TestWorker {
616 id: String,
617 process_count: Arc<AtomicUsize>,
618 }
619
620 impl TestWorker {
621 fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
622 let count = Arc::new(AtomicUsize::new(0));
623 (
624 Self {
625 id: id.to_string(),
626 process_count: count.clone(),
627 },
628 count,
629 )
630 }
631 }
632
633 #[async_trait]
634 impl Worker for TestWorker {
635 fn id(&self) -> &str {
636 &self.id
637 }
638
639 async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
640 self.process_count.fetch_add(1, Ordering::SeqCst);
641 Ok(())
642 }
643 }
644
645 fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
646 let message = Message {
647 id: id.to_string(),
648 payload: serde_json::json!({"test": "data"}),
649 metadata: MessageMetadata::new("test-queue"),
650 };
651 ReceivedMessage::new(message, Arc::new(MockAckHandle))
652 }
653
654 #[tokio::test]
655 async fn test_pool_creation() {
656 let pool = WorkerPool::new(
657 "test-pool",
658 LoadBalancingStrategy::RoundRobin,
659 Arc::new(NoOpMetrics),
660 );
661 assert_eq!(pool.name(), "test-pool");
662 assert_eq!(pool.worker_count(), 0);
663 assert!(pool.is_running.load(Ordering::SeqCst));
664 }
665
666 #[tokio::test]
667 async fn test_add_worker() {
668 let mut pool = WorkerPool::new(
669 "test-pool",
670 LoadBalancingStrategy::RoundRobin,
671 Arc::new(NoOpMetrics),
672 );
673 let (worker, _) = TestWorker::new("worker-1");
674 pool.add_worker(Arc::new(worker));
675
676 assert_eq!(pool.worker_count(), 1);
677 }
678
679 #[tokio::test]
680 async fn test_dispatch_empty_pool() {
681 let pool = WorkerPool::new(
682 "test-pool",
683 LoadBalancingStrategy::RoundRobin,
684 Arc::new(NoOpMetrics),
685 );
686 let message = create_test_message("msg-1");
687
688 let result = pool.dispatch(message).await;
689 assert!(matches!(result, Err(WorkerError::PoolExhausted)));
690 }
691
692 #[tokio::test]
693 async fn test_round_robin_distribution() {
694 let mut pool = WorkerPool::new(
695 "test-pool",
696 LoadBalancingStrategy::RoundRobin,
697 Arc::new(NoOpMetrics),
698 );
699
700 let (worker1, count1) = TestWorker::new("worker-1");
701 let (worker2, count2) = TestWorker::new("worker-2");
702
703 pool.add_worker(Arc::new(worker1));
704 pool.add_worker(Arc::new(worker2));
705
706 for i in 0..4 {
708 let message = create_test_message(&format!("msg-{}", i));
709 pool.dispatch(message).await.unwrap();
710 }
711
712 tokio::time::sleep(Duration::from_millis(100)).await;
714
715 assert_eq!(count1.load(Ordering::SeqCst), 2);
717 assert_eq!(count2.load(Ordering::SeqCst), 2);
718 }
719
720 #[tokio::test]
721 async fn test_pool_health() {
722 let pool = WorkerPool::new(
723 "test-pool",
724 LoadBalancingStrategy::RoundRobin,
725 Arc::new(NoOpMetrics),
726 );
727 assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. })); let mut pool = pool;
730 let (worker, _) = TestWorker::new("worker-1");
731 pool.add_worker(Arc::new(worker));
732
733 assert!(matches!(pool.check_health(), HealthStatus::Healthy));
734 }
735
736 #[tokio::test]
737 async fn test_concurrency_limit_enforcement() {
738 use std::sync::atomic::{AtomicUsize, Ordering};
739
740 let concurrent_count = Arc::new(AtomicUsize::new(0));
742 let max_concurrent = Arc::new(AtomicUsize::new(0));
743
744 struct ConcurrentTestWorker {
745 id: String,
746 concurrent: Arc<AtomicUsize>,
747 max_concurrent: Arc<AtomicUsize>,
748 }
749
750 #[async_trait::async_trait]
751 impl Worker for ConcurrentTestWorker {
752 fn id(&self) -> &str {
753 &self.id
754 }
755
756 async fn process(
757 &self,
758 _message: ReceivedMessage<serde_json::Value>,
759 ) -> WorkerResult<()> {
760 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
762
763 let mut max = self.max_concurrent.load(Ordering::SeqCst);
765 while current > max {
766 match self.max_concurrent.compare_exchange_weak(
767 max,
768 current,
769 Ordering::SeqCst,
770 Ordering::SeqCst,
771 ) {
772 Ok(_) => break,
773 Err(new_max) => max = new_max,
774 }
775 }
776
777 tokio::time::sleep(Duration::from_millis(50)).await;
779
780 self.concurrent.fetch_sub(1, Ordering::SeqCst);
782 Ok(())
783 }
784 }
785
786 let mut pool = WorkerPool::with_concurrency(
788 "test-pool",
789 LoadBalancingStrategy::RoundRobin,
790 3, Arc::new(NoOpMetrics),
792 );
793
794 let worker = ConcurrentTestWorker {
796 id: "worker-1".to_string(),
797 concurrent: concurrent_count.clone(),
798 max_concurrent: max_concurrent.clone(),
799 };
800 pool.add_worker(Arc::new(worker));
801
802 for i in 0..10 {
804 let message = create_test_message(&format!("msg-{}", i));
805 pool.dispatch(message).await.unwrap();
806 }
807
808 tokio::time::sleep(Duration::from_millis(500)).await;
810
811 let actual_max = max_concurrent.load(Ordering::SeqCst);
813 assert!(
814 actual_max <= 3,
815 "Expected max concurrency <= 3, but got {}",
816 actual_max
817 );
818 assert!(
819 actual_max >= 2,
820 "Expected some concurrency (>= 2), but got {}",
821 actual_max
822 );
823 }
824
825 #[tokio::test]
826 async fn test_concurrency_limit_with_builder() {
827 use crate::builder::WorkerPoolBuilder;
828
829 let concurrent_count = Arc::new(AtomicUsize::new(0));
830 let max_concurrent = Arc::new(AtomicUsize::new(0));
831
832 struct TrackedWorker {
833 id: String,
834 concurrent: Arc<AtomicUsize>,
835 max_concurrent: Arc<AtomicUsize>,
836 }
837
838 #[async_trait::async_trait]
839 impl Worker for TrackedWorker {
840 fn id(&self) -> &str {
841 &self.id
842 }
843
844 async fn process(
845 &self,
846 _message: ReceivedMessage<serde_json::Value>,
847 ) -> WorkerResult<()> {
848 let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
850
851 let mut max = self.max_concurrent.load(Ordering::SeqCst);
852 while current > max {
853 match self.max_concurrent.compare_exchange_weak(
854 max,
855 current,
856 Ordering::SeqCst,
857 Ordering::SeqCst,
858 ) {
859 Ok(_) => break,
860 Err(new_max) => max = new_max,
861 }
862 }
863
864 tokio::time::sleep(Duration::from_millis(100)).await;
865
866 self.concurrent.fetch_sub(1, Ordering::SeqCst);
867 Ok(())
868 }
869 }
870
871 let pool = WorkerPoolBuilder::new("test-pool")
873 .with_concurrency_limit(2)
874 .add_worker(TrackedWorker {
875 id: "worker-1".to_string(),
876 concurrent: concurrent_count.clone(),
877 max_concurrent: max_concurrent.clone(),
878 })
879 .build()
880 .unwrap();
881
882 for i in 0..6 {
884 let message = create_test_message(&format!("msg-{}", i));
885 pool.dispatch(message).await.unwrap();
886 }
887
888 tokio::time::sleep(Duration::from_millis(800)).await;
890
891 let actual_max = max_concurrent.load(Ordering::SeqCst);
893 assert!(
894 actual_max <= 2,
895 "Expected max concurrency <= 2, but got {}",
896 actual_max
897 );
898 assert!(
899 actual_max >= 1,
900 "Expected some concurrency (>= 1), but got {}",
901 actual_max
902 );
903 }
904
905 #[tokio::test]
906 async fn test_different_concurrency_limits() {
907 let pool1 = WorkerPool::with_concurrency(
909 "pool1",
910 LoadBalancingStrategy::RoundRobin,
911 5,
912 Arc::new(NoOpMetrics),
913 );
914 let pool2 = WorkerPool::with_concurrency(
915 "pool2",
916 LoadBalancingStrategy::RoundRobin,
917 20,
918 Arc::new(NoOpMetrics),
919 );
920
921 assert_eq!(pool1.name(), "pool1");
924 assert_eq!(pool2.name(), "pool2");
925 }
926
927 #[tokio::test]
928 async fn test_pool_shutdown_prevents_dispatch() {
929 let mut pool = WorkerPool::new(
930 "test-pool",
931 LoadBalancingStrategy::RoundRobin,
932 Arc::new(NoOpMetrics),
933 );
934 let (worker, _) = TestWorker::new("worker-1");
935 pool.add_worker(Arc::new(worker));
936
937 pool.shutdown().await.unwrap();
938
939 let message = create_test_message("msg-after-shutdown");
940 let result = pool.dispatch(message).await;
941 assert!(matches!(result, Err(WorkerError::Shutdown)));
942 assert!(matches!(
943 pool.check_health(),
944 HealthStatus::Unhealthy { .. }
945 ));
946 }
947
948 #[tokio::test]
951 async fn test_no_double_ack_with_middleware() {
952 use crate::AckNackMiddleware;
953 use std::sync::atomic::AtomicUsize;
954
955 let ack_count = Arc::new(AtomicUsize::new(0));
957 let nack_count = Arc::new(AtomicUsize::new(0));
958
959 #[derive(Debug)]
960 struct TrackingAckHandle {
961 ack_count: Arc<AtomicUsize>,
962 nack_count: Arc<AtomicUsize>,
963 }
964
965 #[async_trait]
966 impl AckHandle for TrackingAckHandle {
967 async fn ack(&self) -> WorkerResult<()> {
968 self.ack_count.fetch_add(1, Ordering::SeqCst);
969 Ok(())
970 }
971
972 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
973 self.nack_count.fetch_add(1, Ordering::SeqCst);
974 Ok(())
975 }
976 }
977
978 struct SuccessWorker;
980
981 #[async_trait]
982 impl Worker for SuccessWorker {
983 fn id(&self) -> &str {
984 "success-worker"
985 }
986
987 async fn process(
988 &self,
989 _message: ReceivedMessage<serde_json::Value>,
990 ) -> WorkerResult<()> {
991 Ok(())
992 }
993 }
994
995 let mut pool = WorkerPool::new(
996 "test-pool",
997 LoadBalancingStrategy::RoundRobin,
998 Arc::new(NoOpMetrics),
999 );
1000
1001 pool.middlewares
1003 .push(Arc::new(AckNackMiddleware::default()));
1004 pool.add_worker(Arc::new(SuccessWorker));
1005
1006 let message = Message {
1008 id: "test-msg".to_string(),
1009 payload: serde_json::json!({"test": "data"}),
1010 metadata: MessageMetadata::new("test-queue"),
1011 };
1012 let received = ReceivedMessage::new(
1013 message,
1014 Arc::new(TrackingAckHandle {
1015 ack_count: ack_count.clone(),
1016 nack_count: nack_count.clone(),
1017 }),
1018 );
1019
1020 pool.dispatch(received).await.unwrap();
1022
1023 tokio::time::sleep(Duration::from_millis(100)).await;
1025
1026 assert_eq!(
1028 ack_count.load(Ordering::SeqCst),
1029 1,
1030 "Message should have been acked exactly once by middleware"
1031 );
1032 assert_eq!(
1033 nack_count.load(Ordering::SeqCst),
1034 0,
1035 "Message should not have been nacked"
1036 );
1037 }
1038
1039 #[tokio::test]
1041 async fn test_pool_handles_ack_without_middleware() {
1042 use std::sync::atomic::AtomicUsize;
1043
1044 let ack_count = Arc::new(AtomicUsize::new(0));
1045 let nack_count = Arc::new(AtomicUsize::new(0));
1046
1047 #[derive(Debug)]
1048 struct CountingAckHandle {
1049 ack_count: Arc<AtomicUsize>,
1050 nack_count: Arc<AtomicUsize>,
1051 }
1052
1053 #[async_trait]
1054 impl AckHandle for CountingAckHandle {
1055 async fn ack(&self) -> WorkerResult<()> {
1056 self.ack_count.fetch_add(1, Ordering::SeqCst);
1057 Ok(())
1058 }
1059
1060 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1061 self.nack_count.fetch_add(1, Ordering::SeqCst);
1062 Ok(())
1063 }
1064 }
1065
1066 struct SuccessWorker;
1068
1069 #[async_trait]
1070 impl Worker for SuccessWorker {
1071 fn id(&self) -> &str {
1072 "success-worker"
1073 }
1074
1075 async fn process(
1076 &self,
1077 _message: ReceivedMessage<serde_json::Value>,
1078 ) -> WorkerResult<()> {
1079 Ok(())
1080 }
1081 }
1082
1083 let mut pool = WorkerPool::new(
1084 "test-pool",
1085 LoadBalancingStrategy::RoundRobin,
1086 Arc::new(NoOpMetrics),
1087 );
1088 pool.add_worker(Arc::new(SuccessWorker));
1090
1091 let message = Message {
1092 id: "test-msg".to_string(),
1093 payload: serde_json::json!({"test": "data"}),
1094 metadata: MessageMetadata::new("test-queue"),
1095 };
1096 let received = ReceivedMessage::new(
1097 message,
1098 Arc::new(CountingAckHandle {
1099 ack_count: ack_count.clone(),
1100 nack_count: nack_count.clone(),
1101 }),
1102 );
1103
1104 pool.dispatch(received).await.unwrap();
1105
1106 tokio::time::sleep(Duration::from_millis(100)).await;
1108
1109 assert_eq!(
1111 ack_count.load(Ordering::SeqCst),
1112 1,
1113 "Pool should have acked the message when no middleware is used"
1114 );
1115 assert_eq!(
1116 nack_count.load(Ordering::SeqCst),
1117 0,
1118 "Nack should not be called for successful message"
1119 );
1120 }
1121
1122 #[tokio::test]
1124 async fn test_pool_handles_nack_without_middleware() {
1125 use std::sync::atomic::AtomicUsize;
1126
1127 let ack_count = Arc::new(AtomicUsize::new(0));
1128 let nack_count = Arc::new(AtomicUsize::new(0));
1129
1130 #[derive(Debug)]
1131 struct CountingAckHandle {
1132 ack_count: Arc<AtomicUsize>,
1133 nack_count: Arc<AtomicUsize>,
1134 }
1135
1136 #[async_trait]
1137 impl AckHandle for CountingAckHandle {
1138 async fn ack(&self) -> WorkerResult<()> {
1139 self.ack_count.fetch_add(1, Ordering::SeqCst);
1140 Ok(())
1141 }
1142
1143 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1144 self.nack_count.fetch_add(1, Ordering::SeqCst);
1145 Ok(())
1146 }
1147 }
1148
1149 struct FailingWorker;
1151
1152 #[async_trait]
1153 impl Worker for FailingWorker {
1154 fn id(&self) -> &str {
1155 "failing-worker"
1156 }
1157
1158 async fn process(
1159 &self,
1160 _message: ReceivedMessage<serde_json::Value>,
1161 ) -> WorkerResult<()> {
1162 Err(WorkerError::ProcessingFailed(
1163 "Simulated failure".to_string(),
1164 ))
1165 }
1166 }
1167
1168 let mut pool = WorkerPool::new(
1169 "test-pool",
1170 LoadBalancingStrategy::RoundRobin,
1171 Arc::new(NoOpMetrics),
1172 );
1173 pool.add_worker(Arc::new(FailingWorker));
1175
1176 let message = Message {
1177 id: "test-msg".to_string(),
1178 payload: serde_json::json!({"test": "data"}),
1179 metadata: MessageMetadata::new("test-queue"),
1180 };
1181 let received = ReceivedMessage::new(
1182 message,
1183 Arc::new(CountingAckHandle {
1184 ack_count: ack_count.clone(),
1185 nack_count: nack_count.clone(),
1186 }),
1187 );
1188
1189 pool.dispatch(received).await.unwrap();
1190
1191 tokio::time::sleep(Duration::from_millis(100)).await;
1193
1194 assert_eq!(
1196 nack_count.load(Ordering::SeqCst),
1197 1,
1198 "Pool should have nacked the message when no middleware is used"
1199 );
1200 assert_eq!(
1201 ack_count.load(Ordering::SeqCst),
1202 0,
1203 "Ack should not be called for failed message"
1204 );
1205 }
1206
1207 #[tokio::test]
1209 async fn test_integration_ack_nack_middleware_with_failure() {
1210 use crate::AckNackMiddleware;
1211 use std::sync::atomic::AtomicUsize;
1212
1213 let ack_count = Arc::new(AtomicUsize::new(0));
1214 let nack_count = Arc::new(AtomicUsize::new(0));
1215
1216 #[derive(Debug)]
1217 struct CountingAckHandle {
1218 ack_count: Arc<AtomicUsize>,
1219 nack_count: Arc<AtomicUsize>,
1220 }
1221
1222 #[async_trait]
1223 impl AckHandle for CountingAckHandle {
1224 async fn ack(&self) -> WorkerResult<()> {
1225 self.ack_count.fetch_add(1, Ordering::SeqCst);
1226 Ok(())
1227 }
1228
1229 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
1230 self.nack_count.fetch_add(1, Ordering::SeqCst);
1231 Ok(())
1232 }
1233 }
1234
1235 struct FailingWorker;
1237
1238 #[async_trait]
1239 impl Worker for FailingWorker {
1240 fn id(&self) -> &str {
1241 "failing-worker"
1242 }
1243
1244 async fn process(
1245 &self,
1246 _message: ReceivedMessage<serde_json::Value>,
1247 ) -> WorkerResult<()> {
1248 Err(WorkerError::ProcessingFailed(
1249 "Simulated failure".to_string(),
1250 ))
1251 }
1252 }
1253
1254 let mut pool = WorkerPool::new(
1255 "test-pool",
1256 LoadBalancingStrategy::RoundRobin,
1257 Arc::new(NoOpMetrics),
1258 );
1259
1260 pool.middlewares
1262 .push(Arc::new(AckNackMiddleware::default()));
1263 pool.add_worker(Arc::new(FailingWorker));
1264
1265 let message = Message {
1266 id: "test-msg".to_string(),
1267 payload: serde_json::json!({"test": "data"}),
1268 metadata: MessageMetadata::new("test-queue"),
1269 };
1270 let received = ReceivedMessage::new(
1271 message,
1272 Arc::new(CountingAckHandle {
1273 ack_count: ack_count.clone(),
1274 nack_count: nack_count.clone(),
1275 }),
1276 );
1277
1278 pool.dispatch(received).await.unwrap();
1279
1280 tokio::time::sleep(Duration::from_millis(100)).await;
1282
1283 assert_eq!(
1285 nack_count.load(Ordering::SeqCst),
1286 1,
1287 "Nack should be called exactly once by middleware"
1288 );
1289 assert_eq!(
1290 ack_count.load(Ordering::SeqCst),
1291 0,
1292 "Ack should not be called for failed message"
1293 );
1294 }
1295}