1#[cfg(feature = "unstable-mpmc")]
2use std::sync::mpmc as channel;
3#[cfg(not(feature = "unstable-mpmc"))]
4use std::sync::mpsc as channel;
5use std::{
6 cmp::Ordering as CmpOrdering,
7 collections::BinaryHeap,
8 panic::{catch_unwind, AssertUnwindSafe},
9 sync::{
10 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
11 Arc, Barrier, Condvar, Mutex,
12 },
13 thread::{Builder, JoinHandle},
14 time::Instant,
15};
16
17use crate::error::Result;
18
19pub const DEFAULT_POOL_CAPACITY: usize = 4;
20
21pub type Task = Box<dyn FnOnce() + Send + 'static>;
22
23enum WorkerMessage {
24 Task(Task),
25 Terminate,
26}
27
28pub struct ThreadPool {
29 workers: Vec<Worker>,
30 senders: Vec<channel::Sender<WorkerMessage>>,
31 next_worker: AtomicUsize,
32}
33
34impl Default for ThreadPool {
35 fn default() -> Self {
36 let default_capacity = std::thread::available_parallelism()
37 .map(|n| n.get())
38 .unwrap_or(DEFAULT_POOL_CAPACITY);
39 Self::new(default_capacity)
40 }
41}
42
43impl ThreadPool {
44 pub fn new(capacity: usize) -> Self {
45 let mut workers = Vec::with_capacity(capacity);
46 let mut senders = Vec::with_capacity(capacity);
47
48 for id in 0..capacity {
49 let (sender, receiver) = channel::channel::<WorkerMessage>();
50 workers.push(Worker::new(id, receiver));
51 senders.push(sender);
52 }
53
54 Self {
55 workers,
56 senders,
57 next_worker: AtomicUsize::new(0),
58 }
59 }
60
61 pub fn exec<F>(&self, task: F) -> Result<()>
62 where
63 F: FnOnce() + Send + 'static,
64 {
65 let index = self.next_worker.fetch_add(1, Ordering::Relaxed) % self.senders.len();
67 Ok(self.senders[index].send(WorkerMessage::Task(Box::new(task)))?)
68 }
69
70 pub fn workers_len(&self) -> usize {
71 self.workers.len()
72 }
73}
74
75impl Drop for ThreadPool {
76 fn drop(&mut self) {
77 for sender in &self.senders {
78 let _ = sender.send(WorkerMessage::Terminate);
79 }
80 for worker in &mut self.workers {
81 if let Some(t) = worker.take_thread() {
82 t.join().unwrap();
83 }
84 }
85 }
86}
87
88struct Worker {
89 #[allow(dead_code)]
90 id: usize,
91 thread: Option<JoinHandle<()>>,
92}
93
94impl Worker {
95 pub fn new(id: usize, receiver: channel::Receiver<WorkerMessage>) -> Self {
96 let thread = Some(
97 Builder::new()
98 .name(format!("thread-pool-worker-{id}"))
99 .spawn(move || {
100 while let Ok(message) = receiver.recv() {
101 match message {
102 WorkerMessage::Task(task) => task(),
103 WorkerMessage::Terminate => break,
104 }
105 }
106 })
107 .expect("Couldn't create the worker thread id={id}"),
108 );
109
110 Self { id, thread }
111 }
112
113 pub fn take_thread(&mut self) -> Option<JoinHandle<()>> {
114 self.thread.take()
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
119pub enum TaskPriority {
120 Low = 0,
121 Normal = 1,
122 High = 2,
123 Critical = 3,
124}
125
126#[derive(Debug, Default)]
127pub struct ComputePoolMetrics {
128 pub tasks_submitted: AtomicU64,
129 pub tasks_completed: AtomicU64,
130 pub tasks_failed: AtomicU64,
131 pub active_workers: AtomicUsize,
132 pub queue_depth_low: AtomicUsize,
133 pub queue_depth_normal: AtomicUsize,
134 pub queue_depth_high: AtomicUsize,
135 pub queue_depth_critical: AtomicUsize,
136 pub total_execution_time_ns: AtomicU64,
137}
138
139impl ComputePoolMetrics {
140 pub fn tasks_submitted(&self) -> u64 {
141 self.tasks_submitted.load(Ordering::Relaxed)
142 }
143
144 pub fn tasks_completed(&self) -> u64 {
145 self.tasks_completed.load(Ordering::Relaxed)
146 }
147
148 pub fn tasks_failed(&self) -> u64 {
149 self.tasks_failed.load(Ordering::Relaxed)
150 }
151
152 pub fn active_workers(&self) -> usize {
153 self.active_workers.load(Ordering::Relaxed)
154 }
155
156 pub fn queue_depth_low(&self) -> usize {
157 self.queue_depth_low.load(Ordering::Relaxed)
158 }
159
160 pub fn queue_depth_normal(&self) -> usize {
161 self.queue_depth_normal.load(Ordering::Relaxed)
162 }
163
164 pub fn queue_depth_high(&self) -> usize {
165 self.queue_depth_high.load(Ordering::Relaxed)
166 }
167
168 pub fn queue_depth_critical(&self) -> usize {
169 self.queue_depth_critical.load(Ordering::Relaxed)
170 }
171
172 pub fn total_execution_time_ns(&self) -> u64 {
173 self.total_execution_time_ns.load(Ordering::Relaxed)
174 }
175}
176
177struct PriorityTask {
178 task: Task,
179 priority: TaskPriority,
180 sequence: u64,
181}
182
183impl PartialEq for PriorityTask {
184 fn eq(&self, other: &Self) -> bool {
185 self.priority == other.priority && self.sequence == other.sequence
186 }
187}
188
189impl Eq for PriorityTask {}
190
191impl PartialOrd for PriorityTask {
192 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
193 Some(self.cmp(other))
194 }
195}
196
197impl Ord for PriorityTask {
198 fn cmp(&self, other: &Self) -> CmpOrdering {
199 match self.priority.cmp(&other.priority) {
200 CmpOrdering::Equal => other.sequence.cmp(&self.sequence),
201 ord => ord,
202 }
203 }
204}
205
206struct ComputeSharedState {
207 queue: Mutex<BinaryHeap<PriorityTask>>,
208 condvar: Condvar,
209 shutdown: AtomicBool,
210}
211
212pub struct ComputeThreadPool {
213 workers: Vec<JoinHandle<()>>,
214 state: Arc<ComputeSharedState>,
215 sequence: AtomicU64,
216 metrics: Arc<ComputePoolMetrics>,
217}
218
219impl Default for ComputeThreadPool {
220 fn default() -> Self {
221 let default_capacity = std::thread::available_parallelism()
222 .map(|n| n.get())
223 .unwrap_or(DEFAULT_POOL_CAPACITY);
224 Self::new(default_capacity)
225 }
226}
227
228impl ComputeThreadPool {
229 pub fn new(capacity: usize) -> Self {
230 let state = Arc::new(ComputeSharedState {
231 queue: Mutex::new(BinaryHeap::new()),
232 condvar: Condvar::new(),
233 shutdown: AtomicBool::new(false),
234 });
235 let metrics = Arc::new(ComputePoolMetrics::default());
236
237 let mut workers = Vec::with_capacity(capacity);
238 let barrier = Arc::new(Barrier::new(capacity + 1));
240
241 for id in 0..capacity {
242 let state_clone = Arc::clone(&state);
243 let barrier_clone = Arc::clone(&barrier);
244 let metrics_clone = Arc::clone(&metrics);
245 let thread = Builder::new()
246 .name(format!("compute-worker-{id}"))
247 .spawn(move || {
248 barrier_clone.wait();
250
251 loop {
252 let task = {
253 let mut queue = state_clone.queue.lock().unwrap();
254
255 while queue.is_empty() && !state_clone.shutdown.load(Ordering::Relaxed)
256 {
257 queue = state_clone.condvar.wait(queue).unwrap();
258 }
259
260 if state_clone.shutdown.load(Ordering::Relaxed) && queue.is_empty() {
261 break;
262 }
263
264 let t = queue.pop();
265 if let Some(ref pt) = t {
266 match pt.priority {
267 TaskPriority::Low => metrics_clone
268 .queue_depth_low
269 .fetch_sub(1, Ordering::Relaxed),
270 TaskPriority::Normal => metrics_clone
271 .queue_depth_normal
272 .fetch_sub(1, Ordering::Relaxed),
273 TaskPriority::High => metrics_clone
274 .queue_depth_high
275 .fetch_sub(1, Ordering::Relaxed),
276 TaskPriority::Critical => metrics_clone
277 .queue_depth_critical
278 .fetch_sub(1, Ordering::Relaxed),
279 };
280 }
281 t
282 };
283
284 if let Some(priority_task) = task {
285 metrics_clone.active_workers.fetch_add(1, Ordering::Relaxed);
286 let start = Instant::now();
287
288 let result = catch_unwind(AssertUnwindSafe(|| (priority_task.task)()));
289
290 let duration = start.elapsed();
291 metrics_clone
292 .total_execution_time_ns
293 .fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
294 metrics_clone.active_workers.fetch_sub(1, Ordering::Relaxed);
295
296 if result.is_ok() {
297 metrics_clone
298 .tasks_completed
299 .fetch_add(1, Ordering::Relaxed);
300 } else {
301 metrics_clone.tasks_failed.fetch_add(1, Ordering::Relaxed);
302 }
303 }
304 }
305 })
306 .expect("Failed to create compute worker thread");
307 workers.push(thread);
308 }
309
310 barrier.wait();
312
313 Self {
314 workers,
315 state,
316 sequence: AtomicU64::new(0),
317 metrics,
318 }
319 }
320
321 pub fn spawn<F>(&self, task: F, priority: TaskPriority)
322 where
323 F: FnOnce() + Send + 'static,
324 {
325 let sequence = self.sequence.fetch_add(1, Ordering::Relaxed);
326 let priority_task = PriorityTask {
327 task: Box::new(task),
328 priority,
329 sequence,
330 };
331
332 self.metrics.tasks_submitted.fetch_add(1, Ordering::Relaxed);
333 match priority {
334 TaskPriority::Low => self.metrics.queue_depth_low.fetch_add(1, Ordering::Relaxed),
335 TaskPriority::Normal => self
336 .metrics
337 .queue_depth_normal
338 .fetch_add(1, Ordering::Relaxed),
339 TaskPriority::High => self
340 .metrics
341 .queue_depth_high
342 .fetch_add(1, Ordering::Relaxed),
343 TaskPriority::Critical => self
344 .metrics
345 .queue_depth_critical
346 .fetch_add(1, Ordering::Relaxed),
347 };
348
349 let mut queue = self.state.queue.lock().unwrap();
350 queue.push(priority_task);
351 self.state.condvar.notify_one();
352 }
353
354 pub fn metrics(&self) -> Arc<ComputePoolMetrics> {
355 self.metrics.clone()
356 }
357}
358
359impl Drop for ComputeThreadPool {
360 fn drop(&mut self) {
361 self.state.shutdown.store(true, Ordering::SeqCst);
362 self.state.condvar.notify_all();
363
364 for worker in self.workers.drain(..) {
365 let _ = worker.join();
366 }
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use std::{
373 sync::atomic::{AtomicUsize, Ordering},
374 sync::{Arc, Barrier, Mutex},
375 time::Duration,
376 };
377
378 use super::*;
379
380 #[test]
381 fn test_thread_pool_creation() {
382 let pool = ThreadPool::new(4);
383 assert_eq!(pool.workers_len(), 4);
384 }
385
386 #[test]
387 fn test_task_execution() {
388 let pool = ThreadPool::new(2);
389 let counter = Arc::new(AtomicUsize::new(0));
390 let counter_clone = counter.clone();
391
392 pool.exec(move || {
393 counter_clone.fetch_add(1, Ordering::SeqCst);
394 })
395 .unwrap();
396
397 std::thread::sleep(Duration::from_millis(100));
398 assert_eq!(counter.load(Ordering::SeqCst), 1);
399 }
400 #[test]
401 fn test_multiple_tasks() {
402 let pool = ThreadPool::new(4);
403 let counter = Arc::new(AtomicUsize::new(0));
404
405 for _ in 0..10 {
406 let counter_clone = counter.clone();
407 pool.exec(move || {
408 counter_clone.fetch_add(1, Ordering::SeqCst);
409 })
410 .unwrap();
411 }
412
413 std::thread::sleep(Duration::from_millis(200));
414 assert_eq!(counter.load(Ordering::SeqCst), 10);
415 }
416
417 #[test]
418 fn test_pool_cleanup() {
419 let counter = Arc::new(AtomicUsize::new(0));
420 {
421 let pool = ThreadPool::new(2);
422 let counter_clone = counter.clone();
423
424 pool.exec(move || {
425 std::thread::sleep(Duration::from_millis(50));
426 counter_clone.fetch_add(1, Ordering::SeqCst);
427 })
428 .unwrap();
429 }
430
431 assert_eq!(counter.load(Ordering::SeqCst), 1);
432 }
433
434 #[test]
435 fn test_compute_pool_priority() {
436 let pool = ComputeThreadPool::new(1); let result = Arc::new(Mutex::new(Vec::new()));
438
439 let barrier = Arc::new(Barrier::new(2));
441 let b_clone = barrier.clone();
442
443 let r1 = result.clone();
444 pool.spawn(
445 move || {
446 b_clone.wait(); std::thread::sleep(Duration::from_millis(50)); r1.lock().unwrap().push(1);
449 },
450 TaskPriority::Low,
451 );
452
453 barrier.wait();
455
456 let r2 = result.clone();
458 pool.spawn(
459 move || {
460 r2.lock().unwrap().push(2);
461 },
462 TaskPriority::Low,
463 );
464
465 let r3 = result.clone();
466 pool.spawn(
467 move || {
468 r3.lock().unwrap().push(3);
469 },
470 TaskPriority::High,
471 );
472
473 let r4 = result.clone();
474 pool.spawn(
475 move || {
476 r4.lock().unwrap().push(4);
477 },
478 TaskPriority::Normal,
479 );
480
481 std::thread::sleep(Duration::from_millis(200));
483
484 let res = result.lock().unwrap();
485 assert_eq!(*res, vec![1, 3, 4, 2]);
488 }
489
490 #[test]
491 fn test_compute_pool_metrics() {
492 let pool = ComputeThreadPool::new(2);
493 let metrics = pool.metrics();
494
495 let barrier = Arc::new(Barrier::new(3)); let barrier_clone = barrier.clone();
497
498 pool.spawn(
500 move || {
501 barrier_clone.wait(); },
503 TaskPriority::Normal,
504 );
505
506 let barrier_clone2 = barrier.clone();
507 pool.spawn(
509 move || {
510 barrier_clone2.wait(); },
512 TaskPriority::Normal,
513 );
514
515 std::thread::sleep(Duration::from_millis(50));
517
518 pool.spawn(|| {}, TaskPriority::Low);
520
521 pool.spawn(|| {}, TaskPriority::High);
523
524 assert_eq!(metrics.tasks_submitted(), 4);
526 assert_eq!(metrics.active_workers(), 2);
528 assert_eq!(metrics.queue_depth_low(), 1);
530 assert_eq!(metrics.queue_depth_high(), 1);
531 assert_eq!(metrics.queue_depth_normal(), 0);
533
534 barrier.wait();
535
536 let start = std::time::Instant::now();
538 while metrics.tasks_completed() < 4 {
539 if start.elapsed() > Duration::from_secs(2) {
540 panic!("Timed out waiting for tasks to complete");
541 }
542 std::thread::sleep(Duration::from_millis(10));
543 }
544
545 assert_eq!(metrics.tasks_completed(), 4);
547 assert_eq!(metrics.active_workers(), 0);
548 assert_eq!(metrics.queue_depth_low(), 0);
549 assert_eq!(metrics.queue_depth_high(), 0);
550 assert!(metrics.total_execution_time_ns() > 0);
551 }
552}