#![allow(dead_code)]
use super::address::Address;
use super::params::SlhDsaParams;
use arcanum_primitives::sha2::Sha256;
use core::marker::PhantomData;
use std::vec::Vec;
pub trait SlhDsaHash<P: SlhDsaParams> {
fn h_msg(r: &[u8], pk_seed: &[u8], pk_root: &[u8], m: &[u8], out_len: usize) -> Vec<u8>;
fn prf(pk_seed: &[u8], sk_seed: &[u8], adrs: &Address) -> Vec<u8>;
fn prf_msg(sk_prf: &[u8], opt_rand: &[u8], m: &[u8]) -> Vec<u8>;
fn f(pk_seed: &[u8], adrs: &Address, m: &[u8]) -> Vec<u8>;
fn h(pk_seed: &[u8], adrs: &Address, m1: &[u8], m2: &[u8]) -> Vec<u8>;
fn t_l(pk_seed: &[u8], adrs: &Address, m: &[u8]) -> Vec<u8>;
}
pub struct Sha2Hash<P: SlhDsaParams> {
_params: PhantomData<P>,
}
impl<P: SlhDsaParams> Sha2Hash<P> {
fn mgf1_sha256(seed: &[u8], length: usize) -> Vec<u8> {
let mut output = Vec::with_capacity(length);
let mut counter: u32 = 0;
while output.len() < length {
let mut hasher = Sha256::new();
hasher.update(seed);
hasher.update(&counter.to_be_bytes());
let hash = hasher.finalize();
let remaining = length - output.len();
let to_copy = remaining.min(32);
output.extend_from_slice(&hash[..to_copy]);
counter += 1;
}
output.truncate(length);
output
}
fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
const BLOCK_SIZE: usize = 64;
const IPAD: u8 = 0x36;
const OPAD: u8 = 0x5c;
let mut k = [0u8; BLOCK_SIZE];
if key.len() > BLOCK_SIZE {
let h = Sha256::hash(key);
k[..32].copy_from_slice(&h);
} else {
k[..key.len()].copy_from_slice(key);
}
let mut k_ipad = [0u8; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
k_ipad[i] = k[i] ^ IPAD;
}
let mut k_opad = [0u8; BLOCK_SIZE];
for i in 0..BLOCK_SIZE {
k_opad[i] = k[i] ^ OPAD;
}
let mut inner_hasher = Sha256::new();
inner_hasher.update(&k_ipad);
inner_hasher.update(data);
let inner_hash = inner_hasher.finalize();
let mut outer_hasher = Sha256::new();
outer_hasher.update(&k_opad);
outer_hasher.update(&inner_hash);
outer_hasher.finalize()
}
#[inline]
fn truncate(hash: &[u8; 32]) -> Vec<u8> {
hash[..P::N].to_vec()
}
}
impl<P: SlhDsaParams> SlhDsaHash<P> for Sha2Hash<P> {
fn h_msg(r: &[u8], pk_seed: &[u8], pk_root: &[u8], m: &[u8], out_len: usize) -> Vec<u8> {
let mut inner_hasher = Sha256::new();
inner_hasher.update(r);
inner_hasher.update(pk_seed);
inner_hasher.update(pk_root);
inner_hasher.update(m);
let inner_hash = inner_hasher.finalize();
let mut mgf_seed = Vec::with_capacity(r.len() + pk_seed.len() + 32);
mgf_seed.extend_from_slice(r);
mgf_seed.extend_from_slice(pk_seed);
mgf_seed.extend_from_slice(&inner_hash);
Self::mgf1_sha256(&mgf_seed, out_len)
}
fn prf(pk_seed: &[u8], sk_seed: &[u8], adrs: &Address) -> Vec<u8> {
let adrs_c = adrs.to_compressed();
let mut hasher = Sha256::new();
hasher.update(pk_seed);
if P::N > 16 {
let padding = [0u8; 64];
hasher.update(&padding[..(64 - P::N)]);
}
hasher.update(&adrs_c);
hasher.update(sk_seed);
let hash = hasher.finalize();
Self::truncate(&hash)
}
fn prf_msg(sk_prf: &[u8], opt_rand: &[u8], m: &[u8]) -> Vec<u8> {
let mut data = Vec::with_capacity(opt_rand.len() + m.len());
data.extend_from_slice(opt_rand);
data.extend_from_slice(m);
let mac = Self::hmac_sha256(sk_prf, &data);
mac[..P::N].to_vec()
}
fn f(pk_seed: &[u8], adrs: &Address, m: &[u8]) -> Vec<u8> {
let adrs_c = adrs.to_compressed();
let mut hasher = Sha256::new();
hasher.update(pk_seed);
if P::N > 16 {
let padding = [0u8; 64];
hasher.update(&padding[..(64 - P::N)]);
}
hasher.update(&adrs_c);
hasher.update(m);
let hash = hasher.finalize();
Self::truncate(&hash)
}
fn h(pk_seed: &[u8], adrs: &Address, m1: &[u8], m2: &[u8]) -> Vec<u8> {
let adrs_c = adrs.to_compressed();
let mut hasher = Sha256::new();
hasher.update(pk_seed);
if P::N > 16 {
let padding = [0u8; 64];
hasher.update(&padding[..(64 - P::N)]);
}
hasher.update(&adrs_c);
hasher.update(m1);
hasher.update(m2);
let hash = hasher.finalize();
Self::truncate(&hash)
}
fn t_l(pk_seed: &[u8], adrs: &Address, m: &[u8]) -> Vec<u8> {
let adrs_c = adrs.to_compressed();
let mut hasher = Sha256::new();
hasher.update(pk_seed);
if P::N > 16 {
let padding = [0u8; 64];
hasher.update(&padding[..(64 - P::N)]);
}
hasher.update(&adrs_c);
hasher.update(m);
let hash = hasher.finalize();
Self::truncate(&hash)
}
}
pub struct ShakeHash<P: SlhDsaParams> {
_params: PhantomData<P>,
}
impl<P: SlhDsaParams> SlhDsaHash<P> for ShakeHash<P> {
fn h_msg(r: &[u8], pk_seed: &[u8], pk_root: &[u8], m: &[u8], out_len: usize) -> Vec<u8> {
let _ = (r, pk_seed, pk_root, m, out_len);
unimplemented!("SHAKE H_msg not yet implemented - requires SHAKE256 in arcanum-primitives")
}
fn prf(pk_seed: &[u8], sk_seed: &[u8], adrs: &Address) -> Vec<u8> {
let _ = (pk_seed, sk_seed, adrs);
unimplemented!("SHAKE PRF not yet implemented")
}
fn prf_msg(sk_prf: &[u8], opt_rand: &[u8], m: &[u8]) -> Vec<u8> {
let _ = (sk_prf, opt_rand, m);
unimplemented!("SHAKE PRF_msg not yet implemented")
}
fn f(pk_seed: &[u8], adrs: &Address, m: &[u8]) -> Vec<u8> {
let _ = (pk_seed, adrs, m);
unimplemented!("SHAKE F not yet implemented")
}
fn h(pk_seed: &[u8], adrs: &Address, m1: &[u8], m2: &[u8]) -> Vec<u8> {
let _ = (pk_seed, adrs, m1, m2);
unimplemented!("SHAKE H not yet implemented")
}
fn t_l(pk_seed: &[u8], adrs: &Address, m: &[u8]) -> Vec<u8> {
let _ = (pk_seed, adrs, m);
unimplemented!("SHAKE T_l not yet implemented")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::slh_dsa::params::Sha2_128f;
#[test]
fn test_prf_returns_n_bytes() {
let pk_seed = [0u8; 16];
let sk_seed = [1u8; 16];
let adrs = Address::new();
let result = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed, &adrs);
assert_eq!(result.len(), Sha2_128f::N);
}
#[test]
fn test_prf_is_deterministic() {
let pk_seed = [0u8; 16];
let sk_seed = [1u8; 16];
let adrs = Address::wots_hash(0, 0, 0, 0, 0);
let result1 = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed, &adrs);
let result2 = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed, &adrs);
assert_eq!(result1, result2);
}
#[test]
fn test_prf_different_address_different_output() {
let pk_seed = [0u8; 16];
let sk_seed = [1u8; 16];
let adrs1 = Address::wots_hash(0, 0, 0, 0, 0);
let adrs2 = Address::wots_hash(0, 0, 0, 0, 1);
let result1 = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed, &adrs1);
let result2 = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed, &adrs2);
assert_ne!(result1, result2);
}
#[test]
fn test_prf_different_seed_different_output() {
let pk_seed = [0u8; 16];
let sk_seed1 = [1u8; 16];
let sk_seed2 = [2u8; 16];
let adrs = Address::new();
let result1 = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed1, &adrs);
let result2 = Sha2Hash::<Sha2_128f>::prf(&pk_seed, &sk_seed2, &adrs);
assert_ne!(result1, result2);
}
#[test]
fn test_prf_msg_returns_n_bytes() {
let sk_prf = [0u8; 16];
let opt_rand = [1u8; 16];
let message = b"test message";
let result = Sha2Hash::<Sha2_128f>::prf_msg(&sk_prf, &opt_rand, message);
assert_eq!(result.len(), Sha2_128f::N);
}
#[test]
fn test_prf_msg_deterministic() {
let sk_prf = [0u8; 16];
let opt_rand = [1u8; 16];
let message = b"test message";
let result1 = Sha2Hash::<Sha2_128f>::prf_msg(&sk_prf, &opt_rand, message);
let result2 = Sha2Hash::<Sha2_128f>::prf_msg(&sk_prf, &opt_rand, message);
assert_eq!(result1, result2);
}
#[test]
fn test_prf_msg_different_message_different_output() {
let sk_prf = [0u8; 16];
let opt_rand = [1u8; 16];
let result1 = Sha2Hash::<Sha2_128f>::prf_msg(&sk_prf, &opt_rand, b"message 1");
let result2 = Sha2Hash::<Sha2_128f>::prf_msg(&sk_prf, &opt_rand, b"message 2");
assert_ne!(result1, result2);
}
#[test]
fn test_f_returns_n_bytes() {
let pk_seed = [0u8; 16];
let adrs = Address::new();
let m = [2u8; 16];
let result = Sha2Hash::<Sha2_128f>::f(&pk_seed, &adrs, &m);
assert_eq!(result.len(), Sha2_128f::N);
}
#[test]
fn test_f_deterministic() {
let pk_seed = [0u8; 16];
let adrs = Address::wots_hash(0, 0, 0, 5, 0);
let m = [2u8; 16];
let result1 = Sha2Hash::<Sha2_128f>::f(&pk_seed, &adrs, &m);
let result2 = Sha2Hash::<Sha2_128f>::f(&pk_seed, &adrs, &m);
assert_eq!(result1, result2);
}
#[test]
fn test_f_different_input_different_output() {
let pk_seed = [0u8; 16];
let adrs = Address::new();
let m1 = [1u8; 16];
let m2 = [2u8; 16];
let result1 = Sha2Hash::<Sha2_128f>::f(&pk_seed, &adrs, &m1);
let result2 = Sha2Hash::<Sha2_128f>::f(&pk_seed, &adrs, &m2);
assert_ne!(result1, result2);
}
#[test]
fn test_h_returns_n_bytes() {
let pk_seed = [0u8; 16];
let adrs = Address::new();
let m1 = [1u8; 16];
let m2 = [2u8; 16];
let result = Sha2Hash::<Sha2_128f>::h(&pk_seed, &adrs, &m1, &m2);
assert_eq!(result.len(), Sha2_128f::N);
}
#[test]
fn test_h_not_commutative() {
let pk_seed = [0u8; 16];
let adrs = Address::tree(0, 0, 1, 0);
let m1 = [1u8; 16];
let m2 = [2u8; 16];
let result1 = Sha2Hash::<Sha2_128f>::h(&pk_seed, &adrs, &m1, &m2);
let result2 = Sha2Hash::<Sha2_128f>::h(&pk_seed, &adrs, &m2, &m1);
assert_ne!(result1, result2);
}
#[test]
fn test_t_l_returns_n_bytes() {
let pk_seed = [0u8; 16];
let adrs = Address::wots_pk(0, 0, 0);
let m = vec![0u8; 35 * 16];
let result = Sha2Hash::<Sha2_128f>::t_l(&pk_seed, &adrs, &m);
assert_eq!(result.len(), Sha2_128f::N);
}
#[test]
fn test_h_msg_returns_requested_length() {
let r = [0u8; 16];
let pk_seed = [1u8; 16];
let pk_root = [2u8; 16];
let m = b"test message";
let result = Sha2Hash::<Sha2_128f>::h_msg(&r, &pk_seed, &pk_root, m, 50);
assert_eq!(result.len(), 50);
}
#[test]
fn test_h_msg_deterministic() {
let r = [0u8; 16];
let pk_seed = [1u8; 16];
let pk_root = [2u8; 16];
let m = b"test message";
let result1 = Sha2Hash::<Sha2_128f>::h_msg(&r, &pk_seed, &pk_root, m, 32);
let result2 = Sha2Hash::<Sha2_128f>::h_msg(&r, &pk_seed, &pk_root, m, 32);
assert_eq!(result1, result2);
}
#[test]
fn test_h_msg_different_r_different_output() {
let r1 = [0u8; 16];
let r2 = [1u8; 16];
let pk_seed = [1u8; 16];
let pk_root = [2u8; 16];
let m = b"test message";
let result1 = Sha2Hash::<Sha2_128f>::h_msg(&r1, &pk_seed, &pk_root, m, 32);
let result2 = Sha2Hash::<Sha2_128f>::h_msg(&r2, &pk_seed, &pk_root, m, 32);
assert_ne!(result1, result2);
}
#[test]
fn test_hmac_sha256_known_vector() {
let key = [0x0bu8; 20];
let data = b"Hi There";
let expected = [
0xb0, 0x34, 0x4c, 0x61, 0xd8, 0xdb, 0x38, 0x53, 0x5c, 0xa8, 0xaf, 0xce, 0xaf, 0x0b,
0xf1, 0x2b, 0x88, 0x1d, 0xc2, 0x00, 0xc9, 0x83, 0x3d, 0xa7, 0x26, 0xe9, 0x37, 0x6c,
0x2e, 0x32, 0xcf, 0xf7,
];
let result = Sha2Hash::<Sha2_128f>::hmac_sha256(&key, data);
assert_eq!(result, expected);
}
#[test]
fn test_mgf1_output_length() {
let seed = b"test seed";
let out1 = Sha2Hash::<Sha2_128f>::mgf1_sha256(seed, 10);
assert_eq!(out1.len(), 10);
let out2 = Sha2Hash::<Sha2_128f>::mgf1_sha256(seed, 100);
assert_eq!(out2.len(), 100);
let out3 = Sha2Hash::<Sha2_128f>::mgf1_sha256(seed, 0);
assert_eq!(out3.len(), 0);
}
#[test]
fn test_mgf1_deterministic() {
let seed = b"test seed";
let out1 = Sha2Hash::<Sha2_128f>::mgf1_sha256(seed, 64);
let out2 = Sha2Hash::<Sha2_128f>::mgf1_sha256(seed, 64);
assert_eq!(out1, out2);
}
}