1use crate::error::{NumRs2Error, Result};
11use std::collections::VecDeque;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, Condvar, Mutex};
14use std::thread::{self, JoinHandle};
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone)]
19pub struct ThreadPoolConfig {
20 pub num_threads: Option<usize>,
22 pub enable_thread_pinning: bool,
24 pub adaptive_threads: bool,
26 pub min_threads: usize,
28 pub max_threads: usize,
30 pub queue_capacity: usize,
32 pub steal_interval: Duration,
34 pub idle_timeout: Duration,
36}
37
38impl Default for ThreadPoolConfig {
39 fn default() -> Self {
40 let num_cpus = thread::available_parallelism().map_or(4, |n| n.get());
41 Self {
42 num_threads: Some(num_cpus),
43 enable_thread_pinning: false,
44 adaptive_threads: false,
45 min_threads: 1,
46 max_threads: num_cpus * 2,
47 queue_capacity: 1000,
48 steal_interval: Duration::from_millis(1),
49 idle_timeout: Duration::from_millis(10),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
56pub enum Priority {
57 Low = 0,
58 Normal = 1,
59 High = 2,
60 Critical = 3,
61}
62
63pub struct PoolTask {
65 pub(crate) id: u64,
66 pub(crate) priority: Priority,
67 pub(crate) submitted_at: Instant,
68 pub(crate) estimated_cost: Option<u64>,
69 pub(crate) dependencies: Vec<u64>,
70 pub(crate) task: Box<dyn FnOnce() + Send + 'static>,
71}
72
73impl std::fmt::Debug for PoolTask {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("PoolTask")
76 .field("id", &self.id)
77 .field("priority", &self.priority)
78 .field("submitted_at", &self.submitted_at)
79 .field("estimated_cost", &self.estimated_cost)
80 .field("dependencies", &self.dependencies)
81 .finish()
82 }
83}
84
85#[repr(align(64))]
88struct WorkerState {
89 id: usize,
90 deque: Mutex<VecDeque<PoolTask>>,
91 is_idle: AtomicBool,
92 tasks_executed: AtomicUsize,
93 tasks_stolen: AtomicUsize,
94 total_execution_time: Mutex<Duration>,
95 last_steal_time: Mutex<Instant>,
96 cpu_affinity: Option<usize>,
97 _padding: [u8; 0], }
100
101impl WorkerState {
102 fn new(id: usize, cpu_affinity: Option<usize>) -> Self {
103 Self {
104 id,
105 deque: Mutex::new(VecDeque::new()),
106 is_idle: AtomicBool::new(true),
107 tasks_executed: AtomicUsize::new(0),
108 tasks_stolen: AtomicUsize::new(0),
109 total_execution_time: Mutex::new(Duration::ZERO),
110 last_steal_time: Mutex::new(Instant::now()),
111 cpu_affinity,
112 _padding: [],
113 }
114 }
115
116 fn push_task(&self, task: PoolTask) -> Result<()> {
117 let mut deque = self
118 .deque
119 .lock()
120 .map_err(|_| NumRs2Error::RuntimeError("Failed to acquire deque lock".to_string()))?;
121 deque.push_back(task);
122 Ok(())
123 }
124
125 fn pop_task(&self) -> Result<Option<PoolTask>> {
126 let mut deque = self
127 .deque
128 .lock()
129 .map_err(|_| NumRs2Error::RuntimeError("Failed to acquire deque lock".to_string()))?;
130 Ok(deque.pop_front())
131 }
132
133 fn steal_task(&self) -> Result<Option<PoolTask>> {
134 let mut deque = self
135 .deque
136 .lock()
137 .map_err(|_| NumRs2Error::RuntimeError("Failed to acquire deque lock".to_string()))?;
138 let task = deque.pop_back();
139 if task.is_some() {
140 self.tasks_stolen.fetch_add(1, Ordering::Relaxed);
141 }
142 Ok(task)
143 }
144
145 fn queue_len(&self) -> usize {
146 self.deque.lock().map(|d| d.len()).unwrap_or(0)
147 }
148
149 fn is_idle(&self) -> bool {
150 self.is_idle.load(Ordering::Relaxed)
151 }
152
153 fn set_idle(&self, idle: bool) {
154 self.is_idle.store(idle, Ordering::Relaxed);
155 }
156}
157
158pub struct ThreadPool {
160 config: ThreadPoolConfig,
161 workers: Vec<Arc<WorkerState>>,
162 threads: Vec<JoinHandle<()>>,
163 shutdown: Arc<AtomicBool>,
164 global_queue: Arc<Mutex<VecDeque<PoolTask>>>,
165 idle_notify: Arc<(Mutex<()>, Condvar)>,
166 next_task_id: AtomicUsize,
167 stats: Arc<Mutex<ThreadPoolStats>>,
168 completed_tasks: Arc<Mutex<Vec<u64>>>,
169}
170
171#[derive(Debug, Clone, Default)]
173pub struct ThreadPoolStats {
174 pub tasks_submitted: u64,
175 pub tasks_completed: u64,
176 pub tasks_stolen: u64,
177 pub average_queue_time: Duration,
178 pub average_execution_time: Duration,
179 pub worker_utilization: Vec<f64>,
180 pub active_threads: usize,
181}
182
183impl ThreadPool {
184 pub fn new() -> Result<Self> {
186 Self::with_config(ThreadPoolConfig::default())
187 }
188
189 pub fn with_config(config: ThreadPoolConfig) -> Result<Self> {
191 let num_threads = config
192 .num_threads
193 .unwrap_or_else(|| thread::available_parallelism().map_or(4, |n| n.get()));
194
195 let shutdown = Arc::new(AtomicBool::new(false));
196 let global_queue = Arc::new(Mutex::new(VecDeque::new()));
197 let idle_notify = Arc::new((Mutex::new(()), Condvar::new()));
198 let stats = Arc::new(Mutex::new(ThreadPoolStats::default()));
199 let completed_tasks = Arc::new(Mutex::new(Vec::new()));
200
201 let mut workers = Vec::new();
202 let mut threads = Vec::new();
203
204 for i in 0..num_threads {
206 let cpu_affinity = if config.enable_thread_pinning {
207 Some(i % num_cpus::get())
208 } else {
209 None
210 };
211 workers.push(Arc::new(WorkerState::new(i, cpu_affinity)));
212 }
213
214 for worker in &workers {
216 let worker_clone = Arc::clone(worker);
217 let workers_clone = workers.clone();
218 let shutdown_clone = Arc::clone(&shutdown);
219 let global_queue_clone = Arc::clone(&global_queue);
220 let idle_notify_clone = Arc::clone(&idle_notify);
221 let stats_clone = Arc::clone(&stats);
222 let completed_tasks_clone = Arc::clone(&completed_tasks);
223 let config_clone = config.clone();
224
225 let handle = thread::spawn(move || {
226 if let Some(cpu_id) = worker_clone.cpu_affinity {
228 Self::set_thread_affinity(cpu_id);
229 }
230
231 Self::worker_main(
232 worker_clone,
233 workers_clone,
234 shutdown_clone,
235 global_queue_clone,
236 idle_notify_clone,
237 stats_clone,
238 completed_tasks_clone,
239 config_clone,
240 );
241 });
242
243 threads.push(handle);
244 }
245
246 Ok(Self {
247 config,
248 workers,
249 threads,
250 shutdown,
251 global_queue,
252 idle_notify,
253 next_task_id: AtomicUsize::new(0),
254 stats,
255 completed_tasks,
256 })
257 }
258
259 pub fn submit<F>(&self, task: F) -> Result<u64>
261 where
262 F: FnOnce() + Send + 'static,
263 {
264 self.submit_with_priority(task, Priority::Normal, None)
265 }
266
267 pub fn submit_with_priority<F>(
269 &self,
270 task: F,
271 priority: Priority,
272 estimated_cost: Option<u64>,
273 ) -> Result<u64>
274 where
275 F: FnOnce() + Send + 'static,
276 {
277 if self.shutdown.load(Ordering::Relaxed) {
278 return Err(NumRs2Error::RuntimeError(
279 "Thread pool is shutting down".to_string(),
280 ));
281 }
282
283 let task_id = self.next_task_id.fetch_add(1, Ordering::Relaxed) as u64;
284
285 let pool_task = PoolTask {
286 id: task_id,
287 priority,
288 submitted_at: Instant::now(),
289 estimated_cost,
290 dependencies: Vec::new(),
291 task: Box::new(task),
292 };
293
294 let target_worker = self.find_least_loaded_worker();
296
297 if let Some(worker_idx) = target_worker {
298 self.workers[worker_idx].push_task(pool_task)?;
299
300 if self.workers[worker_idx].is_idle() {
302 let (lock, cvar) = &*self.idle_notify;
303 let _guard = lock.lock().map_err(|_| {
304 NumRs2Error::RuntimeError("Failed to acquire idle notify lock".to_string())
305 })?;
306 cvar.notify_one();
307 }
308 } else {
309 let mut global = self.global_queue.lock().map_err(|_| {
311 NumRs2Error::RuntimeError("Failed to acquire global queue lock".to_string())
312 })?;
313 global.push_back(pool_task);
314
315 let (lock, cvar) = &*self.idle_notify;
316 let _guard = lock.lock().map_err(|_| {
317 NumRs2Error::RuntimeError("Failed to acquire idle notify lock".to_string())
318 })?;
319 cvar.notify_all();
320 }
321
322 if let Ok(mut stats) = self.stats.lock() {
324 stats.tasks_submitted += 1;
325 }
326
327 Ok(task_id)
328 }
329
330 pub fn statistics(&self) -> ThreadPoolStats {
332 if let Ok(mut stats) = self.stats.lock() {
333 stats.worker_utilization = self
334 .workers
335 .iter()
336 .map(|w| if w.is_idle() { 0.0 } else { 1.0 })
337 .collect();
338
339 stats.active_threads = self.workers.iter().filter(|w| !w.is_idle()).count();
340
341 stats.clone()
342 } else {
343 ThreadPoolStats::default()
344 }
345 }
346
347 pub fn num_threads(&self) -> usize {
349 self.workers.len()
350 }
351
352 pub fn pending_tasks(&self) -> usize {
354 let global_count = self.global_queue.lock().map(|q| q.len()).unwrap_or(0);
355
356 let worker_count: usize = self.workers.iter().map(|w| w.queue_len()).sum();
357
358 global_count + worker_count
359 }
360
361 pub fn wait(&self) -> Result<()> {
363 while self.pending_tasks() > 0 || self.has_active_workers() {
365 thread::sleep(Duration::from_millis(1));
366 }
367 Ok(())
368 }
369
370 fn has_active_workers(&self) -> bool {
372 self.workers.iter().any(|w| !w.is_idle())
373 }
374
375 pub fn shutdown(self) -> Result<()> {
377 self.shutdown.store(true, Ordering::Relaxed);
378
379 let (lock, cvar) = &*self.idle_notify;
381 let _guard = lock.lock().map_err(|_| {
382 NumRs2Error::RuntimeError("Failed to acquire idle notify lock".to_string())
383 })?;
384 cvar.notify_all();
385 drop(_guard);
386
387 for handle in self.threads {
389 if let Err(_e) = handle.join() {
390 }
392 }
393
394 Ok(())
395 }
396
397 fn find_least_loaded_worker(&self) -> Option<usize> {
400 self.workers
401 .iter()
402 .enumerate()
403 .min_by_key(|(_, w)| w.queue_len())
404 .map(|(idx, _)| idx)
405 }
406
407 fn worker_main(
408 worker: Arc<WorkerState>,
409 workers: Vec<Arc<WorkerState>>,
410 shutdown: Arc<AtomicBool>,
411 global_queue: Arc<Mutex<VecDeque<PoolTask>>>,
412 idle_notify: Arc<(Mutex<()>, Condvar)>,
413 stats: Arc<Mutex<ThreadPoolStats>>,
414 completed_tasks: Arc<Mutex<Vec<u64>>>,
415 config: ThreadPoolConfig,
416 ) {
417 let worker_id = worker.id;
418
419 while !shutdown.load(Ordering::Relaxed) {
420 let mut task_found = false;
421
422 if let Ok(Some(task)) = worker.pop_task() {
424 Self::execute_task(task, &worker, &stats, &completed_tasks);
425 task_found = true;
426 }
427
428 if !task_found {
430 if let Ok(mut global) = global_queue.try_lock() {
431 if let Some(task) = global.pop_front() {
432 drop(global);
433 Self::execute_task(task, &worker, &stats, &completed_tasks);
434 task_found = true;
435 }
436 }
437 }
438
439 if !task_found {
441 if let Some(stolen_task) = Self::try_steal_work(&worker, &workers, &config) {
442 Self::execute_task(stolen_task, &worker, &stats, &completed_tasks);
443 task_found = true;
444 }
445 }
446
447 if !task_found {
449 worker.set_idle(true);
450
451 let (lock, cvar) = &*idle_notify;
452 if let Ok(guard) = lock.lock() {
453 let _result = cvar.wait_timeout(guard, config.idle_timeout);
454 }
455
456 worker.set_idle(false);
457
458 if shutdown.load(Ordering::Relaxed) {
460 break;
461 }
462 }
463 }
464 }
465
466 fn execute_task(
467 task: PoolTask,
468 worker: &Arc<WorkerState>,
469 stats: &Arc<Mutex<ThreadPoolStats>>,
470 completed_tasks: &Arc<Mutex<Vec<u64>>>,
471 ) {
472 let start_time = Instant::now();
473 let task_id = task.id;
474
475 (task.task)();
477
478 let execution_time = start_time.elapsed();
479
480 worker.tasks_executed.fetch_add(1, Ordering::Relaxed);
482 if let Ok(mut total_time) = worker.total_execution_time.lock() {
483 *total_time += execution_time;
484 }
485
486 if let Ok(mut completed) = completed_tasks.lock() {
488 completed.push(task_id);
489 }
490
491 if let Ok(mut global_stats) = stats.lock() {
493 global_stats.tasks_completed += 1;
494
495 let alpha = 0.1;
497 global_stats.average_execution_time = Duration::from_secs_f64(
498 alpha * execution_time.as_secs_f64()
499 + (1.0 - alpha) * global_stats.average_execution_time.as_secs_f64(),
500 );
501 }
502 }
503
504 fn try_steal_work(
505 worker: &Arc<WorkerState>,
506 workers: &[Arc<WorkerState>],
507 config: &ThreadPoolConfig,
508 ) -> Option<PoolTask> {
509 let now = Instant::now();
510
511 if let Ok(mut last_steal) = worker.last_steal_time.lock() {
513 if now.duration_since(*last_steal) < config.steal_interval {
514 return None;
515 }
516 *last_steal = now;
517 }
518
519 let victim = workers
521 .iter()
522 .filter(|w| w.id != worker.id)
523 .max_by_key(|w| w.queue_len())?;
524
525 if victim.queue_len() > 1 {
526 if let Ok(Some(task)) = victim.steal_task() {
527 return Some(task);
528 }
529 }
530
531 None
532 }
533
534 fn set_thread_affinity(_cpu_id: usize) {
535 #[cfg(target_os = "linux")]
538 {
539 }
542 }
543}
544
545impl Default for ThreadPool {
546 fn default() -> Self {
547 Self::new().expect("Failed to create default thread pool")
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use std::sync::atomic::AtomicU32;
555
556 #[test]
557 fn test_thread_pool_creation() {
558 let pool = ThreadPool::new().expect("Failed to create thread pool");
559 assert!(pool.num_threads() > 0);
560 }
561
562 #[test]
563 fn test_task_submission() {
564 let pool = ThreadPool::new().expect("Failed to create thread pool");
565 let counter = Arc::new(AtomicU32::new(0));
566
567 for _ in 0..10 {
568 let counter_clone = Arc::clone(&counter);
569 pool.submit(move || {
570 counter_clone.fetch_add(1, Ordering::SeqCst);
571 })
572 .expect("Failed to submit task");
573 }
574
575 pool.wait().expect("Failed to wait for tasks");
576 assert_eq!(counter.load(Ordering::SeqCst), 10);
577 }
578
579 #[test]
580 fn test_priority_tasks() {
581 let pool = ThreadPool::new().expect("Failed to create thread pool");
582 let counter = Arc::new(AtomicU32::new(0));
583
584 let counter_clone = Arc::clone(&counter);
586 pool.submit_with_priority(
587 move || {
588 counter_clone.fetch_add(1, Ordering::SeqCst);
589 },
590 Priority::High,
591 None,
592 )
593 .expect("Failed to submit high priority task");
594
595 pool.wait().expect("Failed to wait for tasks");
596 assert_eq!(counter.load(Ordering::SeqCst), 1);
597 }
598
599 #[test]
600 fn test_statistics() {
601 let pool = ThreadPool::new().expect("Failed to create thread pool");
602
603 for _ in 0..5 {
604 pool.submit(|| {
605 thread::sleep(Duration::from_millis(10));
606 })
607 .expect("Failed to submit task");
608 }
609
610 thread::sleep(Duration::from_millis(100));
611
612 let stats = pool.statistics();
613 assert_eq!(stats.tasks_submitted, 5);
614 assert!(stats.active_threads <= pool.num_threads());
615 }
616
617 #[test]
618 fn test_work_stealing() {
619 let config = ThreadPoolConfig {
620 num_threads: Some(2),
621 ..Default::default()
622 };
623 let pool = ThreadPool::with_config(config).expect("Failed to create thread pool");
624 let counter = Arc::new(AtomicU32::new(0));
625
626 for _ in 0..20 {
628 let counter_clone = Arc::clone(&counter);
629 pool.submit(move || {
630 thread::sleep(Duration::from_millis(5));
631 counter_clone.fetch_add(1, Ordering::SeqCst);
632 })
633 .expect("Failed to submit task");
634 }
635
636 pool.wait().expect("Failed to wait for tasks");
637
638 thread::sleep(Duration::from_millis(200));
640
641 assert_eq!(counter.load(Ordering::SeqCst), 20);
642 }
643}