use aes::{
cipher::{
generic_array::{
typenum::{U16, U8},
GenericArray,
},
BlockEncrypt, KeyInit,
},
Aes128,
};
use ark_std::rand::{CryptoRng, Error, RngCore, SeedableRng};
use byteorder::{ByteOrder, LittleEndian};
use core::{mem, slice};
const AES_BLK_SIZE: usize = 16;
const PIPELINES_U128: u128 = 8;
const PIPELINES_USIZE: usize = 8;
const STATE_SIZE: usize = PIPELINES_USIZE * AES_BLK_SIZE;
pub const SEED_SIZE: usize = AES_BLK_SIZE;
pub type RngSeed = [u8; SEED_SIZE];
type Block128 = GenericArray<u8, U16>;
type Block128x8 = GenericArray<Block128, U8>;
pub struct AesRngState {
blocks: Block128x8,
next_index: u128,
used_bytes: usize,
}
impl Default for AesRngState {
fn default() -> Self {
AesRngState::init()
}
}
fn create_init_state() -> Block128x8 {
let mut state = [0_u8; STATE_SIZE];
Block128x8::from_exact_iter((0..PIPELINES_USIZE).map(|i| {
LittleEndian::write_u128(
&mut state[i * AES_BLK_SIZE..(i + 1) * AES_BLK_SIZE],
i as u128,
);
let sliced_state = &mut state[i * AES_BLK_SIZE..(i + 1) * AES_BLK_SIZE];
let block = GenericArray::from_mut_slice(sliced_state);
*block
}))
.unwrap()
}
impl AesRngState {
fn as_mut_bytes(&mut self) -> &mut [u8] {
#[allow(unsafe_code)]
unsafe {
slice::from_raw_parts_mut(&mut self.blocks as *mut Block128x8 as *mut u8, STATE_SIZE)
}
}
fn init() -> Self {
AesRngState {
blocks: create_init_state(),
next_index: PIPELINES_U128,
used_bytes: 0,
}
}
fn next(&mut self) {
let counter = self.next_index;
let blocks_bytes = self.as_mut_bytes();
for i in 0..PIPELINES_USIZE {
LittleEndian::write_u128(
&mut blocks_bytes[i * AES_BLK_SIZE..(i + 1) * AES_BLK_SIZE],
counter + i as u128,
);
}
self.next_index += PIPELINES_U128;
self.used_bytes = 0;
}
}
pub struct AesRng {
state: AesRngState,
cipher: Aes128,
}
impl SeedableRng for AesRng {
type Seed = RngSeed;
#[inline]
fn from_seed(seed: Self::Seed) -> Self {
let key = GenericArray::clone_from_slice(&seed);
let mut out = AesRng {
state: AesRngState::default(),
cipher: Aes128::new(&key),
};
out.init();
out
}
}
impl AesRng {
fn init(&mut self) {
self.cipher.encrypt_blocks(&mut self.state.blocks);
}
fn next(&mut self) {
self.state.next();
self.cipher.encrypt_blocks(&mut self.state.blocks);
}
}
impl RngCore for AesRng {
fn next_u32(&mut self) -> u32 {
let u32_size = mem::size_of::<u32>();
if self.state.used_bytes >= STATE_SIZE - u32_size {
self.next();
}
let used_bytes = self.state.used_bytes;
self.state.used_bytes += u32_size; let blocks_bytes = self.state.as_mut_bytes();
LittleEndian::read_u32(&blocks_bytes[used_bytes..used_bytes + u32_size])
}
fn next_u64(&mut self) -> u64 {
let u64_size = mem::size_of::<u64>();
if self.state.used_bytes >= STATE_SIZE - u64_size {
self.next();
}
let used_bytes = self.state.used_bytes;
self.state.used_bytes += u64_size; LittleEndian::read_u64(&self.state.as_mut_bytes()[used_bytes..used_bytes + u64_size])
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
let mut read_len = STATE_SIZE - self.state.used_bytes;
let mut dest_start = 0;
while read_len < dest.len() {
let src_start = self.state.used_bytes;
dest[dest_start..read_len]
.copy_from_slice(&self.state.as_mut_bytes()[src_start..STATE_SIZE]);
self.next();
dest_start = read_len;
read_len += STATE_SIZE;
}
let src_start = self.state.used_bytes;
let remainder = dest.len() - dest_start;
let dest_len = dest.len();
dest[dest_start..dest_len]
.copy_from_slice(&self.state.as_mut_bytes()[src_start..src_start + remainder]);
self.state.used_bytes += remainder;
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
self.fill_bytes(dest);
Ok(())
}
}
impl CryptoRng for AesRng {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prng_match_aes() {
let seed = [0u8; SEED_SIZE];
let key: Block128 = GenericArray::clone_from_slice(&seed);
let cipher = Aes128::new(&key);
let block0 =
GenericArray::clone_from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block1 =
GenericArray::clone_from_slice(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block2 =
GenericArray::clone_from_slice(&[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block3 =
GenericArray::clone_from_slice(&[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block4 =
GenericArray::clone_from_slice(&[4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block5 =
GenericArray::clone_from_slice(&[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block6 =
GenericArray::clone_from_slice(&[6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let block7 =
GenericArray::clone_from_slice(&[7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let mut blocks = Block128x8::clone_from_slice(&[
block0, block1, block2, block3, block4, block5, block6, block7,
]);
cipher.encrypt_blocks(&mut blocks);
let mut rng = AesRng::from_seed(seed);
let mut out = [0u8; 16 * 8];
rng.try_fill_bytes(&mut out).expect("");
assert_eq!(rng.state.blocks, blocks);
}
#[test]
fn test_prng_vector1() {
let seed = [0u8; SEED_SIZE];
let mut rng = AesRng::from_seed(seed);
let mut out = [0u8; 16];
for _ in 0..129 {
rng.try_fill_bytes(&mut out).expect("");
}
let expected: [u8; 16] = [
58, 215, 142, 114, 108, 30, 192, 43, 126, 191, 233, 43, 35, 217, 236, 52,
];
assert_eq!(expected, out);
}
#[test]
fn test_prng_vector2() {
let seed = [0u8; SEED_SIZE];
let mut rng = AesRng::from_seed(seed);
let mut out = [0u8; 16];
for _ in 0..17 {
rng.try_fill_bytes(&mut out).expect("");
}
let expected: [u8; 16] = [
245, 86, 155, 58, 182, 166, 209, 30, 253, 225, 191, 10, 100, 198, 133, 74,
];
assert_eq!(expected, out);
}
#[test]
fn test_prng_used_bytes() {
let seed = [1u8; SEED_SIZE];
let mut rng: AesRng = AesRng::from_seed(seed);
let mut out = [0u8; 16 * 8];
rng.try_fill_bytes(&mut out).expect("");
assert_eq!(rng.state.used_bytes, 16 * 8);
let _ = rng.next_u32();
assert_eq!(rng.state.used_bytes, 4);
}
#[test]
fn test_seeded_prng() {
let seed = [2u8; SEED_SIZE];
let mut rng: AesRng = AesRng::from_seed(seed);
let _ = rng.next_u32();
let _ = rng.next_u64();
}
}