1#[cfg(feature = "unstable-mpmc")]
2use std::sync::mpmc as channel;
3#[cfg(not(feature = "unstable-mpmc"))]
4use std::sync::mpsc as channel;
5
6use parking_lot::{Condvar, Mutex};
7use std::{
8 cmp::Ordering as CmpOrdering,
9 collections::BinaryHeap,
10 panic::{catch_unwind, AssertUnwindSafe},
11 sync::{
12 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
13 Arc, Barrier,
14 },
15 thread::{Builder, JoinHandle},
16 time::Instant,
17};
18
19use crate::error::Result;
20
21pub const DEFAULT_POOL_CAPACITY: usize = 4;
22
23pub type Task = Box<dyn FnOnce() + Send + 'static>;
24
25enum WorkerMessage {
26 Task(Task),
27 Terminate,
28}
29
30pub struct ThreadPool {
31 workers: Vec<Worker>,
32 senders: Vec<channel::Sender<WorkerMessage>>,
33 next_worker: AtomicUsize,
34}
35
36impl Default for ThreadPool {
37 fn default() -> Self {
38 let default_capacity = std::thread::available_parallelism()
39 .map(|n| n.get())
40 .unwrap_or(DEFAULT_POOL_CAPACITY);
41 Self::new(default_capacity)
42 }
43}
44
45impl ThreadPool {
46 pub fn new(capacity: usize) -> Self {
47 let mut workers = Vec::with_capacity(capacity);
48 let mut senders = Vec::with_capacity(capacity);
49
50 for id in 0..capacity {
51 let (sender, receiver) = channel::channel::<WorkerMessage>();
52 workers.push(Worker::new(id, receiver));
53 senders.push(sender);
54 }
55
56 Self {
57 workers,
58 senders,
59 next_worker: AtomicUsize::new(0),
60 }
61 }
62
63 pub fn exec<F>(&self, task: F) -> Result<()>
64 where
65 F: FnOnce() + Send + 'static,
66 {
67 let index = self.next_worker.fetch_add(1, Ordering::Relaxed) % self.senders.len();
69 Ok(self.senders[index].send(WorkerMessage::Task(Box::new(task)))?)
70 }
71
72 pub fn workers_len(&self) -> usize {
73 self.workers.len()
74 }
75}
76
77impl Drop for ThreadPool {
78 fn drop(&mut self) {
79 for sender in &self.senders {
80 let _ = sender.send(WorkerMessage::Terminate);
81 }
82 for worker in &mut self.workers {
83 if let Some(t) = worker.take_thread() {
84 t.join().unwrap();
85 }
86 }
87 }
88}
89
90struct Worker {
91 #[allow(dead_code)]
92 id: usize,
93 thread: Option<JoinHandle<()>>,
94}
95
96impl Worker {
97 pub fn new(id: usize, receiver: channel::Receiver<WorkerMessage>) -> Self {
98 let thread = Some(
99 Builder::new()
100 .name(format!("thread-pool-worker-{id}"))
101 .spawn(move || {
102 while let Ok(message) = receiver.recv() {
103 match message {
104 WorkerMessage::Task(task) => task(),
105 WorkerMessage::Terminate => break,
106 }
107 }
108 })
109 .expect("Couldn't create the worker thread id={id}"),
110 );
111
112 Self { id, thread }
113 }
114
115 pub fn take_thread(&mut self) -> Option<JoinHandle<()>> {
116 self.thread.take()
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
121pub enum TaskPriority {
122 Low = 0,
123 Normal = 1,
124 High = 2,
125 Critical = 3,
126}
127
128#[derive(Debug, Default)]
129pub struct ComputePoolMetrics {
130 pub tasks_submitted: AtomicU64,
131 pub tasks_completed: AtomicU64,
132 pub tasks_failed: AtomicU64,
133 pub active_workers: AtomicUsize,
134 pub queue_depth_low: AtomicUsize,
135 pub queue_depth_normal: AtomicUsize,
136 pub queue_depth_high: AtomicUsize,
137 pub queue_depth_critical: AtomicUsize,
138 pub total_execution_time_ns: AtomicU64,
139}
140
141impl ComputePoolMetrics {
142 pub fn tasks_submitted(&self) -> u64 {
143 self.tasks_submitted.load(Ordering::Relaxed)
144 }
145
146 pub fn tasks_completed(&self) -> u64 {
147 self.tasks_completed.load(Ordering::Relaxed)
148 }
149
150 pub fn tasks_failed(&self) -> u64 {
151 self.tasks_failed.load(Ordering::Relaxed)
152 }
153
154 pub fn active_workers(&self) -> usize {
155 self.active_workers.load(Ordering::Relaxed)
156 }
157
158 pub fn queue_depth_low(&self) -> usize {
159 self.queue_depth_low.load(Ordering::Relaxed)
160 }
161
162 pub fn queue_depth_normal(&self) -> usize {
163 self.queue_depth_normal.load(Ordering::Relaxed)
164 }
165
166 pub fn queue_depth_high(&self) -> usize {
167 self.queue_depth_high.load(Ordering::Relaxed)
168 }
169
170 pub fn queue_depth_critical(&self) -> usize {
171 self.queue_depth_critical.load(Ordering::Relaxed)
172 }
173
174 pub fn total_execution_time_ns(&self) -> u64 {
175 self.total_execution_time_ns.load(Ordering::Relaxed)
176 }
177}
178
179struct PriorityTask {
180 task: Task,
181 priority: TaskPriority,
182 sequence: u64,
183}
184
185impl PartialEq for PriorityTask {
186 fn eq(&self, other: &Self) -> bool {
187 self.priority == other.priority && self.sequence == other.sequence
188 }
189}
190
191impl Eq for PriorityTask {}
192
193impl PartialOrd for PriorityTask {
194 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
195 Some(self.cmp(other))
196 }
197}
198
199impl Ord for PriorityTask {
200 fn cmp(&self, other: &Self) -> CmpOrdering {
201 match self.priority.cmp(&other.priority) {
202 CmpOrdering::Equal => other.sequence.cmp(&self.sequence),
203 ord => ord,
204 }
205 }
206}
207
208struct ComputeSharedState {
209 queue: Mutex<BinaryHeap<PriorityTask>>,
210 condvar: Condvar,
211 shutdown: AtomicBool,
212}
213
214pub struct ComputeThreadPool {
215 workers: Vec<JoinHandle<()>>,
216 state: Arc<ComputeSharedState>,
217 sequence: AtomicU64,
218 metrics: Arc<ComputePoolMetrics>,
219}
220
221impl Default for ComputeThreadPool {
222 fn default() -> Self {
223 let default_capacity = std::thread::available_parallelism()
224 .map(|n| n.get())
225 .unwrap_or(DEFAULT_POOL_CAPACITY);
226 Self::new(default_capacity)
227 }
228}
229
230impl ComputeThreadPool {
231 pub fn new(capacity: usize) -> Self {
232 let state = Arc::new(ComputeSharedState {
233 queue: Mutex::new(BinaryHeap::new()),
234 condvar: Condvar::new(),
235 shutdown: AtomicBool::new(false),
236 });
237 let metrics = Arc::new(ComputePoolMetrics::default());
238
239 let mut workers = Vec::with_capacity(capacity);
240 let barrier = Arc::new(Barrier::new(capacity + 1));
242
243 for id in 0..capacity {
244 let state_clone = Arc::clone(&state);
245 let barrier_clone = Arc::clone(&barrier);
246 let metrics_clone = Arc::clone(&metrics);
247 let thread = Builder::new()
248 .name(format!("compute-worker-{id}"))
249 .spawn(move || {
250 barrier_clone.wait();
252
253 loop {
254 let task = {
255 let mut queue = state_clone.queue.lock();
256
257 while queue.is_empty() && !state_clone.shutdown.load(Ordering::Relaxed)
258 {
259 state_clone.condvar.wait(&mut queue);
260 }
261
262 if state_clone.shutdown.load(Ordering::Relaxed) && queue.is_empty() {
263 break;
264 }
265
266 let t = queue.pop();
267 if let Some(ref pt) = t {
268 match pt.priority {
269 TaskPriority::Low => metrics_clone
270 .queue_depth_low
271 .fetch_sub(1, Ordering::Relaxed),
272 TaskPriority::Normal => metrics_clone
273 .queue_depth_normal
274 .fetch_sub(1, Ordering::Relaxed),
275 TaskPriority::High => metrics_clone
276 .queue_depth_high
277 .fetch_sub(1, Ordering::Relaxed),
278 TaskPriority::Critical => metrics_clone
279 .queue_depth_critical
280 .fetch_sub(1, Ordering::Relaxed),
281 };
282 }
283 t
284 };
285
286 if let Some(priority_task) = task {
287 metrics_clone.active_workers.fetch_add(1, Ordering::Relaxed);
288 let start = Instant::now();
289
290 let result = catch_unwind(AssertUnwindSafe(|| (priority_task.task)()));
291
292 let duration = start.elapsed();
293 metrics_clone
294 .total_execution_time_ns
295 .fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
296 metrics_clone.active_workers.fetch_sub(1, Ordering::Relaxed);
297
298 if result.is_ok() {
299 metrics_clone
300 .tasks_completed
301 .fetch_add(1, Ordering::Relaxed);
302 } else {
303 metrics_clone.tasks_failed.fetch_add(1, Ordering::Relaxed);
304 }
305 }
306 }
307 })
308 .expect("Failed to create compute worker thread");
309 workers.push(thread);
310 }
311
312 barrier.wait();
314
315 Self {
316 workers,
317 state,
318 sequence: AtomicU64::new(0),
319 metrics,
320 }
321 }
322
323 pub fn spawn<F>(&self, task: F, priority: TaskPriority)
324 where
325 F: FnOnce() + Send + 'static,
326 {
327 let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
328 let priority_task = PriorityTask {
329 task: Box::new(task),
330 priority,
331 sequence,
332 };
333
334 self.metrics.tasks_submitted.fetch_add(1, Ordering::Relaxed);
335 match priority {
336 TaskPriority::Low => self.metrics.queue_depth_low.fetch_add(1, Ordering::Relaxed),
337 TaskPriority::Normal => self
338 .metrics
339 .queue_depth_normal
340 .fetch_add(1, Ordering::Relaxed),
341 TaskPriority::High => self
342 .metrics
343 .queue_depth_high
344 .fetch_add(1, Ordering::Relaxed),
345 TaskPriority::Critical => self
346 .metrics
347 .queue_depth_critical
348 .fetch_add(1, Ordering::Relaxed),
349 };
350
351 let mut queue = self.state.queue.lock();
352 queue.push(priority_task);
353 self.state.condvar.notify_one();
354 }
355
356 pub fn metrics(&self) -> Arc<ComputePoolMetrics> {
357 self.metrics.clone()
358 }
359}
360
361impl Drop for ComputeThreadPool {
362 fn drop(&mut self) {
363 self.state.shutdown.store(true, Ordering::SeqCst);
364 self.state.condvar.notify_all();
365
366 for worker in self.workers.drain(..) {
367 let _ = worker.join();
368 }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use std::{
375 sync::atomic::{AtomicUsize, Ordering},
376 sync::{Arc, Barrier, Mutex},
377 time::Duration,
378 };
379
380 use super::*;
381
382 #[test]
383 fn test_thread_pool_creation() {
384 let pool = ThreadPool::new(4);
385 assert_eq!(pool.workers_len(), 4);
386 }
387
388 #[test]
389 fn test_task_execution() {
390 let pool = ThreadPool::new(2);
391 let counter = Arc::new(AtomicUsize::new(0));
392 let counter_clone = counter.clone();
393
394 pool.exec(move || {
395 counter_clone.fetch_add(1, Ordering::SeqCst);
396 })
397 .unwrap();
398
399 std::thread::sleep(Duration::from_millis(100));
400 assert_eq!(counter.load(Ordering::SeqCst), 1);
401 }
402 #[test]
403 fn test_multiple_tasks() {
404 let pool = ThreadPool::new(4);
405 let counter = Arc::new(AtomicUsize::new(0));
406
407 for _ in 0..10 {
408 let counter_clone = counter.clone();
409 pool.exec(move || {
410 counter_clone.fetch_add(1, Ordering::SeqCst);
411 })
412 .unwrap();
413 }
414
415 std::thread::sleep(Duration::from_millis(200));
416 assert_eq!(counter.load(Ordering::SeqCst), 10);
417 }
418
419 #[test]
420 fn test_pool_cleanup() {
421 let counter = Arc::new(AtomicUsize::new(0));
422 {
423 let pool = ThreadPool::new(2);
424 let counter_clone = counter.clone();
425
426 pool.exec(move || {
427 std::thread::sleep(Duration::from_millis(50));
428 counter_clone.fetch_add(1, Ordering::SeqCst);
429 })
430 .unwrap();
431 }
432
433 assert_eq!(counter.load(Ordering::SeqCst), 1);
434 }
435
436 #[test]
437 fn test_compute_pool_priority() {
438 let pool = ComputeThreadPool::new(1); let result = Arc::new(Mutex::new(Vec::new()));
440
441 let barrier = Arc::new(Barrier::new(2));
443 let b_clone = barrier.clone();
444
445 let r1 = result.clone();
446 pool.spawn(
447 move || {
448 b_clone.wait(); std::thread::sleep(Duration::from_millis(50)); r1.lock().unwrap().push(1);
451 },
452 TaskPriority::Low,
453 );
454
455 barrier.wait();
457
458 let r2 = result.clone();
460 pool.spawn(
461 move || {
462 r2.lock().unwrap().push(2);
463 },
464 TaskPriority::Low,
465 );
466
467 let r3 = result.clone();
468 pool.spawn(
469 move || {
470 r3.lock().unwrap().push(3);
471 },
472 TaskPriority::High,
473 );
474
475 let r4 = result.clone();
476 pool.spawn(
477 move || {
478 r4.lock().unwrap().push(4);
479 },
480 TaskPriority::Normal,
481 );
482
483 std::thread::sleep(Duration::from_millis(200));
485
486 let res = result.lock().unwrap();
487 assert_eq!(*res, vec![1, 3, 4, 2]);
490 }
491
492 #[test]
493 fn test_compute_pool_metrics() {
494 let pool = ComputeThreadPool::new(2);
495 let metrics = pool.metrics();
496
497 let barrier = Arc::new(Barrier::new(3)); let barrier_clone = barrier.clone();
499
500 pool.spawn(
502 move || {
503 barrier_clone.wait(); },
505 TaskPriority::Normal,
506 );
507
508 let barrier_clone2 = barrier.clone();
509 pool.spawn(
511 move || {
512 barrier_clone2.wait(); },
514 TaskPriority::Normal,
515 );
516
517 std::thread::sleep(Duration::from_millis(50));
519
520 pool.spawn(|| {}, TaskPriority::Low);
522
523 pool.spawn(|| {}, TaskPriority::High);
525
526 assert_eq!(metrics.tasks_submitted(), 4);
528 assert_eq!(metrics.active_workers(), 2);
530 assert_eq!(metrics.queue_depth_low(), 1);
532 assert_eq!(metrics.queue_depth_high(), 1);
533 assert_eq!(metrics.queue_depth_normal(), 0);
535
536 barrier.wait();
537
538 let start = std::time::Instant::now();
540 while metrics.tasks_completed() < 4 {
541 if start.elapsed() > Duration::from_secs(2) {
542 panic!("Timed out waiting for tasks to complete");
543 }
544 std::thread::sleep(Duration::from_millis(10));
545 }
546
547 assert_eq!(metrics.tasks_completed(), 4);
549 assert_eq!(metrics.active_workers(), 0);
550 assert_eq!(metrics.queue_depth_low(), 0);
551 assert_eq!(metrics.queue_depth_high(), 0);
552 assert!(metrics.total_execution_time_ns() > 0);
553 }
554}