scirs2_integrate/
scheduling.rs

1//! Work-stealing schedulers for adaptive algorithms
2//!
3//! This module provides work-stealing task schedulers optimized for adaptive
4//! numerical algorithms. These schedulers dynamically balance workload across
5//! threads, which is particularly important for algorithms with irregular
6//! computational patterns.
7//!
8//! # Work-Stealing Concepts
9//!
10//! Work-stealing is a scheduling technique where idle threads "steal" work
11//! from busy threads' task queues. This is especially effective for:
12//! - Adaptive algorithms with unpredictable work distribution
13//! - Recursive divide-and-conquer algorithms
14//! - Dynamic load balancing scenarios
15//!
16//! # Examples
17//!
18//! ```
19//! use scirs2_integrate::scheduling::{WorkStealingPool, Task};
20//!
21//! // Create work-stealing pool with 4 threads
22//! let pool = WorkStealingPool::new(4);
23//!
24//! // Submit a simple task
25//! let task = Task::new(|| 0.5 * 0.5); // Simple computation
26//! pool.submit(task);
27//! ```
28
29use crate::common::IntegrateFloat;
30use crate::error::IntegrateResult;
31use std::collections::VecDeque;
32use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
33use std::sync::{Arc, Condvar, Mutex};
34use std::thread::{self, JoinHandle};
35use std::time::{Duration, Instant};
36
37/// Generic task that can be executed by the work-stealing scheduler
38pub trait WorkStealingTask: Send + 'static {
39    type Output: Send;
40
41    /// Execute the task
42    fn execute(&mut self) -> Self::Output;
43
44    /// Estimate computational cost (for load balancing)
45    fn estimated_cost(&self) -> f64 {
46        1.0
47    }
48
49    /// Check if task can be subdivided for better load balancing
50    fn can_subdivide(&self) -> bool {
51        false
52    }
53
54    /// Subdivide task into smaller tasks (if possible)
55    fn subdivide(&self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>>
56    where
57        Self: Sized,
58    {
59        vec![]
60    }
61}
62
63/// Simple boxed task for closures
64pub struct Task<F, R>
65where
66    F: FnOnce() -> R + Send + 'static,
67    R: Send + 'static,
68{
69    func: Option<F>,
70    cost_estimate: f64,
71}
72
73impl<F, R> Task<F, R>
74where
75    F: FnOnce() -> R + Send + 'static,
76    R: Send + 'static,
77{
78    /// Create a new task from closure
79    pub fn new(func: F) -> Self {
80        Self {
81            func: Some(func),
82            cost_estimate: 1.0,
83        }
84    }
85
86    /// Create task with cost estimate
87    pub fn with_cost(func: F, cost: f64) -> Self {
88        Self {
89            func: Some(func),
90            cost_estimate: cost,
91        }
92    }
93}
94
95impl<F, R> WorkStealingTask for Task<F, R>
96where
97    F: FnOnce() -> R + Send + 'static,
98    R: Send + 'static,
99{
100    type Output = R;
101
102    fn execute(&mut self) -> Self::Output {
103        (self.func.take().unwrap())()
104    }
105
106    fn estimated_cost(&self) -> f64 {
107        self.cost_estimate
108    }
109}
110
111/// Work-stealing deque for efficient task queue operations
112#[derive(Debug)]
113struct WorkStealingDeque<T> {
114    items: VecDeque<T>,
115    total_cost: f64,
116}
117
118impl<T: WorkStealingTask> WorkStealingDeque<T> {
119    fn new() -> Self {
120        Self {
121            items: VecDeque::new(),
122            total_cost: 0.0,
123        }
124    }
125
126    fn push_back(&mut self, task: T) {
127        self.total_cost += task.estimated_cost();
128        self.items.push_back(task);
129    }
130
131    fn pop_back(&mut self) -> Option<T> {
132        if let Some(task) = self.items.pop_back() {
133            self.total_cost -= task.estimated_cost();
134            Some(task)
135        } else {
136            None
137        }
138    }
139
140    fn steal_front(&mut self) -> Option<T> {
141        if let Some(task) = self.items.pop_front() {
142            self.total_cost -= task.estimated_cost();
143            Some(task)
144        } else {
145            None
146        }
147    }
148
149    #[allow(dead_code)]
150    fn len(&self) -> usize {
151        self.items.len()
152    }
153
154    fn is_empty(&self) -> bool {
155        self.items.is_empty()
156    }
157
158    fn total_cost(&self) -> f64 {
159        self.total_cost
160    }
161}
162
163/// Worker thread state
164struct WorkerState<T: WorkStealingTask> {
165    /// Local task queue
166    local_queue: Mutex<WorkStealingDeque<T>>,
167    /// Number of tasks completed by this worker
168    completed_tasks: AtomicUsize,
169    /// Total computation time for this worker
170    computation_time: Mutex<Duration>,
171}
172
173impl<T: WorkStealingTask> WorkerState<T> {
174    fn new() -> Self {
175        Self {
176            local_queue: Mutex::new(WorkStealingDeque::new()),
177            completed_tasks: AtomicUsize::new(0),
178            computation_time: Mutex::new(Duration::ZERO),
179        }
180    }
181}
182
183/// Work-stealing thread pool for adaptive algorithms
184pub struct WorkStealingPool<T: WorkStealingTask> {
185    /// Worker threads
186    workers: Vec<JoinHandle<()>>,
187    /// Worker states (shared between threads)
188    worker_states: Arc<Vec<WorkerState<T>>>,
189    /// Global task queue for initial distribution
190    global_queue: Arc<Mutex<WorkStealingDeque<T>>>,
191    /// Number of tasks currently being executed
192    active_tasks: Arc<AtomicUsize>,
193    /// Shutdown flag
194    shutdown: Arc<AtomicBool>,
195    /// Condition variable for thread coordination
196    cv: Arc<Condvar>,
197    /// Mutex for condition variable
198    #[allow(dead_code)]
199    cv_mutex: Arc<Mutex<()>>,
200    /// Pool statistics
201    stats: Arc<Mutex<PoolStatistics>>,
202}
203
204/// Statistics about pool performance
205#[derive(Debug, Clone, Default)]
206pub struct PoolStatistics {
207    /// Total tasks executed
208    pub total_tasks: usize,
209    /// Total computation time across all threads
210    pub total_computation_time: Duration,
211    /// Number of work-stealing operations
212    pub steal_attempts: usize,
213    /// Successful steals
214    pub successful_steals: usize,
215    /// Load balancing efficiency (0.0 to 1.0)
216    pub load_balance_efficiency: f64,
217}
218
219impl<T: WorkStealingTask + 'static> WorkStealingPool<T> {
220    /// Create new work-stealing pool with specified number of threads
221    pub fn new(_numthreads: usize) -> Self {
222        let _num_threads = _numthreads.max(1);
223
224        let worker_states = Arc::new(
225            (0.._num_threads)
226                .map(|_| WorkerState::new())
227                .collect::<Vec<_>>(),
228        );
229
230        let global_queue = Arc::new(Mutex::new(WorkStealingDeque::new()));
231        let active_tasks = Arc::new(AtomicUsize::new(0));
232        let shutdown = Arc::new(AtomicBool::new(false));
233        let cv = Arc::new(Condvar::new());
234        let cv_mutex = Arc::new(Mutex::new(()));
235        let stats = Arc::new(Mutex::new(PoolStatistics::default()));
236
237        let workers = (0.._num_threads)
238            .map(|worker_id| {
239                let worker_states = Arc::clone(&worker_states);
240                let global_queue = Arc::clone(&global_queue);
241                let active_tasks = Arc::clone(&active_tasks);
242                let shutdown = Arc::clone(&shutdown);
243                let cv = Arc::clone(&cv);
244                let cv_mutex = Arc::clone(&cv_mutex);
245                let stats = Arc::clone(&stats);
246
247                thread::spawn(move || {
248                    Self::worker_thread(
249                        worker_id,
250                        worker_states,
251                        global_queue,
252                        active_tasks,
253                        shutdown,
254                        cv,
255                        cv_mutex,
256                        stats,
257                    );
258                })
259            })
260            .collect();
261
262        Self {
263            workers,
264            worker_states,
265            global_queue,
266            active_tasks,
267            shutdown,
268            cv,
269            cv_mutex,
270            stats,
271        }
272    }
273
274    /// Submit a single task for execution
275    pub fn submit(&self, task: T) {
276        let mut global_queue = self.global_queue.lock().unwrap();
277        global_queue.push_back(task);
278        drop(global_queue);
279
280        // Notify workers
281        self.cv.notify_one();
282    }
283
284    /// Submit multiple tasks for execution
285    pub fn submit_all(&self, tasks: Vec<T>) {
286        let mut global_queue = self.global_queue.lock().unwrap();
287        for task in tasks {
288            global_queue.push_back(task);
289        }
290        drop(global_queue);
291
292        // Notify all workers
293        self.cv.notify_all();
294    }
295
296    /// Execute all submitted tasks and wait for completion
297    pub fn execute_and_wait(&self) -> IntegrateResult<()> {
298        // Wait for all tasks to complete
299        loop {
300            // Check if all queues are empty AND no tasks are currently being executed
301            let global_empty = self.global_queue.lock().unwrap().is_empty();
302            let locals_empty = self
303                .worker_states
304                .iter()
305                .all(|state| state.local_queue.lock().unwrap().is_empty());
306            let no_active_tasks = self.active_tasks.load(Ordering::Relaxed) == 0;
307
308            if global_empty && locals_empty && no_active_tasks {
309                break;
310            }
311
312            // Small delay to avoid busy waiting
313            thread::sleep(Duration::from_micros(100));
314        }
315
316        Ok(())
317    }
318
319    /// Get current pool statistics
320    pub fn statistics(&self) -> PoolStatistics {
321        let mut stats = self.stats.lock().unwrap();
322
323        // Update statistics from worker states
324        stats.total_tasks = self
325            .worker_states
326            .iter()
327            .map(|state| state.completed_tasks.load(Ordering::Relaxed))
328            .sum();
329
330        stats.total_computation_time = self
331            .worker_states
332            .iter()
333            .map(|state| *state.computation_time.lock().unwrap())
334            .sum();
335
336        // Calculate load balance efficiency
337        if stats.total_tasks > 0 {
338            let worker_loads: Vec<f64> = self
339                .worker_states
340                .iter()
341                .map(|state| {
342                    let completed = state.completed_tasks.load(Ordering::Relaxed);
343                    completed as f64 / stats.total_tasks as f64
344                })
345                .collect();
346
347            let ideal_load = 1.0 / self.worker_states.len() as f64;
348            let load_variance: f64 = worker_loads
349                .iter()
350                .map(|&load| (load - ideal_load).powi(2))
351                .sum::<f64>()
352                / self.worker_states.len() as f64;
353
354            stats.load_balance_efficiency = (1.0 - load_variance).max(0.0);
355        }
356
357        stats.clone()
358    }
359
360    /// Worker thread main loop
361    fn worker_thread(
362        worker_id: usize,
363        worker_states: Arc<Vec<WorkerState<T>>>,
364        global_queue: Arc<Mutex<WorkStealingDeque<T>>>,
365        active_tasks: Arc<AtomicUsize>,
366        shutdown: Arc<AtomicBool>,
367        cv: Arc<Condvar>,
368        cv_mutex: Arc<Mutex<()>>,
369        stats: Arc<Mutex<PoolStatistics>>,
370    ) {
371        let my_state = &worker_states[worker_id];
372
373        while !shutdown.load(Ordering::Relaxed) {
374            // Try to get work from local _queue first
375            let mut task_opt = my_state.local_queue.lock().unwrap().pop_back();
376
377            // If no local work, try global _queue
378            if task_opt.is_none() {
379                task_opt = global_queue.lock().unwrap().pop_back();
380            }
381
382            // If still no work, try stealing from other workers
383            if task_opt.is_none() {
384                task_opt = Self::try_steal_work(worker_id, &worker_states, &stats);
385            }
386
387            if let Some(mut task) = task_opt {
388                // Mark task as active
389                active_tasks.fetch_add(1, Ordering::Relaxed);
390
391                // Execute the task
392                let start_time = Instant::now();
393                let _result = task.execute();
394                let computation_time = start_time.elapsed();
395
396                // Mark task as completed
397                active_tasks.fetch_sub(1, Ordering::Relaxed);
398
399                // Update statistics
400                my_state.completed_tasks.fetch_add(1, Ordering::Relaxed);
401                *my_state.computation_time.lock().unwrap() += computation_time;
402            } else {
403                // No work available, wait for notification
404                let _guard = cv
405                    .wait_timeout(cv_mutex.lock().unwrap(), Duration::from_millis(10))
406                    .unwrap();
407            }
408        }
409    }
410
411    /// Try to steal work from other workers
412    fn try_steal_work(
413        worker_id: usize,
414        worker_states: &[WorkerState<T>],
415        stats: &Arc<Mutex<PoolStatistics>>,
416    ) -> Option<T> {
417        // Update steal attempt counter
418        stats.lock().unwrap().steal_attempts += 1;
419
420        // Find worker with most work (highest cost)
421        let mut best_victim = None;
422        let mut best_cost = 0.0;
423
424        for (victim_id, victim_state) in worker_states.iter().enumerate() {
425            if victim_id == worker_id {
426                continue; // Don't steal from ourselves
427            }
428
429            let queue = victim_state.local_queue.lock().unwrap();
430            let cost = queue.total_cost();
431
432            if cost > best_cost && !queue.is_empty() {
433                best_cost = cost;
434                best_victim = Some(victim_id);
435            }
436        }
437
438        // Try to steal from the best victim
439        if let Some(victim_id) = best_victim {
440            let victim_state = &worker_states[victim_id];
441            let mut victim_queue = victim_state.local_queue.lock().unwrap();
442
443            if let Some(stolen_task) = victim_queue.steal_front() {
444                // Update successful steal counter
445                stats.lock().unwrap().successful_steals += 1;
446                return Some(stolen_task);
447            }
448        }
449
450        None
451    }
452}
453
454impl<T: WorkStealingTask> Drop for WorkStealingPool<T> {
455    fn drop(&mut self) {
456        // Signal shutdown
457        self.shutdown.store(true, Ordering::Relaxed);
458        self.cv.notify_all();
459
460        // Wait for all workers to finish
461        while let Some(worker) = self.workers.pop() {
462            let _ = worker.join();
463        }
464    }
465}
466
467/// Adaptive integration task for work-stealing scheduler
468pub struct AdaptiveIntegrationTask<F: IntegrateFloat, Func> {
469    /// Function to integrate
470    integrand: Func,
471    /// Integration interval
472    interval: (F, F),
473    /// Tolerance for this region
474    tolerance: F,
475    /// Current depth (for subdivision control)
476    depth: usize,
477    /// Maximum subdivision depth
478    max_depth: usize,
479}
480
481impl<F: IntegrateFloat, Func> AdaptiveIntegrationTask<F, Func>
482where
483    Func: Fn(F) -> F + Send + Clone + 'static,
484{
485    /// Create new adaptive integration task
486    pub fn new(integrand: Func, interval: (F, F), tolerance: F, max_depth: usize) -> Self {
487        Self {
488            integrand,
489            interval,
490            tolerance,
491            depth: 0,
492            max_depth,
493        }
494    }
495
496    /// Simple trapezoidal rule integration
497    fn integrate_region(&self) -> F {
498        let (a, b) = self.interval;
499        let h = b - a;
500        let fa = (self.integrand)(a);
501        let fb = (self.integrand)(b);
502        h * (fa + fb) / F::from(2.0).unwrap()
503    }
504
505    /// Estimate integration error using subdivision
506    fn estimate_error(&self) -> F {
507        let (a, b) = self.interval;
508        let mid = (a + b) / F::from(2.0).unwrap();
509
510        // Coarse estimate (full interval)
511        let coarse = self.integrate_region();
512
513        // Fine estimate (two half intervals)
514        let left_task = AdaptiveIntegrationTask {
515            integrand: self.integrand.clone(),
516            interval: (a, mid),
517            tolerance: self.tolerance,
518            depth: self.depth + 1,
519            max_depth: self.max_depth,
520        };
521
522        let right_task = AdaptiveIntegrationTask {
523            integrand: self.integrand.clone(),
524            interval: (mid, b),
525            tolerance: self.tolerance,
526            depth: self.depth + 1,
527            max_depth: self.max_depth,
528        };
529
530        let fine = left_task.integrate_region() + right_task.integrate_region();
531
532        (fine - coarse).abs()
533    }
534}
535
536impl<F: IntegrateFloat + Send, Func> WorkStealingTask for AdaptiveIntegrationTask<F, Func>
537where
538    Func: Fn(F) -> F + Send + Clone + 'static,
539{
540    type Output = IntegrateResult<F>;
541
542    fn execute(&mut self) -> Self::Output {
543        let result = self.integrate_region();
544        Ok(result)
545    }
546
547    fn estimated_cost(&self) -> f64 {
548        let (a, b) = self.interval;
549        (b - a).to_f64().unwrap_or(1.0)
550    }
551
552    fn can_subdivide(&self) -> bool {
553        self.depth < self.max_depth && self.estimate_error() > self.tolerance
554    }
555
556    fn subdivide(&self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>> {
557        let (a, b) = self.interval;
558        let mid = (a + b) / F::from(2.0).unwrap();
559
560        let left_task = AdaptiveIntegrationTask {
561            integrand: self.integrand.clone(),
562            interval: (a, mid),
563            tolerance: self.tolerance / F::from(2.0).unwrap(),
564            depth: self.depth + 1,
565            max_depth: self.max_depth,
566        };
567
568        let right_task = AdaptiveIntegrationTask {
569            integrand: self.integrand.clone(),
570            interval: (mid, b),
571            tolerance: self.tolerance / F::from(2.0).unwrap(),
572            depth: self.depth + 1,
573            max_depth: self.max_depth,
574        };
575
576        vec![
577            Box::new(left_task) as Box<dyn WorkStealingTask<Output = Self::Output>>,
578            Box::new(right_task) as Box<dyn WorkStealingTask<Output = Self::Output>>,
579        ]
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use std::sync::atomic::AtomicI32;
587
588    #[test]
589    fn test_work_stealing_pool_basic() {
590        let pool: WorkStealingPool<Task<_, i32>> = WorkStealingPool::new(2);
591
592        // Submit some simple tasks
593        for i in 0..10 {
594            let task = Task::new(move || i * 2);
595            pool.submit(task);
596        }
597
598        // Wait for completion
599        assert!(pool.execute_and_wait().is_ok());
600
601        // Check statistics
602        let stats = pool.statistics();
603        assert_eq!(stats.total_tasks, 10);
604        assert!(stats.load_balance_efficiency >= 0.0);
605    }
606
607    #[test]
608    fn test_task_subdivision() {
609        let integrand = |x: f64| x * x;
610        let task = AdaptiveIntegrationTask::new(integrand, (0.0, 1.0), 1e-6, 5);
611
612        assert!(task.can_subdivide());
613
614        let subtasks = task.subdivide();
615        assert_eq!(subtasks.len(), 2);
616    }
617
618    #[test]
619    fn test_load_balancing() {
620        let pool: WorkStealingPool<Task<_, ()>> = WorkStealingPool::new(4);
621        let counter = Arc::new(AtomicI32::new(0));
622
623        // Submit tasks with varying computational cost
624        for i in 0..20 {
625            let counter_clone = Arc::clone(&counter);
626            let sleep_time = (i % 5) * 10; // Variable work
627
628            let task = Task::with_cost(
629                move || {
630                    thread::sleep(Duration::from_millis(sleep_time));
631                    counter_clone.fetch_add(1, Ordering::Relaxed);
632                },
633                sleep_time as f64,
634            );
635
636            pool.submit(task);
637        }
638
639        pool.execute_and_wait().unwrap();
640
641        assert_eq!(counter.load(Ordering::Relaxed), 20);
642
643        let stats = pool.statistics();
644        assert_eq!(stats.total_tasks, 20);
645        assert!(stats.steal_attempts > 0); // Should have attempted work stealing
646    }
647}