#![allow(clippy::uninit_assumed_init, invalid_value)]
use crate::rounds::*;
use crate::util::*;
use crate::variations::*;
use core::marker::PhantomData;
use core::mem::{MaybeUninit, transmute};
use core::ptr::copy_nonoverlapping;
#[repr(C)]
pub struct ChaChaCore<M, R, V> {
row_b: Row,
row_c: Row,
row_d: Row,
_phantom: PhantomData<(M, R, V)>,
}
impl<M, R, V> From<u8> for ChaChaCore<M, R, V> {
#[inline]
fn from(value: u8) -> Self {
[value; SEED_LEN_U8].into()
}
}
impl<M, R, V> From<u32> for ChaChaCore<M, R, V> {
#[inline]
fn from(value: u32) -> Self {
[value; SEED_LEN_U32].into()
}
}
impl<M, R, V> From<u64> for ChaChaCore<M, R, V> {
#[inline]
fn from(value: u64) -> Self {
[value; SEED_LEN_U64].into()
}
}
impl<M, R, V> From<[u8; SEED_LEN_U8]> for ChaChaCore<M, R, V> {
#[inline]
fn from(value: [u8; SEED_LEN_U8]) -> Self {
unsafe { transmute(value) }
}
}
impl<M, R, V> From<[u32; SEED_LEN_U32]> for ChaChaCore<M, R, V> {
#[inline]
fn from(value: [u32; SEED_LEN_U32]) -> Self {
unsafe { transmute(value) }
}
}
impl<M, R, V> From<[u64; SEED_LEN_U64]> for ChaChaCore<M, R, V> {
#[inline]
fn from(value: [u64; SEED_LEN_U64]) -> Self {
unsafe { transmute(value) }
}
}
impl<M, R, V> ChaChaCore<M, R, V>
where
M: Machine,
R: DoubleRounds,
V: Variant,
{
pub fn new(key: [u32; 8], counter: u64, nonce: [u32; 3]) -> Self {
let row_b = Row {
u32x4: [key[0], key[1], key[2], key[3]],
};
let row_c = Row {
u32x4: [key[4], key[5], key[6], key[7]],
};
let row_d = match V::VAR {
Variants::Djb => {
let nonce = unsafe { transmute([nonce[0], nonce[1]]) };
Row {
u64x2: [counter, nonce],
}
}
Variants::Ietf => {
let counter = counter as u32;
Row {
u32x4: [counter, nonce[0], nonce[1], nonce[2]],
}
}
};
Self {
row_b,
row_c,
row_d,
_phantom: PhantomData,
}
}
#[inline]
pub fn get_counter(&self) -> u64 {
unsafe {
match V::VAR {
Variants::Djb => self.row_d.u64x2[0],
Variants::Ietf => self.row_d.u32x4[0] as u64,
}
}
}
#[inline]
pub fn set_counter(&mut self, new_counter: u64) {
unsafe {
match V::VAR {
Variants::Djb => self.row_d.u64x2[0] = new_counter,
Variants::Ietf => self.row_d.u32x4[0] = new_counter as u32,
}
}
}
#[inline(never)]
pub fn xor(&mut self, dst: &mut [u8]) {
self.slice::<true>(dst);
}
#[inline(never)]
pub fn fill(&mut self, dst: &mut [u8]) {
self.slice::<false>(dst);
}
#[inline]
fn slice<const XOR: bool>(&mut self, dst: &mut [u8]) {
let mut machine = M::new::<V>(self.get_naked());
dst.chunks_exact_mut(BUF_LEN_U8).for_each(|chunk| {
let buf: &mut [u8; BUF_LEN_U8] = chunk.try_into().unwrap();
self.chacha::<true, XOR>(&mut machine, buf)
});
let rem = dst.chunks_exact_mut(BUF_LEN_U8).into_remainder();
if !rem.is_empty() {
let mut buf: [u8; BUF_LEN_U8] = unsafe { MaybeUninit::uninit().assume_init() };
self.chacha::<false, XOR>(&mut machine, &mut buf);
unsafe {
copy_nonoverlapping(buf.as_ptr(), rem.as_mut_ptr(), rem.len());
}
let increment = rem.len().div_ceil(MATRIX_SIZE_U8);
unsafe {
match V::VAR {
Variants::Djb => {
self.row_d.u64x2[0] = self.row_d.u64x2[0].wrapping_add(increment as u64);
}
Variants::Ietf => {
self.row_d.u32x4[0] = self.row_d.u32x4[0].wrapping_add(increment as u32);
}
}
}
}
}
#[inline]
pub fn get_block_u64(&mut self) -> [u64; BUF_LEN_U64] {
let mut result = unsafe { MaybeUninit::uninit().assume_init() };
self.fill_block_u64(&mut result);
result
}
#[inline]
pub fn get_block(&mut self) -> [u8; BUF_LEN_U8] {
let mut result = unsafe { MaybeUninit::uninit().assume_init() };
self.fill_block(&mut result);
result
}
#[inline]
pub fn fill_block_u64(&mut self, buf: &mut [u64; BUF_LEN_U64]) {
let temp = unsafe { transmute(buf) };
self.chacha_once::<false>(temp);
}
#[inline]
pub fn fill_block(&mut self, buf: &mut [u8; BUF_LEN_U8]) {
self.chacha_once::<false>(buf);
}
#[inline]
pub fn xor_block(&mut self, buf: &mut [u8; BUF_LEN_U8]) {
self.chacha_once::<true>(buf);
}
#[inline(never)]
fn chacha_once<const XOR: bool>(&mut self, buf: &mut [u8; BUF_LEN_U8]) {
let mut machine = M::new::<V>(self.get_naked());
self.chacha::<false, XOR>(&mut machine, buf);
self.increment();
}
#[inline]
fn chacha<const INCREMENT: bool, const XOR: bool>(
&mut self,
machine: &mut M,
buf: &mut [u8; BUF_LEN_U8],
) {
let mut cur = machine.clone();
for _ in 0..R::COUNT {
cur.double_round();
}
let result = cur + machine.clone();
if XOR {
result.xor_result(buf);
} else {
result.fetch_result(buf);
}
if INCREMENT {
machine.increment::<V>();
self.increment();
}
}
#[inline]
fn increment(&mut self) {
unsafe {
match V::VAR {
Variants::Djb => {
self.row_d.u64x2[0] = self.row_d.u64x2[0].wrapping_add(DEPTH as u64);
}
Variants::Ietf => {
self.row_d.u32x4[0] = self.row_d.u32x4[0].wrapping_add(DEPTH as u32);
}
}
}
}
#[inline]
fn get_naked(&self) -> &ChaChaNaked {
const {
assert!(align_of::<Self>() == align_of::<ChaChaNaked>());
assert!(size_of::<Self>() == size_of::<ChaChaNaked>());
}
unsafe { transmute(self) }
}
}