use cryptography::StreamCipher;
use super::Rng;
const CHUNK: usize = 64;
const _: () = assert!(CHUNK.is_multiple_of(8), "CHUNK must be a multiple of 8");
pub struct StreamRng<C: StreamCipher> {
cipher: C,
buf: [u8; CHUNK],
pos: usize,
}
impl<C: StreamCipher> StreamRng<C> {
pub fn new(cipher: C) -> Self {
Self {
cipher,
buf: [0u8; CHUNK],
pos: CHUNK,
}
}
fn refill(&mut self) {
self.buf = [0u8; CHUNK];
self.cipher.fill(&mut self.buf);
self.pos = 0;
}
}
impl<C: StreamCipher> Rng for StreamRng<C> {
fn next_u32(&mut self) -> u32 {
if self.pos + 4 > CHUNK {
self.refill();
}
let w = u32::from_le_bytes(self.buf[self.pos..self.pos + 4].try_into().unwrap());
self.pos += 4;
w
}
fn next_u64(&mut self) -> u64 {
if self.pos + 8 > CHUNK {
self.refill();
}
let w = u64::from_le_bytes(self.buf[self.pos..self.pos + 8].try_into().unwrap());
self.pos += 8;
w
}
}
#[cfg(test)]
mod tests {
use super::*;
use cryptography::Rabbit;
#[test]
fn stream_rng_kat_rfc4503() {
let key = [0u8; 16];
let iv = [0u8; 8];
let mut rng = StreamRng::new(Rabbit::new(&key, &iv));
assert_eq!(
rng.next_u64(),
0xd895_54f8_5e27_a7c6,
"First u64 must match RFC 4503 §A.2 Test Vector 1"
);
}
#[test]
fn stream_rng_advances() {
let key = [0u8; 16];
let iv = [0u8; 8];
let mut rng = StreamRng::new(Rabbit::new(&key, &iv));
let a = rng.next_u64();
let b = rng.next_u64();
assert_ne!(a, b, "consecutive Rabbit words should differ");
}
#[test]
fn stream_rng_crosses_chunk_boundary() {
let key = [0u8; 16];
let iv = [0u8; 8];
let mut rng = StreamRng::new(Rabbit::new(&key, &iv));
for _ in 0..16 {
let _ = rng.next_u32();
}
let v = rng.next_u32(); let key2 = [0u8; 16];
let iv2 = [0u8; 8];
let mut rng2 = StreamRng::new(Rabbit::new(&key2, &iv2));
for _ in 0..16 {
let _ = rng2.next_u32();
}
assert_eq!(v, rng2.next_u32(), "post-boundary value must be deterministic");
}
}