use crate::{
core::{
actually_used_field::ActuallyUsedField,
circuits::boolean::{boolean_value::Boolean, byte::Byte},
},
utils::crypto::{
hmac::{HMAC_RescuePrime, HMAC_Sha3_256, HMAC},
key::RESCUE_KEY_COUNT,
rescue_desc::RescueArg,
},
};
pub trait HKDF<const L: usize, T> {
fn extract(&self, salt: Vec<T>, ikm: Vec<T>) -> Vec<T>;
fn expand(&self, prk: Vec<T>, info: Vec<T>) -> [T; L];
fn okm(&self, salt: Vec<T>, ikm: Vec<T>, info: Vec<T>) -> [T; L] {
let prk = self.extract(salt, ikm);
self.expand(prk, info)
}
}
#[allow(non_camel_case_types)]
pub struct HKDF_RescuePrime<F: ActuallyUsedField, T: RescueArg<F>> {
hmac: HMAC_RescuePrime<F, T>,
}
impl<F: ActuallyUsedField, T: RescueArg<F>> HKDF_RescuePrime<F, T> {
pub fn new() -> Self {
Self {
hmac: HMAC_RescuePrime::new(),
}
}
}
impl<F: ActuallyUsedField, T: RescueArg<F>> HKDF<RESCUE_KEY_COUNT, T> for HKDF_RescuePrime<F, T> {
fn extract(&self, mut salt: Vec<T>, ikm: Vec<T>) -> Vec<T> {
if salt.is_empty() {
salt = vec![T::from(F::ZERO); self.hmac.hasher.rate]
}
self.hmac.digest(salt, ikm)
}
fn expand(&self, prk: Vec<T>, mut info: Vec<T>) -> [T; RESCUE_KEY_COUNT] {
info.push(T::from(F::ONE));
self.hmac
.digest(prk, info)
.try_into()
.unwrap_or_else(|v: Vec<T>| {
panic!(
"Expected a Vec of length {} (found {})",
RESCUE_KEY_COUNT,
v.len()
)
})
}
}
impl<F: ActuallyUsedField, T: RescueArg<F>> Default for HKDF_RescuePrime<F, T> {
fn default() -> Self {
Self::new()
}
}
macro_rules! impl_hkdf_sha3 {
($t: ident, $hmac: ident) => {
#[derive(Clone, Debug)]
#[allow(non_camel_case_types)]
pub struct $t {
pub hmac: $hmac,
}
impl $t {
pub fn new() -> Self {
Self { hmac: $hmac::new() }
}
}
impl<const L: usize, B: Boolean> HKDF<L, Byte<B>> for $t {
fn extract(&self, mut salt: Vec<Byte<B>>, ikm: Vec<Byte<B>>) -> Vec<Byte<B>> {
if salt.is_empty() {
salt = vec![Byte::from(0); self.hmac.hasher.digest_in_bytes()]
}
self.hmac.digest(salt, ikm)
}
fn expand(&self, prk: Vec<Byte<B>>, mut info: Vec<Byte<B>>) -> [Byte<B>; L] {
assert!(
L <= self.hmac.hasher.digest_in_bytes(),
"we only support L at most {} (found {})",
self.hmac.hasher.digest_in_bytes(),
L
);
info.push(Byte::from(1u8));
self.hmac
.digest(prk, info)
.into_iter()
.take(L)
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length {} (found {})", L, v.len())
})
}
}
impl Default for $t {
fn default() -> Self {
Self::new()
}
}
};
}
impl_hkdf_sha3!(HKDF_Sha3_256, HMAC_Sha3_256);