use core::{mem::MaybeUninit, ptr};
use sha3_asm::{Buffer, SHA3_absorb, SHA3_squeeze};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
const MAX_BUFSZ: usize = (1600 / 8) - 32;
#[derive(Clone)]
#[allow(non_snake_case)]
pub(crate) struct Sha3State<const BITS: usize, const PAD: u8> {
A: Buffer,
bufsz: usize,
buf: [MaybeUninit<u8>; MAX_BUFSZ],
}
impl<const BITS: usize, const PAD: u8> Default for Sha3State<BITS, PAD> {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "zeroize")]
impl<const BITS: usize, const PAD: u8> Drop for Sha3State<BITS, PAD> {
fn drop(&mut self) {
self.A.zeroize();
self.buf.zeroize();
}
}
#[cfg(feature = "zeroize")]
impl<const BITS: usize, const PAD: u8> ZeroizeOnDrop for Sha3State<BITS, PAD> {}
impl<const BITS: usize, const PAD: u8> Sha3State<BITS, PAD> {
const OUT_SIZE: usize = BITS / 8;
const BLOCK_SIZE: usize = (1600 - BITS * 2) / 8;
#[inline(always)]
pub(crate) fn new() -> Self {
Self { A: [0; 25], bufsz: 0, buf: unsafe { MaybeUninit::uninit().assume_init() } }
}
#[inline(always)]
pub(crate) fn reset(&mut self) {
self.A = [0; 25];
self.bufsz = 0;
}
#[inline]
pub(crate) unsafe fn update(&mut self, mut inp: *const u8, mut len: usize) {
let bsz: usize = Self::BLOCK_SIZE;
if len == 0 {
return;
}
let num = self.bufsz;
let mut rem;
if num != 0 {
rem = bsz - num;
if len < rem {
memcpy(self.buf().add(num), inp, len);
self.bufsz += len;
return;
}
memcpy(self.buf().add(num), inp, rem);
inp = inp.add(rem);
len -= rem;
SHA3_absorb(&mut self.A, self.buf(), bsz, bsz);
self.bufsz = 0;
}
rem = if len >= bsz { SHA3_absorb(&mut self.A, inp, len, bsz) } else { len };
if rem > 0 {
memcpy(self.buf(), inp.add(len).sub(rem), rem);
self.bufsz = rem;
}
}
#[inline]
pub(crate) unsafe fn finalize(&mut self, out: *mut u8) {
let bsz: usize = Self::BLOCK_SIZE;
let num = self.bufsz;
memset(self.buf().add(num), 0, bsz - num);
*self.buf().add(num) = PAD;
*self.buf().add(bsz - 1) |= 0x80;
SHA3_absorb(&mut self.A, self.buf(), bsz, bsz);
SHA3_squeeze(&mut self.A, out, Self::OUT_SIZE, bsz);
}
#[inline(always)]
fn buf(&mut self) -> *mut u8 {
self.buf.as_mut_ptr().cast()
}
}
#[inline(always)]
unsafe fn memcpy(dst: *mut u8, src: *const u8, count: usize) {
ptr::copy_nonoverlapping(src, dst, count);
}
#[inline(always)]
unsafe fn memset(dst: *mut u8, val: u8, count: usize) {
ptr::write_bytes(dst, val, count);
}