use super::task::{Task, TaskResult};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
use uuid::Uuid;
pub struct WorkerPool {
workers: Arc<Mutex<HashMap<Uuid, WorkerInfo>>>,
max_workers: usize,
}
#[derive(Debug)]
struct WorkerInfo {
handle: JoinHandle<TaskResult>,
task_id: String,
start_time: std::time::Instant,
worker_type: WorkerType,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WorkerType {
CpuIntensive,
IoIntensive,
Mixed,
}
impl WorkerPool {
pub fn new(max_workers: usize) -> Self {
Self {
workers: Arc::new(Mutex::new(HashMap::new())),
max_workers,
}
}
pub async fn execute(&self, task: Box<dyn Task + Send + Sync>) -> Result<TaskResult, String> {
let worker_id = Uuid::now_v7();
let task_id = task.task_id();
let worker_type = self.determine_worker_type(task.task_type());
{
let workers = self.workers.lock().unwrap();
if workers.len() >= self.max_workers {
return Err("Worker pool is full".to_string());
}
}
let handle = tokio::spawn(async move { task.execute().await });
{
let mut workers = self.workers.lock().unwrap();
workers.insert(
worker_id,
WorkerInfo {
handle,
task_id: task_id.clone(),
start_time: std::time::Instant::now(),
worker_type,
},
);
}
Ok(TaskResult::Success("Task submitted".to_string()))
}
fn determine_worker_type(&self, task_type: &str) -> WorkerType {
match task_type {
"convert" => WorkerType::CpuIntensive,
"sync" => WorkerType::Mixed,
"match" => WorkerType::IoIntensive,
"validate" => WorkerType::IoIntensive,
_ => WorkerType::Mixed,
}
}
pub fn get_active_count(&self) -> usize {
self.workers.lock().unwrap().len()
}
pub fn get_capacity(&self) -> usize {
self.max_workers
}
pub fn get_worker_stats(&self) -> WorkerStats {
let workers = self.workers.lock().unwrap();
let mut cpu = 0;
let mut io = 0;
let mut mixed = 0;
for w in workers.values() {
match w.worker_type {
WorkerType::CpuIntensive => cpu += 1,
WorkerType::IoIntensive => io += 1,
WorkerType::Mixed => mixed += 1,
}
}
WorkerStats {
total_active: workers.len(),
cpu_intensive_count: cpu,
io_intensive_count: io,
mixed_count: mixed,
max_capacity: self.max_workers,
}
}
pub async fn shutdown(&self) {
let workers = { std::mem::take(&mut *self.workers.lock().unwrap()) };
for (id, info) in workers {
if !crate::cli::output::is_quiet() && !crate::cli::output::active_mode().is_json() {
eprintln!(
"Waiting for worker {} to complete task {}",
id, info.task_id
);
}
let _ = info.handle.await;
}
}
pub fn list_active_workers(&self) -> Vec<ActiveWorkerInfo> {
let workers = self.workers.lock().unwrap();
workers
.iter()
.map(|(id, info)| ActiveWorkerInfo {
worker_id: *id,
task_id: info.task_id.clone(),
worker_type: info.worker_type.clone(),
runtime: info.start_time.elapsed(),
})
.collect()
}
}
impl Clone for WorkerPool {
fn clone(&self) -> Self {
Self {
workers: Arc::clone(&self.workers),
max_workers: self.max_workers,
}
}
}
#[derive(Debug, Clone)]
pub struct WorkerStats {
pub total_active: usize,
pub cpu_intensive_count: usize,
pub io_intensive_count: usize,
pub mixed_count: usize,
pub max_capacity: usize,
}
#[derive(Debug, Clone)]
pub struct ActiveWorkerInfo {
pub worker_id: Uuid,
pub task_id: String,
pub worker_type: WorkerType,
pub runtime: std::time::Duration,
}
pub struct Worker {
id: Uuid,
status: WorkerStatus,
}
#[derive(Debug, Clone)]
pub enum WorkerStatus {
Idle,
Busy(String),
Stopped,
Error(String),
}
impl Worker {
pub fn new() -> Self {
Self {
id: Uuid::now_v7(),
status: WorkerStatus::Idle,
}
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn status(&self) -> &WorkerStatus {
&self.status
}
pub fn set_status(&mut self, status: WorkerStatus) {
self.status = status;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_worker_pool_capacity() {
let pool = WorkerPool::new(2);
assert_eq!(pool.get_capacity(), 2);
assert_eq!(pool.get_active_count(), 0);
let stats = pool.get_worker_stats();
assert_eq!(stats.max_capacity, 2);
assert_eq!(stats.total_active, 0);
}
#[tokio::test]
async fn test_execute_and_active_count() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
#[derive(Clone)]
struct DummyTask {
id: String,
tp: &'static str,
}
#[async_trait]
impl Task for DummyTask {
async fn execute(&self) -> TaskResult {
TaskResult::Success(self.id.clone())
}
fn task_type(&self) -> &'static str {
self.tp
}
fn task_id(&self) -> String {
self.id.clone()
}
}
let pool = WorkerPool::new(1);
let task = DummyTask {
id: "t1".into(),
tp: "convert",
};
let res = pool.execute(Box::new(task.clone())).await;
assert!(matches!(res, Ok(TaskResult::Success(_))));
assert_eq!(pool.get_active_count(), 1);
}
#[tokio::test]
async fn test_reject_when_full() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
#[derive(Clone)]
struct DummyTask;
#[async_trait]
impl Task for DummyTask {
async fn execute(&self) -> TaskResult {
TaskResult::Success("".into())
}
fn task_type(&self) -> &'static str {
"match"
}
fn task_id(&self) -> String {
"".into()
}
}
let pool = WorkerPool::new(1);
let _ = pool.execute(Box::new(DummyTask)).await;
let err = pool.execute(Box::new(DummyTask)).await;
assert!(err.is_err());
}
#[tokio::test]
async fn test_list_active_workers_and_stats() {
use super::WorkerType;
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
#[derive(Clone)]
struct DummyTask2;
#[async_trait]
impl Task for DummyTask2 {
async fn execute(&self) -> TaskResult {
TaskResult::Success("".into())
}
fn task_type(&self) -> &'static str {
"sync"
}
fn task_id(&self) -> String {
"tok2".into()
}
}
let pool = WorkerPool::new(2);
let _ = pool.execute(Box::new(DummyTask2)).await;
let workers = pool.list_active_workers();
assert_eq!(workers.len(), 1);
let info = &workers[0];
assert_eq!(info.task_id, "tok2");
assert_eq!(info.worker_type, WorkerType::Mixed);
let stats = pool.get_worker_stats();
assert_eq!(stats.total_active, 1);
}
#[tokio::test]
async fn test_worker_job_distribution() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Clone)]
struct CountingTask {
id: String,
counter: Arc<AtomicUsize>,
}
#[async_trait]
impl Task for CountingTask {
async fn execute(&self) -> TaskResult {
self.counter.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
TaskResult::Success(format!("task-{}", self.id))
}
fn task_type(&self) -> &'static str {
"convert"
}
fn task_id(&self) -> String {
self.id.clone()
}
}
let pool = WorkerPool::new(4);
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for i in 0..4 {
let task = CountingTask {
id: format!("task-{}", i),
counter: Arc::clone(&counter),
};
let pool_clone = pool.clone();
let handle = tokio::spawn(async move { pool_clone.execute(Box::new(task)).await });
handles.push(handle);
}
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok(), "Task submission should succeed");
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let final_count = counter.load(Ordering::SeqCst);
assert_eq!(final_count, 4, "All 4 tasks should have been executed");
}
#[tokio::test]
async fn test_worker_error_recovery() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
#[derive(Clone)]
struct FailingTask {
id: String,
should_fail: bool,
}
#[async_trait]
impl Task for FailingTask {
async fn execute(&self) -> TaskResult {
if self.should_fail {
TaskResult::Failed("Intentional failure".to_string())
} else {
TaskResult::Success(format!("success-{}", self.id))
}
}
fn task_type(&self) -> &'static str {
"sync"
}
fn task_id(&self) -> String {
self.id.clone()
}
}
let pool = WorkerPool::new(2);
let success_task = FailingTask {
id: "success".to_string(),
should_fail: false,
};
let result = pool.execute(Box::new(success_task)).await;
assert!(result.is_ok(), "Successful task should be submitted");
let fail_task = FailingTask {
id: "fail".to_string(),
should_fail: true,
};
let result = pool.execute(Box::new(fail_task)).await;
assert!(
result.is_ok(),
"Failing task should still be submitted successfully"
);
assert!(
pool.get_active_count() <= 2,
"Active count should be within limits"
);
}
#[tokio::test]
async fn test_parallel_processing_performance() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
use std::time::Instant;
#[derive(Clone)]
struct CpuIntensiveTask {
id: String,
duration_ms: u64,
}
#[async_trait]
impl Task for CpuIntensiveTask {
async fn execute(&self) -> TaskResult {
tokio::time::sleep(tokio::time::Duration::from_millis(self.duration_ms)).await;
TaskResult::Success(format!("completed-{}", self.id))
}
fn task_type(&self) -> &'static str {
"convert"
}
fn task_id(&self) -> String {
self.id.clone()
}
}
let sequential_pool = WorkerPool::new(1);
let start = Instant::now();
for i in 0..2 {
let task = CpuIntensiveTask {
id: format!("seq-{}", i),
duration_ms: 10, };
if let Err(e) = sequential_pool.execute(Box::new(task)).await {
println!("Sequential task {} failed: {}", i, e);
}
}
let sequential_time = start.elapsed();
let parallel_pool = WorkerPool::new(2); let start = Instant::now();
let task = CpuIntensiveTask {
id: "par-0".to_string(),
duration_ms: 10,
};
if let Err(e) = parallel_pool.execute(Box::new(task)).await {
println!("Parallel task failed: {}", e);
}
let parallel_time = start.elapsed();
println!("Sequential submission time: {:?}", sequential_time);
println!("Parallel submission time: {:?}", parallel_time);
assert!(
parallel_time <= sequential_time * 2,
"Parallel submission should not be significantly slower"
);
}
#[tokio::test]
async fn test_resource_management() {
let pool = WorkerPool::new(3);
assert_eq!(
pool.determine_worker_type("convert"),
WorkerType::CpuIntensive
);
assert_eq!(pool.determine_worker_type("sync"), WorkerType::Mixed);
assert_eq!(pool.determine_worker_type("match"), WorkerType::IoIntensive);
assert_eq!(
pool.determine_worker_type("validate"),
WorkerType::IoIntensive
);
assert_eq!(pool.determine_worker_type("unknown"), WorkerType::Mixed);
let stats = pool.get_worker_stats();
assert_eq!(stats.total_active, 0);
assert_eq!(stats.max_capacity, 3);
assert_eq!(stats.cpu_intensive_count, 0);
assert_eq!(stats.io_intensive_count, 0);
assert_eq!(stats.mixed_count, 0);
assert_eq!(pool.get_capacity(), 3);
assert_eq!(pool.get_active_count(), 0);
}
#[tokio::test]
async fn test_worker_pool_shutdown() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
#[derive(Clone)]
struct SlowTask {
id: String,
}
#[async_trait]
impl Task for SlowTask {
async fn execute(&self) -> TaskResult {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
TaskResult::Success(format!("slow-{}", self.id))
}
fn task_type(&self) -> &'static str {
"mixed"
}
fn task_id(&self) -> String {
self.id.clone()
}
}
let pool = WorkerPool::new(2);
for i in 0..2 {
let task = SlowTask {
id: format!("slow-{}", i),
};
pool.execute(Box::new(task)).await.unwrap();
}
assert!(pool.get_active_count() <= 2);
let start = std::time::Instant::now();
pool.shutdown().await;
let shutdown_time = start.elapsed();
assert!(shutdown_time >= std::time::Duration::from_millis(30));
assert_eq!(pool.get_active_count(), 0);
}
#[tokio::test]
async fn test_active_worker_tracking() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
#[derive(Clone)]
struct TrackableTask {
id: String,
task_type: &'static str,
}
#[async_trait]
impl Task for TrackableTask {
async fn execute(&self) -> TaskResult {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
TaskResult::Success(format!("tracked-{}", self.id))
}
fn task_type(&self) -> &'static str {
self.task_type
}
fn task_id(&self) -> String {
self.id.clone()
}
}
let pool = WorkerPool::new(3);
let tasks = vec![
("cpu-task", "convert"),
("io-task", "match"),
("mixed-task", "sync"),
];
for (id, task_type) in tasks {
let task = TrackableTask {
id: id.to_string(),
task_type,
};
pool.execute(Box::new(task)).await.unwrap();
}
let active_workers = pool.list_active_workers();
assert!(active_workers.len() <= 3, "Should not exceed pool capacity");
for worker in &active_workers {
assert!(!worker.task_id.is_empty(), "Task ID should be set");
assert!(matches!(
worker.worker_type,
WorkerType::CpuIntensive | WorkerType::IoIntensive | WorkerType::Mixed
));
assert!(
worker.runtime.as_millis() < u128::MAX,
"Runtime should be valid"
);
}
let stats = pool.get_worker_stats();
assert!(stats.total_active <= 3);
assert_eq!(stats.max_capacity, 3);
tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
}
#[test]
fn worker_id_is_uuidv7() {
let w = Worker::new();
assert_eq!(w.id().get_version_num(), 7);
}
#[test]
fn consecutive_workers_have_distinct_ids() {
let a = Worker::new();
let b = Worker::new();
assert_ne!(a.id(), b.id());
}
#[tokio::test]
async fn worker_pool_execute_dispatches_uuidv7_worker_id() {
use crate::core::parallel::task::{Task, TaskResult};
use async_trait::async_trait;
struct DummyTask;
#[async_trait]
impl Task for DummyTask {
async fn execute(&self) -> TaskResult {
TaskResult::Success("done".into())
}
fn task_type(&self) -> &'static str {
"match"
}
fn task_id(&self) -> String {
"dummy".into()
}
}
let pool = WorkerPool::new(1);
let res = pool.execute(Box::new(DummyTask)).await;
assert!(matches!(res, Ok(TaskResult::Success(_))));
let workers = pool.list_active_workers();
assert_eq!(workers.len(), 1);
assert_eq!(workers[0].worker_id.get_version_num(), 7);
}
}