use blake3;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::ct::ct_eq;
mod serde_bytes {
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(bytes)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
<Vec<u8>>::deserialize(deserializer)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HmacError {
InvalidKeySize,
InvalidTagSize,
VerificationFailed,
SerializationError(String),
}
impl std::fmt::Display for HmacError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidKeySize => write!(f, "Invalid HMAC key size"),
Self::InvalidTagSize => write!(f, "Invalid HMAC tag size"),
Self::VerificationFailed => write!(f, "HMAC verification failed"),
Self::SerializationError(e) => write!(f, "Serialization error: {}", e),
}
}
}
impl std::error::Error for HmacError {}
pub type HmacResult<T> = Result<T, HmacError>;
pub const HMAC_KEY_SIZE: usize = 32;
pub const HMAC_SHA256_TAG_SIZE: usize = 32;
pub const HMAC_BLAKE3_TAG_SIZE: usize = 32;
#[derive(Clone, Zeroize, ZeroizeOnDrop, Serialize, Deserialize)]
pub struct HmacKey {
#[serde(with = "serde_bytes")]
key: Vec<u8>,
}
impl HmacKey {
pub fn generate() -> Self {
use rand::Rng as _;
let mut rng = rand::rng();
let mut key = vec![0u8; HMAC_KEY_SIZE];
rng.fill_bytes(&mut key[..]);
Self { key }
}
pub fn from_bytes(bytes: &[u8]) -> HmacResult<Self> {
if bytes.len() != HMAC_KEY_SIZE {
return Err(HmacError::InvalidKeySize);
}
Ok(Self {
key: bytes.to_vec(),
})
}
pub fn as_bytes(&self) -> &[u8] {
&self.key
}
pub fn to_bytes(&self) -> Vec<u8> {
self.key.clone()
}
pub fn derive_from_password(password: &[u8], salt: &[u8], iterations: u32) -> Self {
use pbkdf2::pbkdf2_hmac;
let mut key = vec![0u8; HMAC_KEY_SIZE];
pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut key);
Self { key }
}
}
impl std::fmt::Debug for HmacKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HmacKey")
.field("key", &"[REDACTED]")
.finish()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct HmacTag {
#[serde(with = "serde_bytes")]
tag: Vec<u8>,
}
impl HmacTag {
pub fn from_bytes(bytes: &[u8]) -> Self {
Self {
tag: bytes.to_vec(),
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.tag
}
pub fn to_bytes(&self) -> Vec<u8> {
self.tag.clone()
}
pub fn verify(&self, other: &Self) -> bool {
ct_eq(&self.tag, &other.tag)
}
}
pub fn compute_hmac_sha256(key: &HmacKey, message: &[u8]) -> HmacTag {
use hmac::digest::KeyInit;
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut mac = <HmacSha256 as KeyInit>::new_from_slice(key.as_bytes())
.expect("HMAC can take key of any size");
mac.update(message);
let result = mac.finalize();
HmacTag::from_bytes(&result.into_bytes())
}
pub fn compute_hmac_blake3(key: &HmacKey, message: &[u8]) -> HmacTag {
let key_array: [u8; 32] = key.as_bytes().try_into().expect("key is 32 bytes");
let hash = blake3::keyed_hash(&key_array, message);
HmacTag::from_bytes(hash.as_bytes())
}
pub fn compute_hmac(key: &HmacKey, message: &[u8]) -> HmacTag {
compute_hmac_blake3(key, message)
}
pub fn verify_hmac(key: &HmacKey, message: &[u8], tag: &HmacTag) -> bool {
let computed = compute_hmac(key, message);
computed.verify(tag)
}
pub fn verify_hmac_sha256(key: &HmacKey, message: &[u8], tag: &HmacTag) -> bool {
let computed = compute_hmac_sha256(key, message);
computed.verify(tag)
}
pub fn verify_hmac_blake3(key: &HmacKey, message: &[u8], tag: &HmacTag) -> bool {
let computed = compute_hmac_blake3(key, message);
computed.verify(tag)
}
pub fn compute_tagged_hmac(key: &HmacKey, context: &[u8], message: &[u8]) -> HmacTag {
let key_array: [u8; 32] = key.as_bytes().try_into().expect("key is 32 bytes");
let mut hasher = blake3::Hasher::new_keyed(&key_array);
hasher.update(context);
hasher.update(message);
HmacTag::from_bytes(hasher.finalize().as_bytes())
}
pub fn verify_tagged_hmac(key: &HmacKey, context: &[u8], message: &[u8], tag: &HmacTag) -> bool {
let computed = compute_tagged_hmac(key, context, message);
computed.verify(tag)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AuthenticatedMessage {
#[serde(with = "serde_bytes")]
data: Vec<u8>,
tag: HmacTag,
}
impl AuthenticatedMessage {
pub fn new(key: &HmacKey, data: Vec<u8>) -> Self {
let tag = compute_hmac(key, &data);
Self { data, tag }
}
pub fn new_tagged(key: &HmacKey, context: &[u8], data: Vec<u8>) -> Self {
let tag = compute_tagged_hmac(key, context, &data);
Self { data, tag }
}
pub fn verify(self, key: &HmacKey) -> HmacResult<Vec<u8>> {
if verify_hmac(key, &self.data, &self.tag) {
Ok(self.data)
} else {
Err(HmacError::VerificationFailed)
}
}
pub fn verify_tagged(self, key: &HmacKey, context: &[u8]) -> HmacResult<Vec<u8>> {
if verify_tagged_hmac(key, context, &self.data, &self.tag) {
Ok(self.data)
} else {
Err(HmacError::VerificationFailed)
}
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn tag(&self) -> &HmacTag {
&self.tag
}
pub fn to_bytes(&self) -> HmacResult<Vec<u8>> {
crate::codec::encode(self).map_err(|e| HmacError::SerializationError(e.to_string()))
}
pub fn from_bytes(bytes: &[u8]) -> HmacResult<Self> {
crate::codec::decode(bytes).map_err(|e| HmacError::SerializationError(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hmac_basic() {
let key = HmacKey::generate();
let message = b"Hello, CHIE!";
let tag = compute_hmac(&key, message);
assert!(verify_hmac(&key, message, &tag));
assert!(!verify_hmac(&key, b"Wrong message", &tag));
}
#[test]
fn test_hmac_sha256() {
let key = HmacKey::generate();
let message = b"Test message";
let tag = compute_hmac_sha256(&key, message);
assert!(verify_hmac_sha256(&key, message, &tag));
assert_eq!(tag.as_bytes().len(), HMAC_SHA256_TAG_SIZE);
}
#[test]
fn test_hmac_blake3() {
let key = HmacKey::generate();
let message = b"Test message";
let tag = compute_hmac_blake3(&key, message);
assert!(verify_hmac_blake3(&key, message, &tag));
assert_eq!(tag.as_bytes().len(), HMAC_BLAKE3_TAG_SIZE);
}
#[test]
fn test_tagged_hmac() {
let key = HmacKey::generate();
let context = b"CHIE:BandwidthProof";
let message = b"1234567890";
let tag = compute_tagged_hmac(&key, context, message);
assert!(verify_tagged_hmac(&key, context, message, &tag));
assert!(!verify_tagged_hmac(&key, b"wrong", message, &tag));
}
#[test]
fn test_authenticated_message() {
let key = HmacKey::generate();
let data = b"Secret data".to_vec();
let msg = AuthenticatedMessage::new(&key, data.clone());
let verified = msg.verify(&key).unwrap();
assert_eq!(verified, data);
}
#[test]
fn test_authenticated_message_fails() {
let key1 = HmacKey::generate();
let key2 = HmacKey::generate();
let data = b"Secret data".to_vec();
let msg = AuthenticatedMessage::new(&key1, data);
assert!(msg.verify(&key2).is_err());
}
#[test]
fn test_tagged_authenticated_message() {
let key = HmacKey::generate();
let context = b"CHIE:Chunk";
let data = b"Chunk data".to_vec();
let msg = AuthenticatedMessage::new_tagged(&key, context, data.clone());
let verified = msg.verify_tagged(&key, context).unwrap();
assert_eq!(verified, data);
}
#[test]
fn test_hmac_key_from_bytes() {
let bytes = [42u8; HMAC_KEY_SIZE];
let key = HmacKey::from_bytes(&bytes).unwrap();
assert_eq!(key.as_bytes(), &bytes);
}
#[test]
fn test_hmac_key_invalid_size() {
let bytes = [42u8; 16]; assert!(HmacKey::from_bytes(&bytes).is_err());
}
#[test]
fn test_hmac_key_derive_from_password() {
let password = b"my secret password";
let salt = b"unique salt";
let key1 = HmacKey::derive_from_password(password, salt, 10000);
let key2 = HmacKey::derive_from_password(password, salt, 10000);
assert_eq!(key1.as_bytes(), key2.as_bytes());
let key3 = HmacKey::derive_from_password(password, b"different salt", 10000);
assert_ne!(key1.as_bytes(), key3.as_bytes());
}
#[test]
fn test_serialization() {
let key = HmacKey::generate();
let data = b"Test data".to_vec();
let msg = AuthenticatedMessage::new(&key, data);
let bytes = msg.to_bytes().unwrap();
let deserialized = AuthenticatedMessage::from_bytes(&bytes).unwrap();
assert_eq!(msg.data, deserialized.data);
assert_eq!(msg.tag, deserialized.tag);
}
#[test]
fn test_constant_time_verification() {
let key = HmacKey::generate();
let message = b"Test message";
let tag1 = compute_hmac(&key, message);
let tag2 = compute_hmac(&key, message);
assert!(tag1.verify(&tag2));
let mut wrong_tag = tag1.clone();
wrong_tag.tag[0] ^= 1;
assert!(!tag1.verify(&wrong_tag));
}
}