use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Notify, Semaphore};
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use crate::error::{WorkerError, WorkerResult};
use crate::health::{HealthCheck, HealthStatus};
use crate::message::ReceivedMessage;
use crate::middleware::{MessageHandler, Middleware, MiddlewareChain};
use crate::metrics::WorkerMetrics;
use crate::strategies::{LoadBalancingStrategy, LeastLoadedBalancer, RandomBalancer, RoundRobinBalancer};
use crate::worker::Worker;
pub struct WorkerPool {
name: String,
workers: Vec<Arc<dyn Worker>>,
strategy: LoadBalancingStrategy,
semaphore: Arc<Semaphore>,
concurrency_limit: usize,
least_loaded_balancer: Option<Arc<LeastLoadedBalancer>>,
round_robin_balancer: Arc<RoundRobinBalancer>,
random_balancer: RandomBalancer,
middlewares: Vec<Arc<dyn Middleware>>,
metrics_collector: Arc<dyn WorkerMetrics>,
is_running: Arc<AtomicBool>,
cancellation_token: CancellationToken,
task_completion_notify: Arc<Notify>,
in_flight_tasks: Arc<AtomicUsize>,
}
impl std::fmt::Debug for WorkerPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerPool")
.field("name", &self.name)
.field("worker_count", &self.workers.len())
.field("strategy", &self.strategy)
.field("is_running", &self.is_running.load(Ordering::SeqCst))
.finish()
}
}
impl WorkerPool {
pub fn new(
name: impl Into<String>,
strategy: LoadBalancingStrategy,
metrics_collector: Arc<dyn WorkerMetrics>,
) -> Self {
Self::with_concurrency(name, strategy, 1000, metrics_collector)
}
pub fn with_concurrency(
name: impl Into<String>,
strategy: LoadBalancingStrategy,
concurrency_limit: usize,
metrics_collector: Arc<dyn WorkerMetrics>,
) -> Self {
let least_loaded_balancer = if matches!(strategy, LoadBalancingStrategy::LeastLoaded) {
Some(Arc::new(LeastLoadedBalancer::new(0)))
} else {
None
};
Self {
name: name.into(),
workers: Vec::new(),
strategy,
semaphore: Arc::new(Semaphore::new(concurrency_limit)),
concurrency_limit,
least_loaded_balancer,
round_robin_balancer: Arc::new(RoundRobinBalancer::new()),
random_balancer: RandomBalancer,
middlewares: Vec::new(),
metrics_collector,
is_running: Arc::new(AtomicBool::new(true)),
cancellation_token: CancellationToken::new(),
task_completion_notify: Arc::new(Notify::new()),
in_flight_tasks: Arc::new(AtomicUsize::new(0)),
}
}
pub fn add_worker(&mut self, worker: Arc<dyn Worker>) {
self.workers.push(worker);
if let Some(ref balancer) = self.least_loaded_balancer {
balancer.add_worker();
}
self.metrics_collector.record_active_workers(self.workers.len());
}
pub fn add_workers(&mut self, workers: Vec<Arc<dyn Worker>>) {
for worker in workers {
self.add_worker(worker);
}
}
pub fn worker_count(&self) -> usize {
self.workers.len()
}
pub fn with_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
self.middlewares = middlewares;
self
}
pub fn name(&self) -> &str {
&self.name
}
pub async fn dispatch(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
if !self.is_running.load(Ordering::SeqCst) {
return Err(WorkerError::Shutdown);
}
if self.workers.is_empty() {
return Err(WorkerError::PoolExhausted);
}
let worker_index = self.select_worker();
let worker = self.workers[worker_index].clone();
let worker_id = worker.id().to_string();
let queue_name = message.message.metadata.source.clone();
self.metrics_collector.record_message_received(&worker_id, &queue_name);
let start_time = Instant::now();
if let Some(ref balancer) = self.least_loaded_balancer {
balancer.increment_load(worker_index);
}
let permit = self.semaphore.clone().acquire_owned().await
.map_err(|_| WorkerError::Shutdown)?;
self.metrics_collector.record_in_flight_messages(self.semaphore.available_permits());
let handler: Arc<dyn MessageHandler> = if !self.middlewares.is_empty() {
let worker_handler = WorkerHandler(worker);
let boxed_middlewares: Vec<Box<dyn Middleware>> = self.middlewares.iter()
.map(|m| Box::new(ArcMiddlewareWrapper(m.clone())) as Box<dyn Middleware>)
.collect();
let chain = MiddlewareChain::new(boxed_middlewares, Box::new(worker_handler));
Arc::new(ArcHandlerWrapper(chain.build()))
} else {
Arc::new(WorkerHandler(worker))
};
let metrics_collector_clone = self.metrics_collector.clone();
let least_loaded_balancer = self.least_loaded_balancer.clone();
let cancellation_token = self.cancellation_token.child_token();
let task_completion_notify = self.task_completion_notify.clone();
let in_flight_tasks = self.in_flight_tasks.clone();
in_flight_tasks.fetch_add(1, Ordering::SeqCst);
let ack_handle = message.ack_handle.clone();
let message_id = message.message.id.clone();
let attempt = message.message.metadata.attempt;
tokio::spawn(async move {
let result = tokio::select! {
result = handler.handle(message) => result, _ = cancellation_token.cancelled() => {
tracing::warn!("Message {} processing cancelled due to shutdown", message_id);
in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
task_completion_notify.notify_one();
return;
}
};
match result {
Ok(_) => {
tracing::debug!("Message {} processed successfully", message_id);
metrics_collector_clone.record_message_processed(&worker_id, &queue_name, start_time);
if let Err(e) = retry_ack(&ack_handle, &message_id).await {
tracing::error!("Failed to ack message {} after retries: {}. Message may be redelivered.", message_id, e);
}
}
Err(WorkerError::RetryableFailure { source, delay_ms }) => {
tracing::warn!(
"Message {} failed (will retry in {:?}): {}",
message_id,
delay_ms,
source
);
metrics_collector_clone.record_message_retried(&worker_id, &queue_name, attempt);
metrics_collector_clone.record_message_failed(&worker_id, &queue_name, "RetryableFailure", start_time);
if let Err(e) = ack_handle.nack(true).await {
tracing::error!("Failed to requeue message {}: {}", message_id, e);
}
sleep(delay_ms).await;
}
Err(WorkerError::RetriesExhausted { source }) => {
tracing::error!(
"Message {} exhausted all retries, sending to DLQ: {}",
message_id,
source
);
metrics_collector_clone.record_message_retries_exhausted(&worker_id, &queue_name);
metrics_collector_clone.record_message_failed(&worker_id, &queue_name, "RetriesExhausted", start_time);
if let Err(e) = ack_handle.nack(false).await {
tracing::error!("Failed to send message {} to DLQ: {}", message_id, e);
}
}
Err(e) => {
if matches!(e, WorkerError::AlreadyAcknowledged) {
in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
task_completion_notify.notify_one();
return;
}
let error_type = format!("{:?}", e);
tracing::error!("Message {} failed: {}", message_id, e);
metrics_collector_clone.record_message_failed(&worker_id, &queue_name, &error_type, start_time);
if let Err(nack_err) = ack_handle.nack(false).await {
tracing::error!("Failed to nack message {}: {}", message_id, nack_err);
}
}
}
if let Some(ref balancer) = least_loaded_balancer {
balancer.decrement_load(worker_index);
}
drop(permit);
in_flight_tasks.fetch_sub(1, Ordering::SeqCst);
task_completion_notify.notify_one();
});
Ok(())
}
fn select_worker(&self) -> usize {
match self.strategy {
LoadBalancingStrategy::RoundRobin => {
self.round_robin_balancer.next(self.workers.len())
}
LoadBalancingStrategy::Random => {
self.random_balancer.next(self.workers.len())
}
LoadBalancingStrategy::LeastLoaded => {
if let Some(ref balancer) = self.least_loaded_balancer {
balancer.next()
} else {
0 }
}
}
}
pub async fn shutdown(&self) -> WorkerResult<()> {
tracing::info!("Shutting down worker pool: {}", self.name);
self.is_running.store(false, Ordering::SeqCst);
self.metrics_collector.record_active_workers(0);
self.cancellation_token.cancel();
tracing::info!("Cancelled all in-flight tasks for pool {}", self.name);
self.semaphore.close();
let shutdown_timeout = Duration::from_secs(30); let start = Instant::now();
loop {
let available = self.semaphore.available_permits();
let in_flight = self.concurrency_limit.saturating_sub(available);
if in_flight == 0 {
break; }
if start.elapsed() >= shutdown_timeout {
tracing::warn!(
"Shutdown timeout reached for pool {}. {} tasks still running. Forcing shutdown.",
self.name, in_flight
);
break;
}
tokio::select! {
_ = self.task_completion_notify.notified() => {
continue;
}
_ = tokio::time::sleep(Duration::from_millis(100)) => {
continue;
}
}
}
self.metrics_collector.record_in_flight_messages(0);
tracing::info!("Worker pool {} shutdown complete", self.name);
Ok(())
}
pub fn in_flight_count(&self) -> usize {
self.in_flight_tasks.load(Ordering::SeqCst)
}
}
impl HealthCheck for WorkerPool {
fn check_health(&self) -> HealthStatus {
let is_running = self.is_running.load(Ordering::SeqCst);
let worker_count = self.worker_count();
let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
if !is_running {
return HealthStatus::Unhealthy {
reason: "Pool is not running".to_string()
};
}
if worker_count == 0 {
return HealthStatus::Degraded {
reason: "No workers available".to_string()
};
}
let saturation = in_flight as f64 / self.concurrency_limit as f64;
if saturation > 0.9 {
return HealthStatus::Degraded {
reason: format!("Pool near capacity: {} in-flight messages ({:.0}% saturation)",
in_flight, saturation * 100.0)
};
}
HealthStatus::Healthy
}
fn status_message(&self) -> String {
let worker_count = self.worker_count();
let in_flight = self.in_flight_tasks.load(Ordering::SeqCst);
let available_permits = self.semaphore.available_permits();
match self.check_health() {
HealthStatus::Healthy => {
format!(
"WorkerPool '{}' is healthy with {} workers. {} in-flight, {} available permits.",
self.name, worker_count, in_flight, available_permits
)
}
HealthStatus::Degraded { ref reason } => {
format!(
"WorkerPool '{}' is degraded: {}. {} workers, {} in-flight.",
self.name, reason, worker_count, in_flight
)
}
HealthStatus::Unhealthy { ref reason } => {
format!(
"WorkerPool '{}' is unhealthy: {}. {} workers, {} in-flight.",
self.name, reason, worker_count, in_flight
)
}
}
}
}
struct WorkerHandler(Arc<dyn Worker>);
#[async_trait::async_trait]
impl MessageHandler for WorkerHandler {
async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
self.0.process(message).await
}
}
struct ArcMiddlewareWrapper(Arc<dyn Middleware>);
#[async_trait::async_trait]
impl Middleware for ArcMiddlewareWrapper {
fn name(&self) -> &str {
self.0.name()
}
async fn handle(
&self,
message: ReceivedMessage<serde_json::Value>,
next: Box<dyn MessageHandler>,
) -> WorkerResult<()> {
self.0.handle(message, next).await
}
}
struct ArcHandlerWrapper(Box<dyn MessageHandler>);
#[async_trait::async_trait]
impl MessageHandler for ArcHandlerWrapper {
async fn handle(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
self.0.handle(message).await
}
}
async fn retry_ack(ack_handle: &Arc<dyn crate::message::AckHandle>, message_id: &str) -> WorkerResult<()> {
let max_retries = 3;
let base_delay_ms = 100;
for attempt in 0..max_retries {
match ack_handle.ack().await {
Ok(_) => return Ok(()),
Err(e) => {
if attempt < max_retries - 1 {
let delay = Duration::from_millis(base_delay_ms * (2u64.pow(attempt as u32)));
tracing::warn!(
"Attempt {} failed to ack message {}: {}. Retrying in {:?}",
attempt + 1,
message_id,
e,
delay
);
sleep(delay).await;
} else {
return Err(e);
}
}
}
}
unreachable!()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageMetadata, AckHandle};
use async_trait::async_trait;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use crate::metrics::NoOpMetrics;
#[derive(Debug)]
struct MockAckHandle;
#[async_trait]
impl AckHandle for MockAckHandle {
async fn ack(&self) -> WorkerResult<()> {
Ok(())
}
async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
Ok(())
}
}
struct TestWorker {
id: String,
process_count: Arc<AtomicUsize>,
}
impl TestWorker {
fn new(id: &str) -> (Self, Arc<AtomicUsize>) {
let count = Arc::new(AtomicUsize::new(0));
(
Self {
id: id.to_string(),
process_count: count.clone(),
},
count,
)
}
}
#[async_trait]
impl Worker for TestWorker {
fn id(&self) -> &str {
&self.id
}
async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
self.process_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
fn create_test_message(id: &str) -> ReceivedMessage<serde_json::Value> {
let message = Message {
id: id.to_string(),
payload: serde_json::json!({"test": "data"}),
metadata: MessageMetadata::new("test-queue"),
};
ReceivedMessage::new(message, Arc::new(MockAckHandle))
}
#[tokio::test]
async fn test_pool_creation() {
let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
assert_eq!(pool.name(), "test-pool");
assert_eq!(pool.worker_count(), 0);
assert!(pool.is_running.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_add_worker() {
let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
let (worker, _) = TestWorker::new("worker-1");
pool.add_worker(Arc::new(worker));
assert_eq!(pool.worker_count(), 1);
}
#[tokio::test]
async fn test_dispatch_empty_pool() {
let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
let message = create_test_message("msg-1");
let result = pool.dispatch(message).await;
assert!(matches!(result, Err(WorkerError::PoolExhausted)));
}
#[tokio::test]
async fn test_round_robin_distribution() {
let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
let (worker1, count1) = TestWorker::new("worker-1");
let (worker2, count2) = TestWorker::new("worker-2");
pool.add_worker(Arc::new(worker1));
pool.add_worker(Arc::new(worker2));
for i in 0..4 {
let message = create_test_message(&format!("msg-{}", i));
pool.dispatch(message).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(count1.load(Ordering::SeqCst), 2);
assert_eq!(count2.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_pool_health() {
let pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
assert!(matches!(pool.check_health(), HealthStatus::Degraded { .. }));
let mut pool = pool;
let (worker, _) = TestWorker::new("worker-1");
pool.add_worker(Arc::new(worker));
assert!(matches!(pool.check_health(), HealthStatus::Healthy));
}
#[tokio::test]
async fn test_concurrency_limit_enforcement() {
use std::sync::atomic::{AtomicUsize, Ordering};
let concurrent_count = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
struct ConcurrentTestWorker {
id: String,
concurrent: Arc<AtomicUsize>,
max_concurrent: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl Worker for ConcurrentTestWorker {
fn id(&self) -> &str {
&self.id
}
async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
let mut max = self.max_concurrent.load(Ordering::SeqCst);
while current > max {
match self.max_concurrent.compare_exchange_weak(
max,
current,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break,
Err(new_max) => max = new_max,
}
}
tokio::time::sleep(Duration::from_millis(50)).await;
self.concurrent.fetch_sub(1, Ordering::SeqCst);
Ok(())
}
}
let mut pool = WorkerPool::with_concurrency(
"test-pool",
LoadBalancingStrategy::RoundRobin,
3, Arc::new(NoOpMetrics),
);
let worker = ConcurrentTestWorker {
id: "worker-1".to_string(),
concurrent: concurrent_count.clone(),
max_concurrent: max_concurrent.clone(),
};
pool.add_worker(Arc::new(worker));
for i in 0..10 {
let message = create_test_message(&format!("msg-{}", i));
pool.dispatch(message).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(500)).await;
let actual_max = max_concurrent.load(Ordering::SeqCst);
assert!(
actual_max <= 3,
"Expected max concurrency <= 3, but got {}",
actual_max
);
assert!(
actual_max >= 2,
"Expected some concurrency (>= 2), but got {}",
actual_max
);
}
#[tokio::test]
async fn test_concurrency_limit_with_builder() {
use crate::builder::WorkerPoolBuilder;
let concurrent_count = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
struct TrackedWorker {
id: String,
concurrent: Arc<AtomicUsize>,
max_concurrent: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl Worker for TrackedWorker {
fn id(&self) -> &str {
&self.id
}
async fn process(&self, _message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
let current = self.concurrent.fetch_add(1, Ordering::SeqCst) + 1;
let mut max = self.max_concurrent.load(Ordering::SeqCst);
while current > max {
match self.max_concurrent.compare_exchange_weak(
max,
current,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break,
Err(new_max) => max = new_max,
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
self.concurrent.fetch_sub(1, Ordering::SeqCst);
Ok(())
}
}
let pool = WorkerPoolBuilder::new("test-pool")
.with_concurrency_limit(2)
.add_worker(TrackedWorker {
id: "worker-1".to_string(),
concurrent: concurrent_count.clone(),
max_concurrent: max_concurrent.clone(),
})
.build()
.unwrap();
for i in 0..6 {
let message = create_test_message(&format!("msg-{}", i));
pool.dispatch(message).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(800)).await;
let actual_max = max_concurrent.load(Ordering::SeqCst);
assert!(
actual_max <= 2,
"Expected max concurrency <= 2, but got {}",
actual_max
);
assert!(
actual_max >= 1,
"Expected some concurrency (>= 1), but got {}",
actual_max
);
}
#[tokio::test]
async fn test_different_concurrency_limits() {
let pool1 = WorkerPool::with_concurrency("pool1", LoadBalancingStrategy::RoundRobin, 5, Arc::new(NoOpMetrics));
let pool2 = WorkerPool::with_concurrency("pool2", LoadBalancingStrategy::RoundRobin, 20, Arc::new(NoOpMetrics));
assert_eq!(pool1.name(), "pool1");
assert_eq!(pool2.name(), "pool2");
}
#[tokio::test]
async fn test_pool_shutdown_prevents_dispatch() {
let mut pool = WorkerPool::new("test-pool", LoadBalancingStrategy::RoundRobin, Arc::new(NoOpMetrics));
let (worker, _) = TestWorker::new("worker-1");
pool.add_worker(Arc::new(worker));
pool.shutdown().await.unwrap();
let message = create_test_message("msg-after-shutdown");
let result = pool.dispatch(message).await;
assert!(matches!(result, Err(WorkerError::Shutdown)));
assert!(matches!(pool.check_health(), HealthStatus::Unhealthy { .. }));
}
}