use zeroize::Zeroize;
use crate::shake::{i_shake256_extract, InnerShake256Context};
pub struct Prng {
pub(crate) buf: [u8; 512],
pub(crate) ptr: usize,
pub(crate) state: [u8; 256],
}
impl Prng {
pub fn new() -> Self {
Prng {
buf: [0u8; 512],
ptr: 0,
state: [0u8; 256],
}
}
}
impl Default for Prng {
fn default() -> Self {
Self::new()
}
}
impl Drop for Prng {
fn drop(&mut self) {
self.buf.zeroize();
self.state.zeroize();
self.ptr = 0;
}
}
#[cfg(not(feature = "getrandom"))]
pub fn get_seed(_seed: &mut [u8]) -> bool {
false
}
#[cfg(feature = "getrandom")]
pub fn get_seed(seed: &mut [u8]) -> bool {
if seed.is_empty() {
return true;
}
getrandom::getrandom(seed).is_ok()
}
pub fn prng_init(p: &mut Prng, src: &mut InnerShake256Context) {
let mut tmp = [0u8; 56];
i_shake256_extract(src, &mut tmp);
p.state[..56].copy_from_slice(&tmp);
tmp.zeroize();
prng_refill(p);
}
const CW: [u32; 4] = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574];
pub fn prng_refill(p: &mut Prng) {
let cc = u64::from_le_bytes([
p.state[48],
p.state[49],
p.state[50],
p.state[51],
p.state[52],
p.state[53],
p.state[54],
p.state[55],
]);
let mut init_state = [0u32; 12];
for i in 0..12 {
let off = i * 4;
init_state[i] = u32::from_le_bytes([
p.state[off],
p.state[off + 1],
p.state[off + 2],
p.state[off + 3],
]);
}
for u in 0..8u64 {
let mut state = [0u32; 16];
state[0] = CW[0];
state[1] = CW[1];
state[2] = CW[2];
state[3] = CW[3];
state[4..16].copy_from_slice(&init_state);
let counter = cc.wrapping_add(u);
state[14] ^= counter as u32;
state[15] ^= (counter >> 32) as u32;
let s0 = state;
for _ in 0..10 {
quarter_round(&mut state, 0, 4, 8, 12);
quarter_round(&mut state, 1, 5, 9, 13);
quarter_round(&mut state, 2, 6, 10, 14);
quarter_round(&mut state, 3, 7, 11, 15);
quarter_round(&mut state, 0, 5, 10, 15);
quarter_round(&mut state, 1, 6, 11, 12);
quarter_round(&mut state, 2, 7, 8, 13);
quarter_round(&mut state, 3, 4, 9, 14);
}
for i in 0..16 {
state[i] = state[i].wrapping_add(s0[i]);
}
let u_idx = u as usize;
for v in 0..16 {
let off = (u_idx << 2) + (v << 5);
let bytes = state[v].to_le_bytes();
debug_assert!(off + 3 < 512, "PRNG buffer overflow: off={}", off);
unsafe {
*p.buf.get_unchecked_mut(off) = bytes[0];
*p.buf.get_unchecked_mut(off + 1) = bytes[1];
*p.buf.get_unchecked_mut(off + 2) = bytes[2];
*p.buf.get_unchecked_mut(off + 3) = bytes[3];
}
}
}
let new_cc = cc.wrapping_add(8);
p.state[48..56].copy_from_slice(&new_cc.to_le_bytes());
p.ptr = 0;
}
#[inline(always)]
fn quarter_round(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(16);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(12);
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(8);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(7);
}
#[inline]
pub fn prng_get_u64(p: &mut Prng) -> u64 {
let u = p.ptr;
if u >= 512 - 9 {
prng_refill(p);
return prng_get_u64(p);
}
p.ptr = u + 8;
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&p.buf[u..u + 8]);
u64::from_le_bytes(bytes)
}
#[inline]
pub fn prng_get_u8(p: &mut Prng) -> u32 {
let v = p.buf[p.ptr] as u32;
p.ptr += 1;
if p.ptr == 512 {
prng_refill(p);
}
v
}
pub fn prng_get_bytes(p: &mut Prng, dst: &mut [u8]) {
let mut offset = 0;
let mut remaining = dst.len();
while remaining > 0 {
let mut clen = 512 - p.ptr;
if clen > remaining {
clen = remaining;
}
dst[offset..offset + clen].copy_from_slice(&p.buf[p.ptr..p.ptr + clen]);
offset += clen;
remaining -= clen;
p.ptr += clen;
if p.ptr == 512 {
prng_refill(p);
}
}
}