dataforge/multithreading/
mod.rs1use rayon::prelude::*;
6use std::sync::{Arc, Mutex};
7use std::collections::HashMap;
8use serde_json::Value;
9use crate::error::{DataForgeError, Result};
10
11pub mod thread_pool;
12pub mod safety;
13
14pub use thread_pool::*;
15pub use safety::*;
16
17pub struct ParallelGenerator {
19 thread_count: usize,
21 batch_size: usize,
23 generators: HashMap<String, Arc<dyn Fn() -> Value + Send + Sync>>,
25}
26
27impl ParallelGenerator {
28 pub fn new(thread_count: usize, batch_size: usize) -> Self {
30 Self {
31 thread_count,
32 batch_size,
33 generators: HashMap::new(),
34 }
35 }
36
37 pub fn register<F>(&mut self, name: String, generator: F)
39 where
40 F: Fn() -> Value + Send + Sync + 'static,
41 {
42 self.generators.insert(name, Arc::new(generator));
43 }
44
45 pub fn generate_parallel(&self, generator_name: &str, count: usize) -> Result<Vec<Value>> {
47 let generator = self.generators.get(generator_name)
48 .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
49
50 let pool = rayon::ThreadPoolBuilder::new()
52 .num_threads(self.thread_count)
53 .build()
54 .map_err(|e| DataForgeError::generator(&format!("Failed to create thread pool: {}", e)))?;
55
56 let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
57 let generator = Arc::clone(generator);
58
59 pool.install(|| {
60 (0..count).into_par_iter().for_each(|_| {
61 let value = generator();
62 results.lock().unwrap().push(value);
63 });
64 });
65
66 let results = Arc::try_unwrap(results)
67 .map_err(|_| DataForgeError::generator("Failed to unwrap results"))?
68 .into_inner()
69 .map_err(|_| DataForgeError::generator("Failed to acquire mutex"))?;
70
71 Ok(results)
72 }
73
74 pub fn generate_batched(&self, generator_name: &str, total_count: usize) -> Result<Vec<Value>> {
76 let generator = self.generators.get(generator_name)
77 .ok_or_else(|| DataForgeError::validation(&format!("Generator '{}' not found", generator_name)))?;
78
79 let batches: Vec<usize> = (0..total_count)
80 .step_by(self.batch_size)
81 .map(|start| {
82 let end = (start + self.batch_size).min(total_count);
83 end - start
84 })
85 .collect();
86
87 let generator = Arc::clone(generator);
88 let results: Result<Vec<Vec<Value>>> = batches
89 .into_par_iter()
90 .map(|batch_size| {
91 let mut batch_results = Vec::with_capacity(batch_size);
92 for _ in 0..batch_size {
93 batch_results.push(generator());
94 }
95 Ok(batch_results)
96 })
97 .collect();
98
99 let results = results?;
100 Ok(results.into_iter().flatten().collect())
101 }
102
103 pub fn thread_count(&self) -> usize {
105 self.thread_count
106 }
107
108 pub fn batch_size(&self) -> usize {
110 self.batch_size
111 }
112}
113
114impl Default for ParallelGenerator {
115 fn default() -> Self {
116 Self::new(num_cpus::get(), 1000)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_parallel_generator_creation() {
126 let generator = ParallelGenerator::new(4, 100);
127 assert_eq!(generator.thread_count(), 4);
128 assert_eq!(generator.batch_size(), 100);
129 }
130
131 #[test]
132 fn test_register_generator() {
133 let mut generator = ParallelGenerator::new(2, 50);
134 generator.register("test".to_string(), || Value::String("test".to_string()));
135
136 let results = generator.generate_parallel("test", 10);
137 assert!(results.is_ok());
138 assert_eq!(results.unwrap().len(), 10);
139 }
140
141 #[test]
142 fn test_batched_generation() {
143 let mut generator = ParallelGenerator::new(2, 3);
144 generator.register("counter".to_string(), || {
145 Value::Number(serde_json::Number::from(42))
146 });
147
148 let results = generator.generate_batched("counter", 10);
149 assert!(results.is_ok());
150 assert_eq!(results.unwrap().len(), 10);
151 }
152}