use alloc::vec::Vec;
use digest::{
block_api::{Block, BlockSizeUser},
typenum::Unsigned,
Digest, FixedOutputReset, Output, Reset,
};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
use crate::DuplexSpongeInterface;
#[derive(Clone)]
pub struct Hash<D: Digest + Clone + Reset + BlockSizeUser> {
hasher: D,
cv: Output<D>,
mode: Mode,
leftovers: Vec<u8>,
}
impl<D: BlockSizeUser + Digest + Clone + FixedOutputReset> DuplexSpongeInterface for Hash<D> {
type U = u8;
fn absorb(&mut self, input: &[u8]) -> &mut Self {
self.squeeze_end();
if self.mode == Mode::Start {
self.mode = Mode::Absorb;
Digest::update(&mut self.hasher, Self::mask_absorb());
Digest::update(&mut self.hasher, &self.cv);
}
Digest::update(&mut self.hasher, input);
self
}
fn ratchet(&mut self) -> &mut Self {
self.squeeze_end();
self.cv = <D as Digest>::digest(self.hasher.finalize_reset());
#[cfg(feature = "zeroize")]
self.leftovers.zeroize();
self.leftovers.clear();
self.mode = Mode::Start;
self
}
fn squeeze(&mut self, output: &mut [u8]) -> &mut Self {
if self.mode == Mode::Start {
self.mode = Mode::Squeeze(0);
Digest::update(&mut self.hasher, Self::mask_squeeze());
Digest::update(&mut self.hasher, &self.cv);
self.squeeze(output)
} else if self.mode == Mode::Absorb {
self.ratchet();
self.squeeze(output)
} else if output.is_empty() {
self
} else if !self.leftovers.is_empty() {
let len = usize::min(output.len(), self.leftovers.len());
output[..len].copy_from_slice(&self.leftovers[..len]);
self.leftovers.drain(..len);
self.squeeze(&mut output[len..])
} else if let Mode::Squeeze(i) = self.mode {
let mut output_hasher_prefix = self.hasher.clone();
Digest::update(&mut output_hasher_prefix, i.to_be_bytes());
let digest = output_hasher_prefix.finalize();
let chunk_len = usize::min(output.len(), Self::DIGEST_SIZE);
output[..chunk_len].copy_from_slice(&digest[..chunk_len]);
self.leftovers.extend_from_slice(&digest[chunk_len..]);
self.mode = Mode::Squeeze(i + 1);
self.squeeze(&mut output[chunk_len..])
} else {
unreachable!()
}
}
}
#[derive(Clone, PartialEq, Eq)]
enum Mode {
Start,
Absorb,
Squeeze(usize),
}
impl<D: BlockSizeUser + Digest + Clone + Reset> Hash<D> {
const BLOCK_SIZE: usize = D::BlockSize::USIZE;
const DIGEST_SIZE: usize = D::OutputSize::USIZE;
fn pad_block(start: &[u8], end: &[u8]) -> Block<D> {
debug_assert!(start.len() + end.len() < Self::BLOCK_SIZE);
let mut mask = Block::<D>::default();
mask[..start.len()].copy_from_slice(start);
mask[Self::BLOCK_SIZE - end.len()..].copy_from_slice(end);
mask
}
fn mask_absorb() -> Block<D> {
Self::pad_block(&[], &[0x00])
}
fn mask_squeeze() -> Block<D> {
Self::pad_block(&[], &[0x01])
}
fn mask_squeeze_end() -> Block<D> {
Self::pad_block(&[], &[0x02])
}
fn squeeze_end(&mut self) {
if let Mode::Squeeze(count) = self.mode {
Digest::reset(&mut self.hasher);
let byte_count = count * Self::DIGEST_SIZE - self.leftovers.len();
let mut squeeze_hasher = D::new();
Digest::update(&mut squeeze_hasher, Self::mask_squeeze_end());
Digest::update(&mut squeeze_hasher, &self.cv);
Digest::update(&mut squeeze_hasher, byte_count.to_be_bytes());
self.cv = Digest::finalize(squeeze_hasher);
self.mode = Mode::Start;
self.leftovers.clear();
}
}
}
#[cfg(feature = "zeroize")]
impl<D: Clone + Digest + Reset + BlockSizeUser> Zeroize for Hash<D> {
fn zeroize(&mut self) {
self.cv.zeroize();
Digest::reset(&mut self.hasher);
}
}
#[cfg(feature = "zeroize")]
impl<D: Clone + Digest + Reset + BlockSizeUser> Drop for Hash<D> {
fn drop(&mut self) {
self.zeroize();
}
}
impl<D: BlockSizeUser + Digest + Clone + FixedOutputReset> Default for Hash<D> {
fn default() -> Self {
Self {
hasher: D::new(),
cv: Output::<D>::default(),
mode: Mode::Start,
leftovers: Vec::new(),
}
}
}
#[cfg(all(test, feature = "sha2"))]
#[test]
fn test_shosha() {
let expected = b"\xEB\xE4\xEF\x29\xE1\x8A\xA5\x41\x37\xED\xD8\x9C\x23\xF8\
\xBF\xEA\xC2\x73\x1C\x9F\x67\x5D\xA2\x0E\x7C\x67\xD5\xAD\
\x68\xD7\xEE\x2D\x40\xA4\x52\x32\xB5\x99\x55\x2D\x46\xB5\
\x20\x08\x2F\xB2\x70\x59\x71\xF0\x7B\x31\x58\xB0\x72\xB6\
\x3A\xB0\x93\x4A\x05\xE6\xAF\x64";
let mut sho = Hash::<sha2::Sha256>::default();
let mut got = [0u8; 64];
sho.absorb(b"asd");
sho.ratchet();
sho.absorb(b"asd");
sho.absorb(b"asd");
sho.squeeze(&mut got[..32]);
sho.squeeze(&mut got[32..]);
assert_eq!(&got, expected);
let expected = b"\xEB\xE4\xEF\x29\xE1\x8A\xA5\x41\x37\xED\xD8\x9C\x23\xF8\
\xBF\xEA\xC2\x73\x1C\x9F\x67\x5D\xA2\x0E\x7C\x67\xD5\xAD\
\x68\xD7\xEE\x2D\x40\xA4\x52\x32\xB5\x99\x55\x2D\x46\xB5\
\x20\x08\x2F\xB2\x70\x59\x71\xF0\x7B\x31\x58\xB0\x72\xB6\
\x3A\xB0\x93\x4A\x05\xE6\xAF\x64\x48";
let mut sho = Hash::<sha2::Sha256>::default();
let mut got = [0u8; 65];
sho.absorb(b"asd");
sho.ratchet();
sho.absorb(b"asdasd");
sho.squeeze(&mut got);
assert_eq!(&got, expected);
let expected = b"\x0D\xDE\xEA\x97\x3F\x32\x10\xF7\x72\x5A\x3C\xDB\x24\x73\
\xF8\x73\xAE\xAB\x8F\xEB\x32\xB8\x0D\xEE\x67\xF0\xCD\xE7\
\x95\x4E\x92\x9A\x4E\x78\x7A\xEF\xEE\x6D\xBE\x91\xD3\xFF\
\xF1\x62\x1A\xAB\x8D\x0D\x29\x19\x4F\x8A\xF9\x86\xD6\xF3\
\x57\xAD\xD0\x15\x0D\xF7\xD9";
let mut sho = Hash::<sha2::Sha256>::default();
let mut got = [0u8; 150];
sho.absorb(b"");
sho.ratchet();
sho.absorb(b"abc");
sho.ratchet();
sho.absorb(&[0u8; 63]);
sho.ratchet();
sho.absorb(&[0u8; 64]);
sho.ratchet();
sho.absorb(&[0u8; 65]);
sho.ratchet();
sho.absorb(&[0u8; 127]);
sho.ratchet();
sho.absorb(&[0u8; 128]);
sho.ratchet();
sho.absorb(&[0u8; 129]);
sho.ratchet();
sho.squeeze(&mut got[..63]);
sho.squeeze_end();
sho.squeeze(&mut got[..64]);
sho.squeeze_end();
sho.squeeze(&mut got[..65]);
sho.squeeze_end();
sho.squeeze(&mut got[..127]);
sho.squeeze_end();
sho.squeeze(&mut got[..128]);
sho.squeeze_end();
sho.squeeze(&mut got[..129]);
assert_eq!(got[0], 0xd0);
sho.absorb(b"def");
sho.ratchet();
sho.squeeze(&mut got[..63]);
assert_eq!(&got[..63], expected);
}