rustywallet_batch/
generator.rs1use crate::config::BatchConfig;
7use crate::error::BatchError;
8use crate::fast_gen::FastKeyGenerator;
9use crate::stream::KeyStream;
10use rayon::prelude::*;
11use rustywallet_keys::private_key::PrivateKey;
12
13#[derive(Debug, Clone)]
47pub struct BatchGenerator {
48 config: BatchConfig,
49}
50
51impl Default for BatchGenerator {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl BatchGenerator {
58 pub fn new() -> Self {
60 Self {
61 config: BatchConfig::default(),
62 }
63 }
64
65 pub fn with_config(config: BatchConfig) -> Self {
67 Self { config }
68 }
69
70 pub fn count(mut self, count: usize) -> Self {
72 self.config.batch_size = count;
73 self
74 }
75
76 pub fn parallel(mut self) -> Self {
78 self.config.parallel = true;
79 self
80 }
81
82 pub fn threads(mut self, count: usize) -> Self {
84 self.config.thread_count = Some(count);
85 self.config.parallel = true;
86 self
87 }
88
89 pub fn simd(mut self) -> Self {
91 self.config.use_simd = true;
92 self
93 }
94
95 pub fn chunk_size(mut self, size: usize) -> Self {
97 self.config.chunk_size = size;
98 self
99 }
100
101 pub fn deterministic(mut self) -> Self {
103 self.config.deterministic_order = true;
104 self
105 }
106
107 pub fn generate(self) -> Result<KeyStream, BatchError> {
112 self.config.validate()?;
113
114 let count = self.config.batch_size;
115 let parallel = self.config.parallel;
116
117 if parallel {
118 self.generate_parallel_stream(count)
119 } else {
120 self.generate_sequential_stream(count)
121 }
122 }
123
124 pub fn generate_vec(self) -> Result<Vec<PrivateKey>, BatchError> {
129 self.config.validate()?;
130
131 let count = self.config.batch_size;
132 let parallel = self.config.parallel;
133
134 let keys = FastKeyGenerator::new(count)
136 .parallel(parallel)
137 .chunk_size(self.config.chunk_size)
138 .generate();
139
140 Ok(keys)
141 }
142
143 fn generate_sequential_stream(self, count: usize) -> Result<KeyStream, BatchError> {
145 let iter = (0..count).map(|_| Ok(PrivateKey::random()));
146 Ok(KeyStream::new(iter, Some(count)))
147 }
148
149 #[allow(dead_code)]
151 fn generate_sequential_vec(&self, count: usize) -> Result<Vec<PrivateKey>, BatchError> {
152 let keys: Vec<PrivateKey> = (0..count).map(|_| PrivateKey::random()).collect();
153 Ok(keys)
154 }
155
156 fn generate_parallel_stream(self, count: usize) -> Result<KeyStream, BatchError> {
158 let chunk_size = self.config.chunk_size;
160 let deterministic = self.config.deterministic_order;
161
162 let iter = ParallelChunkIterator::new(count, chunk_size, deterministic);
163 Ok(KeyStream::new(iter, Some(count)))
164 }
165
166 #[allow(dead_code)]
168 fn generate_parallel_vec(&self, count: usize) -> Result<Vec<PrivateKey>, BatchError> {
169 let keys: Vec<PrivateKey> = if self.config.deterministic_order {
170 (0..count)
172 .into_par_iter()
173 .map(|_| generate_single_key())
174 .collect()
175 } else {
176 (0..count)
178 .into_par_iter()
179 .map(|_| generate_single_key())
180 .collect()
181 };
182
183 Ok(keys)
184 }
185}
186
187fn generate_single_key() -> PrivateKey {
189 PrivateKey::random()
190}
191
192struct ParallelChunkIterator {
194 remaining: usize,
195 chunk_size: usize,
196 current_chunk: std::vec::IntoIter<PrivateKey>,
197 deterministic: bool,
198}
199
200impl ParallelChunkIterator {
201 fn new(total: usize, chunk_size: usize, deterministic: bool) -> Self {
202 Self {
203 remaining: total,
204 chunk_size,
205 current_chunk: Vec::new().into_iter(),
206 deterministic,
207 }
208 }
209
210 fn generate_chunk(&mut self) -> Vec<PrivateKey> {
211 let chunk_count = self.remaining.min(self.chunk_size);
212 self.remaining -= chunk_count;
213
214 if self.deterministic {
215 (0..chunk_count)
216 .into_par_iter()
217 .map(|_| generate_single_key())
218 .collect()
219 } else {
220 (0..chunk_count)
221 .into_par_iter()
222 .map(|_| generate_single_key())
223 .collect()
224 }
225 }
226}
227
228impl Iterator for ParallelChunkIterator {
229 type Item = Result<PrivateKey, BatchError>;
230
231 fn next(&mut self) -> Option<Self::Item> {
232 if let Some(key) = self.current_chunk.next() {
234 return Some(Ok(key));
235 }
236
237 if self.remaining > 0 {
239 let chunk = self.generate_chunk();
240 self.current_chunk = chunk.into_iter();
241 self.current_chunk.next().map(Ok)
242 } else {
243 None
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_generate_sequential() {
254 let keys = BatchGenerator::new()
255 .count(100)
256 .generate_vec()
257 .unwrap();
258
259 assert_eq!(keys.len(), 100);
260 }
261
262 #[test]
263 fn test_generate_parallel() {
264 let keys = BatchGenerator::new()
265 .count(1000)
266 .parallel()
267 .generate_vec()
268 .unwrap();
269
270 assert_eq!(keys.len(), 1000);
271 }
272
273 #[test]
274 fn test_generate_stream() {
275 let stream = BatchGenerator::new()
276 .count(100)
277 .generate()
278 .unwrap();
279
280 let keys: Vec<_> = stream.collect();
281 assert_eq!(keys.len(), 100);
282 assert!(keys.iter().all(|r| r.is_ok()));
283 }
284
285 #[test]
286 fn test_generate_parallel_stream() {
287 let stream = BatchGenerator::new()
288 .count(1000)
289 .parallel()
290 .chunk_size(100)
291 .generate()
292 .unwrap();
293
294 let keys: Vec<_> = stream.collect();
295 assert_eq!(keys.len(), 1000);
296 assert!(keys.iter().all(|r| r.is_ok()));
297 }
298
299 #[test]
300 fn test_keys_are_unique() {
301 let keys = BatchGenerator::new()
302 .count(1000)
303 .parallel()
304 .generate_vec()
305 .unwrap();
306
307 let hex_keys: std::collections::HashSet<_> = keys.iter().map(|k| k.to_hex()).collect();
309 assert_eq!(hex_keys.len(), keys.len(), "All keys should be unique");
310 }
311
312 #[test]
313 fn test_with_config() {
314 let config = BatchConfig::fast();
315 let generator = BatchGenerator::with_config(config);
316
317 let keys = generator.count(500).generate_vec().unwrap();
318 assert_eq!(keys.len(), 500);
319 }
320
321 #[test]
322 fn test_deterministic_mode() {
323 let keys = BatchGenerator::new()
324 .count(100)
325 .parallel()
326 .deterministic()
327 .generate_vec()
328 .unwrap();
329
330 assert_eq!(keys.len(), 100);
331 }
332}