1use hmac::{Hmac, KeyInit, Mac};
25use sha2::Sha256;
26use subtle::ConstantTimeEq;
27
28use crate::FORMAT_VERSION;
29use crate::error::KeyError;
30
31type HmacSha256 = Hmac<Sha256>;
32
33const LOOKUP_CONTEXT: &str = "lookup";
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41pub struct KeyVersion(String);
42
43impl KeyVersion {
44 #[must_use]
46 pub fn new(value: impl Into<String>) -> Self {
47 Self(value.into())
48 }
49
50 #[must_use]
52 pub fn as_str(&self) -> &str {
53 &self.0
54 }
55}
56
57impl core::fmt::Display for KeyVersion {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59 f.write_str(&self.0)
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70pub struct LookupKey(String);
71
72impl LookupKey {
73 #[must_use]
75 pub fn as_str(&self) -> &str {
76 &self.0
77 }
78
79 #[must_use]
82 pub fn ct_eq(&self, other: &LookupKey) -> bool {
83 let a = self.0.as_bytes();
84 let b = other.0.as_bytes();
85 if a.len() != b.len() {
86 return false;
87 }
88 a.ct_eq(b).into()
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
95pub enum SecretDomain {
96 Code,
98 Session,
100 FormToken,
102 FlowTicket,
104}
105
106impl SecretDomain {
107 #[must_use]
110 pub const fn label(self) -> &'static str {
111 match self {
112 SecretDomain::Code => "code",
113 SecretDomain::Session => "session",
114 SecretDomain::FormToken => "form_token",
115 SecretDomain::FlowTicket => "flow_ticket",
116 }
117 }
118}
119
120pub struct HmacKeyRef<'a> {
122 pub version: KeyVersion,
124 pub bytes: &'a [u8],
126}
127
128impl core::fmt::Debug for HmacKeyRef<'_> {
129 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
130 f.debug_struct("HmacKeyRef")
131 .field("version", &self.version)
132 .field("bytes", &"<redacted>")
133 .finish()
134 }
135}
136
137pub trait KeyProvider {
141 fn active_hmac_key(&self) -> Result<HmacKeyRef<'_>, KeyError>;
146
147 fn hmac_key_by_version(&self, version: &KeyVersion) -> Result<HmacKeyRef<'_>, KeyError>;
154
155 fn all_hmac_keys(&self) -> Result<Vec<HmacKeyRef<'_>>, KeyError>;
160}
161
162#[derive(Clone)]
169pub struct StaticKeyProvider {
170 active_version: KeyVersion,
171 keys: Vec<(KeyVersion, Vec<u8>)>,
172}
173
174impl StaticKeyProvider {
175 pub fn new(
180 active_version: impl Into<String>,
181 active_key: Vec<u8>,
182 previous: Vec<(KeyVersion, Vec<u8>)>,
183 ) -> Result<Self, KeyError> {
184 if active_key.is_empty() {
185 return Err(KeyError::InvalidKeyMaterial);
186 }
187 let active_version = KeyVersion::new(active_version);
188 let mut keys = Vec::with_capacity(previous.len() + 1);
189 keys.push((active_version.clone(), active_key));
190 keys.extend(previous);
191 Ok(Self {
192 active_version,
193 keys,
194 })
195 }
196
197 pub fn single(version: impl Into<String>, key: Vec<u8>) -> Result<Self, KeyError> {
202 Self::new(version, key, Vec::new())
203 }
204}
205
206impl core::fmt::Debug for StaticKeyProvider {
207 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
208 f.debug_struct("StaticKeyProvider")
209 .field("active_version", &self.active_version)
210 .field("key_versions", &self.keys.len())
211 .field("keys", &"<redacted>")
212 .finish()
213 }
214}
215
216impl KeyProvider for StaticKeyProvider {
217 fn active_hmac_key(&self) -> Result<HmacKeyRef<'_>, KeyError> {
218 self.keys
219 .iter()
220 .find(|(v, _)| *v == self.active_version)
221 .map(|(v, k)| HmacKeyRef {
222 version: v.clone(),
223 bytes: k,
224 })
225 .ok_or(KeyError::MissingActiveKey)
226 }
227
228 fn hmac_key_by_version(&self, version: &KeyVersion) -> Result<HmacKeyRef<'_>, KeyError> {
229 self.keys
230 .iter()
231 .find(|(v, _)| v == version)
232 .map(|(v, k)| HmacKeyRef {
233 version: v.clone(),
234 bytes: k,
235 })
236 .ok_or(KeyError::MissingKeyVersion)
237 }
238
239 fn all_hmac_keys(&self) -> Result<Vec<HmacKeyRef<'_>>, KeyError> {
240 if self.keys.is_empty() {
241 return Err(KeyError::MissingActiveKey);
242 }
243 Ok(self
244 .keys
245 .iter()
246 .map(|(v, k)| HmacKeyRef {
247 version: v.clone(),
248 bytes: k,
249 })
250 .collect())
251 }
252}
253
254#[derive(Debug, Clone)]
256pub struct SecretHasher<K> {
257 key_provider: K,
258}
259
260impl<K: KeyProvider> SecretHasher<K> {
261 #[must_use]
263 pub fn new(key_provider: K) -> Self {
264 Self { key_provider }
265 }
266
267 #[must_use]
269 pub fn key_provider(&self) -> &K {
270 &self.key_provider
271 }
272
273 pub fn lookup_key(
279 &self,
280 domain: SecretDomain,
281 value: &str,
282 ) -> Result<(LookupKey, KeyVersion), KeyError> {
283 let key = self.key_provider.active_hmac_key()?;
284 let lk = derive(key.bytes, domain, value);
285 Ok((lk, key.version))
286 }
287
288 pub fn lookup_key_candidates(
296 &self,
297 domain: SecretDomain,
298 value: &str,
299 ) -> Result<Vec<(LookupKey, KeyVersion)>, KeyError> {
300 let keys = self.key_provider.all_hmac_keys()?;
301 Ok(keys
302 .into_iter()
303 .map(|k| {
304 let lk = derive(k.bytes, domain, value);
305 (lk, k.version)
306 })
307 .collect())
308 }
309
310 pub fn lookup_key_with_version(
317 &self,
318 domain: SecretDomain,
319 value: &str,
320 version: &KeyVersion,
321 ) -> Result<LookupKey, KeyError> {
322 let key = self.key_provider.hmac_key_by_version(version)?;
323 Ok(derive(key.bytes, domain, value))
324 }
325}
326
327fn derive(key_bytes: &[u8], domain: SecretDomain, value: &str) -> LookupKey {
331 let mut mac =
334 HmacSha256::new_from_slice(key_bytes).expect("HMAC-SHA256 accepts any key length");
335 mac.update(FORMAT_VERSION.as_bytes());
336 mac.update(b"/");
337 mac.update(LOOKUP_CONTEXT.as_bytes());
338 mac.update(&[0u8]);
339 mac.update(domain.label().as_bytes());
340 mac.update(&[0u8]);
341 mac.update(value.as_bytes());
342 let digest = mac.finalize().into_bytes();
343 LookupKey(hex_lower(&digest))
344}
345
346fn hex_lower(bytes: &[u8]) -> String {
349 const HEX: &[u8; 16] = b"0123456789abcdef";
350 let mut s = String::with_capacity(bytes.len() * 2);
351 for &b in bytes {
352 s.push(HEX[(b >> 4) as usize] as char);
353 s.push(HEX[(b & 0x0f) as usize] as char);
354 }
355 s
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 fn hasher() -> SecretHasher<StaticKeyProvider> {
363 let kp = StaticKeyProvider::single("v1", b"super-secret-key-material".to_vec()).unwrap();
364 SecretHasher::new(kp)
365 }
366
367 #[test]
368 fn deterministic_same_inputs_same_key() {
369 let h = hasher();
370 let (a, va) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
371 let (b, vb) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
372 assert_eq!(a, b);
373 assert_eq!(va, vb);
374 assert_eq!(va.as_str(), "v1");
375 assert_eq!(a.as_str().len(), 64);
377 assert!(a.as_str().bytes().all(|c| c.is_ascii_hexdigit()));
378 }
379
380 #[test]
381 fn different_value_different_key() {
382 let h = hasher();
383 let (a, _) = h.lookup_key(SecretDomain::Code, "AAAAAAAA").unwrap();
384 let (b, _) = h.lookup_key(SecretDomain::Code, "BBBBBBBB").unwrap();
385 assert_ne!(a, b);
386 }
387
388 #[test]
389 fn domain_separation_distinguishes_same_value() {
390 let h = hasher();
391 let (code, _) = h.lookup_key(SecretDomain::Code, "SAME").unwrap();
392 let (sess, _) = h.lookup_key(SecretDomain::Session, "SAME").unwrap();
393 let (form, _) = h.lookup_key(SecretDomain::FormToken, "SAME").unwrap();
394 let (flow, _) = h.lookup_key(SecretDomain::FlowTicket, "SAME").unwrap();
395 let all = [&code, &sess, &form, &flow];
397 for i in 0..all.len() {
398 for j in (i + 1)..all.len() {
399 assert_ne!(all[i], all[j], "domains {i},{j} collided");
400 }
401 }
402 }
403
404 #[test]
405 fn different_key_different_output() {
406 let h1 = SecretHasher::new(StaticKeyProvider::single("v1", b"key-one".to_vec()).unwrap());
407 let h2 = SecretHasher::new(StaticKeyProvider::single("v1", b"key-two".to_vec()).unwrap());
408 let (a, _) = h1.lookup_key(SecretDomain::Code, "X").unwrap();
409 let (b, _) = h2.lookup_key(SecretDomain::Code, "X").unwrap();
410 assert_ne!(a, b);
411 }
412
413 #[test]
414 fn missing_active_key_fails_closed() {
415 let kp = StaticKeyProvider {
417 active_version: KeyVersion::new("missing"),
418 keys: vec![(KeyVersion::new("v1"), b"k".to_vec())],
419 };
420 let h = SecretHasher::new(kp);
421 assert_eq!(
422 h.lookup_key(SecretDomain::Code, "X").unwrap_err(),
423 KeyError::MissingActiveKey
424 );
425 }
426
427 #[test]
428 fn empty_key_rejected_at_construction() {
429 assert_eq!(
430 StaticKeyProvider::single("v1", Vec::new()).unwrap_err(),
431 KeyError::InvalidKeyMaterial
432 );
433 }
434
435 #[test]
436 fn key_version_round_trip_validation() {
437 let kp = StaticKeyProvider::new(
439 "v2",
440 b"key-two".to_vec(),
441 vec![(KeyVersion::new("v1"), b"key-one".to_vec())],
442 )
443 .unwrap();
444 let h = SecretHasher::new(kp);
445 let (active, av) = h.lookup_key(SecretDomain::Session, "tok").unwrap();
446 assert_eq!(av.as_str(), "v2");
447 let v1 = KeyVersion::new("v1");
448 let prev = h
449 .lookup_key_with_version(SecretDomain::Session, "tok", &v1)
450 .unwrap();
451 assert_ne!(active, prev);
453 let missing = KeyVersion::new("v9");
455 assert_eq!(
456 h.lookup_key_with_version(SecretDomain::Session, "tok", &missing)
457 .unwrap_err(),
458 KeyError::MissingKeyVersion
459 );
460 }
461
462 #[test]
463 fn lookup_key_ct_eq_matches_value_eq() {
464 let h = hasher();
465 let (a, _) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
466 let (b, _) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
467 let (c, _) = h.lookup_key(SecretDomain::Code, "DIFFEREN").unwrap();
468 assert!(a.ct_eq(&b));
469 assert!(!a.ct_eq(&c));
470 }
471
472 #[test]
473 fn key_material_redacted_in_debug() {
474 let kp = StaticKeyProvider::single("v1", b"secret-bytes".to_vec()).unwrap();
475 let dbg = format!("{kp:?}");
476 assert!(!dbg.contains("secret-bytes"), "key bytes leaked: {dbg}");
477 assert!(dbg.contains("<redacted>"));
478 let key = kp.active_hmac_key().unwrap();
479 let kdbg = format!("{key:?}");
480 assert!(!kdbg.contains("secret-bytes"));
481 }
482}