crate::ix!();
use core::mem::size_of;
#[derive(Default)]
pub struct SHA3_256 {
state: [u64; 25], buffer: Sha3_256Buffer,
bufsize: u32, pos: u32, }
pub type Sha3_256Buffer = [u8; 8];
pub const SHA3_256_RATE_BITS: usize = 1088;
pub const SHA3_256_RATE_BUFFERS: usize = SHA3_256_RATE_BITS / (8 * size_of::<Sha3_256Buffer>());
const_assert!{
SHA3_256_RATE_BITS % (8 * size_of::<Sha3_256Buffer>()) == 0
}
pub const SHA3_256_OUTPUT_SIZE: usize = 32;
impl SHA3_256 {
pub fn write(&mut self, mut data: &[u8]) -> &mut SHA3_256 {
const BUF_BYTES: usize = size_of::<Sha3_256Buffer>();
if self.bufsize != 0 && (self.bufsize as usize + data.len() >= BUF_BYTES) {
let need = BUF_BYTES - self.bufsize as usize;
self.buffer[self.bufsize as usize .. self.bufsize as usize + need]
.copy_from_slice(&data[..need]);
data = &data[need..];
let lane = u64::from_le_bytes(self.buffer);
self.state[self.pos as usize] ^= lane;
self.pos += 1;
self.bufsize = 0;
if self.pos as usize == SHA3_256_RATE_BUFFERS {
keccakf(&mut self.state);
self.pos = 0;
}
}
while data.len() >= BUF_BYTES {
let lane = u64::from_le_bytes(data[0..BUF_BYTES].try_into().unwrap());
self.state[self.pos as usize] ^= lane;
self.pos += 1;
data = &data[BUF_BYTES..];
if self.pos as usize == SHA3_256_RATE_BUFFERS {
keccakf(&mut self.state);
self.pos = 0;
}
}
if !data.is_empty() {
self.buffer[self.bufsize as usize .. self.bufsize as usize + data.len()]
.copy_from_slice(data);
self.bufsize += data.len() as u32;
}
self
}
pub fn finalize(&mut self, output: &mut [u8]) -> &mut SHA3_256 {
assert_eq!(output.len(), SHA3_256_OUTPUT_SIZE);
for b in self.buffer[self.bufsize as usize ..].iter_mut() {
*b = 0;
}
self.buffer[self.bufsize as usize] ^= 0x06;
let lane = u64::from_le_bytes(self.buffer);
self.state[self.pos as usize] ^= lane;
self.state[SHA3_256_RATE_BUFFERS - 1] ^= 0x8000_0000_0000_0000;
keccakf(&mut self.state);
for i in 0..4 {
let bytes = self.state[i].to_le_bytes();
output[i * 8 .. (i + 1) * 8].copy_from_slice(&bytes);
}
self
}
pub fn reset(&mut self) -> &mut SHA3_256 {
self.bufsize = 0;
self.pos = 0;
self.state.fill(0);
self
}
pub fn finalize_fixed(&mut self) -> [u8; SHA3_256_OUTPUT_SIZE] {
let mut out = [0u8; SHA3_256_OUTPUT_SIZE];
self.finalize(&mut out);
out
}
pub fn finalize_into(&mut self, output: &mut [u8]) -> &mut Self {
self.finalize(output)
}
}
#[inline(always)]
pub fn rotl(x: u64, n: i32) -> u64 {
x.rotate_left(n as u32)
}
pub fn keccakf(st: &mut [u64; 25]) {
const RNDC: [u64; 24] = [
0x0000_0000_0000_0001, 0x0000_0000_0000_8082, 0x8000_0000_0000_808a, 0x8000_0000_8000_8000,
0x0000_0000_0000_808b, 0x0000_0000_8000_0001, 0x8000_0000_8000_8081, 0x8000_0000_0000_8009,
0x0000_0000_0000_008a, 0x0000_0000_0000_0088, 0x0000_0000_8000_8009, 0x0000_0000_8000_000a,
0x0000_0000_8000_808b, 0x8000_0000_0000_008b, 0x8000_0000_0000_8089, 0x8000_0000_0000_8003,
0x8000_0000_0000_8002, 0x8000_0000_0000_0080, 0x0000_0000_0000_800a, 0x8000_0000_8000_000a,
0x8000_0000_8000_8081, 0x8000_0000_0000_8080, 0x0000_0000_8000_0001, 0x8000_0000_8000_8008,
];
const ROUNDS: usize = 24;
for round in 0..ROUNDS {
let (mut bc0, mut bc1, mut bc2, mut bc3, mut bc4);
let mut t;
bc0 = st[0] ^ st[5] ^ st[10] ^ st[15] ^ st[20];
bc1 = st[1] ^ st[6] ^ st[11] ^ st[16] ^ st[21];
bc2 = st[2] ^ st[7] ^ st[12] ^ st[17] ^ st[22];
bc3 = st[3] ^ st[8] ^ st[13] ^ st[18] ^ st[23];
bc4 = st[4] ^ st[9] ^ st[14] ^ st[19] ^ st[24];
t = bc4 ^ rotl(bc1, 1); st[0] ^= t; st[5] ^= t; st[10] ^= t; st[15] ^= t; st[20] ^= t;
t = bc0 ^ rotl(bc2, 1); st[1] ^= t; st[6] ^= t; st[11] ^= t; st[16] ^= t; st[21] ^= t;
t = bc1 ^ rotl(bc3, 1); st[2] ^= t; st[7] ^= t; st[12] ^= t; st[17] ^= t; st[22] ^= t;
t = bc2 ^ rotl(bc4, 1); st[3] ^= t; st[8] ^= t; st[13] ^= t; st[18] ^= t; st[23] ^= t;
t = bc3 ^ rotl(bc0, 1); st[4] ^= t; st[9] ^= t; st[14] ^= t; st[19] ^= t; st[24] ^= t;
t = st[1];
bc0 = st[10]; st[10] = rotl(t, 1); t = bc0;
bc0 = st[7]; st[7] = rotl(t, 3); t = bc0;
bc0 = st[11]; st[11] = rotl(t, 6); t = bc0;
bc0 = st[17]; st[17] = rotl(t, 10); t = bc0;
bc0 = st[18]; st[18] = rotl(t, 15); t = bc0;
bc0 = st[3]; st[3] = rotl(t, 21); t = bc0;
bc0 = st[5]; st[5] = rotl(t, 28); t = bc0;
bc0 = st[16]; st[16] = rotl(t, 36); t = bc0;
bc0 = st[8]; st[8] = rotl(t, 45); t = bc0;
bc0 = st[21]; st[21] = rotl(t, 55); t = bc0;
bc0 = st[24]; st[24] = rotl(t, 2); t = bc0;
bc0 = st[4]; st[4] = rotl(t, 14); t = bc0;
bc0 = st[15]; st[15] = rotl(t, 27); t = bc0;
bc0 = st[23]; st[23] = rotl(t, 41); t = bc0;
bc0 = st[19]; st[19] = rotl(t, 56); t = bc0;
bc0 = st[13]; st[13] = rotl(t, 8); t = bc0;
bc0 = st[12]; st[12] = rotl(t, 25); t = bc0;
bc0 = st[2]; st[2] = rotl(t, 43); t = bc0;
bc0 = st[20]; st[20] = rotl(t, 62); t = bc0;
bc0 = st[14]; st[14] = rotl(t, 18); t = bc0;
bc0 = st[22]; st[22] = rotl(t, 39); t = bc0;
bc0 = st[9]; st[9] = rotl(t, 61); t = bc0;
bc0 = st[6]; st[6] = rotl(t, 20); t = bc0;
st[1] = rotl(t, 44);
bc0 = st[0]; bc1 = st[1]; bc2 = st[2]; bc3 = st[3]; bc4 = st[4];
st[0] = bc0 ^ (!bc1 & bc2) ^ RNDC[round];
st[1] = bc1 ^ (!bc2 & bc3);
st[2] = bc2 ^ (!bc3 & bc4);
st[3] = bc3 ^ (!bc4 & bc0);
st[4] = bc4 ^ (!bc0 & bc1);
bc0 = st[5]; bc1 = st[6]; bc2 = st[7]; bc3 = st[8]; bc4 = st[9];
st[5] = bc0 ^ (!bc1 & bc2);
st[6] = bc1 ^ (!bc2 & bc3);
st[7] = bc2 ^ (!bc3 & bc4);
st[8] = bc3 ^ (!bc4 & bc0);
st[9] = bc4 ^ (!bc0 & bc1);
bc0 = st[10]; bc1 = st[11]; bc2 = st[12]; bc3 = st[13]; bc4 = st[14];
st[10] = bc0 ^ (!bc1 & bc2);
st[11] = bc1 ^ (!bc2 & bc3);
st[12] = bc2 ^ (!bc3 & bc4);
st[13] = bc3 ^ (!bc4 & bc0);
st[14] = bc4 ^ (!bc0 & bc1);
bc0 = st[15]; bc1 = st[16]; bc2 = st[17]; bc3 = st[18]; bc4 = st[19];
st[15] = bc0 ^ (!bc1 & bc2);
st[16] = bc1 ^ (!bc2 & bc3);
st[17] = bc2 ^ (!bc3 & bc4);
st[18] = bc3 ^ (!bc4 & bc0);
st[19] = bc4 ^ (!bc0 & bc1);
bc0 = st[20]; bc1 = st[21]; bc2 = st[22]; bc3 = st[23]; bc4 = st[24];
st[20] = bc0 ^ (!bc1 & bc2);
st[21] = bc1 ^ (!bc2 & bc3);
st[22] = bc2 ^ (!bc3 & bc4);
st[23] = bc3 ^ (!bc4 & bc0);
st[24] = bc4 ^ (!bc0 & bc1);
}
}
#[cfg(test)]
mod sha3_tests {
use super::*;
use core::mem::size_of;
use sha3::{Digest as _, Sha3_256 as RefSha3_256};
fn digest_ours(data: &[u8]) -> [u8; SHA3_256_OUTPUT_SIZE] {
let mut h = SHA3_256::default();
h.write(data);
let mut out = [0u8; SHA3_256_OUTPUT_SIZE];
h.finalize(&mut out);
out
}
fn digest_ours_chunked(chunks: &[&[u8]]) -> [u8; SHA3_256_OUTPUT_SIZE] {
let mut h = SHA3_256::default();
for c in chunks {
h.write(c);
}
let mut out = [0u8; SHA3_256_OUTPUT_SIZE];
h.finalize(&mut out);
out
}
fn digest_ref(data: &[u8]) -> [u8; 32] {
let mut r = RefSha3_256::new();
r.update(data);
let res = r.finalize();
let mut out = [0u8; 32];
out.copy_from_slice(&res[..]);
out
}
fn to_hex(bytes: &[u8]) -> String {
use core::fmt::Write;
let mut s = String::with_capacity(bytes.len() * 2);
for &b in bytes {
let _ = write!(&mut s, "{:02x}", b);
}
s
}
struct XorShift64(u64);
impl XorShift64 {
fn new(seed: u64) -> Self { Self(seed) }
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn fill_bytes(&mut self, buf: &mut [u8]) {
for chunk in buf.chunks_mut(8) {
let v = self.next_u64().to_le_bytes();
let n = chunk.len();
chunk.copy_from_slice(&v[..n]);
}
}
fn gen_range_usize(&mut self, start: usize, end: usize) -> usize {
start + (self.next_u64() as usize % (end - start))
}
}
#[traced_test]
fn kat_empty_string() {
let got = digest_ours(b"");
let expected_hex = "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a";
assert_eq!(to_hex(&got), expected_hex);
assert_eq!(got, digest_ref(b""));
}
#[traced_test]
fn kat_abc() {
let got = digest_ours(b"abc");
let expected_hex = "3a985da74fe225b2045c172d6bd390bd855f086e3e9d525b46bfe24511431532";
assert_eq!(to_hex(&got), expected_hex);
assert_eq!(got, digest_ref(b"abc"));
}
#[traced_test]
fn state_after_small_write() {
let mut h = SHA3_256::default();
h.write(&[1, 2, 3]);
assert_eq!(h.pos, 0);
assert_eq!(h.bufsize, 3);
assert!(h.state.iter().all(|&x| x == 0));
}
#[traced_test]
fn state_after_exact_lane_write() {
let mut h = SHA3_256::default();
let lane_bytes = [1u8, 2, 3, 4, 5, 6, 7, 8];
h.write(&lane_bytes);
assert_eq!(h.bufsize, 0);
assert_eq!(h.pos, 1);
assert_eq!(h.state[0], u64::from_le_bytes(lane_bytes));
assert!(h.state[1..].iter().all(|&x| x == 0));
}
#[traced_test]
fn pos_wraps_after_full_rate_block() {
let mut h = SHA3_256::default();
let block = vec![0u8; 136];
h.write(&block);
assert_eq!(h.bufsize, 0);
assert_eq!(h.pos, 0);
}
#[test]
#[should_panic]
fn finalize_panics_on_wrong_output_len() {
let mut h = SHA3_256::default();
h.write(b"hello");
let mut out = [0u8; 31]; h.finalize(&mut out);
}
#[traced_test]
fn reset_zeroes_everything() {
let mut h = SHA3_256::default();
h.write(b"some bytes");
let mut out = [0u8; 32];
h.finalize(&mut out);
h.reset();
assert_eq!(h.bufsize, 0);
assert_eq!(h.pos, 0);
assert!(h.state.iter().all(|&x| x == 0));
let fresh = digest_ours(b"again");
let mut h2 = SHA3_256::default();
h2.write(b"again");
let mut out2 = [0u8; 32];
h2.finalize(&mut out2);
assert_eq!(fresh, out2);
}
#[traced_test]
fn incremental_equals_one_shot_across_boundaries() {
for len in 0..=272 {
let mut data = vec![0u8; len];
for (i, b) in data.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(31).wrapping_add(7);
}
let one_shot = digest_ours(&data);
assert_eq!(one_shot, digest_ref(&data), "ref mismatch at len={}", len);
for split in 0..=len {
let a = &data[..split];
let b = &data[split..];
let inc = digest_ours_chunked(&[a, b]);
assert_eq!(
one_shot, inc,
"mismatch at len={}, split={} (one_shot={}, inc={})",
len, split, to_hex(&one_shot), to_hex(&inc)
);
}
}
}
#[traced_test]
fn many_chunk_patterns_match_reference() {
let patterns: &[&[usize]] = &[
&[1; 137], &[7; 20], &[8; 17], &[8; 34], &[135, 1],
&[1, 135],
&[136, 1],
&[1, 136],
&[3, 5, 8, 13, 21, 34, 55, 89], ];
let max_total = patterns
.iter()
.map(|p| p.iter().sum::<usize>())
.max()
.unwrap_or(0);
let mut msg = vec![0u8; max_total];
for (i, b) in msg.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(97).wrapping_add(11);
}
for (pi, pat) in patterns.iter().enumerate() {
let total: usize = pat.iter().sum();
let mut idx = 0usize;
let mut chunks: Vec<&[u8]> = Vec::with_capacity(pat.len());
for &sz in *pat {
chunks.push(&msg[idx..idx + sz]);
idx += sz;
}
let inc = digest_ours_chunked(&chunks);
let ref_hash = digest_ref(&msg[..total]);
assert_eq!(
inc, ref_hash,
"pattern {} failed: ours={}, ref={}",
pi, to_hex(&inc), to_hex(&ref_hash)
);
}
}
#[traced_test]
fn randomized_inputs_and_chunkings_match_reference() {
let mut rng = XorShift64::new(0x5A17_EC7A_9B1B_D3C5);
for _case in 0..256 {
let len = rng.gen_range_usize(0, 8 * 1024); let mut data = vec![0u8; len];
rng.fill_bytes(&mut data);
let mut chunks: Vec<&[u8]> = Vec::new();
let mut i = 0usize;
while i < len {
let remain = len - i;
let max_chunk = core::cmp::min(256, remain);
let sz = 1 + rng.gen_range_usize(0, max_chunk);
chunks.push(&data[i..i + sz]);
i += sz;
}
let ours = digest_ours_chunked(&chunks);
let reference = digest_ref(&data);
assert_eq!(ours, reference, "ours={}, ref={}", to_hex(&ours), to_hex(&reference));
}
}
#[traced_test]
fn million_a_matches_reference() {
let mut data = vec![0u8; 1_000_000];
for b in data.iter_mut() { *b = b'a'; }
let ours = digest_ours(&data);
let reference = digest_ref(&data);
assert_eq!(ours, reference, "ours={}, ref={}", to_hex(&ours), to_hex(&reference));
}
}