Skip to main content

numrs2/parallel/
thread_pool.rs

1//! Enhanced thread pool with work-stealing deques and advanced features
2//!
3//! This module provides a high-performance thread pool implementation with:
4//! - Work-stealing deques per thread for efficient load distribution
5//! - Thread affinity and CPU pinning support
6//! - Adaptive thread count based on workload
7//! - Priority-based task scheduling
8//! - Task dependency management
9
10use crate::error::{NumRs2Error, Result};
11use std::collections::VecDeque;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, Condvar, Mutex};
14use std::thread::{self, JoinHandle};
15use std::time::{Duration, Instant};
16
17/// Thread pool configuration
18#[derive(Debug, Clone)]
19pub struct ThreadPoolConfig {
20    /// Number of worker threads (None = auto-detect)
21    pub num_threads: Option<usize>,
22    /// Enable thread pinning to CPU cores
23    pub enable_thread_pinning: bool,
24    /// Enable adaptive thread count adjustment
25    pub adaptive_threads: bool,
26    /// Minimum number of threads (for adaptive mode)
27    pub min_threads: usize,
28    /// Maximum number of threads (for adaptive mode)
29    pub max_threads: usize,
30    /// Task queue capacity per thread
31    pub queue_capacity: usize,
32    /// Work stealing interval
33    pub steal_interval: Duration,
34    /// Thread idle timeout before parking
35    pub idle_timeout: Duration,
36}
37
38impl Default for ThreadPoolConfig {
39    fn default() -> Self {
40        let num_cpus = thread::available_parallelism().map_or(4, |n| n.get());
41        Self {
42            num_threads: Some(num_cpus),
43            enable_thread_pinning: false,
44            adaptive_threads: false,
45            min_threads: 1,
46            max_threads: num_cpus * 2,
47            queue_capacity: 1000,
48            steal_interval: Duration::from_millis(1),
49            idle_timeout: Duration::from_millis(10),
50        }
51    }
52}
53
54/// Task priority levels
55#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
56pub enum Priority {
57    Low = 0,
58    Normal = 1,
59    High = 2,
60    Critical = 3,
61}
62
63/// Task with metadata
64pub struct PoolTask {
65    pub(crate) id: u64,
66    pub(crate) priority: Priority,
67    pub(crate) submitted_at: Instant,
68    pub(crate) estimated_cost: Option<u64>,
69    pub(crate) dependencies: Vec<u64>,
70    pub(crate) task: Box<dyn FnOnce() + Send + 'static>,
71}
72
73impl std::fmt::Debug for PoolTask {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("PoolTask")
76            .field("id", &self.id)
77            .field("priority", &self.priority)
78            .field("submitted_at", &self.submitted_at)
79            .field("estimated_cost", &self.estimated_cost)
80            .field("dependencies", &self.dependencies)
81            .finish()
82    }
83}
84
85/// Thread-local worker state with work-stealing deque
86/// Cache-aligned to prevent false sharing between worker threads
87#[repr(align(64))]
88struct WorkerState {
89    id: usize,
90    deque: Mutex<VecDeque<PoolTask>>,
91    is_idle: AtomicBool,
92    tasks_executed: AtomicUsize,
93    tasks_stolen: AtomicUsize,
94    total_execution_time: Mutex<Duration>,
95    last_steal_time: Mutex<Instant>,
96    cpu_affinity: Option<usize>,
97    // Cache-line padding to prevent false sharing
98    _padding: [u8; 0], // Padding will be added by alignment
99}
100
101impl WorkerState {
102    fn new(id: usize, cpu_affinity: Option<usize>) -> Self {
103        Self {
104            id,
105            deque: Mutex::new(VecDeque::new()),
106            is_idle: AtomicBool::new(true),
107            tasks_executed: AtomicUsize::new(0),
108            tasks_stolen: AtomicUsize::new(0),
109            total_execution_time: Mutex::new(Duration::ZERO),
110            last_steal_time: Mutex::new(Instant::now()),
111            cpu_affinity,
112            _padding: [],
113        }
114    }
115
116    fn push_task(&self, task: PoolTask) -> Result<()> {
117        let mut deque = self
118            .deque
119            .lock()
120            .map_err(|_| NumRs2Error::RuntimeError("Failed to acquire deque lock".to_string()))?;
121        deque.push_back(task);
122        Ok(())
123    }
124
125    fn pop_task(&self) -> Result<Option<PoolTask>> {
126        let mut deque = self
127            .deque
128            .lock()
129            .map_err(|_| NumRs2Error::RuntimeError("Failed to acquire deque lock".to_string()))?;
130        Ok(deque.pop_front())
131    }
132
133    fn steal_task(&self) -> Result<Option<PoolTask>> {
134        let mut deque = self
135            .deque
136            .lock()
137            .map_err(|_| NumRs2Error::RuntimeError("Failed to acquire deque lock".to_string()))?;
138        let task = deque.pop_back();
139        if task.is_some() {
140            self.tasks_stolen.fetch_add(1, Ordering::Relaxed);
141        }
142        Ok(task)
143    }
144
145    fn queue_len(&self) -> usize {
146        self.deque.lock().map(|d| d.len()).unwrap_or(0)
147    }
148
149    fn is_idle(&self) -> bool {
150        self.is_idle.load(Ordering::Relaxed)
151    }
152
153    fn set_idle(&self, idle: bool) {
154        self.is_idle.store(idle, Ordering::Relaxed);
155    }
156}
157
158/// Enhanced thread pool with work-stealing and advanced features
159pub struct ThreadPool {
160    config: ThreadPoolConfig,
161    workers: Vec<Arc<WorkerState>>,
162    threads: Vec<JoinHandle<()>>,
163    shutdown: Arc<AtomicBool>,
164    global_queue: Arc<Mutex<VecDeque<PoolTask>>>,
165    idle_notify: Arc<(Mutex<()>, Condvar)>,
166    next_task_id: AtomicUsize,
167    stats: Arc<Mutex<ThreadPoolStats>>,
168    completed_tasks: Arc<Mutex<Vec<u64>>>,
169}
170
171/// Thread pool statistics
172#[derive(Debug, Clone, Default)]
173pub struct ThreadPoolStats {
174    pub tasks_submitted: u64,
175    pub tasks_completed: u64,
176    pub tasks_stolen: u64,
177    pub average_queue_time: Duration,
178    pub average_execution_time: Duration,
179    pub worker_utilization: Vec<f64>,
180    pub active_threads: usize,
181}
182
183impl ThreadPool {
184    /// Create a new thread pool with default configuration
185    pub fn new() -> Result<Self> {
186        Self::with_config(ThreadPoolConfig::default())
187    }
188
189    /// Create a new thread pool with custom configuration
190    pub fn with_config(config: ThreadPoolConfig) -> Result<Self> {
191        let num_threads = config
192            .num_threads
193            .unwrap_or_else(|| thread::available_parallelism().map_or(4, |n| n.get()));
194
195        let shutdown = Arc::new(AtomicBool::new(false));
196        let global_queue = Arc::new(Mutex::new(VecDeque::new()));
197        let idle_notify = Arc::new((Mutex::new(()), Condvar::new()));
198        let stats = Arc::new(Mutex::new(ThreadPoolStats::default()));
199        let completed_tasks = Arc::new(Mutex::new(Vec::new()));
200
201        let mut workers = Vec::new();
202        let mut threads = Vec::new();
203
204        // Create worker states
205        for i in 0..num_threads {
206            let cpu_affinity = if config.enable_thread_pinning {
207                Some(i % num_cpus::get())
208            } else {
209                None
210            };
211            workers.push(Arc::new(WorkerState::new(i, cpu_affinity)));
212        }
213
214        // Spawn worker threads
215        for worker in &workers {
216            let worker_clone = Arc::clone(worker);
217            let workers_clone = workers.clone();
218            let shutdown_clone = Arc::clone(&shutdown);
219            let global_queue_clone = Arc::clone(&global_queue);
220            let idle_notify_clone = Arc::clone(&idle_notify);
221            let stats_clone = Arc::clone(&stats);
222            let completed_tasks_clone = Arc::clone(&completed_tasks);
223            let config_clone = config.clone();
224
225            let handle = thread::spawn(move || {
226                // Set thread affinity if enabled
227                if let Some(cpu_id) = worker_clone.cpu_affinity {
228                    Self::set_thread_affinity(cpu_id);
229                }
230
231                Self::worker_main(
232                    worker_clone,
233                    workers_clone,
234                    shutdown_clone,
235                    global_queue_clone,
236                    idle_notify_clone,
237                    stats_clone,
238                    completed_tasks_clone,
239                    config_clone,
240                );
241            });
242
243            threads.push(handle);
244        }
245
246        Ok(Self {
247            config,
248            workers,
249            threads,
250            shutdown,
251            global_queue,
252            idle_notify,
253            next_task_id: AtomicUsize::new(0),
254            stats,
255            completed_tasks,
256        })
257    }
258
259    /// Submit a task to the pool
260    pub fn submit<F>(&self, task: F) -> Result<u64>
261    where
262        F: FnOnce() + Send + 'static,
263    {
264        self.submit_with_priority(task, Priority::Normal, None)
265    }
266
267    /// Submit a task with priority and cost estimate
268    pub fn submit_with_priority<F>(
269        &self,
270        task: F,
271        priority: Priority,
272        estimated_cost: Option<u64>,
273    ) -> Result<u64>
274    where
275        F: FnOnce() + Send + 'static,
276    {
277        if self.shutdown.load(Ordering::Relaxed) {
278            return Err(NumRs2Error::RuntimeError(
279                "Thread pool is shutting down".to_string(),
280            ));
281        }
282
283        let task_id = self.next_task_id.fetch_add(1, Ordering::Relaxed) as u64;
284
285        let pool_task = PoolTask {
286            id: task_id,
287            priority,
288            submitted_at: Instant::now(),
289            estimated_cost,
290            dependencies: Vec::new(),
291            task: Box::new(task),
292        };
293
294        // Find least loaded worker
295        let target_worker = self.find_least_loaded_worker();
296
297        if let Some(worker_idx) = target_worker {
298            self.workers[worker_idx].push_task(pool_task)?;
299
300            // Wake up worker if idle
301            if self.workers[worker_idx].is_idle() {
302                let (lock, cvar) = &*self.idle_notify;
303                let _guard = lock.lock().map_err(|_| {
304                    NumRs2Error::RuntimeError("Failed to acquire idle notify lock".to_string())
305                })?;
306                cvar.notify_one();
307            }
308        } else {
309            // Fallback to global queue
310            let mut global = self.global_queue.lock().map_err(|_| {
311                NumRs2Error::RuntimeError("Failed to acquire global queue lock".to_string())
312            })?;
313            global.push_back(pool_task);
314
315            let (lock, cvar) = &*self.idle_notify;
316            let _guard = lock.lock().map_err(|_| {
317                NumRs2Error::RuntimeError("Failed to acquire idle notify lock".to_string())
318            })?;
319            cvar.notify_all();
320        }
321
322        // Update stats
323        if let Ok(mut stats) = self.stats.lock() {
324            stats.tasks_submitted += 1;
325        }
326
327        Ok(task_id)
328    }
329
330    /// Get pool statistics
331    pub fn statistics(&self) -> ThreadPoolStats {
332        if let Ok(mut stats) = self.stats.lock() {
333            stats.worker_utilization = self
334                .workers
335                .iter()
336                .map(|w| if w.is_idle() { 0.0 } else { 1.0 })
337                .collect();
338
339            stats.active_threads = self.workers.iter().filter(|w| !w.is_idle()).count();
340
341            stats.clone()
342        } else {
343            ThreadPoolStats::default()
344        }
345    }
346
347    /// Get number of worker threads
348    pub fn num_threads(&self) -> usize {
349        self.workers.len()
350    }
351
352    /// Get number of pending tasks
353    pub fn pending_tasks(&self) -> usize {
354        let global_count = self.global_queue.lock().map(|q| q.len()).unwrap_or(0);
355
356        let worker_count: usize = self.workers.iter().map(|w| w.queue_len()).sum();
357
358        global_count + worker_count
359    }
360
361    /// Wait for all pending tasks to complete
362    pub fn wait(&self) -> Result<()> {
363        // Wait until there are no pending tasks AND all workers are idle
364        while self.pending_tasks() > 0 || self.has_active_workers() {
365            thread::sleep(Duration::from_millis(1));
366        }
367        Ok(())
368    }
369
370    /// Check if any workers are actively executing tasks
371    fn has_active_workers(&self) -> bool {
372        self.workers.iter().any(|w| !w.is_idle())
373    }
374
375    /// Shutdown the thread pool gracefully
376    pub fn shutdown(self) -> Result<()> {
377        self.shutdown.store(true, Ordering::Relaxed);
378
379        // Wake up all workers
380        let (lock, cvar) = &*self.idle_notify;
381        let _guard = lock.lock().map_err(|_| {
382            NumRs2Error::RuntimeError("Failed to acquire idle notify lock".to_string())
383        })?;
384        cvar.notify_all();
385        drop(_guard);
386
387        // Join all threads
388        for handle in self.threads {
389            if let Err(_e) = handle.join() {
390                // Log error but continue shutting down other threads
391            }
392        }
393
394        Ok(())
395    }
396
397    // Private helper methods
398
399    fn find_least_loaded_worker(&self) -> Option<usize> {
400        self.workers
401            .iter()
402            .enumerate()
403            .min_by_key(|(_, w)| w.queue_len())
404            .map(|(idx, _)| idx)
405    }
406
407    fn worker_main(
408        worker: Arc<WorkerState>,
409        workers: Vec<Arc<WorkerState>>,
410        shutdown: Arc<AtomicBool>,
411        global_queue: Arc<Mutex<VecDeque<PoolTask>>>,
412        idle_notify: Arc<(Mutex<()>, Condvar)>,
413        stats: Arc<Mutex<ThreadPoolStats>>,
414        completed_tasks: Arc<Mutex<Vec<u64>>>,
415        config: ThreadPoolConfig,
416    ) {
417        let worker_id = worker.id;
418
419        while !shutdown.load(Ordering::Relaxed) {
420            let mut task_found = false;
421
422            // 1. Try local queue
423            if let Ok(Some(task)) = worker.pop_task() {
424                Self::execute_task(task, &worker, &stats, &completed_tasks);
425                task_found = true;
426            }
427
428            // 2. Try global queue
429            if !task_found {
430                if let Ok(mut global) = global_queue.try_lock() {
431                    if let Some(task) = global.pop_front() {
432                        drop(global);
433                        Self::execute_task(task, &worker, &stats, &completed_tasks);
434                        task_found = true;
435                    }
436                }
437            }
438
439            // 3. Try work stealing
440            if !task_found {
441                if let Some(stolen_task) = Self::try_steal_work(&worker, &workers, &config) {
442                    Self::execute_task(stolen_task, &worker, &stats, &completed_tasks);
443                    task_found = true;
444                }
445            }
446
447            // 4. Park if no work found
448            if !task_found {
449                worker.set_idle(true);
450
451                let (lock, cvar) = &*idle_notify;
452                if let Ok(guard) = lock.lock() {
453                    let _result = cvar.wait_timeout(guard, config.idle_timeout);
454                }
455
456                worker.set_idle(false);
457
458                // Check shutdown again after waking up
459                if shutdown.load(Ordering::Relaxed) {
460                    break;
461                }
462            }
463        }
464    }
465
466    fn execute_task(
467        task: PoolTask,
468        worker: &Arc<WorkerState>,
469        stats: &Arc<Mutex<ThreadPoolStats>>,
470        completed_tasks: &Arc<Mutex<Vec<u64>>>,
471    ) {
472        let start_time = Instant::now();
473        let task_id = task.id;
474
475        // Execute the task
476        (task.task)();
477
478        let execution_time = start_time.elapsed();
479
480        // Update worker stats
481        worker.tasks_executed.fetch_add(1, Ordering::Relaxed);
482        if let Ok(mut total_time) = worker.total_execution_time.lock() {
483            *total_time += execution_time;
484        }
485
486        // Mark task as completed
487        if let Ok(mut completed) = completed_tasks.lock() {
488            completed.push(task_id);
489        }
490
491        // Update global stats
492        if let Ok(mut global_stats) = stats.lock() {
493            global_stats.tasks_completed += 1;
494
495            // Update average execution time (exponential moving average)
496            let alpha = 0.1;
497            global_stats.average_execution_time = Duration::from_secs_f64(
498                alpha * execution_time.as_secs_f64()
499                    + (1.0 - alpha) * global_stats.average_execution_time.as_secs_f64(),
500            );
501        }
502    }
503
504    fn try_steal_work(
505        worker: &Arc<WorkerState>,
506        workers: &[Arc<WorkerState>],
507        config: &ThreadPoolConfig,
508    ) -> Option<PoolTask> {
509        let now = Instant::now();
510
511        // Check steal interval
512        if let Ok(mut last_steal) = worker.last_steal_time.lock() {
513            if now.duration_since(*last_steal) < config.steal_interval {
514                return None;
515            }
516            *last_steal = now;
517        }
518
519        // Find victim with most tasks
520        let victim = workers
521            .iter()
522            .filter(|w| w.id != worker.id)
523            .max_by_key(|w| w.queue_len())?;
524
525        if victim.queue_len() > 1 {
526            if let Ok(Some(task)) = victim.steal_task() {
527                return Some(task);
528            }
529        }
530
531        None
532    }
533
534    fn set_thread_affinity(_cpu_id: usize) {
535        // Platform-specific implementation would go here
536        // For now, this is a no-op as it requires platform-specific code
537        #[cfg(target_os = "linux")]
538        {
539            // On Linux, we could use libc::pthread_setaffinity_np
540            // But for pure Rust, we'll skip this for now
541        }
542    }
543}
544
545impl Default for ThreadPool {
546    fn default() -> Self {
547        Self::new().expect("Failed to create default thread pool")
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use std::sync::atomic::AtomicU32;
555
556    #[test]
557    fn test_thread_pool_creation() {
558        let pool = ThreadPool::new().expect("Failed to create thread pool");
559        assert!(pool.num_threads() > 0);
560    }
561
562    #[test]
563    fn test_task_submission() {
564        let pool = ThreadPool::new().expect("Failed to create thread pool");
565        let counter = Arc::new(AtomicU32::new(0));
566
567        for _ in 0..10 {
568            let counter_clone = Arc::clone(&counter);
569            pool.submit(move || {
570                counter_clone.fetch_add(1, Ordering::SeqCst);
571            })
572            .expect("Failed to submit task");
573        }
574
575        pool.wait().expect("Failed to wait for tasks");
576        assert_eq!(counter.load(Ordering::SeqCst), 10);
577    }
578
579    #[test]
580    fn test_priority_tasks() {
581        let pool = ThreadPool::new().expect("Failed to create thread pool");
582        let counter = Arc::new(AtomicU32::new(0));
583
584        // Submit high priority task
585        let counter_clone = Arc::clone(&counter);
586        pool.submit_with_priority(
587            move || {
588                counter_clone.fetch_add(1, Ordering::SeqCst);
589            },
590            Priority::High,
591            None,
592        )
593        .expect("Failed to submit high priority task");
594
595        pool.wait().expect("Failed to wait for tasks");
596        assert_eq!(counter.load(Ordering::SeqCst), 1);
597    }
598
599    #[test]
600    fn test_statistics() {
601        let pool = ThreadPool::new().expect("Failed to create thread pool");
602
603        for _ in 0..5 {
604            pool.submit(|| {
605                thread::sleep(Duration::from_millis(10));
606            })
607            .expect("Failed to submit task");
608        }
609
610        thread::sleep(Duration::from_millis(100));
611
612        let stats = pool.statistics();
613        assert_eq!(stats.tasks_submitted, 5);
614        assert!(stats.active_threads <= pool.num_threads());
615    }
616
617    #[test]
618    fn test_work_stealing() {
619        let config = ThreadPoolConfig {
620            num_threads: Some(2),
621            ..Default::default()
622        };
623        let pool = ThreadPool::with_config(config).expect("Failed to create thread pool");
624        let counter = Arc::new(AtomicU32::new(0));
625
626        // Submit many tasks to trigger work stealing
627        for _ in 0..20 {
628            let counter_clone = Arc::clone(&counter);
629            pool.submit(move || {
630                thread::sleep(Duration::from_millis(5));
631                counter_clone.fetch_add(1, Ordering::SeqCst);
632            })
633            .expect("Failed to submit task");
634        }
635
636        pool.wait().expect("Failed to wait for tasks");
637
638        // Extra wait to ensure all tasks complete
639        thread::sleep(Duration::from_millis(200));
640
641        assert_eq!(counter.load(Ordering::SeqCst), 20);
642    }
643}