dataforge/multithreading/
thread_pool.rs

1//! 线程池管理模块
2
3use std::sync::{mpsc, Arc, Mutex};
4use std::thread;
5use crate::error::{DataForgeError, Result};
6
7type Job = Box<dyn FnOnce() + Send + 'static>;
8
9/// 自定义线程池
10pub struct ThreadPool {
11    workers: Vec<Worker>,
12    sender: Option<mpsc::Sender<Job>>,
13}
14
15impl ThreadPool {
16    /// 创建新的线程池
17    pub fn new(size: usize) -> Result<ThreadPool> {
18        if size == 0 {
19            return Err(DataForgeError::validation("Thread pool size must be greater than 0"));
20        }
21
22        let (sender, receiver) = mpsc::channel();
23        let receiver = Arc::new(Mutex::new(receiver));
24        let mut workers = Vec::with_capacity(size);
25
26        for id in 0..size {
27            workers.push(Worker::new(id, Arc::clone(&receiver))?);
28        }
29
30        Ok(ThreadPool {
31            workers,
32            sender: Some(sender),
33        })
34    }
35
36    /// 执行任务
37    pub fn execute<F>(&self, f: F) -> Result<()>
38    where
39        F: FnOnce() + Send + 'static,
40    {
41        let job = Box::new(f);
42        
43        if let Some(sender) = &self.sender {
44            sender.send(job)
45                .map_err(|_| DataForgeError::generator("Failed to send job to thread pool"))?;
46        } else {
47            return Err(DataForgeError::generator("Thread pool has been shut down"));
48        }
49
50        Ok(())
51    }
52
53    /// 获取工作线程数量
54    pub fn size(&self) -> usize {
55        self.workers.len()
56    }
57}
58
59impl Drop for ThreadPool {
60    fn drop(&mut self) {
61        drop(self.sender.take());
62
63        for worker in &mut self.workers {
64            if let Some(thread) = worker.thread.take() {
65                thread.join().unwrap();
66            }
67        }
68    }
69}
70
71/// 工作线程
72struct Worker {
73    #[allow(dead_code)]
74    id: usize,
75    thread: Option<thread::JoinHandle<()>>,
76}
77
78impl Worker {
79    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Result<Worker> {
80        let thread = thread::Builder::new()
81            .name(format!("dataforge-worker-{}", id))
82            .spawn(move || loop {
83                let message = receiver.lock().unwrap().recv();
84
85                match message {
86                    Ok(job) => {
87                        job();
88                    }
89                    Err(_) => {
90                        break;
91                    }
92                }
93            })
94            .map_err(|e| DataForgeError::generator(&format!("Failed to spawn worker thread: {}", e)))?;
95
96        Ok(Worker {
97            id,
98            thread: Some(thread),
99        })
100    }
101}
102
103/// 线程池统计信息
104#[derive(Debug, Clone)]
105pub struct ThreadPoolStats {
106    /// 活跃线程数
107    pub active_threads: usize,
108    /// 总线程数
109    pub total_threads: usize,
110    /// 队列中的任务数
111    pub queued_jobs: usize,
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use std::sync::atomic::{AtomicUsize, Ordering};
118    use std::time::Duration;
119
120    #[test]
121    fn test_thread_pool_creation() {
122        let pool = ThreadPool::new(4);
123        assert!(pool.is_ok());
124        assert_eq!(pool.unwrap().size(), 4);
125    }
126
127    #[test]
128    fn test_thread_pool_zero_size() {
129        let pool = ThreadPool::new(0);
130        assert!(pool.is_err());
131    }
132
133    #[test]
134    fn test_thread_pool_execution() {
135        let pool = ThreadPool::new(2).unwrap();
136        let counter = Arc::new(AtomicUsize::new(0));
137        
138        for _ in 0..10 {
139            let counter = Arc::clone(&counter);
140            pool.execute(move || {
141                counter.fetch_add(1, Ordering::SeqCst);
142            }).unwrap();
143        }
144
145        // 等待任务完成
146        thread::sleep(Duration::from_millis(100));
147        assert_eq!(counter.load(Ordering::SeqCst), 10);
148    }
149}