#![allow(unsafe_op_in_unsafe_fn, reason = "needs triage")]
#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
#![allow(clippy::cast_possible_wrap, reason = "needs triage")]
#![allow(clippy::cast_sign_loss, reason = "needs triage")]
#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
use crate::{Rounds, Variant};
#[cfg(feature = "rng")]
use crate::ChaChaCore;
#[cfg(feature = "cipher")]
use crate::{STATE_WORDS, chacha::Block};
#[cfg(feature = "cipher")]
use cipher::{
BlockSizeUser, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure,
consts::{U4, U64},
};
use core::marker::PhantomData;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
const PAR_BLOCKS: usize = 4;
#[inline]
#[target_feature(enable = "sse2")]
#[cfg(feature = "cipher")]
pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
where
R: Rounds,
F: StreamCipherClosure<BlockSize = U64>,
V: Variant,
{
let state_ptr = state.as_ptr().cast::<__m128i>();
let mut backend = Backend::<R, V> {
v: [
_mm_loadu_si128(state_ptr.add(0)),
_mm_loadu_si128(state_ptr.add(1)),
_mm_loadu_si128(state_ptr.add(2)),
_mm_loadu_si128(state_ptr.add(3)),
],
_pd: PhantomData,
};
f.call(&mut backend);
state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32;
if size_of::<V::Counter>() == 8 {
state[13] = _mm_extract_epi32(backend.v[3], 1) as u32;
}
}
struct Backend<R: Rounds, V: Variant> {
v: [__m128i; 4],
_pd: PhantomData<(R, V)>,
}
#[cfg(feature = "cipher")]
impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
type BlockSize = U64;
}
#[cfg(feature = "cipher")]
impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
type ParBlocksSize = U4;
}
#[cfg(feature = "cipher")]
impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
#[inline(always)]
fn gen_ks_block(&mut self, block: &mut Block) {
unsafe {
let res = rounds::<R, V>(&self.v);
self.v[3] = match size_of::<V::Counter>() {
4 => _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1)),
8 => _mm_add_epi64(self.v[3], _mm_set_epi64x(0, 1)),
_ => unreachable!(),
};
let block_ptr = block.as_mut_ptr().cast::<__m128i>();
for i in 0..4 {
_mm_storeu_si128(block_ptr.add(i), res[0][i]);
}
}
}
#[inline(always)]
fn gen_par_ks_blocks(&mut self, blocks: &mut cipher::ParBlocks<Self>) {
unsafe {
let res = rounds::<R, V>(&self.v);
self.v[3] = match size_of::<V::Counter>() {
4 => _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32)),
8 => _mm_add_epi64(self.v[3], _mm_set_epi64x(0, PAR_BLOCKS as i64)),
_ => unreachable!(),
};
let blocks_ptr = blocks.as_mut_ptr().cast::<__m128i>();
for block in 0..PAR_BLOCKS {
for i in 0..4 {
_mm_storeu_si128(blocks_ptr.add(i + block * PAR_BLOCKS), res[block][i]);
}
}
}
}
}
#[inline]
#[target_feature(enable = "sse2")]
#[cfg(feature = "rng")]
pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64])
where
R: Rounds,
V: Variant,
{
let state_ptr = core.state.as_ptr().cast::<__m128i>();
let mut backend = Backend::<R, V> {
v: [
_mm_loadu_si128(state_ptr.add(0)),
_mm_loadu_si128(state_ptr.add(1)),
_mm_loadu_si128(state_ptr.add(2)),
_mm_loadu_si128(state_ptr.add(3)),
],
_pd: PhantomData,
};
backend.gen_ks_blocks(buffer);
core.state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32;
core.state[13] = _mm_extract_epi32(backend.v[3], 1) as u32;
}
#[cfg(feature = "rng")]
impl<R: Rounds, V: Variant> Backend<R, V> {
#[inline(always)]
fn gen_ks_blocks(&mut self, block: &mut [u32; 64]) {
const _: () = assert!(4 * PAR_BLOCKS * size_of::<__m128i>() == size_of::<[u32; 64]>());
unsafe {
let res = rounds::<R, V>(&self.v);
self.v[3] = _mm_add_epi64(self.v[3], _mm_set_epi64x(0, PAR_BLOCKS as i64));
let blocks_ptr = block.as_mut_ptr().cast::<__m128i>();
for block in 0..PAR_BLOCKS {
for i in 0..4 {
_mm_storeu_si128(blocks_ptr.add(i + block * PAR_BLOCKS), res[block][i]);
}
}
}
}
}
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn rounds<R: Rounds, V: Variant>(v: &[__m128i; 4]) -> [[__m128i; 4]; PAR_BLOCKS] {
let mut res = [*v; 4];
for block in 1..PAR_BLOCKS {
res[block][3] = match size_of::<V::Counter>() {
4 => _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32)),
8 => _mm_add_epi64(res[block][3], _mm_set_epi64x(0, block as i64)),
_ => unreachable!(),
}
}
for _ in 0..R::COUNT {
double_quarter_round(&mut res);
}
for block in 0..PAR_BLOCKS {
for i in 0..3 {
res[block][i] = _mm_add_epi32(res[block][i], v[i]);
}
let ctr = match size_of::<V::Counter>() {
4 => _mm_add_epi32(v[3], _mm_set_epi32(0, 0, 0, block as i32)),
8 => _mm_add_epi64(v[3], _mm_set_epi64x(0, block as i64)),
_ => unreachable!(),
};
res[block][3] = _mm_add_epi32(res[block][3], ctr);
}
res
}
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn double_quarter_round(v: &mut [[__m128i; 4]; PAR_BLOCKS]) {
add_xor_rot(v);
rows_to_cols(v);
add_xor_rot(v);
cols_to_rows(v);
}
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn rows_to_cols(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
for [a, _, c, d] in blocks.iter_mut() {
*c = _mm_shuffle_epi32(*c, 0b_00_11_10_01); *d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); *a = _mm_shuffle_epi32(*a, 0b_10_01_00_11); }
}
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn cols_to_rows(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
for [a, _, c, d] in blocks.iter_mut() {
*c = _mm_shuffle_epi32(*c, 0b_10_01_00_11); *d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); *a = _mm_shuffle_epi32(*a, 0b_00_11_10_01); }
}
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn add_xor_rot(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
for [a, b, c, d] in blocks.iter_mut() {
*a = _mm_add_epi32(*a, *b);
*d = _mm_xor_si128(*d, *a);
*d = _mm_xor_si128(_mm_slli_epi32(*d, 16), _mm_srli_epi32(*d, 16));
*c = _mm_add_epi32(*c, *d);
*b = _mm_xor_si128(*b, *c);
*b = _mm_xor_si128(_mm_slli_epi32(*b, 12), _mm_srli_epi32(*b, 20));
*a = _mm_add_epi32(*a, *b);
*d = _mm_xor_si128(*d, *a);
*d = _mm_xor_si128(_mm_slli_epi32(*d, 8), _mm_srli_epi32(*d, 24));
*c = _mm_add_epi32(*c, *d);
*b = _mm_xor_si128(*b, *c);
*b = _mm_xor_si128(_mm_slli_epi32(*b, 7), _mm_srli_epi32(*b, 25));
}
}