#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use blake3::Hasher as Blake3;
use sha2::{Digest, Sha256};
use crate::params::*;
use crate::primitives::{aes_compress, blake3_compress, sha256_compress};
const ADDRESS_MASK: usize = BLOCKS_PER_SCRATCHPAD - 1;
const GOLDEN_RATIO: u64 = 0x9E3779B97F4A7C15;
#[inline(always)]
fn initial_primitive_index(nonce: u64, chain: usize) -> usize {
((nonce.wrapping_add(chain as u64)) % 3) as usize
}
pub struct UniversalHash {
scratchpads: Vec<Vec<u8>>,
chain_states: [[u8; 32]; CHAINS],
effective_nonce: u64,
}
impl UniversalHash {
pub fn new() -> Self {
Self {
scratchpads: vec![vec![0u8; SCRATCHPAD_SIZE]; CHAINS],
chain_states: [[0u8; 32]; CHAINS],
effective_nonce: 0,
}
}
pub fn hash(&mut self, input: &[u8]) -> [u8; 32] {
self.effective_nonce = extract_nonce(input);
self.init_scratchpads(input);
self.execute_rounds();
self.finalize()
}
fn init_scratchpads(&mut self, input: &[u8]) {
let nonce = self.effective_nonce;
let header_len = input.len().saturating_sub(8);
for (chain, state) in self.chain_states.iter_mut().enumerate() {
let offset = (chain as u64).wrapping_mul(GOLDEN_RATIO);
let modified_nonce = nonce ^ offset;
let mut hasher = Blake3::new();
hasher.update(&input[..header_len]);
hasher.update(&modified_nonce.to_le_bytes());
let hash = hasher.finalize();
let hash_bytes = hash.as_bytes();
state.copy_from_slice(hash_bytes);
let mut seed_array = [0u8; 32];
seed_array.copy_from_slice(hash_bytes);
fill_scratchpad_aes(&mut self.scratchpads[chain], &seed_array);
}
}
fn execute_rounds(&mut self) {
let nonce = self.effective_nonce;
for chain in 0..CHAINS {
let initial_primitive = initial_primitive_index(nonce, chain);
for round in 0..ROUNDS {
round_step_spec_compliant(
&mut self.scratchpads[chain],
&mut self.chain_states[chain],
initial_primitive,
round,
);
}
}
}
fn finalize(&self) -> [u8; 32] {
let mut combined = [0u8; 32];
for state in &self.chain_states {
for i in 0..32 {
combined[i] ^= state[i];
}
}
let sha_hash = Sha256::digest(combined);
let mut hasher = Blake3::new();
hasher.update(&sha_hash);
hasher.finalize().into()
}
}
#[inline(always)]
fn extract_nonce(input: &[u8]) -> u64 {
if input.len() >= 8 {
let nonce_bytes: [u8; 8] = input[input.len() - 8..].try_into().unwrap();
u64::from_le_bytes(nonce_bytes)
} else {
let hash = blake3::hash(input);
let bytes: [u8; 8] = hash.as_bytes()[..8].try_into().unwrap();
u64::from_le_bytes(bytes)
}
}
#[inline(always)]
fn fill_scratchpad_aes(scratchpad: &mut [u8], seed: &[u8; 32]) {
use crate::primitives::{aes128_key_expand, aes_expand_block};
let raw_key: [u8; 16] = seed[0..16].try_into().unwrap();
let round_keys = aes128_key_expand(&raw_key);
let mut state: [u8; 16] = seed[16..32].try_into().unwrap();
for i in 0..BLOCKS_PER_SCRATCHPAD {
let offset = i * BLOCK_SIZE;
state = aes_expand_block(&state, &round_keys);
scratchpad[offset..offset + 16].copy_from_slice(&state);
state = aes_expand_block(&state, &round_keys);
scratchpad[offset + 16..offset + 32].copy_from_slice(&state);
state = aes_expand_block(&state, &round_keys);
scratchpad[offset + 32..offset + 48].copy_from_slice(&state);
state = aes_expand_block(&state, &round_keys);
scratchpad[offset + 48..offset + 64].copy_from_slice(&state);
}
}
#[inline(always)]
fn round_step_spec_compliant(
scratchpad: &mut [u8],
state: &mut [u8; 32],
initial_primitive: usize,
round: usize,
) {
let addr = compute_address(state, round);
let block: [u8; BLOCK_SIZE] =
unsafe { core::ptr::read(scratchpad.as_ptr().add(addr) as *const [u8; BLOCK_SIZE]) };
let primitive = (initial_primitive + round + 1) % 3;
let new_state = match primitive {
0 => aes_compress(state, &block),
1 => sha256_compress(state, &block),
_ => blake3_compress(state, &block),
};
unsafe {
core::ptr::copy_nonoverlapping(new_state.as_ptr(), scratchpad.as_mut_ptr().add(addr), 32);
}
*state = new_state;
}
#[inline(always)]
fn compute_address(state: &[u8; 32], round: usize) -> usize {
const MIXING_CONSTANT: u64 = 0x517cc1b727220a95;
let state_lo = unsafe { core::ptr::read_unaligned(state.as_ptr() as *const u64) };
let state_hi = unsafe { core::ptr::read_unaligned(state.as_ptr().add(8) as *const u64) };
let round_u64 = round as u64;
let mixed =
state_lo ^ state_hi ^ round_u64.rotate_left(13) ^ round_u64.wrapping_mul(MIXING_CONSTANT);
((mixed as usize) & ADDRESS_MASK) * BLOCK_SIZE
}
impl Default for UniversalHash {
fn default() -> Self {
Self::new()
}
}
pub fn hash(input: &[u8]) -> [u8; 32] {
let mut hasher = UniversalHash::new();
hasher.hash(input)
}