pub(super) mod xmd;
pub(super) mod xof;
use core::num::NonZero;
use digest::{Digest, ExtendableOutput, Update, XofReader};
use elliptic_curve::Error;
use elliptic_curve::array::{Array, ArraySize};
use xmd::ExpandMsgXmdError;
use xof::ExpandMsgXofError;
const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
const MAX_DST_LEN: usize = 255;
pub trait ExpandMsg<K> {
type Hash;
type Expander<'dst>: Expander + Sized;
type Error: core::error::Error;
fn expand_message<'dst>(
msg: &[&[u8]],
dst: &'dst [&[u8]],
len_in_bytes: NonZero<u16>,
) -> Result<Self::Expander<'dst>, Self::Error>;
}
pub trait Expander {
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error>;
}
#[derive(Debug)]
pub(crate) enum Domain<'a, L: ArraySize> {
Hashed(Array<u8, L>),
Array(&'a [&'a [u8]]),
}
impl<'a, L: ArraySize> Domain<'a, L> {
pub fn xof<X>(dst: &'a [&'a [u8]]) -> Result<Self, ExpandMsgXofError>
where
X: Default + ExtendableOutput + Update,
{
let dst_len = dst.iter().map(|slice| slice.len()).sum::<usize>();
if dst_len == 0 {
Err(ExpandMsgXofError::EmptyDst)
} else if dst_len > MAX_DST_LEN {
if L::USIZE > u8::MAX.into() {
return Err(ExpandMsgXofError::DstSecurityLevel);
}
let mut data = Array::<u8, L>::default();
let mut hash = X::default();
hash.update(OVERSIZE_DST_SALT);
for slice in dst {
hash.update(slice);
}
hash.finalize_xof().read(&mut data);
Ok(Self::Hashed(data))
} else {
Ok(Self::Array(dst))
}
}
pub fn xmd<X>(dst: &'a [&'a [u8]]) -> Result<Self, ExpandMsgXmdError>
where
X: Digest<OutputSize = L>,
{
let dst_len = dst.iter().map(|slice| slice.len()).sum::<usize>();
if dst_len == 0 {
Err(ExpandMsgXmdError::EmptyDst)
} else if dst_len > MAX_DST_LEN {
if L::USIZE > u8::MAX.into() {
return Err(ExpandMsgXmdError::DstHash);
}
Ok(Self::Hashed({
let mut hash = X::new();
hash.update(OVERSIZE_DST_SALT);
for slice in dst {
hash.update(slice);
}
hash.finalize()
}))
} else {
Ok(Self::Array(dst))
}
}
pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
match self {
Self::Hashed(d) => hash.update(d),
Self::Array(d) => {
for d in d.iter() {
hash.update(d)
}
}
}
}
pub fn len(&self) -> u8 {
match self {
Self::Hashed(_) => L::U8,
Self::Array(d) => {
u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
}
}
}
#[cfg(test)]
pub fn assert(&self, bytes: &[u8]) {
let data = match self {
Domain::Hashed(d) => d.to_vec(),
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
};
assert_eq!(data, bytes);
}
#[cfg(test)]
pub fn assert_dst(&self, bytes: &[u8]) {
let data = match self {
Domain::Hashed(d) => d.to_vec(),
Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
};
assert_eq!(data, &bytes[..bytes.len() - 1]);
assert_eq!(self.len(), bytes[bytes.len() - 1]);
}
}