use alloc::vec;
use alloc::vec::Vec;
use core::fmt;
use digest::block_api::{
AlgorithmName,
Block,
BlockSizeUser,
BufferKindUser,
Eager,
UpdateCore,
};
use digest::common::hazmat::{
DeserializeStateError,
SerializableState,
SerializedState,
};
use digest::consts::{
U16,
U32,
U136,
U168,
U400,
};
use digest::{
CollisionResistance,
ExtendableOutput,
HashMarker,
Reset,
Update,
XofReader,
};
#[cfg(feature = "parallelhash")]
use rayon::prelude::*;
use crate::cshake::{
CShake128,
CShake128Reader,
CShake256,
CShake256Reader,
};
use crate::shake::{
Shake128,
Shake256,
};
use crate::utils::{
MAX_SP800185_FIXED_OUTPUT_BYTES,
left_encode,
right_encode,
};
#[derive(Clone)]
pub struct ParallelHash128 {
inner: CShake128,
buf: Vec<u8>,
n: u64,
rate: usize,
blocksize: usize,
}
#[derive(Clone)]
pub struct ParallelHash256 {
inner: CShake256,
buf: Vec<u8>,
n: u64,
rate: usize,
blocksize: usize,
}
#[derive(Clone)]
pub struct ParallelHash128Reader {
inner: CShake128Reader,
}
#[derive(Clone)]
pub struct ParallelHash256Reader {
inner: CShake256Reader,
}
macro_rules! impl_parallelhash {
(
$name:ident, $inner_type:ident, $reader_name:ident, $inner_reader_type:ident, $shake_type:ident, $rate:ident, $rate_expr:expr, $alg_name:expr
) => {
impl $name {
pub fn new(custom: &[u8], blocksize: usize) -> Self {
let mut hasher = Self {
inner: $inner_type::new_with_function_name(b"ParallelHash", custom),
buf: Vec::new(),
n: 0,
rate: $rate_expr,
blocksize,
};
hasher.init();
hasher
}
fn init(&mut self) {
let mut enc_buf = [0u8; 9];
let encoded = left_encode(self.blocksize as u64, &mut enc_buf);
Update::update(&mut self.inner, encoded);
}
fn hash_block(block: &[u8], rate: usize) -> Vec<u8> {
let mut shake = $shake_type::default();
Update::update(&mut shake, block);
let mut output = vec![0u8; rate / 8];
ExtendableOutput::finalize_xof_into(shake, &mut output);
output
}
pub fn update(&mut self, data: &[u8]) {
let mut pos = 0;
if !self.buf.is_empty() {
let len = self.blocksize - self.buf.len();
if data.len() < len {
self.buf.extend_from_slice(data);
return;
} else {
self.buf.extend_from_slice(&data[..len]);
let block_hash = Self::hash_block(&self.buf, self.rate);
Update::update(&mut self.inner, &block_hash);
self.buf.clear();
self.n += 1;
pos = len;
}
}
#[cfg(feature = "parallelhash")]
{
let rate = self.rate;
let blocksize = self.blocksize;
let complete_blocks = (data.len() - pos) / blocksize;
if complete_blocks > 0 {
let block_data = &data[pos..pos + complete_blocks * blocksize];
let hashes: Vec<Vec<u8>> = block_data
.par_chunks(blocksize)
.map(|chunk| Self::hash_block(chunk, rate))
.collect();
for hash in hashes {
Update::update(&mut self.inner, &hash);
self.n += 1;
}
pos += complete_blocks * blocksize;
}
if pos < data.len() {
self.buf.extend_from_slice(&data[pos..]);
}
}
#[cfg(not(feature = "parallelhash"))]
{
while pos + self.blocksize <= data.len() {
let block_hash =
Self::hash_block(&data[pos..pos + self.blocksize], self.rate);
Update::update(&mut self.inner, &block_hash);
self.n += 1;
pos += self.blocksize;
}
if pos < data.len() {
self.buf.extend_from_slice(&data[pos..]);
}
}
}
pub fn finalize(mut self, output: &mut [u8]) -> Option<()> {
if output.len() > MAX_SP800185_FIXED_OUTPUT_BYTES {
return None;
}
self.with_bitlength((output.len() * 8) as u64);
ExtendableOutput::finalize_xof_into(self.inner, output);
Some(())
}
pub fn finalize_with_length(mut self, output_len: usize) -> Option<Vec<u8>> {
if output_len > MAX_SP800185_FIXED_OUTPUT_BYTES {
return None;
}
let mut output = vec![0u8; output_len];
self.with_bitlength((output_len * 8) as u64);
ExtendableOutput::finalize_xof_into(self.inner, &mut output);
Some(output)
}
pub fn xof(mut self) -> $reader_name {
self.with_bitlength(0);
$reader_name {
inner: ExtendableOutput::finalize_xof(self.inner),
}
}
fn with_bitlength(&mut self, bitlength: u64) {
if !self.buf.is_empty() {
let block_hash = Self::hash_block(&self.buf, self.rate);
Update::update(&mut self.inner, &block_hash);
self.buf.clear();
self.n += 1;
}
let mut enc_buf = [0u8; 9];
let encoded = right_encode(self.n, &mut enc_buf);
Update::update(&mut self.inner, encoded);
let length_encoded = right_encode(bitlength, &mut enc_buf);
Update::update(&mut self.inner, length_encoded);
}
}
impl BlockSizeUser for $name {
type BlockSize = $rate;
}
impl BufferKindUser for $name {
type BufferKind = Eager;
}
impl HashMarker for $name {}
impl Update for $name {
#[inline]
fn update(&mut self, data: &[u8]) {
let mut pos = 0;
if !self.buf.is_empty() {
let len = self.blocksize - self.buf.len();
if data.len() < len {
self.buf.extend_from_slice(data);
return;
} else {
self.buf.extend_from_slice(&data[..len]);
let block_hash = Self::hash_block(&self.buf, self.rate);
Update::update(&mut self.inner, &block_hash);
self.buf.clear();
self.n += 1;
pos = len;
}
}
#[cfg(feature = "parallelhash")]
{
let rate = self.rate;
let blocksize = self.blocksize;
let complete_blocks = (data.len() - pos) / blocksize;
if complete_blocks > 0 {
let block_data = &data[pos..pos + complete_blocks * blocksize];
let hashes: Vec<Vec<u8>> = block_data
.par_chunks(blocksize)
.map(|chunk| Self::hash_block(chunk, rate))
.collect();
for hash in hashes {
Update::update(&mut self.inner, &hash);
self.n += 1;
}
pos += complete_blocks * blocksize;
}
if pos < data.len() {
self.buf.extend_from_slice(&data[pos..]);
}
}
#[cfg(not(feature = "parallelhash"))]
{
while pos + self.blocksize <= data.len() {
let block_hash =
Self::hash_block(&data[pos..pos + self.blocksize], self.rate);
Update::update(&mut self.inner, &block_hash);
self.n += 1;
pos += self.blocksize;
}
if pos < data.len() {
self.buf.extend_from_slice(&data[pos..]);
}
}
}
}
impl UpdateCore for $name {
#[inline]
fn update_blocks(&mut self, blocks: &[Block<Self>]) {
for block in blocks {
self.update(block);
}
}
}
impl Reset for $name {
#[inline]
fn reset(&mut self) {
self.inner.reset();
self.buf.clear();
self.n = 0;
self.init();
}
}
impl AlgorithmName for $name {
fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str($alg_name)
}
}
impl fmt::Debug for $name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(concat!(stringify!($name), " { ... }"))
}
}
#[cfg(feature = "zeroize")]
impl digest::zeroize::ZeroizeOnDrop for $name {}
impl Default for $name {
fn default() -> Self {
Self::new(b"", 8192)
}
}
impl XofReader for $reader_name {
fn read(&mut self, buf: &mut [u8]) {
self.inner.read(buf);
}
}
impl fmt::Debug for $reader_name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(concat!(stringify!($reader_name), " { ... }"))
}
}
};
}
impl_parallelhash!(
ParallelHash128,
CShake128,
ParallelHash128Reader,
CShake128Reader,
Shake128,
U168,
168,
"ParallelHash128"
);
impl_parallelhash!(
ParallelHash256,
CShake256,
ParallelHash256Reader,
CShake256Reader,
Shake256,
U136,
136,
"ParallelHash256"
);
impl CollisionResistance for ParallelHash128 {
type CollisionResistance = U16;
}
impl CollisionResistance for ParallelHash256 {
type CollisionResistance = U32;
}
impl SerializableState for ParallelHash128 {
type SerializedStateSize = U400;
fn serialize(&self) -> SerializedState<Self> {
self.inner.serialize()
}
fn deserialize(
serialized_state: &SerializedState<Self>,
) -> Result<Self, DeserializeStateError> {
let inner = CShake128::deserialize(serialized_state)?;
Ok(Self {
inner,
buf: Vec::new(),
n: 0,
rate: 168,
blocksize: 8192,
})
}
}
impl SerializableState for ParallelHash256 {
type SerializedStateSize = U400;
fn serialize(&self) -> SerializedState<Self> {
self.inner.serialize()
}
fn deserialize(
serialized_state: &SerializedState<Self>,
) -> Result<Self, DeserializeStateError> {
let inner = CShake256::deserialize(serialized_state)?;
Ok(Self {
inner,
buf: Vec::new(),
n: 0,
rate: 136,
blocksize: 8192,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallelhash128_basic() {
let custom = b"custom";
let data = b"test_data";
let mut parallelhash = ParallelHash128::new(custom, 16);
parallelhash.update(data);
let mut output = [0u8; 32];
parallelhash.finalize(&mut output).unwrap();
assert_ne!(output, [0u8; 32]);
}
#[test]
fn test_parallelhash256_basic() {
let custom = b"custom";
let data = b"test_data";
let mut parallelhash = ParallelHash256::new(custom, 16);
parallelhash.update(data);
let mut output = [0u8; 64];
parallelhash.finalize(&mut output).unwrap();
assert_ne!(output, [0u8; 64]);
}
#[test]
fn test_parallelhash_xof() {
let custom = b"custom";
let data = b"test_data";
let mut parallelhash = ParallelHash128::new(custom, 16);
parallelhash.update(data);
let mut reader = parallelhash.xof();
let mut output = [0u8; 100];
reader.read(&mut output);
assert_ne!(output, [0u8; 100]);
}
#[test]
fn test_parallelhash_different_block_sizes() {
let custom = b"custom";
let data = b"test_data_that_is_long_enough";
let mut parallelhash1 = ParallelHash128::new(custom, 8);
parallelhash1.update(data);
let mut output1 = [0u8; 32];
parallelhash1.finalize(&mut output1).unwrap();
let mut parallelhash2 = ParallelHash128::new(custom, 16);
parallelhash2.update(data);
let mut output2 = [0u8; 32];
parallelhash2.finalize(&mut output2).unwrap();
assert_ne!(output1, output2);
}
#[test]
fn test_parallelhash_different_customs() {
let data = b"test_data";
let mut parallelhash1 = ParallelHash128::new(b"custom1", 16);
parallelhash1.update(data);
let mut output1 = [0u8; 32];
parallelhash1.finalize(&mut output1).unwrap();
let mut parallelhash2 = ParallelHash128::new(b"custom2", 16);
parallelhash2.update(data);
let mut output2 = [0u8; 32];
parallelhash2.finalize(&mut output2).unwrap();
assert_ne!(output1, output2);
}
#[test]
#[cfg(feature = "parallelhash")]
fn test_parallelhash_performance_comparison() {
let large_data: Vec<u8> = (0..1024 * 1024).map(|i| (i % 256) as u8).collect();
let custom = b"performance_test";
let block_size = 8192;
let mut parallelhash = ParallelHash128::new(custom, block_size);
parallelhash.update(&large_data);
let mut output = [0u8; 64];
parallelhash.finalize(&mut output).unwrap();
assert_ne!(output, [0u8; 64]);
}
#[test]
fn test_parallelhash_reset() {
let custom = b"custom";
let data = b"test_data";
let mut parallelhash = ParallelHash128::new(custom, 16);
parallelhash.update(data);
parallelhash.reset();
parallelhash.update(data);
let mut output = [0u8; 32];
parallelhash.finalize(&mut output).unwrap();
assert_ne!(output, [0u8; 32]);
}
#[test]
fn test_parallelhash_default() {
let parallelhash = ParallelHash128::default();
let data = b"test_data";
let mut hasher = parallelhash;
hasher.update(data);
let result = hasher.finalize_with_length(32).unwrap();
assert_eq!(result.len(), 32);
}
#[test]
fn test_parallelhash_serialization() {
let custom = b"custom";
let data = b"test_data";
let mut parallelhash = ParallelHash128::new(custom, 16);
parallelhash.update(data);
let serialized = parallelhash.serialize();
let mut parallelhash2 = ParallelHash128::deserialize(&serialized).unwrap();
parallelhash2.update(b"more_data");
let mut output = [0u8; 32];
parallelhash2.finalize(&mut output).unwrap();
assert_ne!(output, [0u8; 32]);
}
#[test]
fn test_parallelhash_finalize_with_length_rejects_over_cap() {
let mut h = ParallelHash128::new(b"", 16);
h.update(b"x");
assert!(
h.finalize_with_length(MAX_SP800185_FIXED_OUTPUT_BYTES + 1)
.is_none()
);
}
#[test]
fn test_parallelhash_finalize_rejects_over_cap_buffer() {
let mut h = ParallelHash128::new(b"", 16);
h.update(b"x");
let mut out = vec![0u8; MAX_SP800185_FIXED_OUTPUT_BYTES + 1];
assert!(h.finalize(&mut out).is_none());
}
}