use core::fmt;
use rand_core::{CryptoRng, RngCore};
use crate::{bindings::*, GenerateSecret};
const MIN_KEY_SIZE: usize = 16;
const MAX_KEY_SIZE: usize = 64;
#[derive(Copy, Clone)]
pub struct Key {
buf: [u8; MAX_KEY_SIZE],
len: usize,
}
impl Key {
fn expose(&self) -> &[u8] {
&self.buf[..self.len]
}
}
impl GenerateSecret for Key {
fn generate<RNG: RngCore + CryptoRng>(rng: &mut RNG) -> Self {
let mut buf = [0u8; 32];
rng.fill_bytes(&mut buf);
Self::from(buf)
}
}
#[derive(Debug)]
pub struct InvalidKeySize;
impl fmt::Display for InvalidKeySize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Invalid key size.")
}
}
impl TryFrom<&[u8]> for Key {
type Error = InvalidKeySize;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() > MAX_KEY_SIZE || value.len() < MIN_KEY_SIZE {
return Err(InvalidKeySize);
}
let mut buf = [0; MAX_KEY_SIZE];
(&mut buf[..value.len()]).copy_from_slice(value);
Ok(Self {
buf,
len: value.len(),
})
}
}
impl From<[u8; 32]> for Key {
fn from(key: [u8; 32]) -> Self {
Self::try_from(key.as_slice()).unwrap_or_else(|_| unreachable!())
}
}
impl From<[u8; 64]> for Key {
fn from(key: [u8; 64]) -> Self {
Self::try_from(key.as_slice()).unwrap_or_else(|_| unreachable!())
}
}
pub struct Context<const HASH_SIZE: usize>(crypto_blake2b_ctx);
impl<const HASH_SIZE: usize> Context<HASH_SIZE> {
fn new(key: &[u8]) -> Self {
assert!(HASH_SIZE > 0 && HASH_SIZE <= 64);
let mut ctx = core::mem::MaybeUninit::zeroed();
let ctx = unsafe {
crypto_blake2b_general_init(ctx.as_mut_ptr(), HASH_SIZE, key.as_ptr(), key.len());
ctx.assume_init()
};
Self(ctx)
}
pub fn update(&mut self, message: &[u8]) {
unsafe {
crypto_blake2b_update(&mut self.0, message.as_ptr(), message.len());
}
}
pub fn finish(mut self) -> [u8; HASH_SIZE] {
let mut hash = [0u8; HASH_SIZE];
unsafe {
crypto_blake2b_final(&mut self.0, hash.as_mut_ptr());
}
hash
}
}
impl<const HASH_SIZE: usize> From<&Key> for Context<HASH_SIZE> {
fn from(key: &Key) -> Self {
Self::new(key.expose())
}
}
impl<const HASH_SIZE: usize> Default for Context<HASH_SIZE> {
fn default() -> Self {
Self::new(&[])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
pub fn generate_key() {
let key = Key::generate(&mut rand_core::OsRng);
let mut ctx = Context::<32>::from(&key);
ctx.update(b"test");
ctx.finish();
}
#[test]
pub fn test_vectors() {
let test_vectors = std::fs::read_to_string("test_vectors/blake2.json").unwrap();
json::parse(&test_vectors)
.unwrap()
.members()
.into_iter()
.filter(|tv| tv["hash"] == "blake2b")
.for_each(|tv| {
let key = hex::decode(tv["key"].as_str().unwrap()).unwrap();
let input = hex::decode(tv["in"].as_str().unwrap()).unwrap();
let output = hex::decode(tv["out"].as_str().unwrap()).unwrap();
let mut ctx: Context<64> = if key.is_empty() {
Context::default()
} else {
let key = Key::try_from(key.as_slice()).unwrap();
Context::from(&key)
};
ctx.update(&input);
assert_eq!(output.as_slice(), ctx.finish().as_slice());
});
}
}