use std::sync::Arc;
use tokio::sync::{Mutex, Semaphore};
use tokio::task::JoinSet;
#[derive(Debug, Clone)]
pub struct JobResult {
pub job_id: String,
pub success: bool,
pub message: Option<String>,
}
#[derive(Debug)]
pub struct JobError {
pub message: String,
}
impl From<String> for JobError {
fn from(message: String) -> Self {
Self { message }
}
}
impl std::fmt::Display for JobError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for JobError {}
pub trait AsyncJob: Send + 'static {
fn execute(self) -> impl std::future::Future<Output = Result<JobResult, JobError>> + Send;
fn job_id(&self) -> String;
}
pub struct JobHandle {
#[allow(dead_code)]
pub job_id: String,
}
pub struct WorkerPool<T: AsyncJob> {
workers: usize,
semaphore: Arc<Semaphore>,
active_jobs: Arc<Mutex<JoinSet<Result<JobResult, JobError>>>>,
shutting_down: Arc<Mutex<bool>>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: AsyncJob> WorkerPool<T> {
pub fn new(workers: usize) -> Self {
Self {
workers,
semaphore: Arc::new(Semaphore::new(workers)),
active_jobs: Arc::new(Mutex::new(JoinSet::new())),
shutting_down: Arc::new(Mutex::new(false)),
_phantom: std::marker::PhantomData,
}
}
pub fn total_workers(&self) -> usize {
self.workers
}
pub fn active_workers(&self) -> usize {
self.workers - self.semaphore.available_permits()
}
pub fn available_workers(&self) -> usize {
self.semaphore.available_permits()
}
pub async fn try_submit(&self, job: T) -> Result<Option<JobHandle>, JobError> {
if *self.shutting_down.lock().await {
return Err(JobError::from("Worker pool is shutting down".to_string()));
}
let permit = match self.semaphore.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => return Ok(None),
};
let job_id = job.job_id();
let mut jobs = self.active_jobs.lock().await;
jobs.spawn(async move {
let _permit = permit;
job.execute().await
});
drop(jobs);
Ok(Some(JobHandle { job_id }))
}
pub async fn poll_completed(&self) -> Vec<Result<JobResult, JobError>> {
let mut results = Vec::new();
let mut jobs = self.active_jobs.lock().await;
while let Some(result) = jobs.try_join_next() {
match result {
Ok(job_result) => results.push(job_result),
Err(e) => results.push(Err(JobError::from(format!("Job panicked: {}", e)))),
}
}
results
}
pub async fn shutdown(&self) -> Result<Vec<Result<JobResult, JobError>>, JobError> {
*self.shutting_down.lock().await = true;
let mut results = Vec::new();
let mut jobs = self.active_jobs.lock().await;
while let Some(result) = jobs.join_next().await {
match result {
Ok(job_result) => results.push(job_result),
Err(e) => results.push(Err(JobError::from(format!("Job panicked: {}", e)))),
}
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
struct TestJob {
id: String,
duration_ms: u64,
should_fail: bool,
}
impl AsyncJob for TestJob {
async fn execute(self) -> Result<JobResult, JobError> {
sleep(Duration::from_millis(self.duration_ms)).await;
if self.should_fail {
Err(JobError::from(format!(
"Job {} failed as requested",
self.id
)))
} else {
Ok(JobResult {
job_id: self.id.clone(),
success: true,
message: Some(format!(
"Job {} completed after {}ms",
self.id, self.duration_ms
)),
})
}
}
fn job_id(&self) -> String {
self.id.clone()
}
}
#[tokio::test]
async fn test_worker_pool_basic_execution() {
let pool = WorkerPool::new(2);
let job = TestJob {
id: "test-1".to_string(),
duration_ms: 10,
should_fail: false,
};
pool.try_submit(job).await.unwrap().unwrap();
sleep(Duration::from_millis(50)).await;
let completed = pool.poll_completed().await;
assert_eq!(completed.len(), 1, "Should have 1 completed job");
assert!(completed[0].is_ok());
assert!(completed[0].as_ref().unwrap().success);
assert_eq!(completed[0].as_ref().unwrap().job_id, "test-1");
}
#[tokio::test]
async fn test_worker_pool_concurrent_execution() {
let pool = WorkerPool::new(2);
let mut submitted = 0;
let mut attempts = 0;
while submitted < 4 && attempts < 100 {
let job = TestJob {
id: format!("job-{}", submitted),
duration_ms: 50,
should_fail: false,
};
if pool.try_submit(job).await.unwrap().is_some() {
submitted += 1;
} else {
sleep(Duration::from_millis(10)).await;
}
attempts += 1;
}
assert_eq!(submitted, 4, "Should have submitted all 4 jobs");
sleep(Duration::from_millis(150)).await;
let mut all_completed = Vec::new();
loop {
let completed = pool.poll_completed().await;
if completed.is_empty() {
break;
}
all_completed.extend(completed);
}
assert_eq!(all_completed.len(), 4, "All 4 jobs should complete");
for result in &all_completed {
assert!(result.is_ok());
assert!(result.as_ref().unwrap().success);
}
}
#[tokio::test]
async fn test_worker_pool_concurrency_limit() {
let pool = WorkerPool::new(2);
let start = tokio::time::Instant::now();
let mut submitted = 0;
let mut attempts = 0;
while submitted < 3 && attempts < 100 {
let job = TestJob {
id: format!("job-{}", submitted),
duration_ms: 100,
should_fail: false,
};
if pool.try_submit(job).await.unwrap().is_some() {
submitted += 1;
} else {
sleep(Duration::from_millis(10)).await;
}
attempts += 1;
}
assert_eq!(submitted, 3, "Should have submitted all 3 jobs");
sleep(Duration::from_millis(250)).await;
let mut total_completed = 0;
loop {
let completed = pool.poll_completed().await;
total_completed += completed.len();
if total_completed >= 3 {
break;
}
sleep(Duration::from_millis(10)).await;
}
let elapsed = start.elapsed();
assert_eq!(total_completed, 3, "All 3 jobs should complete");
assert!(
elapsed >= Duration::from_millis(150),
"Should take at least 150ms"
);
}
#[tokio::test]
async fn test_worker_pool_error_handling() {
let pool = WorkerPool::new(2);
let job = TestJob {
id: "failing-job".to_string(),
duration_ms: 10,
should_fail: true,
};
pool.try_submit(job).await.unwrap().unwrap();
sleep(Duration::from_millis(50)).await;
let completed = pool.poll_completed().await;
assert_eq!(completed.len(), 1, "Should have 1 completed job");
assert!(completed[0].is_err());
assert!(
completed[0]
.as_ref()
.unwrap_err()
.message
.contains("failed as requested")
);
}
#[tokio::test]
async fn test_worker_pool_try_submit() {
let pool = WorkerPool::new(1);
let job1 = TestJob {
id: "long-job".to_string(),
duration_ms: 200,
should_fail: false,
};
pool.try_submit(job1).await.unwrap().unwrap();
let job2 = TestJob {
id: "quick-job".to_string(),
duration_ms: 10,
should_fail: false,
};
let result = pool.try_submit(job2).await.unwrap();
assert!(
result.is_none(),
"Should not be able to submit when pool is full"
);
sleep(Duration::from_millis(250)).await;
let job3 = TestJob {
id: "another-job".to_string(),
duration_ms: 10,
should_fail: false,
};
let result = pool.try_submit(job3).await.unwrap();
assert!(
result.is_some(),
"Should be able to submit after worker is free"
);
}
#[tokio::test]
async fn test_worker_pool_shutdown() {
let pool = WorkerPool::new(2);
let mut submitted = 0;
let mut attempts = 0;
while submitted < 3 && attempts < 100 {
let job = TestJob {
id: format!("job-{}", submitted),
duration_ms: 50,
should_fail: false,
};
if pool.try_submit(job).await.unwrap().is_some() {
submitted += 1;
} else {
sleep(Duration::from_millis(10)).await;
}
attempts += 1;
}
assert_eq!(submitted, 3, "Should have submitted all 3 jobs");
let results = pool.shutdown().await.unwrap();
assert_eq!(
results.len(),
3,
"Should get results for all 3 submitted jobs"
);
let job = TestJob {
id: "late-job".to_string(),
duration_ms: 10,
should_fail: false,
};
let result = pool.try_submit(job).await;
assert!(result.is_err());
let err = result.err().unwrap();
assert!(err.message.contains("shutting down"));
}
#[tokio::test]
async fn test_worker_pool_worker_counts() {
let pool = WorkerPool::new(3);
assert_eq!(pool.total_workers(), 3);
assert_eq!(pool.available_workers(), 3);
assert_eq!(pool.active_workers(), 0);
let job1 = TestJob {
id: "job-1".to_string(),
duration_ms: 100,
should_fail: false,
};
pool.try_submit(job1).await.unwrap().unwrap();
sleep(Duration::from_millis(10)).await; assert_eq!(pool.total_workers(), 3);
assert_eq!(pool.available_workers(), 2);
assert_eq!(pool.active_workers(), 1);
let job2 = TestJob {
id: "job-2".to_string(),
duration_ms: 100,
should_fail: false,
};
pool.try_submit(job2).await.unwrap().unwrap();
sleep(Duration::from_millis(10)).await;
assert_eq!(pool.total_workers(), 3);
assert_eq!(pool.available_workers(), 1);
assert_eq!(pool.active_workers(), 2);
sleep(Duration::from_millis(120)).await;
let _ = pool.poll_completed().await;
assert_eq!(pool.total_workers(), 3);
assert_eq!(pool.available_workers(), 3);
assert_eq!(pool.active_workers(), 0);
}
}