#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
use core::fmt;
use rand_core::{
Infallible, SeedableRng, TryCryptoRng, TryRng,
block::{BlockRng, Generator},
};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::{
ChaChaCore, R8, R12, R20, Rounds, backends,
variants::{Legacy, Variant},
};
use cfg_if::cfg_if;
pub type Seed = [u8; 32];
pub type SerializedRngState = [u8; 49];
pub(crate) const BLOCK_WORDS: u8 = 16;
const BUF_BLOCKS: u8 = 4;
const BUFFER_SIZE: usize = (BLOCK_WORDS * BUF_BLOCKS) as usize;
impl<R: Rounds, V: Variant> SeedableRng for ChaChaCore<R, V> {
type Seed = Seed;
#[inline]
fn from_seed(seed: Self::Seed) -> Self {
ChaChaCore::new_internal(&seed, &[0u8; 8])
}
}
impl<R: Rounds, V: Variant> Generator for ChaChaCore<R, V> {
type Output = [u32; BUFFER_SIZE];
fn generate(&mut self, buffer: &mut [u32; BUFFER_SIZE]) {
cfg_if! {
if #[cfg(chacha20_backend = "soft")] {
backends::soft::Backend(self).gen_ks_blocks(buffer);
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
cfg_if! {
if #[cfg(any(chacha20_backend = "avx2", chacha20_backend = "avx512"))] {
unsafe {
backends::avx2::rng_inner::<R, V>(self, buffer);
}
} else if #[cfg(chacha20_backend = "sse2")] {
unsafe {
backends::sse2::rng_inner::<R, V>(self, buffer);
}
} else {
#[cfg(chacha20_avx512)]
let (_avx512_token, avx2_token, sse2_token) = self.tokens;
#[cfg(not(chacha20_avx512))]
let (avx2_token, sse2_token) = self.tokens;
if avx2_token.get() {
unsafe {
backends::avx2::rng_inner::<R, V>(self, buffer);
}
} else if sse2_token.get() {
unsafe {
backends::sse2::rng_inner::<R, V>(self, buffer);
}
} else {
backends::soft::Backend(self).gen_ks_blocks(buffer);
}
}
}
} else if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
unsafe {
backends::neon::rng_inner::<R, V>(self, buffer);
}
} else {
backends::soft::Backend(self).gen_ks_blocks(buffer);
}
}
}
#[cfg(feature = "zeroize")]
fn drop(&mut self, output: &mut Self::Output) {
output.zeroize();
}
}
macro_rules! impl_chacha_rng {
($Rng:ident, $rounds:ident) => {
#[doc = concat!("use chacha20::", stringify!($Rng), ";")]
#[doc = concat!("let mut rng = ", stringify!($Rng), "::from_seed(seed);")]
pub struct $Rng {
core: BlockRng<ChaChaCore<$rounds, Legacy>>,
}
impl SeedableRng for $Rng {
type Seed = Seed;
#[inline]
fn from_seed(seed: Self::Seed) -> Self {
let core = ChaChaCore::new_internal(&seed, &[0u8; 8]);
Self {
core: BlockRng::new(core),
}
}
}
impl TryRng for $Rng {
type Error = Infallible;
#[inline]
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(self.core.next_word())
}
#[inline]
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(self.core.next_u64_from_u32())
}
#[inline]
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
self.core.fill_bytes(dest);
Ok(())
}
}
impl TryCryptoRng for $Rng {}
#[cfg(feature = "zeroize")]
impl ZeroizeOnDrop for $Rng {}
impl PartialEq<$Rng> for $Rng {
fn eq(&self, rhs: &$Rng) -> bool {
(self.get_seed() == rhs.get_seed())
&& (self.get_stream() == rhs.get_stream())
&& (self.get_word_pos() == rhs.get_word_pos())
}
}
impl Eq for $Rng {}
impl fmt::Debug for $Rng {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, concat!(stringify!($Rng), " {{ ... }}"))
}
}
impl $Rng {
#[inline]
#[must_use]
pub fn get_word_pos(&self) -> u128 {
let mut block_counter = (u64::from(self.core.core.state[13]) << 32)
| u64::from(self.core.core.state[12]);
if self.core.word_offset() != 0 {
block_counter = block_counter.wrapping_sub(u64::from(BUF_BLOCKS));
}
let word_pos = u128::from(block_counter) * u128::from(BLOCK_WORDS)
+ self.core.word_offset() as u128;
word_pos & ((1 << 68) - 1)
}
#[inline]
pub fn set_word_pos(&mut self, word_offset: u128) {
let index = (word_offset % u128::from(BLOCK_WORDS)) as usize;
let counter = word_offset / u128::from(BLOCK_WORDS);
self.core.core.state[12] = counter as u32;
self.core.core.state[13] = (counter >> 32) as u32;
self.core.reset_and_skip(index);
}
#[inline]
#[allow(unused)]
pub fn set_block_pos(&mut self, block_pos: u64) {
self.core.reset_and_skip(0);
self.core.core.set_block_pos(block_pos);
}
#[inline]
#[allow(unused)]
#[must_use]
pub fn get_block_pos(&self) -> u64 {
let counter = self.core.core.get_block_pos();
let offset = self.core.word_offset();
if offset != 0 {
counter - u64::from(BUF_BLOCKS) + offset as u64 / 16
} else {
counter
}
}
#[inline]
pub fn set_stream(&mut self, stream: u64) {
self.core.core.state[14] = stream as u32;
self.core.core.state[15] = (stream >> 32) as u32;
self.set_block_pos(0);
}
#[inline]
#[must_use]
pub fn get_stream(&self) -> u64 {
let mut result = [0u8; 8];
result[..4].copy_from_slice(&self.core.core.state[14].to_le_bytes());
result[4..].copy_from_slice(&self.core.core.state[15].to_le_bytes());
u64::from_le_bytes(result)
}
#[inline]
#[must_use]
pub fn get_seed(&self) -> [u8; 32] {
let seed = &self.core.core.state[4..12];
let mut result = [0u8; 32];
for (src, dst) in seed.iter().zip(result.chunks_exact_mut(4)) {
dst.copy_from_slice(&src.to_le_bytes())
}
result
}
#[inline]
pub fn serialize_state(&self) -> SerializedRngState {
let seed = self.get_seed();
let stream = self.get_stream().to_le_bytes();
let word_pos = self.get_word_pos().to_le_bytes();
let mut res = [0u8; 49];
let (seed_dst, res_rem) = res.split_at_mut(32);
let (stream_dst, word_pos_dst) = res_rem.split_at_mut(8);
seed_dst.copy_from_slice(&seed);
stream_dst.copy_from_slice(&stream);
word_pos_dst.copy_from_slice(&word_pos[..9]);
debug_assert_eq!(&word_pos[9..], &[0u8; 7]);
res
}
#[inline]
pub fn deserialize_state(state: &SerializedRngState) -> Self {
let (seed, state_rem) = state.split_at(32);
let (stream, word_pos_raw) = state_rem.split_at(8);
let seed: &[u8; 32] = seed.try_into().expect("seed.len() is equal to 32");
let stream: &[u8; 8] = stream.try_into().expect("stream.len() is equal to 8");
let mut word_pos_buf = [0u8; 16];
word_pos_buf[..9].copy_from_slice(word_pos_raw);
let word_pos = u128::from_le_bytes(word_pos_buf);
let core = ChaChaCore::new_internal(seed, stream);
let mut res = Self {
core: BlockRng::new(core),
};
res.set_word_pos(word_pos);
res
}
}
};
}
impl_chacha_rng!(ChaCha8Rng, R8);
impl_chacha_rng!(ChaCha12Rng, R12);
impl_chacha_rng!(ChaCha20Rng, R20);