use std::mem;
use aes::{
Aes128,
cipher::{BlockCipherEncrypt, KeyInit},
};
use rand::rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng};
use rand::{CryptoRng, Rng, RngCore, SeedableRng};
use crate::block::Block;
#[derive(Clone, Debug)]
pub(crate) struct AesRng(BlockRng<AesRngCore>);
impl RngCore for AesRng {
#[inline]
fn next_u32(&mut self) -> u32 {
self.0.next_u32()
}
#[inline]
fn next_u64(&mut self) -> u64 {
self.0.next_u64()
}
#[inline]
fn fill_bytes(&mut self, dest: &mut [u8]) {
const BLOCK_SIZE: usize = mem::size_of::<aes::Block>();
let whole_blocks = dest.len() / BLOCK_SIZE;
let (block_bytes, rest_bytes) = dest.split_at_mut(whole_blocks * BLOCK_SIZE);
let blocks = bytemuck::cast_slice_mut::<_, aes::Block>(block_bytes);
for chunk in blocks.chunks_mut(AES_PAR_BLOCKS) {
for block in chunk.iter_mut() {
*block = aes::cipher::Array(self.0.core.state.to_le_bytes());
self.0.core.state += 1;
}
self.0.core.aes.encrypt_blocks(chunk);
}
self.0.fill_bytes(rest_bytes)
}
}
impl SeedableRng for AesRng {
type Seed = Block;
#[inline]
fn from_seed(seed: Self::Seed) -> Self {
AesRng(BlockRng::<AesRngCore>::from_seed(seed))
}
}
impl CryptoRng for AesRng {}
impl AesRng {
#[inline]
pub(crate) fn new() -> Self {
let seed = rand::random::<Block>();
AesRng::from_seed(seed)
}
#[inline]
pub(crate) fn fork(&mut self) -> Self {
let seed = self.random::<Block>();
AesRng::from_seed(seed)
}
}
impl Default for AesRng {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub(crate) struct AesRngCore {
aes: Aes128,
state: u128,
}
impl std::fmt::Debug for AesRngCore {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "AesRngCore {{}}")
}
}
impl BlockRngCore for AesRngCore {
type Item = u32;
type Results = hidden::ParBlockWrapper;
#[inline]
fn generate(&mut self, results: &mut Self::Results) {
let blocks = bytemuck::cast_slice_mut::<_, aes::Block>(results.as_mut());
blocks.iter_mut().for_each(|blk| {
*blk = aes::cipher::Array(self.state.to_le_bytes());
self.state += 1;
});
self.aes.encrypt_blocks(blocks);
}
}
mod hidden {
use crate::crypto::AES_PAR_BLOCKS;
#[derive(Copy, Clone)]
pub(crate) struct ParBlockWrapper([u32; AES_PAR_BLOCKS * 4]);
impl Default for ParBlockWrapper {
fn default() -> Self {
Self([0; AES_PAR_BLOCKS * 4])
}
}
impl AsMut<[u32]> for ParBlockWrapper {
fn as_mut(&mut self) -> &mut [u32] {
&mut self.0
}
}
impl AsRef<[u32]> for ParBlockWrapper {
fn as_ref(&self) -> &[u32] {
&self.0
}
}
}
impl SeedableRng for AesRngCore {
type Seed = Block;
#[inline]
fn from_seed(seed: Self::Seed) -> Self {
let aes = Aes128::new(&seed.into());
AesRngCore {
aes,
state: Default::default(),
}
}
}
impl CryptoBlockRng for AesRngCore {}
impl From<AesRngCore> for AesRng {
#[inline]
fn from(core: AesRngCore) -> Self {
AesRng(BlockRng::new(core))
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) const AES_PAR_BLOCKS: usize = 9;
#[cfg(target_arch = "aarch64")]
pub(crate) const AES_PAR_BLOCKS: usize = 21;
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
pub(crate) const AES_PAR_BLOCKS: usize = 4;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate() {
let mut rng = AesRng::new();
let a = rng.random::<[Block; 8]>();
let b = rng.random::<[Block; 8]>();
assert_ne!(a, b);
}
}
#[cfg(all(test, not(miri), target_feature = "aes"))]
mod aes_par_blocks_tests {
use aes::{
Aes128,
cipher::{
BlockCipherEncClosure, BlockCipherEncrypt, BlockSizeUser, KeyInit, ParBlocksSizeUser,
},
};
use super::AES_PAR_BLOCKS;
#[test]
fn aes_par_block_size() {
use aes::cipher::typenum::Unsigned;
struct GetParBlockSize;
impl BlockSizeUser for GetParBlockSize {
type BlockSize = aes::cipher::array::sizes::U16;
}
impl BlockCipherEncClosure for GetParBlockSize {
fn call<B: aes::cipher::BlockCipherEncBackend<BlockSize = Self::BlockSize>>(
self,
_backend: &B,
) {
assert_eq!(
AES_PAR_BLOCKS,
<<B as ParBlocksSizeUser>::ParBlocksSize as Unsigned>::USIZE,
);
}
}
let aes = Aes128::new(&Default::default());
aes.encrypt_with_backend(GetParBlockSize);
}
}