use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use crate::FORMAT_VERSION;
use crate::error::KeyError;
type HmacSha256 = Hmac<Sha256>;
const LOOKUP_CONTEXT: &str = "lookup";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct KeyVersion(String);
impl KeyVersion {
#[must_use]
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl core::fmt::Display for KeyVersion {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct LookupKey(String);
impl LookupKey {
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn ct_eq(&self, other: &LookupKey) -> bool {
let a = self.0.as_bytes();
let b = other.0.as_bytes();
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SecretDomain {
Code,
Session,
FormToken,
FlowTicket,
}
impl SecretDomain {
#[must_use]
pub const fn label(self) -> &'static str {
match self {
SecretDomain::Code => "code",
SecretDomain::Session => "session",
SecretDomain::FormToken => "form_token",
SecretDomain::FlowTicket => "flow_ticket",
}
}
}
pub struct HmacKeyRef<'a> {
pub version: KeyVersion,
pub bytes: &'a [u8],
}
impl core::fmt::Debug for HmacKeyRef<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HmacKeyRef")
.field("version", &self.version)
.field("bytes", &"<redacted>")
.finish()
}
}
pub trait KeyProvider {
fn active_hmac_key(&self) -> Result<HmacKeyRef<'_>, KeyError>;
fn hmac_key_by_version(&self, version: &KeyVersion) -> Result<HmacKeyRef<'_>, KeyError>;
}
#[derive(Clone)]
pub struct StaticKeyProvider {
active_version: KeyVersion,
keys: Vec<(KeyVersion, Vec<u8>)>,
}
impl StaticKeyProvider {
pub fn new(
active_version: impl Into<String>,
active_key: Vec<u8>,
previous: Vec<(KeyVersion, Vec<u8>)>,
) -> Result<Self, KeyError> {
if active_key.is_empty() {
return Err(KeyError::InvalidKeyMaterial);
}
let active_version = KeyVersion::new(active_version);
let mut keys = Vec::with_capacity(previous.len() + 1);
keys.push((active_version.clone(), active_key));
keys.extend(previous);
Ok(Self {
active_version,
keys,
})
}
pub fn single(version: impl Into<String>, key: Vec<u8>) -> Result<Self, KeyError> {
Self::new(version, key, Vec::new())
}
}
impl core::fmt::Debug for StaticKeyProvider {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("StaticKeyProvider")
.field("active_version", &self.active_version)
.field("key_versions", &self.keys.len())
.field("keys", &"<redacted>")
.finish()
}
}
impl KeyProvider for StaticKeyProvider {
fn active_hmac_key(&self) -> Result<HmacKeyRef<'_>, KeyError> {
self.keys
.iter()
.find(|(v, _)| *v == self.active_version)
.map(|(v, k)| HmacKeyRef {
version: v.clone(),
bytes: k,
})
.ok_or(KeyError::MissingActiveKey)
}
fn hmac_key_by_version(&self, version: &KeyVersion) -> Result<HmacKeyRef<'_>, KeyError> {
self.keys
.iter()
.find(|(v, _)| v == version)
.map(|(v, k)| HmacKeyRef {
version: v.clone(),
bytes: k,
})
.ok_or(KeyError::MissingKeyVersion)
}
}
#[derive(Debug, Clone)]
pub struct SecretHasher<K> {
key_provider: K,
}
impl<K: KeyProvider> SecretHasher<K> {
#[must_use]
pub fn new(key_provider: K) -> Self {
Self { key_provider }
}
#[must_use]
pub fn key_provider(&self) -> &K {
&self.key_provider
}
pub fn lookup_key(
&self,
domain: SecretDomain,
value: &str,
) -> Result<(LookupKey, KeyVersion), KeyError> {
let key = self.key_provider.active_hmac_key()?;
let lk = derive(key.bytes, domain, value);
Ok((lk, key.version))
}
pub fn lookup_key_with_version(
&self,
domain: SecretDomain,
value: &str,
version: &KeyVersion,
) -> Result<LookupKey, KeyError> {
let key = self.key_provider.hmac_key_by_version(version)?;
Ok(derive(key.bytes, domain, value))
}
}
fn derive(key_bytes: &[u8], domain: SecretDomain, value: &str) -> LookupKey {
let mut mac =
HmacSha256::new_from_slice(key_bytes).expect("HMAC-SHA256 accepts any key length");
mac.update(FORMAT_VERSION.as_bytes());
mac.update(b"/");
mac.update(LOOKUP_CONTEXT.as_bytes());
mac.update(&[0u8]);
mac.update(domain.label().as_bytes());
mac.update(&[0u8]);
mac.update(value.as_bytes());
let digest = mac.finalize().into_bytes();
LookupKey(hex_lower(&digest))
}
fn hex_lower(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut s = String::with_capacity(bytes.len() * 2);
for &b in bytes {
s.push(HEX[(b >> 4) as usize] as char);
s.push(HEX[(b & 0x0f) as usize] as char);
}
s
}
#[cfg(test)]
mod tests {
use super::*;
fn hasher() -> SecretHasher<StaticKeyProvider> {
let kp = StaticKeyProvider::single("v1", b"super-secret-key-material".to_vec()).unwrap();
SecretHasher::new(kp)
}
#[test]
fn deterministic_same_inputs_same_key() {
let h = hasher();
let (a, va) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
let (b, vb) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
assert_eq!(a, b);
assert_eq!(va, vb);
assert_eq!(va.as_str(), "v1");
assert_eq!(a.as_str().len(), 64);
assert!(a.as_str().bytes().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn different_value_different_key() {
let h = hasher();
let (a, _) = h.lookup_key(SecretDomain::Code, "AAAAAAAA").unwrap();
let (b, _) = h.lookup_key(SecretDomain::Code, "BBBBBBBB").unwrap();
assert_ne!(a, b);
}
#[test]
fn domain_separation_distinguishes_same_value() {
let h = hasher();
let (code, _) = h.lookup_key(SecretDomain::Code, "SAME").unwrap();
let (sess, _) = h.lookup_key(SecretDomain::Session, "SAME").unwrap();
let (form, _) = h.lookup_key(SecretDomain::FormToken, "SAME").unwrap();
let (flow, _) = h.lookup_key(SecretDomain::FlowTicket, "SAME").unwrap();
let all = [&code, &sess, &form, &flow];
for i in 0..all.len() {
for j in (i + 1)..all.len() {
assert_ne!(all[i], all[j], "domains {i},{j} collided");
}
}
}
#[test]
fn different_key_different_output() {
let h1 = SecretHasher::new(StaticKeyProvider::single("v1", b"key-one".to_vec()).unwrap());
let h2 = SecretHasher::new(StaticKeyProvider::single("v1", b"key-two".to_vec()).unwrap());
let (a, _) = h1.lookup_key(SecretDomain::Code, "X").unwrap();
let (b, _) = h2.lookup_key(SecretDomain::Code, "X").unwrap();
assert_ne!(a, b);
}
#[test]
fn missing_active_key_fails_closed() {
let kp = StaticKeyProvider {
active_version: KeyVersion::new("missing"),
keys: vec![(KeyVersion::new("v1"), b"k".to_vec())],
};
let h = SecretHasher::new(kp);
assert_eq!(
h.lookup_key(SecretDomain::Code, "X").unwrap_err(),
KeyError::MissingActiveKey
);
}
#[test]
fn empty_key_rejected_at_construction() {
assert_eq!(
StaticKeyProvider::single("v1", Vec::new()).unwrap_err(),
KeyError::InvalidKeyMaterial
);
}
#[test]
fn key_version_round_trip_validation() {
let kp = StaticKeyProvider::new(
"v2",
b"key-two".to_vec(),
vec![(KeyVersion::new("v1"), b"key-one".to_vec())],
)
.unwrap();
let h = SecretHasher::new(kp);
let (active, av) = h.lookup_key(SecretDomain::Session, "tok").unwrap();
assert_eq!(av.as_str(), "v2");
let v1 = KeyVersion::new("v1");
let prev = h
.lookup_key_with_version(SecretDomain::Session, "tok", &v1)
.unwrap();
assert_ne!(active, prev);
let missing = KeyVersion::new("v9");
assert_eq!(
h.lookup_key_with_version(SecretDomain::Session, "tok", &missing)
.unwrap_err(),
KeyError::MissingKeyVersion
);
}
#[test]
fn lookup_key_ct_eq_matches_value_eq() {
let h = hasher();
let (a, _) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
let (b, _) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
let (c, _) = h.lookup_key(SecretDomain::Code, "DIFFEREN").unwrap();
assert!(a.ct_eq(&b));
assert!(!a.ct_eq(&c));
}
#[test]
fn key_material_redacted_in_debug() {
let kp = StaticKeyProvider::single("v1", b"secret-bytes".to_vec()).unwrap();
let dbg = format!("{kp:?}");
assert!(!dbg.contains("secret-bytes"), "key bytes leaked: {dbg}");
assert!(dbg.contains("<redacted>"));
let key = kp.active_hmac_key().unwrap();
let kdbg = format!("{key:?}");
assert!(!kdbg.contains("secret-bytes"));
}
}