use crate::keccak_scalar::keccak256 as keccak256_scalar;
const RATE: usize = 136;
const RC: [u64; 24] = [
0x0000000000000001,
0x0000000000008082,
0x800000000000808a,
0x8000000080008000,
0x000000000000808b,
0x0000000080000001,
0x8000000080008081,
0x8000000000008009,
0x000000000000008a,
0x0000000000000088,
0x0000000080008009,
0x000000008000000a,
0x000000008000808b,
0x800000000000008b,
0x8000000000008089,
0x8000000000008003,
0x8000000000008002,
0x8000000000000080,
0x000000000000800a,
0x800000008000000a,
0x8000000080008081,
0x8000000000008080,
0x0000000080000001,
0x8000000080008008,
];
#[cfg(target_arch = "x86_64")]
mod avx512 {
#![allow(unsafe_op_in_unsafe_fn)]
use super::{RATE, RC};
use std::arch::x86_64::*;
type State = [__m512i; 25];
macro_rules! rol {
($v:expr, 0) => {
$v
};
($v:expr, $n:literal) => {
_mm512_rol_epi64($v, $n)
};
}
#[inline(always)]
unsafe fn xor5(a: __m512i, b: __m512i, c: __m512i, d: __m512i, e: __m512i) -> __m512i {
_mm512_xor_si512(
_mm512_xor_si512(a, b),
_mm512_xor_si512(_mm512_xor_si512(c, d), e),
)
}
#[target_feature(enable = "avx512f,avx512bw")]
unsafe fn permute(s: &mut State) {
for ri in 0..24usize {
let c0 = xor5(s[0], s[5], s[10], s[15], s[20]);
let c1 = xor5(s[1], s[6], s[11], s[16], s[21]);
let c2 = xor5(s[2], s[7], s[12], s[17], s[22]);
let c3 = xor5(s[3], s[8], s[13], s[18], s[23]);
let c4 = xor5(s[4], s[9], s[14], s[19], s[24]);
let d0 = _mm512_xor_si512(c4, _mm512_rol_epi64(c1, 1));
let d1 = _mm512_xor_si512(c0, _mm512_rol_epi64(c2, 1));
let d2 = _mm512_xor_si512(c1, _mm512_rol_epi64(c3, 1));
let d3 = _mm512_xor_si512(c2, _mm512_rol_epi64(c4, 1));
let d4 = _mm512_xor_si512(c3, _mm512_rol_epi64(c0, 1));
s[0] = _mm512_xor_si512(s[0], d0);
s[5] = _mm512_xor_si512(s[5], d0);
s[10] = _mm512_xor_si512(s[10], d0);
s[15] = _mm512_xor_si512(s[15], d0);
s[20] = _mm512_xor_si512(s[20], d0);
s[1] = _mm512_xor_si512(s[1], d1);
s[6] = _mm512_xor_si512(s[6], d1);
s[11] = _mm512_xor_si512(s[11], d1);
s[16] = _mm512_xor_si512(s[16], d1);
s[21] = _mm512_xor_si512(s[21], d1);
s[2] = _mm512_xor_si512(s[2], d2);
s[7] = _mm512_xor_si512(s[7], d2);
s[12] = _mm512_xor_si512(s[12], d2);
s[17] = _mm512_xor_si512(s[17], d2);
s[22] = _mm512_xor_si512(s[22], d2);
s[3] = _mm512_xor_si512(s[3], d3);
s[8] = _mm512_xor_si512(s[8], d3);
s[13] = _mm512_xor_si512(s[13], d3);
s[18] = _mm512_xor_si512(s[18], d3);
s[23] = _mm512_xor_si512(s[23], d3);
s[4] = _mm512_xor_si512(s[4], d4);
s[9] = _mm512_xor_si512(s[9], d4);
s[14] = _mm512_xor_si512(s[14], d4);
s[19] = _mm512_xor_si512(s[19], d4);
s[24] = _mm512_xor_si512(s[24], d4);
let b0 = rol!(s[0], 0);
let b1 = rol!(s[6], 44);
let b2 = rol!(s[12], 43);
let b3 = rol!(s[18], 21);
let b4 = rol!(s[24], 14);
let b5 = rol!(s[3], 28);
let b6 = rol!(s[9], 20);
let b7 = rol!(s[10], 3);
let b8 = rol!(s[16], 45);
let b9 = rol!(s[22], 61);
let b10 = rol!(s[1], 1);
let b11 = rol!(s[7], 6);
let b12 = rol!(s[13], 25);
let b13 = rol!(s[19], 8);
let b14 = rol!(s[20], 18);
let b15 = rol!(s[4], 27);
let b16 = rol!(s[5], 36);
let b17 = rol!(s[11], 10);
let b18 = rol!(s[17], 15);
let b19 = rol!(s[23], 56);
let b20 = rol!(s[2], 62);
let b21 = rol!(s[8], 55);
let b22 = rol!(s[14], 39);
let b23 = rol!(s[15], 41);
let b24 = rol!(s[21], 2);
s[0] = _mm512_xor_si512(b0, _mm512_andnot_si512(b1, b2));
s[1] = _mm512_xor_si512(b1, _mm512_andnot_si512(b2, b3));
s[2] = _mm512_xor_si512(b2, _mm512_andnot_si512(b3, b4));
s[3] = _mm512_xor_si512(b3, _mm512_andnot_si512(b4, b0));
s[4] = _mm512_xor_si512(b4, _mm512_andnot_si512(b0, b1));
s[5] = _mm512_xor_si512(b5, _mm512_andnot_si512(b6, b7));
s[6] = _mm512_xor_si512(b6, _mm512_andnot_si512(b7, b8));
s[7] = _mm512_xor_si512(b7, _mm512_andnot_si512(b8, b9));
s[8] = _mm512_xor_si512(b8, _mm512_andnot_si512(b9, b5));
s[9] = _mm512_xor_si512(b9, _mm512_andnot_si512(b5, b6));
s[10] = _mm512_xor_si512(b10, _mm512_andnot_si512(b11, b12));
s[11] = _mm512_xor_si512(b11, _mm512_andnot_si512(b12, b13));
s[12] = _mm512_xor_si512(b12, _mm512_andnot_si512(b13, b14));
s[13] = _mm512_xor_si512(b13, _mm512_andnot_si512(b14, b10));
s[14] = _mm512_xor_si512(b14, _mm512_andnot_si512(b10, b11));
s[15] = _mm512_xor_si512(b15, _mm512_andnot_si512(b16, b17));
s[16] = _mm512_xor_si512(b16, _mm512_andnot_si512(b17, b18));
s[17] = _mm512_xor_si512(b17, _mm512_andnot_si512(b18, b19));
s[18] = _mm512_xor_si512(b18, _mm512_andnot_si512(b19, b15));
s[19] = _mm512_xor_si512(b19, _mm512_andnot_si512(b15, b16));
s[20] = _mm512_xor_si512(b20, _mm512_andnot_si512(b21, b22));
s[21] = _mm512_xor_si512(b21, _mm512_andnot_si512(b22, b23));
s[22] = _mm512_xor_si512(b22, _mm512_andnot_si512(b23, b24));
s[23] = _mm512_xor_si512(b23, _mm512_andnot_si512(b24, b20));
s[24] = _mm512_xor_si512(b24, _mm512_andnot_si512(b20, b21));
s[0] = _mm512_xor_si512(s[0], _mm512_set1_epi64(RC[ri] as i64));
}
}
#[inline]
#[target_feature(enable = "avx512f,avx512bw")]
unsafe fn absorb_block(state: &mut State, blocks: [&[u8; RATE]; 8]) {
for lane in 0..17usize {
let off = lane * 8;
let v = _mm512_set_epi64(
i64::from_le_bytes(blocks[7][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[6][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[5][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[4][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[3][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[2][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[1][off..off + 8].try_into().unwrap()),
i64::from_le_bytes(blocks[0][off..off + 8].try_into().unwrap()),
);
state[lane] = _mm512_xor_si512(state[lane], v);
}
permute(state);
}
#[inline]
#[target_feature(enable = "avx512f,avx512bw")]
unsafe fn squeeze(state: &State) -> [[u8; 32]; 8] {
let mut out = [[0u8; 32]; 8];
let mut tmp = [0i64; 8];
for lane in 0..4usize {
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut __m512i, state[lane]);
for stream in 0..8usize {
let off = lane * 8;
out[stream][off..off + 8].copy_from_slice(&(tmp[stream] as u64).to_le_bytes());
}
}
out
}
fn apply_padding(buf: &mut [u8; RATE], offset: usize) {
buf[offset] = 0x01;
for b in buf[offset + 1..RATE - 1].iter_mut() {
*b = 0;
}
buf[RATE - 1] = 0x80;
}
unsafe fn extract_scalar_state(state: &State, stream: usize) -> [u64; 25] {
let mut out = [0u64; 25];
let mut tmp = [0i64; 8];
for lane in 0..25usize {
_mm512_storeu_si512(tmp.as_mut_ptr() as *mut __m512i, state[lane]);
out[lane] = tmp[stream] as u64;
}
out
}
fn finish_scalar(mut state: [u64; 25], remaining: &[u8]) -> [u8; 32] {
let mut buf = [0u8; RATE];
let mut offset = 0usize;
for chunk in remaining.chunks(RATE) {
if chunk.len() == RATE {
let block: &[u8; RATE] = chunk.try_into().unwrap();
for i in 0..17usize {
let lane = u64::from_le_bytes(block[8 * i..8 * i + 8].try_into().unwrap());
state[i] ^= lane;
}
keccak_f1600_scalar(&mut state);
} else {
buf[..chunk.len()].copy_from_slice(chunk);
offset = chunk.len();
}
}
apply_padding(&mut buf, offset);
for i in 0..17usize {
let lane = u64::from_le_bytes(buf[8 * i..8 * i + 8].try_into().unwrap());
state[i] ^= lane;
}
keccak_f1600_scalar(&mut state);
let mut digest = [0u8; 32];
for i in 0..4usize {
digest[8 * i..8 * i + 8].copy_from_slice(&state[i].to_le_bytes());
}
digest
}
fn keccak_f1600_scalar(state: &mut [u64; 25]) {
const ROTATIONS: [u32; 25] = [
0, 1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61,
56, 14,
];
for round in 0..24 {
let mut c = [0u64; 5];
for x in 0..5 {
c[x] = state[x] ^ state[x + 5] ^ state[x + 10] ^ state[x + 15] ^ state[x + 20];
}
let mut d = [0u64; 5];
for x in 0..5 {
d[x] = c[(x + 4) % 5] ^ c[(x + 1) % 5].rotate_left(1);
}
for y in 0..5 {
for x in 0..5 {
state[x + 5 * y] ^= d[x];
}
}
let mut b = [0u64; 25];
for y in 0..5usize {
for x in 0..5usize {
b[y + 5 * ((2 * x + 3 * y) % 5)] =
state[x + 5 * y].rotate_left(ROTATIONS[x + 5 * y]);
}
}
for y in 0..5 {
for x in 0..5 {
state[x + 5 * y] =
b[x + 5 * y] ^ ((!b[(x + 1) % 5 + 5 * y]) & b[(x + 2) % 5 + 5 * y]);
}
}
state[0] ^= RC[round];
}
}
#[target_feature(enable = "avx512f,avx512bw")]
pub(super) unsafe fn keccak256_batch_impl(inputs: [&[u8]; 8]) -> [[u8; 32]; 8] {
let mut state: State = [_mm512_setzero_si512(); 25];
let min_len = inputs.iter().map(|s| s.len()).min().unwrap_or(0);
let shared_full_blocks = min_len / RATE;
for b in 0..shared_full_blocks {
let off = b * RATE;
absorb_block(
&mut state,
std::array::from_fn(|i| inputs[i][off..off + RATE].try_into().unwrap()),
);
}
let rems: [&[u8]; 8] = std::array::from_fn(|i| &inputs[i][shared_full_blocks * RATE..]);
if rems[1..].iter().all(|r| r.len() == rems[0].len()) {
let rem_len = rems[0].len();
let rem_full = rem_len / RATE;
for b in 0..rem_full {
let off = b * RATE;
absorb_block(
&mut state,
std::array::from_fn(|i| rems[i][off..off + RATE].try_into().unwrap()),
);
}
let partial_off = rem_full * RATE;
let partial_len = rem_len - partial_off;
let mut pads: [[u8; RATE]; 8] = [[0u8; RATE]; 8];
for i in 0..8 {
pads[i][..partial_len].copy_from_slice(&rems[i][partial_off..]);
apply_padding(&mut pads[i], partial_len);
}
absorb_block(&mut state, std::array::from_fn(|i| &pads[i]));
squeeze(&state)
} else {
let mut out = [[0u8; 32]; 8];
for stream in 0..8 {
let scalar = extract_scalar_state(&state, stream);
out[stream] = finish_scalar(scalar, rems[stream]);
}
out
}
}
}
#[cfg(target_arch = "x86_64")]
pub fn keccak256_batch(inputs: [&[u8]; 8]) -> [[u8; 32]; 8] {
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
unsafe { avx512::keccak256_batch_impl(inputs) }
} else {
std::array::from_fn(|i| keccak256_scalar(inputs[i]))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::keccak_scalar::keccak256 as keccak256_scalar;
fn hex(b: &[u8]) -> String {
b.iter().map(|x| format!("{x:02x}")).collect()
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_batch_matches_scalar_empty() {
let inputs: [&[u8]; 8] = [b""; 8];
let batch = keccak256_batch(inputs);
for digest in &batch {
assert_eq!(
hex(digest),
"c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_batch_matches_scalar_various() {
let msgs: [&[u8]; 8] = [
b"",
b"abc",
b"hello, world",
b"Transfer(address,address,uint256)",
&[0u8; 64], &[0u8; 136], &[0u8; 137], b"Ethereum",
];
let batch = keccak256_batch(msgs);
for (i, digest) in batch.iter().enumerate() {
let expected = keccak256_scalar(msgs[i]);
assert_eq!(
digest,
&expected,
"stream {i} mismatch: got {} expected {}",
hex(digest),
hex(&expected)
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_batch_uniform_64bytes() {
let bufs: Vec<[u8; 64]> = (0..8u8).map(|i| [i; 64]).collect();
let inputs: [&[u8]; 8] = std::array::from_fn(|i| bufs[i].as_slice());
let batch = keccak256_batch(inputs);
for (i, digest) in batch.iter().enumerate() {
let expected = keccak256_scalar(inputs[i]);
assert_eq!(digest, &expected, "stream {i} mismatch");
}
}
}