use crate::{
salsa20,
util::{slice_as_chunks_mut, xor},
};
use core::mem;
const PWXSIMPLE: usize = 2;
const PWXGATHER: usize = 4;
const PWXROUNDS: usize = 6;
const SWIDTH: usize = 8;
const PWXBYTES: usize = PWXGATHER * PWXSIMPLE * 8;
const PWXWORDS: usize = PWXBYTES / size_of::<u32>();
const SMASK: usize = ((1 << SWIDTH) - 1) * PWXSIMPLE * 8;
pub(crate) const SBYTES: usize = 3 * (1 << SWIDTH) * PWXSIMPLE * 8;
pub(crate) const SWORDS: usize = SBYTES / size_of::<u32>();
#[allow(clippy::cast_possible_truncation)]
pub(crate) const RMIN: u32 = PWXBYTES.div_ceil(128) as u32;
pub(crate) struct PwxformCtx<'a> {
pub(crate) s0: &'a mut [[u32; 2]],
pub(crate) s1: &'a mut [[u32; 2]],
pub(crate) s2: &'a mut [[u32; 2]],
pub(crate) w: usize,
}
impl PwxformCtx<'_> {
pub(crate) fn blockmix_pwxform(&mut self, b: &mut [u32], r: usize) {
let (b, _b) = slice_as_chunks_mut::<_, PWXWORDS>(b);
assert_eq!(b.len(), 2 * r);
assert!(_b.is_empty());
let r1 = (128 * r) / PWXBYTES;
let mut x = b[r1 - 1];
#[allow(clippy::needless_range_loop)]
for i in 0..r1 {
if r1 > 1 {
xor(&mut x, &b[i]);
}
self.pwxform(&mut x);
b[i] = x;
}
let i = (r1 - 1) * PWXBYTES / 64;
salsa20::salsa20_2(&mut b[i]);
for i in (i + 1)..(2 * r) {
let (bim1, bi) = b[(i - 1)..i].split_at_mut(1);
let (bim1, bi) = (&bim1[0], &mut bi[0]);
xor(bi, bim1);
salsa20::salsa20_2(bi);
}
}
fn pwxform(&mut self, b: &mut [u32; 16]) {
let b = slice_as_chunks_mut::<_, 2>(slice_as_chunks_mut::<_, PWXSIMPLE>(b).0).0;
let mut w = self.w;
for i in 0..PWXROUNDS {
#[allow(clippy::needless_range_loop)]
for j in 0..PWXGATHER {
let mut xl: u32 = b[j][0][0];
let mut xh: u32 = b[j][0][1];
let p0 = &self.s0[(xl as usize & SMASK) / 8..];
let p1 = &self.s1[(xh as usize & SMASK) / 8..];
for k in 0..PWXSIMPLE {
let s0 = (u64::from(p0[k][1]) << 32).wrapping_add(u64::from(p0[k][0]));
let s1 = (u64::from(p1[k][1]) << 32).wrapping_add(u64::from(p1[k][0]));
xl = b[j][k][0];
xh = b[j][k][1];
let mut x = u64::from(xh).wrapping_mul(u64::from(xl));
x = x.wrapping_add(s0);
x ^= s1;
let x_lo = (x & 0xFFFF_FFFF) as u32;
let x_hi = ((x >> 32) & 0xFFFF_FFFF) as u32;
b[j][k][0] = x_lo;
b[j][k][1] = x_hi;
if i != 0 && i != (PWXROUNDS - 1) {
self.s2[w][0] = x_lo;
self.s2[w][1] = x_hi;
w += 1;
}
}
}
}
mem::swap(&mut self.s0, &mut self.s2);
mem::swap(&mut self.s1, &mut self.s2);
self.w = w & ((1 << SWIDTH) * PWXSIMPLE - 1);
}
}