1use hmac::{Hmac, Mac};
25use sha2::Sha256;
26use subtle::ConstantTimeEq;
27
28use crate::FORMAT_VERSION;
29use crate::error::KeyError;
30
31type HmacSha256 = Hmac<Sha256>;
32
33const LOOKUP_CONTEXT: &str = "lookup";
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41pub struct KeyVersion(String);
42
43impl KeyVersion {
44 #[must_use]
46 pub fn new(value: impl Into<String>) -> Self {
47 Self(value.into())
48 }
49
50 #[must_use]
52 pub fn as_str(&self) -> &str {
53 &self.0
54 }
55}
56
57impl core::fmt::Display for KeyVersion {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59 f.write_str(&self.0)
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70pub struct LookupKey(String);
71
72impl LookupKey {
73 #[must_use]
75 pub fn as_str(&self) -> &str {
76 &self.0
77 }
78
79 #[must_use]
82 pub fn ct_eq(&self, other: &LookupKey) -> bool {
83 let a = self.0.as_bytes();
84 let b = other.0.as_bytes();
85 if a.len() != b.len() {
86 return false;
87 }
88 a.ct_eq(b).into()
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
95pub enum SecretDomain {
96 Code,
98 Session,
100 FormToken,
102 FlowTicket,
104}
105
106impl SecretDomain {
107 #[must_use]
110 pub const fn label(self) -> &'static str {
111 match self {
112 SecretDomain::Code => "code",
113 SecretDomain::Session => "session",
114 SecretDomain::FormToken => "form_token",
115 SecretDomain::FlowTicket => "flow_ticket",
116 }
117 }
118}
119
120pub struct HmacKeyRef<'a> {
122 pub version: KeyVersion,
124 pub bytes: &'a [u8],
126}
127
128impl core::fmt::Debug for HmacKeyRef<'_> {
129 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
130 f.debug_struct("HmacKeyRef")
131 .field("version", &self.version)
132 .field("bytes", &"<redacted>")
133 .finish()
134 }
135}
136
137pub trait KeyProvider {
141 fn active_hmac_key(&self) -> Result<HmacKeyRef<'_>, KeyError>;
146
147 fn hmac_key_by_version(&self, version: &KeyVersion) -> Result<HmacKeyRef<'_>, KeyError>;
154}
155
156#[derive(Clone)]
163pub struct StaticKeyProvider {
164 active_version: KeyVersion,
165 keys: Vec<(KeyVersion, Vec<u8>)>,
166}
167
168impl StaticKeyProvider {
169 pub fn new(
174 active_version: impl Into<String>,
175 active_key: Vec<u8>,
176 previous: Vec<(KeyVersion, Vec<u8>)>,
177 ) -> Result<Self, KeyError> {
178 if active_key.is_empty() {
179 return Err(KeyError::InvalidKeyMaterial);
180 }
181 let active_version = KeyVersion::new(active_version);
182 let mut keys = Vec::with_capacity(previous.len() + 1);
183 keys.push((active_version.clone(), active_key));
184 keys.extend(previous);
185 Ok(Self {
186 active_version,
187 keys,
188 })
189 }
190
191 pub fn single(version: impl Into<String>, key: Vec<u8>) -> Result<Self, KeyError> {
196 Self::new(version, key, Vec::new())
197 }
198}
199
200impl core::fmt::Debug for StaticKeyProvider {
201 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
202 f.debug_struct("StaticKeyProvider")
203 .field("active_version", &self.active_version)
204 .field("key_versions", &self.keys.len())
205 .field("keys", &"<redacted>")
206 .finish()
207 }
208}
209
210impl KeyProvider for StaticKeyProvider {
211 fn active_hmac_key(&self) -> Result<HmacKeyRef<'_>, KeyError> {
212 self.keys
213 .iter()
214 .find(|(v, _)| *v == self.active_version)
215 .map(|(v, k)| HmacKeyRef {
216 version: v.clone(),
217 bytes: k,
218 })
219 .ok_or(KeyError::MissingActiveKey)
220 }
221
222 fn hmac_key_by_version(&self, version: &KeyVersion) -> Result<HmacKeyRef<'_>, KeyError> {
223 self.keys
224 .iter()
225 .find(|(v, _)| v == version)
226 .map(|(v, k)| HmacKeyRef {
227 version: v.clone(),
228 bytes: k,
229 })
230 .ok_or(KeyError::MissingKeyVersion)
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct SecretHasher<K> {
237 key_provider: K,
238}
239
240impl<K: KeyProvider> SecretHasher<K> {
241 #[must_use]
243 pub fn new(key_provider: K) -> Self {
244 Self { key_provider }
245 }
246
247 #[must_use]
249 pub fn key_provider(&self) -> &K {
250 &self.key_provider
251 }
252
253 pub fn lookup_key(
259 &self,
260 domain: SecretDomain,
261 value: &str,
262 ) -> Result<(LookupKey, KeyVersion), KeyError> {
263 let key = self.key_provider.active_hmac_key()?;
264 let lk = derive(key.bytes, domain, value);
265 Ok((lk, key.version))
266 }
267
268 pub fn lookup_key_with_version(
275 &self,
276 domain: SecretDomain,
277 value: &str,
278 version: &KeyVersion,
279 ) -> Result<LookupKey, KeyError> {
280 let key = self.key_provider.hmac_key_by_version(version)?;
281 Ok(derive(key.bytes, domain, value))
282 }
283}
284
285fn derive(key_bytes: &[u8], domain: SecretDomain, value: &str) -> LookupKey {
289 let mut mac =
292 HmacSha256::new_from_slice(key_bytes).expect("HMAC-SHA256 accepts any key length");
293 mac.update(FORMAT_VERSION.as_bytes());
294 mac.update(b"/");
295 mac.update(LOOKUP_CONTEXT.as_bytes());
296 mac.update(&[0u8]);
297 mac.update(domain.label().as_bytes());
298 mac.update(&[0u8]);
299 mac.update(value.as_bytes());
300 let digest = mac.finalize().into_bytes();
301 LookupKey(hex_lower(&digest))
302}
303
304fn hex_lower(bytes: &[u8]) -> String {
307 const HEX: &[u8; 16] = b"0123456789abcdef";
308 let mut s = String::with_capacity(bytes.len() * 2);
309 for &b in bytes {
310 s.push(HEX[(b >> 4) as usize] as char);
311 s.push(HEX[(b & 0x0f) as usize] as char);
312 }
313 s
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 fn hasher() -> SecretHasher<StaticKeyProvider> {
321 let kp = StaticKeyProvider::single("v1", b"super-secret-key-material".to_vec()).unwrap();
322 SecretHasher::new(kp)
323 }
324
325 #[test]
326 fn deterministic_same_inputs_same_key() {
327 let h = hasher();
328 let (a, va) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
329 let (b, vb) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
330 assert_eq!(a, b);
331 assert_eq!(va, vb);
332 assert_eq!(va.as_str(), "v1");
333 assert_eq!(a.as_str().len(), 64);
335 assert!(a.as_str().bytes().all(|c| c.is_ascii_hexdigit()));
336 }
337
338 #[test]
339 fn different_value_different_key() {
340 let h = hasher();
341 let (a, _) = h.lookup_key(SecretDomain::Code, "AAAAAAAA").unwrap();
342 let (b, _) = h.lookup_key(SecretDomain::Code, "BBBBBBBB").unwrap();
343 assert_ne!(a, b);
344 }
345
346 #[test]
347 fn domain_separation_distinguishes_same_value() {
348 let h = hasher();
349 let (code, _) = h.lookup_key(SecretDomain::Code, "SAME").unwrap();
350 let (sess, _) = h.lookup_key(SecretDomain::Session, "SAME").unwrap();
351 let (form, _) = h.lookup_key(SecretDomain::FormToken, "SAME").unwrap();
352 let (flow, _) = h.lookup_key(SecretDomain::FlowTicket, "SAME").unwrap();
353 let all = [&code, &sess, &form, &flow];
355 for i in 0..all.len() {
356 for j in (i + 1)..all.len() {
357 assert_ne!(all[i], all[j], "domains {i},{j} collided");
358 }
359 }
360 }
361
362 #[test]
363 fn different_key_different_output() {
364 let h1 = SecretHasher::new(StaticKeyProvider::single("v1", b"key-one".to_vec()).unwrap());
365 let h2 = SecretHasher::new(StaticKeyProvider::single("v1", b"key-two".to_vec()).unwrap());
366 let (a, _) = h1.lookup_key(SecretDomain::Code, "X").unwrap();
367 let (b, _) = h2.lookup_key(SecretDomain::Code, "X").unwrap();
368 assert_ne!(a, b);
369 }
370
371 #[test]
372 fn missing_active_key_fails_closed() {
373 let kp = StaticKeyProvider {
375 active_version: KeyVersion::new("missing"),
376 keys: vec![(KeyVersion::new("v1"), b"k".to_vec())],
377 };
378 let h = SecretHasher::new(kp);
379 assert_eq!(
380 h.lookup_key(SecretDomain::Code, "X").unwrap_err(),
381 KeyError::MissingActiveKey
382 );
383 }
384
385 #[test]
386 fn empty_key_rejected_at_construction() {
387 assert_eq!(
388 StaticKeyProvider::single("v1", Vec::new()).unwrap_err(),
389 KeyError::InvalidKeyMaterial
390 );
391 }
392
393 #[test]
394 fn key_version_round_trip_validation() {
395 let kp = StaticKeyProvider::new(
397 "v2",
398 b"key-two".to_vec(),
399 vec![(KeyVersion::new("v1"), b"key-one".to_vec())],
400 )
401 .unwrap();
402 let h = SecretHasher::new(kp);
403 let (active, av) = h.lookup_key(SecretDomain::Session, "tok").unwrap();
404 assert_eq!(av.as_str(), "v2");
405 let v1 = KeyVersion::new("v1");
406 let prev = h
407 .lookup_key_with_version(SecretDomain::Session, "tok", &v1)
408 .unwrap();
409 assert_ne!(active, prev);
411 let missing = KeyVersion::new("v9");
413 assert_eq!(
414 h.lookup_key_with_version(SecretDomain::Session, "tok", &missing)
415 .unwrap_err(),
416 KeyError::MissingKeyVersion
417 );
418 }
419
420 #[test]
421 fn lookup_key_ct_eq_matches_value_eq() {
422 let h = hasher();
423 let (a, _) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
424 let (b, _) = h.lookup_key(SecretDomain::Code, "ABCD2345").unwrap();
425 let (c, _) = h.lookup_key(SecretDomain::Code, "DIFFEREN").unwrap();
426 assert!(a.ct_eq(&b));
427 assert!(!a.ct_eq(&c));
428 }
429
430 #[test]
431 fn key_material_redacted_in_debug() {
432 let kp = StaticKeyProvider::single("v1", b"secret-bytes".to_vec()).unwrap();
433 let dbg = format!("{kp:?}");
434 assert!(!dbg.contains("secret-bytes"), "key bytes leaked: {dbg}");
435 assert!(dbg.contains("<redacted>"));
436 let key = kp.active_hmac_key().unwrap();
437 let kdbg = format!("{key:?}");
438 assert!(!kdbg.contains("secret-bytes"));
439 }
440}