use crate::{
core::{
actually_used_field::ActuallyUsedField,
circuits::boolean::{
boolean_value::Boolean,
byte::Byte,
sha3::{SHA3_256, SHA3_512},
},
},
utils::crypto::{rescue_desc::RescueArg, rescue_prime_hash::RescuePrimeHash},
};
pub trait HMAC<T> {
fn digest(&self, key: Vec<T>, message: Vec<T>) -> Vec<T>;
}
#[derive(Debug)]
#[allow(non_camel_case_types)]
pub struct HMAC_RescuePrime<F: ActuallyUsedField, T: RescueArg<F>> {
pub hasher: RescuePrimeHash<F, T>,
}
impl<F: ActuallyUsedField, T: RescueArg<F>> HMAC_RescuePrime<F, T> {
pub fn new() -> Self {
Self {
hasher: RescuePrimeHash::new(),
}
}
}
impl<F: ActuallyUsedField, T: RescueArg<F>> HMAC<T> for HMAC_RescuePrime<F, T> {
fn digest(&self, mut key: Vec<T>, message: Vec<T>) -> Vec<T> {
assert!(
self.hasher.digest_len <= key.len() && key.len() <= self.hasher.rate,
"The length of the key is supposed to be at least the hash function's digest length and at most the hash function's rate (found key length: {}, digest length: {} and rate: {})",
key.len(),
self.hasher.digest_len,
self.hasher.rate
);
let ipad = T::from(F::from_le_bytes([0x36; 32]));
let opad = T::from(F::from_le_bytes([0x5c; 32]));
key.extend(vec![T::from(F::ZERO); self.hasher.rate - key.len()]);
let mut key_plus_ipad = key.iter().map(|k| *k + ipad).collect::<Vec<T>>();
key_plus_ipad.extend(message);
let inner_digest = self.hasher.digest(key_plus_ipad).to_vec();
let mut key_plus_opad = key.iter().map(|k| *k + opad).collect::<Vec<T>>();
key_plus_opad.extend(inner_digest);
self.hasher.digest(key_plus_opad).to_vec()
}
}
impl<F: ActuallyUsedField, T: RescueArg<F>> Default for HMAC_RescuePrime<F, T> {
fn default() -> Self {
Self::new()
}
}
macro_rules! impl_hmac_sha3 {
($t: ident, $hasher: ident) => {
#[derive(Clone, Debug)]
#[allow(non_camel_case_types)]
pub struct $t {
pub hasher: $hasher,
}
impl $t {
pub fn new() -> Self {
Self { hasher: $hasher::new() }
}
}
impl<B: Boolean> HMAC<Byte<B>> for $t {
fn digest(&self, mut key: Vec<Byte<B>>, message: Vec<Byte<B>>) -> Vec<Byte<B>> {
assert!(
self.hasher.digest_in_bytes() <= key.len() && key.len() <= self.hasher.rate_in_bytes(),
"The length of the key is supposed to be at least the hash function's digest length and at most the hash function's rate (found key len: {}, digest length: {} and rate: {})",
key.len(),
self.hasher.digest_in_bytes(),
self.hasher.rate_in_bytes()
);
let ipad = Byte::from(0x36);
let opad = Byte::from(0x5C);
key.extend(vec![
Byte::from(0u8);
self.hasher.rate_in_bytes() - key.len()
]);
let mut key_xor_ipad = key.iter().map(|k| *k ^ ipad).collect::<Vec<Byte<B>>>();
key_xor_ipad.extend(&message);
let inner_digest = self.hasher.digest(key_xor_ipad).to_vec();
let mut key_xor_opad = key.iter().map(|k| *k ^ opad).collect::<Vec<Byte<B>>>();
key_xor_opad.extend(inner_digest);
self.hasher.digest(key_xor_opad).to_vec()
}
}
impl Default for $t {
fn default() -> Self {
Self::new()
}
}
};
}
impl_hmac_sha3!(HMAC_Sha3_256, SHA3_256);
impl_hmac_sha3!(HMAC_Sha3_512, SHA3_512);