1use crate::error::{Error, Result};
2use crate::model::{StoredValue, Value};
3use crate::schema::FieldProtection;
4use aes_gcm::aead::Aead;
5use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
6use hmac::{Hmac, Mac};
7use rand::RngCore;
8use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey};
9use rsa::{Oaep, Pkcs1v15Sign, RsaPrivateKey, RsaPublicKey};
10use sha2::{Digest, Sha256};
11use std::collections::BTreeMap;
12
13type HmacSha256 = Hmac<Sha256>;
14
15pub type KeyVersion = u32;
21
22pub const KEYLESS: KeyVersion = 0;
25
26pub(crate) mod b64 {
29 const ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
30
31 pub fn encode(data: &[u8]) -> String {
32 let mut out = String::with_capacity(data.len().div_ceil(3) * 4);
33 for chunk in data.chunks(3) {
34 let b = [
35 chunk[0],
36 *chunk.get(1).unwrap_or(&0),
37 *chunk.get(2).unwrap_or(&0),
38 ];
39 let n = ((b[0] as u32) << 16) | ((b[1] as u32) << 8) | b[2] as u32;
40 out.push(ALPHABET[(n >> 18) as usize & 63] as char);
41 out.push(ALPHABET[(n >> 12) as usize & 63] as char);
42 out.push(if chunk.len() > 1 {
43 ALPHABET[(n >> 6) as usize & 63] as char
44 } else {
45 '='
46 });
47 out.push(if chunk.len() > 2 {
48 ALPHABET[n as usize & 63] as char
49 } else {
50 '='
51 });
52 }
53 out
54 }
55
56 pub fn decode(s: &str) -> Option<Vec<u8>> {
57 let mut buf = Vec::with_capacity(s.len() / 4 * 3);
58 let mut acc: u32 = 0;
59 let mut bits = 0u8;
60 for c in s.bytes() {
61 if c == b'=' {
62 break;
63 }
64 let v = ALPHABET.iter().position(|&a| a == c)? as u32;
65 acc = (acc << 6) | v;
66 bits += 6;
67 if bits >= 8 {
68 bits -= 8;
69 buf.push((acc >> bits) as u8);
70 }
71 }
72 Some(buf)
73 }
74}
75
76#[derive(Clone, Default)]
79struct RsaPair {
80 public: Option<RsaPublicKey>,
81 private: Option<RsaPrivateKey>,
82}
83
84#[derive(Clone, Default)]
105pub struct KeyRing {
106 rsa: BTreeMap<KeyVersion, RsaPair>,
107 macs: BTreeMap<KeyVersion, Vec<u8>>,
108}
109
110fn check_version(version: KeyVersion) -> Result<()> {
111 if version == KEYLESS {
112 return Err(Error::Crypto(
113 "key version 0 is reserved for keyless digests — versions start at 1".into(),
114 ));
115 }
116 Ok(())
117}
118
119impl KeyRing {
120 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn with_public_pem(self, pem: &str) -> Result<Self> {
125 self.with_public_pem_version(1, pem)
126 }
127
128 pub fn with_private_pem(self, pem: &str) -> Result<Self> {
129 self.with_private_pem_version(1, pem)
130 }
131
132 pub fn with_hmac_key(self, key: impl AsRef<[u8]>) -> Self {
135 self.with_hmac_key_version(1, key)
136 .expect("version 1 is always valid")
137 }
138
139 pub fn with_public_pem_version(mut self, version: KeyVersion, pem: &str) -> Result<Self> {
141 check_version(version)?;
142 self.rsa.entry(version).or_default().public =
143 Some(RsaPublicKey::from_public_key_pem(pem).map_err(|e| Error::Crypto(e.to_string()))?);
144 Ok(self)
145 }
146
147 pub fn with_private_pem_version(mut self, version: KeyVersion, pem: &str) -> Result<Self> {
150 check_version(version)?;
151 let private =
152 RsaPrivateKey::from_pkcs8_pem(pem).map_err(|e| Error::Crypto(e.to_string()))?;
153 let pair = self.rsa.entry(version).or_default();
154 if pair.public.is_none() {
155 pair.public = Some(private.to_public_key());
156 }
157 pair.private = Some(private);
158 Ok(self)
159 }
160
161 pub fn with_hmac_key_version(
163 mut self,
164 version: KeyVersion,
165 key: impl AsRef<[u8]>,
166 ) -> Result<Self> {
167 check_version(version)?;
168 self.macs.insert(version, key.as_ref().to_vec());
169 Ok(self)
170 }
171
172 pub fn with_generated_rsa(mut self, version: KeyVersion, bits: usize) -> Result<Self> {
175 check_version(version)?;
176 let private = RsaPrivateKey::new(&mut rand::thread_rng(), bits)
177 .map_err(|e| Error::Crypto(e.to_string()))?;
178 self.rsa.insert(
179 version,
180 RsaPair {
181 public: Some(private.to_public_key()),
182 private: Some(private),
183 },
184 );
185 Ok(self)
186 }
187
188 pub fn with_generated_hmac(mut self, version: KeyVersion) -> Result<Self> {
190 check_version(version)?;
191 let mut mac = vec![0u8; 32];
192 rand::thread_rng().fill_bytes(&mut mac);
193 self.macs.insert(version, mac);
194 Ok(self)
195 }
196
197 pub fn generate_ephemeral(bits: usize) -> Result<Self> {
200 Self::new()
201 .with_generated_rsa(1, bits)?
202 .with_generated_hmac(1)
203 }
204
205 pub fn active_hmac_version(&self) -> Option<KeyVersion> {
207 self.macs.keys().next_back().copied()
208 }
209
210 pub fn active_rsa_version(&self) -> Option<KeyVersion> {
212 self.rsa.keys().next_back().copied()
213 }
214
215 pub fn hmac_versions(&self) -> Vec<KeyVersion> {
218 self.macs.keys().rev().copied().collect()
219 }
220
221 pub fn has_rsa_private(&self, version: KeyVersion) -> bool {
224 self.rsa.get(&version).is_some_and(|p| p.private.is_some())
225 }
226
227 fn mac_of(&self, version: KeyVersion) -> Result<&[u8]> {
228 self.macs.get(&version).map(Vec::as_slice).ok_or_else(|| {
229 Error::Crypto(format!("HMAC key version {version} is not in the key ring"))
230 })
231 }
232
233 fn rsa_public_of(&self, version: KeyVersion) -> Result<&RsaPublicKey> {
234 self.rsa
235 .get(&version)
236 .and_then(|p| p.public.as_ref())
237 .ok_or_else(|| {
238 Error::Crypto(format!(
239 "RSA public key version {version} is not in the key ring"
240 ))
241 })
242 }
243
244 fn rsa_private_of(&self, version: KeyVersion) -> Result<&RsaPrivateKey> {
245 self.rsa
246 .get(&version)
247 .and_then(|p| p.private.as_ref())
248 .ok_or_else(|| {
249 Error::Crypto(format!(
250 "RSA private key version {version} is not in the key ring"
251 ))
252 })
253 }
254
255 pub fn hmac_hex(&self, data: &[u8]) -> Result<(KeyVersion, String)> {
259 let version = self.active_hmac_version().ok_or_else(|| {
260 Error::Crypto("HMAC field declared but no HMAC key configured".into())
261 })?;
262 Ok((version, self.hmac_hex_with(version, data)?))
263 }
264
265 pub fn hmac_hex_with(&self, version: KeyVersion, data: &[u8]) -> Result<String> {
268 let mut mac = <HmacSha256 as Mac>::new_from_slice(self.mac_of(version)?)
269 .map_err(|e| Error::Crypto(e.to_string()))?;
270 mac.update(data);
271 Ok(hex(&mac.finalize().into_bytes()))
272 }
273
274 pub fn index_token_digest(
285 &self,
286 field: &str,
287 protection: FieldProtection,
288 token: &str,
289 ) -> Result<(KeyVersion, String)> {
290 match protection {
291 FieldProtection::None | FieldProtection::Sha256 => Ok((
292 KEYLESS,
293 self.index_token_digest_with(KEYLESS, field, protection, token)?,
294 )),
295 FieldProtection::Hmac | FieldProtection::Rsa => {
296 let version = self.active_hmac_version().ok_or_else(|| {
297 Error::Crypto("HMAC field declared but no HMAC key configured".into())
298 })?;
299 Ok((
300 version,
301 self.index_token_digest_with(version, field, protection, token)?,
302 ))
303 }
304 }
305 }
306
307 pub fn index_token_digest_with(
311 &self,
312 version: KeyVersion,
313 field: &str,
314 protection: FieldProtection,
315 token: &str,
316 ) -> Result<String> {
317 let mut data = Vec::with_capacity(4 + field.len() + 1 + token.len());
318 data.extend_from_slice(b"idx:");
319 data.extend_from_slice(field.as_bytes());
320 data.push(0);
321 data.extend_from_slice(token.as_bytes());
322 match protection {
323 FieldProtection::None | FieldProtection::Sha256 => Ok(sha256_hex(&data)),
324 FieldProtection::Hmac | FieldProtection::Rsa => self.hmac_hex_with(version, &data),
325 }
326 }
327
328 pub fn protect(&self, value: &Value, protection: FieldProtection) -> Result<StoredValue> {
332 match protection {
333 FieldProtection::None => Ok(StoredValue::Plain(value.clone())),
334 FieldProtection::Sha256 => {
335 Ok(StoredValue::Sha256(sha256_hex(&value.canonical_bytes())))
336 }
337 FieldProtection::Hmac => {
338 let (key_version, digest) = self.hmac_hex(&value.canonical_bytes())?;
339 Ok(StoredValue::Hmac {
340 key_version,
341 digest,
342 })
343 }
344 FieldProtection::Rsa => {
345 let key_version = self.active_rsa_version().ok_or_else(|| {
346 Error::Crypto("RSA field declared but no public key configured".into())
347 })?;
348 let key = self.rsa_public_of(key_version)?;
349 let mut dek = [0u8; 32];
353 rand::thread_rng().fill_bytes(&mut dek);
354 let mut nonce = [0u8; 12];
355 rand::thread_rng().fill_bytes(&mut nonce);
356 let cipher = Aes256Gcm::new_from_slice(&dek).expect("32-byte key");
357 let ciphertext = cipher
358 .encrypt(
359 Nonce::from_slice(&nonce),
360 value.canonical_bytes().as_slice(),
361 )
362 .map_err(|e| Error::Crypto(e.to_string()))?;
363 let wrapped_key = key
364 .encrypt(&mut rand::thread_rng(), Oaep::new::<Sha256>(), &dek)
365 .map_err(|e| Error::Crypto(e.to_string()))?;
366 Ok(StoredValue::Rsa {
367 key_version,
368 wrapped_key: b64::encode(&wrapped_key),
369 nonce: b64::encode(&nonce),
370 ciphertext: b64::encode(&ciphertext),
371 })
372 }
373 }
374 }
375
376 pub fn can_sign(&self) -> bool {
380 self.active_rsa_version()
381 .is_some_and(|v| self.has_rsa_private(v))
382 }
383
384 pub fn sign(&self, data: &[u8]) -> Result<(KeyVersion, Vec<u8>)> {
390 let version = self
391 .active_rsa_version()
392 .ok_or_else(|| Error::Crypto("no private key configured for signing".into()))?;
393 let key = self.rsa_private_of(version)?;
394 let digest = Sha256::digest(data);
395 let sig = key
396 .sign(Pkcs1v15Sign::new::<Sha256>(), &digest)
397 .map_err(|e| Error::Crypto(e.to_string()))?;
398 Ok((version, sig))
399 }
400
401 pub fn verify_signature(
405 &self,
406 key_version: KeyVersion,
407 data: &[u8],
408 signature: &[u8],
409 ) -> Result<()> {
410 let key = self.rsa_public_of(key_version)?;
411 let digest = Sha256::digest(data);
412 key.verify(Pkcs1v15Sign::new::<Sha256>(), &digest, signature)
413 .map_err(|e| Error::Crypto(e.to_string()))
414 }
415
416 pub fn decrypt(&self, stored: &StoredValue) -> Result<Vec<u8>> {
419 let StoredValue::Rsa {
420 key_version,
421 wrapped_key,
422 nonce,
423 ciphertext,
424 } = stored
425 else {
426 return Err(Error::Crypto("value is not RSA-encrypted".into()));
427 };
428 let key = self.rsa_private_of(*key_version)?;
429 let bad_b64 = || Error::Crypto("invalid base64".into());
430 let dek = key
431 .decrypt(
432 Oaep::new::<Sha256>(),
433 &b64::decode(wrapped_key).ok_or_else(bad_b64)?,
434 )
435 .map_err(|e| Error::Crypto(e.to_string()))?;
436 let cipher = Aes256Gcm::new_from_slice(&dek).map_err(|e| Error::Crypto(e.to_string()))?;
437 let nonce = b64::decode(nonce).ok_or_else(bad_b64)?;
438 if nonce.len() != 12 {
439 return Err(Error::Crypto("invalid nonce length".into()));
440 }
441 cipher
442 .decrypt(
443 Nonce::from_slice(&nonce),
444 b64::decode(ciphertext).ok_or_else(bad_b64)?.as_slice(),
445 )
446 .map_err(|e| Error::Crypto(e.to_string()))
447 }
448
449 pub fn rewrap(&self, stored: &StoredValue) -> Result<StoredValue> {
455 let StoredValue::Rsa {
456 key_version,
457 wrapped_key,
458 nonce,
459 ciphertext,
460 } = stored
461 else {
462 return Err(Error::Crypto("value is not RSA-encrypted".into()));
463 };
464 let active = self
465 .active_rsa_version()
466 .ok_or_else(|| Error::Crypto("no RSA key configured to re-wrap to".into()))?;
467 if *key_version == active {
468 return Ok(stored.clone());
469 }
470 let old = self.rsa_private_of(*key_version)?;
471 let bad_b64 = || Error::Crypto("invalid base64".into());
472 let dek = old
473 .decrypt(
474 Oaep::new::<Sha256>(),
475 &b64::decode(wrapped_key).ok_or_else(bad_b64)?,
476 )
477 .map_err(|e| Error::Crypto(e.to_string()))?;
478 let wrapped = self
479 .rsa_public_of(active)?
480 .encrypt(&mut rand::thread_rng(), Oaep::new::<Sha256>(), &dek)
481 .map_err(|e| Error::Crypto(e.to_string()))?;
482 Ok(StoredValue::Rsa {
483 key_version: active,
484 wrapped_key: b64::encode(&wrapped),
485 nonce: nonce.clone(),
486 ciphertext: ciphertext.clone(),
487 })
488 }
489}
490
491pub fn sha256_hex(data: &[u8]) -> String {
492 hex(&Sha256::digest(data))
493}
494
495pub(crate) fn hex(digest: &[u8]) -> String {
496 let mut s = String::with_capacity(digest.len() * 2);
497 for b in digest {
498 s.push_str(&format!("{b:02x}"));
499 }
500 s
501}
502
503pub(crate) fn hex_decode(s: &str) -> Option<Vec<u8>> {
504 if !s.len().is_multiple_of(2) {
505 return None;
506 }
507 (0..s.len())
508 .step_by(2)
509 .map(|i| u8::from_str_radix(s.get(i..i + 2)?, 16).ok())
510 .collect()
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn b64_roundtrip() {
519 for data in [&b""[..], b"a", b"ab", b"abc", b"hello world!"] {
520 assert_eq!(b64::decode(&b64::encode(data)).unwrap(), data);
521 }
522 }
523
524 #[test]
525 fn rsa_roundtrip() {
526 let ring = KeyRing::generate_ephemeral(2048).unwrap();
527 let v = Value::Text("secret-name".repeat(40)); let stored = ring.protect(&v, FieldProtection::Rsa).unwrap();
529 assert_eq!(ring.decrypt(&stored).unwrap(), v.canonical_bytes());
530 }
531
532 #[test]
533 fn rsa_ciphertext_tampering_is_rejected() {
534 let ring = KeyRing::generate_ephemeral(2048).unwrap();
535 let stored = ring
536 .protect(&Value::Text("secret".into()), FieldProtection::Rsa)
537 .unwrap();
538 let StoredValue::Rsa {
539 key_version,
540 wrapped_key,
541 nonce,
542 ciphertext,
543 } = stored
544 else {
545 unreachable!()
546 };
547 let mut ct = b64::decode(&ciphertext).unwrap();
548 ct[0] ^= 1;
549 let tampered = StoredValue::Rsa {
550 key_version,
551 wrapped_key,
552 nonce,
553 ciphertext: b64::encode(&ct),
554 };
555 assert!(
556 ring.decrypt(&tampered).is_err(),
557 "GCM must reject a flipped bit"
558 );
559 }
560
561 #[test]
562 fn sha256_is_deterministic_and_keyless() {
563 let ring = KeyRing::new(); let a = ring
565 .protect(&Value::Text("x".into()), FieldProtection::Sha256)
566 .unwrap();
567 let b = ring
568 .protect(&Value::Text("x".into()), FieldProtection::Sha256)
569 .unwrap();
570 assert_eq!(a, b);
571 }
572
573 #[test]
574 fn hmac_is_deterministic_per_key_only() {
575 let ring = KeyRing::new().with_hmac_key(b"key-1");
576 let a = ring
577 .protect(&Value::Text("x".into()), FieldProtection::Hmac)
578 .unwrap();
579 let b = ring
580 .protect(&Value::Text("x".into()), FieldProtection::Hmac)
581 .unwrap();
582 assert_eq!(a, b);
583
584 let other = KeyRing::new().with_hmac_key(b"key-2");
585 let c = other
586 .protect(&Value::Text("x".into()), FieldProtection::Hmac)
587 .unwrap();
588 assert_ne!(a, c, "different keys must produce different digests");
589
590 assert!(KeyRing::new()
592 .protect(&Value::Text("x".into()), FieldProtection::Hmac)
593 .is_err());
594 }
595}