use crate::common::IntegrateFloat;
use crate::error::IntegrateResult;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
pub trait WorkStealingTask: Send + 'static {
type Output: Send;
fn execute(&mut self) -> Self::Output;
fn estimated_cost(&self) -> f64 {
1.0
}
fn can_subdivide(&self) -> bool {
false
}
fn subdivide(&self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>>
where
Self: Sized,
{
vec![]
}
}
pub struct Task<F, R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
func: Option<F>,
cost_estimate: f64,
}
impl<F, R> Task<F, R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
pub fn new(func: F) -> Self {
Self {
func: Some(func),
cost_estimate: 1.0,
}
}
pub fn with_cost(func: F, cost: f64) -> Self {
Self {
func: Some(func),
cost_estimate: cost,
}
}
}
impl<F, R> WorkStealingTask for Task<F, R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
type Output = R;
fn execute(&mut self) -> Self::Output {
(self.func.take().expect("Operation failed"))()
}
fn estimated_cost(&self) -> f64 {
self.cost_estimate
}
}
#[derive(Debug)]
struct WorkStealingDeque<T> {
items: VecDeque<T>,
total_cost: f64,
}
impl<T: WorkStealingTask> WorkStealingDeque<T> {
fn new() -> Self {
Self {
items: VecDeque::new(),
total_cost: 0.0,
}
}
fn push_back(&mut self, task: T) {
self.total_cost += task.estimated_cost();
self.items.push_back(task);
}
fn pop_back(&mut self) -> Option<T> {
if let Some(task) = self.items.pop_back() {
self.total_cost -= task.estimated_cost();
Some(task)
} else {
None
}
}
fn steal_front(&mut self) -> Option<T> {
if let Some(task) = self.items.pop_front() {
self.total_cost -= task.estimated_cost();
Some(task)
} else {
None
}
}
#[allow(dead_code)]
fn len(&self) -> usize {
self.items.len()
}
fn is_empty(&self) -> bool {
self.items.is_empty()
}
fn total_cost(&self) -> f64 {
self.total_cost
}
}
struct WorkerState<T: WorkStealingTask> {
local_queue: Mutex<WorkStealingDeque<T>>,
completed_tasks: AtomicUsize,
computation_time: Mutex<Duration>,
}
impl<T: WorkStealingTask> WorkerState<T> {
fn new() -> Self {
Self {
local_queue: Mutex::new(WorkStealingDeque::new()),
completed_tasks: AtomicUsize::new(0),
computation_time: Mutex::new(Duration::ZERO),
}
}
}
pub struct WorkStealingPool<T: WorkStealingTask> {
workers: Vec<JoinHandle<()>>,
worker_states: Arc<Vec<WorkerState<T>>>,
global_queue: Arc<Mutex<WorkStealingDeque<T>>>,
active_tasks: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
cv: Arc<Condvar>,
#[allow(dead_code)]
cv_mutex: Arc<Mutex<()>>,
stats: Arc<Mutex<PoolStatistics>>,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStatistics {
pub total_tasks: usize,
pub total_computation_time: Duration,
pub steal_attempts: usize,
pub successful_steals: usize,
pub load_balance_efficiency: f64,
}
impl<T: WorkStealingTask + 'static> WorkStealingPool<T> {
pub fn new(_numthreads: usize) -> Self {
let _num_threads = _numthreads.max(1);
let worker_states = Arc::new(
(0.._num_threads)
.map(|_| WorkerState::new())
.collect::<Vec<_>>(),
);
let global_queue = Arc::new(Mutex::new(WorkStealingDeque::new()));
let active_tasks = Arc::new(AtomicUsize::new(0));
let shutdown = Arc::new(AtomicBool::new(false));
let cv = Arc::new(Condvar::new());
let cv_mutex = Arc::new(Mutex::new(()));
let stats = Arc::new(Mutex::new(PoolStatistics::default()));
let workers = (0.._num_threads)
.map(|worker_id| {
let worker_states = Arc::clone(&worker_states);
let global_queue = Arc::clone(&global_queue);
let active_tasks = Arc::clone(&active_tasks);
let shutdown = Arc::clone(&shutdown);
let cv = Arc::clone(&cv);
let cv_mutex = Arc::clone(&cv_mutex);
let stats = Arc::clone(&stats);
thread::spawn(move || {
Self::worker_thread(
worker_id,
worker_states,
global_queue,
active_tasks,
shutdown,
cv,
cv_mutex,
stats,
);
})
})
.collect();
Self {
workers,
worker_states,
global_queue,
active_tasks,
shutdown,
cv,
cv_mutex,
stats,
}
}
pub fn submit(&self, task: T) {
let mut global_queue = self.global_queue.lock().expect("Operation failed");
global_queue.push_back(task);
drop(global_queue);
self.cv.notify_one();
}
pub fn submit_all(&self, tasks: Vec<T>) {
let mut global_queue = self.global_queue.lock().expect("Operation failed");
for task in tasks {
global_queue.push_back(task);
}
drop(global_queue);
self.cv.notify_all();
}
pub fn execute_and_wait(&self) -> IntegrateResult<()> {
loop {
let global_empty = self
.global_queue
.lock()
.expect("Operation failed")
.is_empty();
let locals_empty = self.worker_states.iter().all(|state| {
state
.local_queue
.lock()
.expect("Operation failed")
.is_empty()
});
let no_active_tasks = self.active_tasks.load(Ordering::Relaxed) == 0;
if global_empty && locals_empty && no_active_tasks {
break;
}
thread::sleep(Duration::from_micros(100));
}
Ok(())
}
pub fn statistics(&self) -> PoolStatistics {
let mut stats = self.stats.lock().expect("Operation failed");
stats.total_tasks = self
.worker_states
.iter()
.map(|state| state.completed_tasks.load(Ordering::Relaxed))
.sum();
stats.total_computation_time = self
.worker_states
.iter()
.map(|state| *state.computation_time.lock().expect("Operation failed"))
.sum();
if stats.total_tasks > 0 {
let worker_loads: Vec<f64> = self
.worker_states
.iter()
.map(|state| {
let completed = state.completed_tasks.load(Ordering::Relaxed);
completed as f64 / stats.total_tasks as f64
})
.collect();
let ideal_load = 1.0 / self.worker_states.len() as f64;
let load_variance: f64 = worker_loads
.iter()
.map(|&load| (load - ideal_load).powi(2))
.sum::<f64>()
/ self.worker_states.len() as f64;
stats.load_balance_efficiency = (1.0 - load_variance).max(0.0);
}
stats.clone()
}
fn worker_thread(
worker_id: usize,
worker_states: Arc<Vec<WorkerState<T>>>,
global_queue: Arc<Mutex<WorkStealingDeque<T>>>,
active_tasks: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
cv: Arc<Condvar>,
cv_mutex: Arc<Mutex<()>>,
stats: Arc<Mutex<PoolStatistics>>,
) {
let my_state = &worker_states[worker_id];
while !shutdown.load(Ordering::Relaxed) {
let mut task_opt = my_state
.local_queue
.lock()
.expect("Operation failed")
.pop_back();
if task_opt.is_none() {
task_opt = global_queue.lock().expect("Operation failed").pop_back();
}
if task_opt.is_none() {
task_opt = Self::try_steal_work(worker_id, &worker_states, &stats);
}
if let Some(mut task) = task_opt {
active_tasks.fetch_add(1, Ordering::Relaxed);
let start_time = Instant::now();
let _result = task.execute();
let computation_time = start_time.elapsed();
active_tasks.fetch_sub(1, Ordering::Relaxed);
my_state.completed_tasks.fetch_add(1, Ordering::Relaxed);
*my_state.computation_time.lock().expect("Operation failed") += computation_time;
} else {
let _guard = cv
.wait_timeout(
cv_mutex.lock().expect("Operation failed"),
Duration::from_millis(10),
)
.expect("Operation failed");
}
}
}
fn try_steal_work(
worker_id: usize,
worker_states: &[WorkerState<T>],
stats: &Arc<Mutex<PoolStatistics>>,
) -> Option<T> {
stats.lock().expect("Operation failed").steal_attempts += 1;
let mut best_victim = None;
let mut best_cost = 0.0;
for (victim_id, victim_state) in worker_states.iter().enumerate() {
if victim_id == worker_id {
continue; }
let queue = victim_state.local_queue.lock().expect("Operation failed");
let cost = queue.total_cost();
if cost > best_cost && !queue.is_empty() {
best_cost = cost;
best_victim = Some(victim_id);
}
}
if let Some(victim_id) = best_victim {
let victim_state = &worker_states[victim_id];
let mut victim_queue = victim_state.local_queue.lock().expect("Operation failed");
if let Some(stolen_task) = victim_queue.steal_front() {
stats.lock().expect("Operation failed").successful_steals += 1;
return Some(stolen_task);
}
}
None
}
}
impl<T: WorkStealingTask> Drop for WorkStealingPool<T> {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
self.cv.notify_all();
while let Some(worker) = self.workers.pop() {
let _ = worker.join();
}
}
}
pub struct AdaptiveIntegrationTask<F: IntegrateFloat, Func> {
integrand: Func,
interval: (F, F),
tolerance: F,
depth: usize,
max_depth: usize,
}
impl<F: IntegrateFloat, Func> AdaptiveIntegrationTask<F, Func>
where
Func: Fn(F) -> F + Send + Clone + 'static,
{
pub fn new(integrand: Func, interval: (F, F), tolerance: F, max_depth: usize) -> Self {
Self {
integrand,
interval,
tolerance,
depth: 0,
max_depth,
}
}
fn integrate_region(&self) -> F {
let (a, b) = self.interval;
let h = b - a;
let fa = (self.integrand)(a);
let fb = (self.integrand)(b);
h * (fa + fb) / F::from(2.0).expect("Failed to convert constant to float")
}
fn estimate_error(&self) -> F {
let (a, b) = self.interval;
let mid = (a + b) / F::from(2.0).expect("Failed to convert constant to float");
let coarse = self.integrate_region();
let left_task = AdaptiveIntegrationTask {
integrand: self.integrand.clone(),
interval: (a, mid),
tolerance: self.tolerance,
depth: self.depth + 1,
max_depth: self.max_depth,
};
let right_task = AdaptiveIntegrationTask {
integrand: self.integrand.clone(),
interval: (mid, b),
tolerance: self.tolerance,
depth: self.depth + 1,
max_depth: self.max_depth,
};
let fine = left_task.integrate_region() + right_task.integrate_region();
(fine - coarse).abs()
}
}
impl<F: IntegrateFloat + Send, Func> WorkStealingTask for AdaptiveIntegrationTask<F, Func>
where
Func: Fn(F) -> F + Send + Clone + 'static,
{
type Output = IntegrateResult<F>;
fn execute(&mut self) -> Self::Output {
let result = self.integrate_region();
Ok(result)
}
fn estimated_cost(&self) -> f64 {
let (a, b) = self.interval;
(b - a).to_f64().unwrap_or(1.0)
}
fn can_subdivide(&self) -> bool {
self.depth < self.max_depth && self.estimate_error() > self.tolerance
}
fn subdivide(&self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>> {
let (a, b) = self.interval;
let mid = (a + b) / F::from(2.0).expect("Failed to convert constant to float");
let left_task = AdaptiveIntegrationTask {
integrand: self.integrand.clone(),
interval: (a, mid),
tolerance: self.tolerance / F::from(2.0).expect("Failed to convert constant to float"),
depth: self.depth + 1,
max_depth: self.max_depth,
};
let right_task = AdaptiveIntegrationTask {
integrand: self.integrand.clone(),
interval: (mid, b),
tolerance: self.tolerance / F::from(2.0).expect("Failed to convert constant to float"),
depth: self.depth + 1,
max_depth: self.max_depth,
};
vec![
Box::new(left_task) as Box<dyn WorkStealingTask<Output = Self::Output>>,
Box::new(right_task) as Box<dyn WorkStealingTask<Output = Self::Output>>,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicI32;
#[test]
fn test_work_stealing_pool_basic() {
let pool: WorkStealingPool<Task<_, i32>> = WorkStealingPool::new(2);
for i in 0..10 {
let task = Task::new(move || i * 2);
pool.submit(task);
}
assert!(pool.execute_and_wait().is_ok());
let stats = pool.statistics();
assert_eq!(stats.total_tasks, 10);
assert!(stats.load_balance_efficiency >= 0.0);
}
#[test]
fn test_task_subdivision() {
let integrand = |x: f64| x * x;
let task = AdaptiveIntegrationTask::new(integrand, (0.0, 1.0), 1e-6, 5);
assert!(task.can_subdivide());
let subtasks = task.subdivide();
assert_eq!(subtasks.len(), 2);
}
#[test]
fn test_load_balancing() {
let pool: WorkStealingPool<Task<_, ()>> = WorkStealingPool::new(4);
let counter = Arc::new(AtomicI32::new(0));
for i in 0..20 {
let counter_clone = Arc::clone(&counter);
let sleep_time = (i % 5) * 10;
let task = Task::with_cost(
move || {
thread::sleep(Duration::from_millis(sleep_time));
counter_clone.fetch_add(1, Ordering::Relaxed);
},
sleep_time as f64,
);
pool.submit(task);
}
pool.execute_and_wait().expect("Operation failed");
assert_eq!(counter.load(Ordering::Relaxed), 20);
let stats = pool.statistics();
assert_eq!(stats.total_tasks, 20);
assert!(stats.steal_attempts > 0); }
}