dataforge/multithreading/
thread_pool.rs1use std::sync::{mpsc, Arc, Mutex};
4use std::thread;
5use crate::error::{DataForgeError, Result};
6
7type Job = Box<dyn FnOnce() + Send + 'static>;
8
9pub struct ThreadPool {
11 workers: Vec<Worker>,
12 sender: Option<mpsc::Sender<Job>>,
13}
14
15impl ThreadPool {
16 pub fn new(size: usize) -> Result<ThreadPool> {
18 if size == 0 {
19 return Err(DataForgeError::validation("Thread pool size must be greater than 0"));
20 }
21
22 let (sender, receiver) = mpsc::channel();
23 let receiver = Arc::new(Mutex::new(receiver));
24 let mut workers = Vec::with_capacity(size);
25
26 for id in 0..size {
27 workers.push(Worker::new(id, Arc::clone(&receiver))?);
28 }
29
30 Ok(ThreadPool {
31 workers,
32 sender: Some(sender),
33 })
34 }
35
36 pub fn execute<F>(&self, f: F) -> Result<()>
38 where
39 F: FnOnce() + Send + 'static,
40 {
41 let job = Box::new(f);
42
43 if let Some(sender) = &self.sender {
44 sender.send(job)
45 .map_err(|_| DataForgeError::generator("Failed to send job to thread pool"))?;
46 } else {
47 return Err(DataForgeError::generator("Thread pool has been shut down"));
48 }
49
50 Ok(())
51 }
52
53 pub fn size(&self) -> usize {
55 self.workers.len()
56 }
57}
58
59impl Drop for ThreadPool {
60 fn drop(&mut self) {
61 drop(self.sender.take());
62
63 for worker in &mut self.workers {
64 if let Some(thread) = worker.thread.take() {
65 thread.join().unwrap();
66 }
67 }
68 }
69}
70
71struct Worker {
73 #[allow(dead_code)]
74 id: usize,
75 thread: Option<thread::JoinHandle<()>>,
76}
77
78impl Worker {
79 fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Result<Worker> {
80 let thread = thread::Builder::new()
81 .name(format!("dataforge-worker-{}", id))
82 .spawn(move || loop {
83 let message = receiver.lock().unwrap().recv();
84
85 match message {
86 Ok(job) => {
87 job();
88 }
89 Err(_) => {
90 break;
91 }
92 }
93 })
94 .map_err(|e| DataForgeError::generator(&format!("Failed to spawn worker thread: {}", e)))?;
95
96 Ok(Worker {
97 id,
98 thread: Some(thread),
99 })
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct ThreadPoolStats {
106 pub active_threads: usize,
108 pub total_threads: usize,
110 pub queued_jobs: usize,
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use std::sync::atomic::{AtomicUsize, Ordering};
118 use std::time::Duration;
119
120 #[test]
121 fn test_thread_pool_creation() {
122 let pool = ThreadPool::new(4);
123 assert!(pool.is_ok());
124 assert_eq!(pool.unwrap().size(), 4);
125 }
126
127 #[test]
128 fn test_thread_pool_zero_size() {
129 let pool = ThreadPool::new(0);
130 assert!(pool.is_err());
131 }
132
133 #[test]
134 fn test_thread_pool_execution() {
135 let pool = ThreadPool::new(2).unwrap();
136 let counter = Arc::new(AtomicUsize::new(0));
137
138 for _ in 0..10 {
139 let counter = Arc::clone(&counter);
140 pool.execute(move || {
141 counter.fetch_add(1, Ordering::SeqCst);
142 }).unwrap();
143 }
144
145 thread::sleep(Duration::from_millis(100));
147 assert_eq!(counter.load(Ordering::SeqCst), 10);
148 }
149}