#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::kk_mix;
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct KkRng {
state: [u8; 32],
counter: u64,
}
impl KkRng {
pub fn new(seed: &[u8]) -> Self {
Self {
state: kk_mix::kk_hash(seed),
counter: 0,
}
}
pub fn next_bytes(&mut self, len: usize) -> Vec<u8> {
let mut combined = kk_mix::kk_kdf(
&self.state,
&self.counter.to_le_bytes(),
b"KK-RNG",
len + 32,
);
let output = combined[..len].to_vec();
self.state.copy_from_slice(&combined[len..len + 32]);
self.counter = self.counter.wrapping_add(1);
combined.zeroize();
output
}
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
let bytes = self.next_bytes(dest.len());
dest.copy_from_slice(&bytes);
}
pub fn next_u64(&mut self) -> u64 {
let bytes = self.next_bytes(8);
u64::from_le_bytes(bytes[..8].try_into().unwrap())
}
pub fn reseed(&mut self, additional_seed: &[u8]) {
let mut material = Vec::with_capacity(32 + additional_seed.len());
material.extend_from_slice(&self.state);
material.extend_from_slice(additional_seed);
self.state = kk_mix::kk_hash(&material);
self.counter = 0;
material.zeroize();
}
}
#[cfg(feature = "std")]
use std::sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
};
#[cfg(feature = "std")]
pub struct KkRngPool {
generators: Vec<Mutex<KkRng>>,
next: AtomicUsize,
}
#[cfg(feature = "std")]
impl KkRngPool {
pub fn new(seed: &[u8], num_generators: usize) -> Self {
assert!(
num_generators > 0,
"KkRngPool requires at least 1 generator"
);
let generators = (0..num_generators)
.map(|i| {
let mut domain_seed = Vec::with_capacity(seed.len() + 8);
domain_seed.extend_from_slice(seed);
domain_seed.extend_from_slice(&(i as u64).to_le_bytes());
let rng = KkRng::new(&domain_seed);
domain_seed.zeroize();
Mutex::new(rng)
})
.collect();
Self {
generators,
next: AtomicUsize::new(0),
}
}
pub fn num_generators(&self) -> usize {
self.generators.len()
}
pub fn next_bytes(&self, len: usize) -> Vec<u8> {
let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.generators.len();
let mut gen = self.generators[idx]
.lock()
.expect("KkRngPool: poisoned mutex");
gen.next_bytes(len)
}
pub fn fill_bytes_parallel(&self, dest: &mut [u8]) {
use rayon::prelude::*;
if dest.is_empty() {
return;
}
let n = self.generators.len();
let chunk_size = dest.len().div_ceil(n);
dest.chunks_mut(chunk_size)
.enumerate()
.collect::<Vec<_>>()
.into_par_iter()
.for_each(|(i, chunk)| {
let mut gen = self.generators[i]
.lock()
.expect("KkRngPool: poisoned mutex");
gen.fill_bytes(chunk);
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deterministic_output() {
let mut rng1 = KkRng::new(b"test-seed-12345");
let mut rng2 = KkRng::new(b"test-seed-12345");
assert_eq!(rng1.next_bytes(64), rng2.next_bytes(64));
assert_eq!(rng1.next_bytes(32), rng2.next_bytes(32));
}
#[test]
fn different_seeds_differ() {
let mut rng1 = KkRng::new(b"seed-alpha");
let mut rng2 = KkRng::new(b"seed-beta");
assert_ne!(rng1.next_bytes(32), rng2.next_bytes(32));
}
#[test]
fn successive_calls_differ() {
let mut rng = KkRng::new(b"counter-test");
let a = rng.next_bytes(32);
let b = rng.next_bytes(32);
assert_ne!(a, b);
}
#[test]
fn fill_bytes_matches_next_bytes() {
let mut rng1 = KkRng::new(b"fill-test");
let mut rng2 = KkRng::new(b"fill-test");
let expected = rng1.next_bytes(48);
let mut buf = [0u8; 48];
rng2.fill_bytes(&mut buf);
assert_eq!(&buf[..], &expected[..]);
}
#[test]
fn next_u64_deterministic() {
let mut rng1 = KkRng::new(b"u64-test");
let mut rng2 = KkRng::new(b"u64-test");
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
#[test]
fn reseed_changes_output() {
let mut rng1 = KkRng::new(b"reseed-test");
let mut rng2 = KkRng::new(b"reseed-test");
rng1.reseed(b"extra-entropy");
assert_ne!(rng1.next_bytes(32), rng2.next_bytes(32));
}
#[test]
fn large_output() {
let mut rng = KkRng::new(b"large-test");
let out = rng.next_bytes(1024);
assert_eq!(out.len(), 1024);
assert!(out.iter().any(|&b| b != 0));
}
#[test]
fn zero_length_output() {
let mut rng = KkRng::new(b"zero-test");
let out = rng.next_bytes(0);
assert!(out.is_empty());
}
#[cfg(feature = "std")]
mod pool_tests {
use super::super::KkRngPool;
#[test]
fn pool_deterministic() {
let pool1 = KkRngPool::new(b"pool-seed", 4);
let pool2 = KkRngPool::new(b"pool-seed", 4);
for _ in 0..8 {
assert_eq!(pool1.next_bytes(64), pool2.next_bytes(64));
}
}
#[test]
fn pool_domain_separation() {
let pool_a = KkRngPool::new(b"sep-seed", 2);
let pool_b = KkRngPool::new(b"sep-seed", 2);
let from_gen0 = pool_a.next_bytes(32); let _ = pool_b.next_bytes(32); let from_gen1 = pool_b.next_bytes(32); assert_ne!(from_gen0, from_gen1);
}
#[test]
fn pool_round_robin() {
let pool = KkRngPool::new(b"rr-seed", 3);
let a0 = pool.next_bytes(32);
let _a1 = pool.next_bytes(32);
let _a2 = pool.next_bytes(32);
let a3 = pool.next_bytes(32); assert_ne!(a0, a3);
}
#[test]
fn pool_single_generator() {
use super::super::KkRng;
let pool = KkRngPool::new(b"single", 1);
let mut expected_seed = b"single".to_vec();
expected_seed.extend_from_slice(&0u64.to_le_bytes());
let mut bare = KkRng::new(&expected_seed);
for _ in 0..4 {
assert_eq!(pool.next_bytes(128), bare.next_bytes(128));
}
}
#[test]
fn fill_bytes_parallel_deterministic() {
let pool1 = KkRngPool::new(b"fill-par", 4);
let pool2 = KkRngPool::new(b"fill-par", 4);
let mut buf1 = vec![0u8; 8192];
let mut buf2 = vec![0u8; 8192];
pool1.fill_bytes_parallel(&mut buf1);
pool2.fill_bytes_parallel(&mut buf2);
assert_eq!(buf1, buf2);
}
#[test]
fn fill_bytes_parallel_nonzero() {
let pool = KkRngPool::new(b"nz-par", 4);
let mut buf = vec![0u8; 4096];
pool.fill_bytes_parallel(&mut buf);
assert!(buf.iter().any(|&b| b != 0));
}
#[test]
fn fill_bytes_parallel_empty() {
let pool = KkRngPool::new(b"empty-par", 4);
let mut buf: Vec<u8> = vec![];
pool.fill_bytes_parallel(&mut buf); }
#[test]
fn num_generators_correct() {
for n in [1, 4, 16, 32] {
let pool = KkRngPool::new(b"count", n);
assert_eq!(pool.num_generators(), n);
}
}
#[test]
#[should_panic(expected = "at least 1")]
fn pool_zero_generators_panics() {
let _ = KkRngPool::new(b"zero", 0);
}
}
}