use crate::error::{NumRs2Error, Result};
use std::collections::VecDeque;
use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Condvar, Mutex,
};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
pub trait Task: Send + 'static {
type Output: Send + 'static;
fn execute(self: Box<Self>) -> Self::Output;
fn estimated_cost(&self) -> Option<u64> {
None
}
fn can_split(&self) -> bool {
false
}
fn split(self: Box<Self>) -> Vec<Box<dyn Task<Output = Self::Output>>>
where
Self: Sized,
{
vec![self]
}
}
#[derive(Debug)]
pub enum TaskResult<T> {
Success(T),
Error(String),
Cancelled,
}
type BoxedTask = Box<dyn Task<Output = TaskResult<()>>>;
struct WorkerQueue {
deque: VecDeque<BoxedTask>,
steal_count: u64,
execute_count: u64,
}
impl WorkerQueue {
fn new() -> Self {
Self {
deque: VecDeque::new(),
steal_count: 0,
execute_count: 0,
}
}
fn push_back(&mut self, task: BoxedTask) {
self.deque.push_back(task);
}
fn pop_front(&mut self) -> Option<BoxedTask> {
let task = self.deque.pop_front();
if task.is_some() {
self.execute_count += 1;
}
task
}
fn steal(&mut self) -> Option<BoxedTask> {
let task = self.deque.pop_back();
if task.is_some() {
self.steal_count += 1;
}
task
}
fn len(&self) -> usize {
self.deque.len()
}
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.deque.is_empty()
}
}
struct WorkerState {
id: usize,
queue: Mutex<WorkerQueue>,
is_idle: AtomicBool,
tasks_executed: AtomicUsize,
total_execution_time: Mutex<Duration>,
last_steal_attempt: Mutex<Instant>,
}
impl WorkerState {
fn new(id: usize) -> Self {
Self {
id,
queue: Mutex::new(WorkerQueue::new()),
is_idle: AtomicBool::new(true),
tasks_executed: AtomicUsize::new(0),
total_execution_time: Mutex::new(Duration::ZERO),
last_steal_attempt: Mutex::new(Instant::now()),
}
}
fn queue_length(&self) -> usize {
self.queue
.lock()
.expect("lock should not be poisoned")
.len()
}
fn is_idle(&self) -> bool {
self.is_idle.load(Ordering::Relaxed)
}
fn set_idle(&self, idle: bool) {
self.is_idle.store(idle, Ordering::Relaxed);
}
#[allow(dead_code)]
fn throughput(&self) -> f64 {
let total_time = self
.total_execution_time
.lock()
.expect("lock should not be poisoned");
if total_time.is_zero() {
0.0
} else {
self.tasks_executed.load(Ordering::Relaxed) as f64 / total_time.as_secs_f64()
}
}
}
#[derive(Debug, Clone)]
pub struct WorkStealingConfig {
pub num_threads: usize,
pub max_steal_attempts: usize,
pub steal_interval: Duration,
pub idle_timeout: Duration,
pub enable_task_splitting: bool,
pub max_queue_size: usize,
pub adaptive_stealing: bool,
}
impl Default for WorkStealingConfig {
fn default() -> Self {
Self {
num_threads: std::thread::available_parallelism().map_or(4, |n| n.get()),
max_steal_attempts: 3,
steal_interval: Duration::from_millis(1),
idle_timeout: Duration::from_millis(10),
enable_task_splitting: true,
max_queue_size: 1000,
adaptive_stealing: true,
}
}
}
pub struct WorkStealingPool {
config: WorkStealingConfig,
workers: Vec<Arc<WorkerState>>,
#[allow(dead_code)]
threads: Vec<JoinHandle<()>>,
shutdown: Arc<AtomicBool>,
global_queue: Arc<Mutex<VecDeque<BoxedTask>>>,
idle_workers: Arc<(Mutex<Vec<usize>>, Condvar)>,
stats: Arc<Mutex<PoolStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub tasks_submitted: u64,
pub tasks_completed: u64,
pub tasks_stolen: u64,
pub total_steal_attempts: u64,
pub average_queue_time: Duration,
pub average_execution_time: Duration,
pub worker_utilization: Vec<f64>,
pub queue_imbalance: f64,
}
impl WorkStealingPool {
pub fn new(num_threads: usize) -> Result<Self> {
let config = WorkStealingConfig {
num_threads,
..Default::default()
};
Self::with_config(config)
}
pub fn with_config(config: WorkStealingConfig) -> Result<Self> {
let shutdown = Arc::new(AtomicBool::new(false));
let global_queue = Arc::new(Mutex::new(VecDeque::new()));
let idle_workers = Arc::new((Mutex::new(Vec::new()), Condvar::new()));
let stats = Arc::new(Mutex::new(PoolStats::default()));
let mut workers = Vec::new();
let mut threads = Vec::new();
for i in 0..config.num_threads {
workers.push(Arc::new(WorkerState::new(i)));
}
for worker in &workers {
let worker_clone = Arc::clone(worker);
let workers_clone = workers.clone();
let shutdown_clone = Arc::clone(&shutdown);
let global_queue_clone = Arc::clone(&global_queue);
let idle_workers_clone = Arc::clone(&idle_workers);
let stats_clone = Arc::clone(&stats);
let config_clone = config.clone();
let handle = thread::spawn(move || {
Self::worker_main(
worker_clone,
workers_clone,
shutdown_clone,
global_queue_clone,
idle_workers_clone,
stats_clone,
config_clone,
);
});
threads.push(handle);
}
Ok(Self {
config,
workers,
threads,
shutdown,
global_queue,
idle_workers,
stats,
})
}
pub fn submit<T>(&self, task: T) -> Result<()>
where
T: Task<Output = TaskResult<()>> + 'static,
{
if self.shutdown.load(Ordering::Relaxed) {
return Err(NumRs2Error::RuntimeError(
"Pool is shutting down".to_string(),
));
}
let boxed_task: BoxedTask = Box::new(task);
let target_worker = self.find_least_loaded_worker();
if let Some(worker_id) = target_worker {
let worker = &self.workers[worker_id];
let mut queue = worker.queue.lock().expect("lock should not be poisoned");
if queue.len() < self.config.max_queue_size {
queue.push_back(boxed_task);
drop(queue);
if worker.is_idle() {
self.notify_worker(worker_id);
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.tasks_submitted += 1;
}
return Ok(());
}
}
{
let mut global = self
.global_queue
.lock()
.expect("lock should not be poisoned");
if global.len() < self.config.max_queue_size * 2 {
global.push_back(boxed_task);
self.notify_idle_workers();
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.tasks_submitted += 1;
Ok(())
} else {
Err(NumRs2Error::RuntimeError("All queues are full".to_string()))
}
}
}
pub fn submit_urgent<T>(&self, task: T) -> Result<()>
where
T: Task<Output = TaskResult<()>> + 'static,
{
if self.shutdown.load(Ordering::Relaxed) {
return Err(NumRs2Error::RuntimeError(
"Pool is shutting down".to_string(),
));
}
let boxed_task: BoxedTask = Box::new(task);
{
let mut global = self
.global_queue
.lock()
.expect("lock should not be poisoned");
global.push_front(boxed_task);
self.notify_idle_workers();
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.tasks_submitted += 1;
}
Ok(())
}
pub fn statistics(&self) -> PoolStats {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.worker_utilization = self
.workers
.iter()
.map(|worker| if worker.is_idle() { 0.0 } else { 1.0 })
.collect();
stats.queue_imbalance = self.calculate_queue_imbalance();
stats.clone()
}
pub fn active_workers(&self) -> usize {
self.workers
.iter()
.filter(|worker| !worker.is_idle())
.count()
}
pub fn pending_tasks(&self) -> usize {
let global_count = self
.global_queue
.lock()
.expect("lock should not be poisoned")
.len();
let worker_count: usize = self
.workers
.iter()
.map(|worker| worker.queue_length())
.sum();
global_count + worker_count
}
pub fn shutdown(&self) -> Result<()> {
self.shutdown.store(true, Ordering::Relaxed);
let (idle_lock, condvar) = &*self.idle_workers;
let _idle = idle_lock.lock().expect("lock should not be poisoned");
condvar.notify_all();
Ok(())
}
fn find_least_loaded_worker(&self) -> Option<usize> {
self.workers
.iter()
.enumerate()
.min_by_key(|(_, worker)| worker.queue_length())
.map(|(idx, _)| idx)
}
fn notify_worker(&self, worker_id: usize) {
let (idle_lock, condvar) = &*self.idle_workers;
let mut idle = idle_lock.lock().expect("lock should not be poisoned");
if let Some(pos) = idle.iter().position(|&id| id == worker_id) {
idle.remove(pos);
condvar.notify_one();
}
}
fn notify_idle_workers(&self) {
let (_, condvar) = &*self.idle_workers;
condvar.notify_all();
}
fn calculate_queue_imbalance(&self) -> f64 {
let queue_lengths: Vec<usize> = self
.workers
.iter()
.map(|worker| worker.queue_length())
.collect();
if queue_lengths.is_empty() {
return 0.0;
}
let max_len = *queue_lengths.iter().max().unwrap_or(&0) as f64;
let min_len = *queue_lengths.iter().min().unwrap_or(&0) as f64;
if max_len == 0.0 {
0.0
} else {
(max_len - min_len) / max_len
}
}
fn worker_main(
worker: Arc<WorkerState>,
workers: Vec<Arc<WorkerState>>,
shutdown: Arc<AtomicBool>,
global_queue: Arc<Mutex<VecDeque<BoxedTask>>>,
idle_workers: Arc<(Mutex<Vec<usize>>, Condvar)>,
stats: Arc<Mutex<PoolStats>>,
config: WorkStealingConfig,
) {
let worker_id = worker.id;
while !shutdown.load(Ordering::Relaxed) {
let mut task_found = false;
if let Ok(mut queue) = worker.queue.try_lock() {
if let Some(task) = queue.pop_front() {
drop(queue);
Self::execute_task(task, &worker, &stats);
task_found = true;
}
}
if !task_found {
if let Ok(mut global) = global_queue.try_lock() {
if let Some(task) = global.pop_front() {
drop(global);
Self::execute_task(task, &worker, &stats);
task_found = true;
}
}
}
if !task_found && config.adaptive_stealing {
if let Some(stolen_task) = Self::try_steal_work(&worker, &workers, &config) {
Self::execute_task(stolen_task, &worker, &stats);
task_found = true;
}
}
if !task_found {
worker.set_idle(true);
let (idle_lock, condvar) = &*idle_workers;
let mut idle = idle_lock.lock().expect("lock should not be poisoned");
idle.push(worker_id);
let _result = condvar.wait_timeout(idle, config.idle_timeout);
worker.set_idle(false);
}
}
}
fn execute_task(task: BoxedTask, worker: &Arc<WorkerState>, stats: &Arc<Mutex<PoolStats>>) {
let start_time = Instant::now();
let result = Box::new(task).execute();
let execution_time = start_time.elapsed();
worker.tasks_executed.fetch_add(1, Ordering::Relaxed);
{
let mut total_time = worker
.total_execution_time
.lock()
.expect("lock should not be poisoned");
*total_time += execution_time;
}
{
let mut global_stats = stats.lock().expect("lock should not be poisoned");
match result {
TaskResult::Success(_) => global_stats.tasks_completed += 1,
TaskResult::Error(_) | TaskResult::Cancelled => {
global_stats.tasks_completed += 1;
}
}
let alpha = 0.1;
global_stats.average_execution_time = Duration::from_secs_f64(
alpha * execution_time.as_secs_f64()
+ (1.0 - alpha) * global_stats.average_execution_time.as_secs_f64(),
);
}
}
fn try_steal_work(
worker: &Arc<WorkerState>,
workers: &[Arc<WorkerState>],
config: &WorkStealingConfig,
) -> Option<BoxedTask> {
let now = Instant::now();
let mut last_steal = worker
.last_steal_attempt
.lock()
.expect("lock should not be poisoned");
if now.duration_since(*last_steal) < config.steal_interval {
return None;
}
*last_steal = now;
drop(last_steal);
for _ in 0..config.max_steal_attempts {
let victim = workers
.iter()
.filter(|w| w.id != worker.id)
.max_by_key(|w| w.queue_length())?;
if victim.queue_length() > 1 {
if let Ok(mut victim_queue) = victim.queue.try_lock() {
if let Some(stolen_task) = victim_queue.steal() {
return Some(stolen_task);
}
}
}
}
None
}
}
pub struct ClosureTask<F, T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
closure: Option<F>,
_phantom: std::marker::PhantomData<T>,
}
impl<F, T> ClosureTask<F, T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
pub fn new(closure: F) -> Self {
Self {
closure: Some(closure),
_phantom: std::marker::PhantomData,
}
}
}
impl<F> Task for ClosureTask<F, ()>
where
F: FnOnce() + Send + 'static,
{
type Output = TaskResult<()>;
fn execute(mut self: Box<Self>) -> Self::Output {
if let Some(closure) = self.closure.take() {
closure();
TaskResult::Success(())
} else {
TaskResult::Error("Task already executed".to_string())
}
}
}
pub fn task<F>(closure: F) -> ClosureTask<F, ()>
where
F: FnOnce() + Send + 'static,
{
ClosureTask::new(closure)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_work_stealing_pool_creation() {
let pool =
WorkStealingPool::new(2).expect("work-stealing pool creation with 2 threads succeeds");
assert_eq!(pool.config.num_threads, 2);
assert_eq!(pool.active_workers(), 0); assert_eq!(pool.pending_tasks(), 0);
}
#[test]
fn test_task_submission() {
let pool =
WorkStealingPool::new(2).expect("work-stealing pool creation with 2 threads succeeds");
let counter = Arc::new(AtomicU32::new(0));
for _ in 0..5 {
let counter_clone = Arc::clone(&counter);
let task = task(move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
pool.submit(task).expect("task submission should succeed");
}
std::thread::sleep(Duration::from_millis(100));
assert_eq!(counter.load(Ordering::SeqCst), 5);
let stats = pool.statistics();
assert_eq!(stats.tasks_submitted, 5);
assert!(stats.tasks_completed <= 5); }
#[test]
fn test_urgent_task_submission() {
let pool =
WorkStealingPool::new(1).expect("work-stealing pool creation with 1 thread succeeds");
let execution_order = Arc::new(Mutex::new(Vec::new()));
{
let order_clone = Arc::clone(&execution_order);
let task = task(move || {
std::thread::sleep(Duration::from_millis(50)); order_clone
.lock()
.expect("lock should not be poisoned")
.push(1);
});
pool.submit(task).expect("task submission should succeed");
}
{
let order_clone = Arc::clone(&execution_order);
let urgent_task = task(move || {
order_clone
.lock()
.expect("lock should not be poisoned")
.push(2);
});
pool.submit_urgent(urgent_task)
.expect("urgent task submission should succeed");
}
std::thread::sleep(Duration::from_millis(200));
let order = execution_order.lock().expect("lock should not be poisoned");
assert!(!order.is_empty());
}
#[test]
fn test_pool_statistics() {
let pool =
WorkStealingPool::new(2).expect("work-stealing pool creation with 2 threads succeeds");
for i in 0..3 {
let task = task(move || {
std::thread::sleep(Duration::from_millis(10 * i as u64));
});
pool.submit(task).expect("task submission should succeed");
}
std::thread::sleep(Duration::from_millis(100));
let stats = pool.statistics();
assert_eq!(stats.tasks_submitted, 3);
assert!(stats.worker_utilization.len() == 2);
assert!(stats.queue_imbalance >= 0.0);
}
#[test]
fn test_queue_imbalance_calculation() {
let pool =
WorkStealingPool::new(3).expect("work-stealing pool creation with 3 threads succeeds");
for _ in 0..10 {
let task = task(|| {
std::thread::sleep(Duration::from_millis(100)); });
pool.submit(task).expect("task submission should succeed");
}
let imbalance = pool.calculate_queue_imbalance();
assert!((0.0..=1.0).contains(&imbalance));
}
#[test]
fn test_worker_state() {
let worker = WorkerState::new(0);
assert_eq!(worker.id, 0);
assert_eq!(worker.queue_length(), 0);
assert!(worker.is_idle());
assert_eq!(worker.tasks_executed.load(Ordering::Relaxed), 0);
worker.set_idle(false);
assert!(!worker.is_idle());
}
#[test]
fn test_closure_task() {
let executed = Arc::new(AtomicU32::new(0));
let executed_clone = Arc::clone(&executed);
let task = ClosureTask::new(move || {
executed_clone.store(42, Ordering::SeqCst);
});
let result = Box::new(task).execute();
assert!(matches!(result, TaskResult::Success(())));
assert_eq!(executed.load(Ordering::SeqCst), 42);
}
#[test]
fn test_pool_shutdown() {
let pool =
WorkStealingPool::new(2).expect("work-stealing pool creation with 2 threads succeeds");
let task = task(|| {
std::thread::sleep(Duration::from_millis(10));
});
pool.submit(task).expect("task submission should succeed");
assert!(pool.shutdown().is_ok());
}
}