use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha20Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum RngAlgorithm {
#[default]
ChaCha20,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct RngState {
pub algorithm: RngAlgorithm,
pub seed: [u8; 32],
pub stream: u64,
pub word_pos: u128,
}
impl RngState {
#[must_use]
pub fn from_seed(seed: [u8; 32]) -> Self {
Self {
algorithm: RngAlgorithm::ChaCha20,
seed,
stream: 0,
word_pos: 0,
}
}
#[must_use]
pub fn from_parts(seed: [u8; 32], stream: u64, word_pos: u128) -> Self {
Self {
algorithm: RngAlgorithm::ChaCha20,
seed,
stream,
word_pos,
}
}
#[must_use]
pub fn fork(&self, salt: &[u8]) -> Self {
let mut hasher = Sha256::new();
hasher.update(self.stream.to_le_bytes());
hasher.update(salt);
let digest = hasher.finalize();
let mut buf = [0u8; 8];
buf.copy_from_slice(&digest[..8]);
let mix = u64::from_le_bytes(buf);
Self {
algorithm: self.algorithm,
seed: self.seed,
stream: self.stream ^ mix,
word_pos: 0,
}
}
#[must_use]
pub fn into_chacha(self) -> ChaCha20Rng {
let mut rng = ChaCha20Rng::from_seed(self.seed);
rng.set_stream(self.stream);
rng.set_word_pos(self.word_pos);
rng
}
#[must_use]
pub fn snapshot(rng: &ChaCha20Rng, parent: &RngState) -> Self {
Self {
algorithm: parent.algorithm,
seed: parent.seed,
stream: rng.get_stream(),
word_pos: rng.get_word_pos(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand_chacha::rand_core::RngCore;
fn seed_bytes(b: u8) -> [u8; 32] {
[b; 32]
}
fn draw_n(rng: &mut ChaCha20Rng, n: usize) -> Vec<u8> {
let mut buf = vec![0u8; n];
rng.fill_bytes(&mut buf);
buf
}
#[test]
fn from_seed_defaults_stream_and_word_pos_to_zero() {
let s = RngState::from_seed(seed_bytes(0x42));
assert_eq!(s.stream, 0);
assert_eq!(s.word_pos, 0);
assert_eq!(s.algorithm, RngAlgorithm::ChaCha20);
assert_eq!(s.seed, seed_bytes(0x42));
}
#[test]
fn same_state_produces_identical_bytes() {
let s = RngState::from_parts(seed_bytes(0x42), 7, 0);
let mut a = s.clone().into_chacha();
let mut b = s.into_chacha();
assert_eq!(draw_n(&mut a, 1024), draw_n(&mut b, 1024));
}
#[test]
fn fork_with_same_salt_is_deterministic() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let c1 = parent.fork(b"block-0");
let c2 = parent.fork(b"block-0");
assert_eq!(c1, c2);
}
#[test]
fn fork_with_distinct_salts_produces_distinct_streams() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let c1 = parent.fork(b"block-0");
let c2 = parent.fork(b"block-1");
assert_ne!(c1.stream, c2.stream);
let mut r1 = c1.into_chacha();
let mut r2 = c2.into_chacha();
assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
}
#[test]
fn fork_resets_child_word_pos() {
let parent = RngState::from_parts(seed_bytes(0x42), 100, 999);
let child = parent.fork(b"any");
assert_eq!(child.word_pos, 0);
assert_eq!(child.seed, parent.seed);
}
#[test]
fn fork_xors_with_mix_so_distinct_parents_produce_distinct_children_under_same_salt() {
let p1 = RngState::from_parts(seed_bytes(0x42), 1, 0);
let p2 = RngState::from_parts(seed_bytes(0x42), 2, 0);
let c1 = p1.fork(b"same-salt");
let c2 = p2.fork(b"same-salt");
assert_ne!(c1.stream, c2.stream);
}
#[test]
fn snapshot_round_trips_word_pos_and_stream() {
let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
let mut r = s.clone().into_chacha();
let _ = draw_n(&mut r, 8192);
let snap = RngState::snapshot(&r, &s);
assert_eq!(snap.word_pos, r.get_word_pos());
let mut original_continued = r;
let mut resumed = snap.into_chacha();
assert_eq!(
draw_n(&mut original_continued, 1024),
draw_n(&mut resumed, 1024)
);
}
#[test]
fn rngstate_serde_round_trip() {
let s = RngState::from_parts(seed_bytes(0x42), 12345, 67890);
let json = serde_json::to_string(&s).expect("serialize");
let back: RngState = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, s);
}
#[test]
fn rng_algorithm_default_is_chacha20() {
assert_eq!(RngAlgorithm::default(), RngAlgorithm::ChaCha20);
}
#[test]
fn rng_algorithm_serde_round_trip() {
let json = serde_json::to_string(&RngAlgorithm::ChaCha20).expect("serialize");
let back: RngAlgorithm = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, RngAlgorithm::ChaCha20);
}
#[test]
fn from_seed_equals_from_parts_with_zero_stream_and_zero_word_pos() {
let a = RngState::from_seed(seed_bytes(0x99));
let b = RngState::from_parts(seed_bytes(0x99), 0, 0);
assert_eq!(a, b);
}
#[test]
fn from_parts_preserves_all_four_fields() {
let s = RngState::from_parts([0xab; 32], 9_999_999, 12_345_678_901_234_567_890_u128);
assert_eq!(s.algorithm, RngAlgorithm::ChaCha20);
assert_eq!(s.seed, [0xab; 32]);
assert_eq!(s.stream, 9_999_999);
assert_eq!(s.word_pos, 12_345_678_901_234_567_890_u128);
}
#[test]
fn into_chacha_initializes_stream_and_word_pos_correctly() {
let s = RngState::from_parts(seed_bytes(0x42), 1234, 5678);
let rng = s.clone().into_chacha();
assert_eq!(rng.get_stream(), s.stream);
assert_eq!(rng.get_word_pos(), s.word_pos);
}
#[test]
fn into_chacha_clones_so_state_is_unchanged_by_draws() {
let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
let mut r = s.clone().into_chacha();
let _ = draw_n(&mut r, 4096);
assert_eq!(s.word_pos, 0);
}
#[test]
fn fork_is_pure_under_parent_stream_and_seed() {
let p1 = RngState::from_parts(seed_bytes(0x42), 7, 0);
let p2 = RngState::from_parts(seed_bytes(0x42), 7, 999);
assert_eq!(p1.fork(b"some-salt"), p2.fork(b"some-salt"));
}
#[test]
fn fork_keys_off_parent_stream_not_just_seed() {
let p1 = RngState::from_parts(seed_bytes(0x42), 0, 0);
let p2 = RngState::from_parts(seed_bytes(0x42), 1, 0);
assert_ne!(p1.fork(b"x"), p2.fork(b"x"));
}
#[test]
fn fork_keys_off_seed_not_just_stream() {
let p1 = RngState::from_parts(seed_bytes(0x42), 5, 0);
let p2 = RngState::from_parts(seed_bytes(0xbb), 5, 0);
let c1 = p1.fork(b"x");
let c2 = p2.fork(b"x");
assert_eq!(c1.seed, p1.seed);
assert_eq!(c2.seed, p2.seed);
assert_eq!(c1.stream, c2.stream);
let mut r1 = c1.into_chacha();
let mut r2 = c2.into_chacha();
assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
}
#[test]
fn fork_with_empty_salt_is_deterministic() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let c1 = parent.fork(b"");
let c2 = parent.fork(b"");
assert_eq!(c1, c2);
}
#[test]
fn fork_with_empty_salt_differs_from_fork_with_nonempty_salt() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let c_empty = parent.fork(b"");
let c_one = parent.fork(b"x");
assert_ne!(c_empty.stream, c_one.stream);
}
#[test]
fn fork_is_not_commutative_under_double_fork() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let g_ab = parent.fork(b"a").fork(b"b");
let g_ba = parent.fork(b"b").fork(b"a");
assert_ne!(g_ab.stream, g_ba.stream);
}
#[test]
fn fork_chains_of_same_salt_strictly_descend_to_distinct_streams() {
let p = RngState::from_parts(seed_bytes(0x42), 0, 0);
let g1 = p.fork(b"s");
let g2 = g1.fork(b"s");
let g3 = g2.fork(b"s");
assert_ne!(p.stream, g1.stream);
assert_ne!(g1.stream, g2.stream);
assert_ne!(g2.stream, g3.stream);
assert_ne!(p.stream, g2.stream);
assert_ne!(p.stream, g3.stream);
assert_ne!(g1.stream, g3.stream);
}
#[test]
fn fork_long_salt_is_handled() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let long_salt = vec![0xab; 4096];
let c1 = parent.fork(&long_salt);
let c2 = parent.fork(&long_salt);
assert_eq!(c1, c2);
}
#[test]
fn fork_one_byte_salt_difference_changes_stream() {
let parent = RngState::from_parts(seed_bytes(0x42), 0, 0);
let c1 = parent.fork(b"abcdefgh");
let c2 = parent.fork(b"abcdefgi");
assert_ne!(c1.stream, c2.stream);
}
#[test]
fn snapshot_at_word_pos_zero_equals_a_fresh_state() {
let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
let r = s.clone().into_chacha();
let snap = RngState::snapshot(&r, &s);
assert_eq!(snap, s);
}
#[test]
fn snapshot_records_post_draw_word_pos() {
let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
let mut r = s.clone().into_chacha();
let _ = draw_n(&mut r, 4096);
let snap = RngState::snapshot(&r, &s);
assert!(snap.word_pos > 0);
assert_eq!(snap.word_pos, r.get_word_pos());
}
#[test]
fn snapshot_can_be_round_tripped_through_serde() {
let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
let mut r = s.clone().into_chacha();
let _ = draw_n(&mut r, 2048);
let snap = RngState::snapshot(&r, &s);
let json = serde_json::to_string(&snap).expect("serialize");
let back: RngState = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, snap);
}
#[test]
fn snapshot_matches_set_word_pos_resumption() {
let s = RngState::from_parts(seed_bytes(0x42), 0, 0);
let mut r = s.clone().into_chacha();
let _ = draw_n(&mut r, 12_345);
let snap = RngState::snapshot(&r, &s);
let mut r_resumed = snap.into_chacha();
let cont = draw_n(&mut r, 4096);
let resumed = draw_n(&mut r_resumed, 4096);
assert_eq!(cont, resumed);
}
#[test]
fn distinct_streams_give_independent_byte_sequences() {
let seed = seed_bytes(0x42);
let mut draws: Vec<Vec<u8>> = Vec::new();
for stream in 0..4u64 {
let mut r = RngState::from_parts(seed, stream, 0).into_chacha();
draws.push(draw_n(&mut r, 1024));
}
for i in 0..4 {
for j in (i + 1)..4 {
assert_ne!(draws[i], draws[j], "stream {i} == stream {j}");
}
}
}
#[test]
fn streams_with_distinct_seeds_are_independent() {
let mut r1 = RngState::from_parts(seed_bytes(0xaa), 0, 0).into_chacha();
let mut r2 = RngState::from_parts(seed_bytes(0xbb), 0, 0).into_chacha();
assert_ne!(draw_n(&mut r1, 1024), draw_n(&mut r2, 1024));
}
}