use crate::config::BatchConfig;
use crate::error::BatchError;
use crate::fast_gen::FastKeyGenerator;
use crate::stream::KeyStream;
use rayon::prelude::*;
use rustywallet_keys::private_key::PrivateKey;
#[derive(Debug, Clone)]
pub struct BatchGenerator {
config: BatchConfig,
}
impl Default for BatchGenerator {
fn default() -> Self {
Self::new()
}
}
impl BatchGenerator {
pub fn new() -> Self {
Self {
config: BatchConfig::default(),
}
}
pub fn with_config(config: BatchConfig) -> Self {
Self { config }
}
pub fn count(mut self, count: usize) -> Self {
self.config.batch_size = count;
self
}
pub fn parallel(mut self) -> Self {
self.config.parallel = true;
self
}
pub fn threads(mut self, count: usize) -> Self {
self.config.thread_count = Some(count);
self.config.parallel = true;
self
}
pub fn simd(mut self) -> Self {
self.config.use_simd = true;
self
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.config.chunk_size = size;
self
}
pub fn deterministic(mut self) -> Self {
self.config.deterministic_order = true;
self
}
pub fn generate(self) -> Result<KeyStream, BatchError> {
self.config.validate()?;
let count = self.config.batch_size;
let parallel = self.config.parallel;
if parallel {
self.generate_parallel_stream(count)
} else {
self.generate_sequential_stream(count)
}
}
pub fn generate_vec(self) -> Result<Vec<PrivateKey>, BatchError> {
self.config.validate()?;
let count = self.config.batch_size;
let parallel = self.config.parallel;
let keys = FastKeyGenerator::new(count)
.parallel(parallel)
.chunk_size(self.config.chunk_size)
.generate();
Ok(keys)
}
fn generate_sequential_stream(self, count: usize) -> Result<KeyStream, BatchError> {
let iter = (0..count).map(|_| Ok(PrivateKey::random()));
Ok(KeyStream::new(iter, Some(count)))
}
#[allow(dead_code)]
fn generate_sequential_vec(&self, count: usize) -> Result<Vec<PrivateKey>, BatchError> {
let keys: Vec<PrivateKey> = (0..count).map(|_| PrivateKey::random()).collect();
Ok(keys)
}
fn generate_parallel_stream(self, count: usize) -> Result<KeyStream, BatchError> {
let chunk_size = self.config.chunk_size;
let deterministic = self.config.deterministic_order;
let iter = ParallelChunkIterator::new(count, chunk_size, deterministic);
Ok(KeyStream::new(iter, Some(count)))
}
#[allow(dead_code)]
fn generate_parallel_vec(&self, count: usize) -> Result<Vec<PrivateKey>, BatchError> {
let keys: Vec<PrivateKey> = if self.config.deterministic_order {
(0..count)
.into_par_iter()
.map(|_| generate_single_key())
.collect()
} else {
(0..count)
.into_par_iter()
.map(|_| generate_single_key())
.collect()
};
Ok(keys)
}
}
fn generate_single_key() -> PrivateKey {
PrivateKey::random()
}
struct ParallelChunkIterator {
remaining: usize,
chunk_size: usize,
current_chunk: std::vec::IntoIter<PrivateKey>,
deterministic: bool,
}
impl ParallelChunkIterator {
fn new(total: usize, chunk_size: usize, deterministic: bool) -> Self {
Self {
remaining: total,
chunk_size,
current_chunk: Vec::new().into_iter(),
deterministic,
}
}
fn generate_chunk(&mut self) -> Vec<PrivateKey> {
let chunk_count = self.remaining.min(self.chunk_size);
self.remaining -= chunk_count;
if self.deterministic {
(0..chunk_count)
.into_par_iter()
.map(|_| generate_single_key())
.collect()
} else {
(0..chunk_count)
.into_par_iter()
.map(|_| generate_single_key())
.collect()
}
}
}
impl Iterator for ParallelChunkIterator {
type Item = Result<PrivateKey, BatchError>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(key) = self.current_chunk.next() {
return Some(Ok(key));
}
if self.remaining > 0 {
let chunk = self.generate_chunk();
self.current_chunk = chunk.into_iter();
self.current_chunk.next().map(Ok)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_sequential() {
let keys = BatchGenerator::new()
.count(100)
.generate_vec()
.unwrap();
assert_eq!(keys.len(), 100);
}
#[test]
fn test_generate_parallel() {
let keys = BatchGenerator::new()
.count(1000)
.parallel()
.generate_vec()
.unwrap();
assert_eq!(keys.len(), 1000);
}
#[test]
fn test_generate_stream() {
let stream = BatchGenerator::new()
.count(100)
.generate()
.unwrap();
let keys: Vec<_> = stream.collect();
assert_eq!(keys.len(), 100);
assert!(keys.iter().all(|r| r.is_ok()));
}
#[test]
fn test_generate_parallel_stream() {
let stream = BatchGenerator::new()
.count(1000)
.parallel()
.chunk_size(100)
.generate()
.unwrap();
let keys: Vec<_> = stream.collect();
assert_eq!(keys.len(), 1000);
assert!(keys.iter().all(|r| r.is_ok()));
}
#[test]
fn test_keys_are_unique() {
let keys = BatchGenerator::new()
.count(1000)
.parallel()
.generate_vec()
.unwrap();
let hex_keys: std::collections::HashSet<_> = keys.iter().map(|k| k.to_hex()).collect();
assert_eq!(hex_keys.len(), keys.len(), "All keys should be unique");
}
#[test]
fn test_with_config() {
let config = BatchConfig::fast();
let generator = BatchGenerator::with_config(config);
let keys = generator.count(500).generate_vec().unwrap();
assert_eq!(keys.len(), 500);
}
#[test]
fn test_deterministic_mode() {
let keys = BatchGenerator::new()
.count(100)
.parallel()
.deterministic()
.generate_vec()
.unwrap();
assert_eq!(keys.len(), 100);
}
}