use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use rayon::prelude::*;
use rustywallet_keys::private_key::PrivateKey;
use secp256k1::{Secp256k1, SecretKey};
pub struct FastKeyGenerator {
count: usize,
parallel: bool,
chunk_size: usize,
}
impl FastKeyGenerator {
pub fn new(count: usize) -> Self {
Self {
count,
parallel: true,
chunk_size: 10_000,
}
}
pub fn parallel(mut self, enabled: bool) -> Self {
self.parallel = enabled;
self
}
pub fn chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
pub fn generate(self) -> Vec<PrivateKey> {
if self.parallel {
self.generate_parallel()
} else {
self.generate_sequential()
}
}
fn generate_sequential(self) -> Vec<PrivateKey> {
let secp = Secp256k1::new();
let mut rng = ChaCha20Rng::from_entropy();
(0..self.count)
.map(|_| {
let (secret_key, _) = secp.generate_keypair(&mut rng);
secret_key_to_private_key(secret_key)
})
.collect()
}
fn generate_parallel(self) -> Vec<PrivateKey> {
let num_chunks = self.count.div_ceil(self.chunk_size);
(0..num_chunks)
.into_par_iter()
.flat_map(|chunk_idx| {
let start = chunk_idx * self.chunk_size;
let end = (start + self.chunk_size).min(self.count);
let chunk_count = end - start;
let secp = Secp256k1::new();
let mut seed = [0u8; 32];
seed[..8].copy_from_slice(&chunk_idx.to_le_bytes());
use rand::RngCore;
let mut temp_rng = rand::rngs::OsRng;
temp_rng.fill_bytes(&mut seed[8..]);
let mut rng = ChaCha20Rng::from_seed(seed);
(0..chunk_count)
.map(|_| {
let (secret_key, _) = secp.generate_keypair(&mut rng);
secret_key_to_private_key(secret_key)
})
.collect::<Vec<_>>()
})
.collect()
}
}
fn secret_key_to_private_key(secret_key: SecretKey) -> PrivateKey {
let bytes = secret_key.secret_bytes();
PrivateKey::from_bytes(bytes).expect("SecretKey should always be valid")
}
pub struct IncrementalKeyGenerator {
start: [u8; 32],
count: usize,
step: u64,
}
impl IncrementalKeyGenerator {
pub fn new(count: usize) -> Self {
let start_key = PrivateKey::random();
Self {
start: start_key.to_bytes(),
count,
step: 1,
}
}
pub fn from_key(key: &PrivateKey, count: usize) -> Self {
Self {
start: key.to_bytes(),
count,
step: 1,
}
}
pub fn step(mut self, step: u64) -> Self {
self.step = step;
self
}
pub fn generate(self) -> Vec<PrivateKey> {
let mut current = self.start;
let mut keys = Vec::with_capacity(self.count);
for _ in 0..self.count {
if let Ok(key) = PrivateKey::from_bytes(current) {
keys.push(key);
}
add_to_bytes(&mut current, self.step);
}
keys
}
}
fn add_to_bytes(bytes: &mut [u8; 32], value: u64) {
let value_bytes = value.to_be_bytes();
let mut carry: u64 = 0;
for i in (24..32).rev() {
let idx = 31 - i;
let v = if idx < 8 { value_bytes[7 - idx] } else { 0 };
let sum = bytes[i] as u64 + v as u64 + carry;
bytes[i] = sum as u8;
carry = sum >> 8;
}
for i in (0..24).rev() {
if carry == 0 {
break;
}
let sum = bytes[i] as u64 + carry;
bytes[i] = sum as u8;
carry = sum >> 8;
}
if carry > 0 || !PrivateKey::is_valid(bytes) {
*bytes = [0u8; 32];
bytes[31] = 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::time::Instant;
#[test]
fn test_fast_generator_sequential() {
let keys = FastKeyGenerator::new(1000)
.parallel(false)
.generate();
assert_eq!(keys.len(), 1000);
let hex_keys: HashSet<_> = keys.iter().map(|k| k.to_hex()).collect();
assert_eq!(hex_keys.len(), 1000);
}
#[test]
fn test_fast_generator_parallel() {
let keys = FastKeyGenerator::new(10_000)
.parallel(true)
.chunk_size(1000)
.generate();
assert_eq!(keys.len(), 10_000);
let hex_keys: HashSet<_> = keys.iter().map(|k| k.to_hex()).collect();
assert_eq!(hex_keys.len(), 10_000);
}
#[test]
fn test_incremental_generator() {
let base = PrivateKey::from_hex(
"0000000000000000000000000000000000000000000000000000000000000001"
).unwrap();
let keys = IncrementalKeyGenerator::from_key(&base, 5).generate();
assert_eq!(keys.len(), 5);
assert_eq!(keys[0].to_hex(), "0000000000000000000000000000000000000000000000000000000000000001");
assert_eq!(keys[1].to_hex(), "0000000000000000000000000000000000000000000000000000000000000002");
assert_eq!(keys[2].to_hex(), "0000000000000000000000000000000000000000000000000000000000000003");
}
#[test]
fn test_performance_comparison() {
let count = 10_000;
let start = Instant::now();
let keys = FastKeyGenerator::new(count)
.parallel(true)
.generate();
let fast_elapsed = start.elapsed();
assert_eq!(keys.len(), count);
let fast_rate = count as f64 / fast_elapsed.as_secs_f64();
println!("Fast generator: {:.0} keys/sec", fast_rate);
assert!(fast_rate > 500.0, "Fast generator should exceed 500 keys/sec in test mode");
}
}