dataforge/multithreading/
mod.rs

1//! 多线程处理模块
2//! 
3//! 提供多线程并行数据生成的支持,提升数据生成的速度和效率
4
5use rayon::prelude::*;
6use std::sync::{Arc, Mutex};
7use std::collections::HashMap;
8use serde_json::Value;
9use crate::error::{DataForgeError, Result};
10
11pub mod thread_pool;
12pub mod safety;
13
14pub use thread_pool::*;
15pub use safety::*;
16
17/// 并行数据生成器
18pub struct ParallelGenerator {
19    /// 线程池大小
20    thread_count: usize,
21    /// 批次大小
22    batch_size: usize,
23    /// 生成器函数
24    generators: HashMap<String, Arc<dyn Fn() -> Value + Send + Sync>>,
25}
26
27impl ParallelGenerator {
28    /// 创建新的并行生成器
29    pub fn new(thread_count: usize, batch_size: usize) -> Self {
30        Self {
31            thread_count,
32            batch_size,
33            generators: HashMap::new(),
34        }
35    }
36
37    /// 注册生成器
38    pub fn register<F>(&mut self, name: String, generator: F)
39    where
40        F: Fn() -> Value + Send + Sync + 'static,
41    {
42        self.generators.insert(name, Arc::new(generator));
43    }
44
45    /// 并行生成数据
46    pub fn generate_parallel(&self, generator_name: &str, count: usize) -> Result<Vec<Value>> {
47        let generator = self.generators.get(generator_name)
48            .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
49
50        // 配置rayon线程池
51        let pool = rayon::ThreadPoolBuilder::new()
52            .num_threads(self.thread_count)
53            .build()
54            .map_err(|e| DataForgeError::generator(&format!("Failed to create thread pool: {}", e)))?;
55
56        let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
57        let generator = Arc::clone(generator);
58
59        pool.install(|| {
60            (0..count).into_par_iter().for_each(|_| {
61                let value = generator();
62                results.lock().unwrap().push(value);
63            });
64        });
65
66        let results = Arc::try_unwrap(results)
67            .map_err(|_| DataForgeError::generator("Failed to unwrap results"))?
68            .into_inner()
69            .map_err(|_| DataForgeError::generator("Failed to acquire mutex"))?;
70
71        Ok(results)
72    }
73
74    /// 批量并行生成数据
75    pub fn generate_batched(&self, generator_name: &str, total_count: usize) -> Result<Vec<Value>> {
76        let generator = self.generators.get(generator_name)
77            .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
78
79        let batches: Vec<usize> = (0..total_count)
80            .step_by(self.batch_size)
81            .map(|start| {
82                let end = (start + self.batch_size).min(total_count);
83                end - start
84            })
85            .collect();
86
87        let generator = Arc::clone(generator);
88        let results: Result<Vec<Vec<Value>>> = batches
89            .into_par_iter()
90            .map(|batch_size| {
91                let mut batch_results = Vec::with_capacity(batch_size);
92                for _ in 0..batch_size {
93                    batch_results.push(generator());
94                }
95                Ok(batch_results)
96            })
97            .collect();
98
99        let results = results?;
100        Ok(results.into_iter().flatten().collect())
101    }
102
103    /// 获取线程数
104    pub fn thread_count(&self) -> usize {
105        self.thread_count
106    }
107
108    /// 获取批次大小
109    pub fn batch_size(&self) -> usize {
110        self.batch_size
111    }
112}
113
114impl Default for ParallelGenerator {
115    fn default() -> Self {
116        Self::new(num_cpus::get(), 1000)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_parallel_generator_creation() {
126        let generator = ParallelGenerator::new(4, 100);
127        assert_eq!(generator.thread_count(), 4);
128        assert_eq!(generator.batch_size(), 100);
129    }
130
131    #[test]
132    fn test_register_generator() {
133        let mut generator = ParallelGenerator::new(2, 50);
134        generator.register("test".to_string(), || Value::String("test".to_string()));
135        
136        let results = generator.generate_parallel("test", 10);
137        assert!(results.is_ok());
138        assert_eq!(results.unwrap().len(), 10);
139    }
140
141    #[test]
142    fn test_batched_generation() {
143        let mut generator = ParallelGenerator::new(2, 3);
144        generator.register("counter".to_string(), || {
145            Value::Number(serde_json::Number::from(42))
146        });
147        
148        let results = generator.generate_batched("counter", 10);
149        assert!(results.is_ok());
150        assert_eq!(results.unwrap().len(), 10);
151    }
152}