1use std::collections::HashMap;
57use std::fmt;
58use std::sync::Arc;
59
60use aes_gcm::aead::{Aead, KeyInit, OsRng};
61use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
62use arc_swap::ArcSwap;
63use base64::Engine;
64use futures::future::BoxFuture;
65use secrecy::{ExposeSecret, Secret};
66use serde_json::Value;
67use smol_str::SmolStr;
68
69const WIRE_PREFIX: &str = "enc:v1:";
70const B64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD_NO_PAD;
71
72#[derive(Debug)]
75pub enum CryptoError {
76 UnknownKey(KeyId),
78 Shredded(KeyId),
81 Aead,
83 WireFormat,
85 Kek(Box<dyn std::error::Error + Send + Sync>),
87 Codec(serde_json::Error),
89}
90
91impl fmt::Display for CryptoError {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self {
94 Self::UnknownKey(k) => write!(f, "no data key provisioned for `{k}`"),
95 Self::Shredded(k) => write!(f, "data key `{k}` has been shredded — data is erased"),
96 Self::Aead => write!(f, "AEAD failure: wrong key or tampered ciphertext"),
97 Self::WireFormat => write!(f, "malformed encrypted-field wire format"),
98 Self::Kek(e) => write!(f, "KEK source error: {e}"),
99 Self::Codec(e) => write!(f, "field codec error: {e}"),
100 }
101 }
102}
103impl std::error::Error for CryptoError {}
104
105#[derive(Clone, Debug, PartialEq, Eq, Hash)]
110pub struct KeyId(pub SmolStr);
111
112impl KeyId {
113 pub fn new(id: impl AsRef<str>) -> Self {
114 Self(SmolStr::new(id.as_ref()))
115 }
116 pub fn tenant(id: impl AsRef<str>) -> Self {
117 Self(SmolStr::new(format!("tenant:{}", id.as_ref())))
118 }
119 pub fn subject(id: impl AsRef<str>) -> Self {
120 Self(SmolStr::new(format!("subject:{}", id.as_ref())))
121 }
122 pub fn as_str(&self) -> &str {
123 &self.0
124 }
125}
126
127impl fmt::Display for KeyId {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.write_str(&self.0)
130 }
131}
132
133pub struct DataKey {
136 version: u32,
137 material: Secret<[u8; 32]>,
138}
139
140impl DataKey {
141 pub fn new(version: u32, material: [u8; 32]) -> Self {
142 Self {
143 version,
144 material: Secret::new(material),
145 }
146 }
147 pub fn version(&self) -> u32 {
148 self.version
149 }
150}
151
152pub type LoadedKeyring = Vec<(KeyId, Vec<DataKey>)>;
159
160pub trait KekSource: Send + Sync + 'static {
161 fn load_keyring(&self) -> BoxFuture<'_, Result<LoadedKeyring, CryptoError>>;
164
165 fn provision(&self, id: &KeyId) -> BoxFuture<'_, Result<DataKey, CryptoError>>;
168
169 fn destroy(&self, id: &KeyId) -> BoxFuture<'_, Result<(), CryptoError>>;
172}
173
174#[derive(Clone, Debug, PartialEq, Eq)]
180pub struct EncryptedField {
181 pub key_id: KeyId,
182 pub key_version: u32,
183 pub blob: Vec<u8>,
185}
186
187impl EncryptedField {
188 pub fn to_wire(&self) -> String {
189 format!(
190 "{WIRE_PREFIX}{}:{}:{}",
191 self.key_id,
192 self.key_version,
193 B64.encode(&self.blob)
194 )
195 }
196
197 pub fn from_wire(s: &str) -> Result<Self, CryptoError> {
200 let rest = s.strip_prefix(WIRE_PREFIX).ok_or(CryptoError::WireFormat)?;
201 let mut it = rest.rsplitn(3, ':');
202 let blob_b64 = it.next().ok_or(CryptoError::WireFormat)?;
203 let version = it
204 .next()
205 .and_then(|v| v.parse::<u32>().ok())
206 .ok_or(CryptoError::WireFormat)?;
207 let key_id = it
208 .next()
209 .filter(|k| !k.is_empty())
210 .ok_or(CryptoError::WireFormat)?;
211 let blob = B64.decode(blob_b64).map_err(|_| CryptoError::WireFormat)?;
212 if blob.len() < 12 + 16 {
213 return Err(CryptoError::WireFormat);
214 }
215 Ok(Self {
216 key_id: KeyId::new(key_id),
217 key_version: version,
218 blob,
219 })
220 }
221
222 pub fn is_wire(s: &str) -> bool {
224 s.starts_with(WIRE_PREFIX)
225 }
226}
227
228struct KeyRingSnapshot {
233 keys: HashMap<KeyId, Vec<DataKey>>,
235 epoch: u64,
236}
237
238impl KeyRingSnapshot {
239 fn active_key(&self, id: &KeyId) -> Option<&DataKey> {
240 self.keys.get(id).and_then(|v| v.last())
241 }
242 fn key_version(&self, id: &KeyId, version: u32) -> Option<&DataKey> {
243 self.keys
244 .get(id)
245 .and_then(|v| v.iter().find(|k| k.version == version))
246 }
247}
248
249pub struct CryptoVault {
255 ring: ArcSwap<KeyRingSnapshot>,
256 source: Arc<dyn KekSource>,
257 rebuild: tokio::sync::Mutex<()>,
260}
261
262impl CryptoVault {
263 pub async fn bootstrap(source: Arc<dyn KekSource>) -> Result<Self, CryptoError> {
265 let mut keys: HashMap<KeyId, Vec<DataKey>> = HashMap::new();
266 for (id, mut versions) in source.load_keyring().await? {
267 versions.sort_by_key(|k| k.version);
268 keys.insert(id, versions);
269 }
270 Ok(Self {
271 ring: ArcSwap::from_pointee(KeyRingSnapshot { keys, epoch: 0 }),
272 source,
273 rebuild: tokio::sync::Mutex::new(()),
274 })
275 }
276
277 pub fn encrypt(&self, key: &KeyId, plaintext: &[u8]) -> Result<EncryptedField, CryptoError> {
280 let ring = self.ring.load();
281 let dk = ring
282 .active_key(key)
283 .ok_or_else(|| CryptoError::UnknownKey(key.clone()))?;
284 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dk.material.expose_secret()));
285 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
286 let ct = cipher
287 .encrypt(&nonce, plaintext)
288 .map_err(|_| CryptoError::Aead)?;
289 let mut blob = Vec::with_capacity(12 + ct.len());
290 blob.extend_from_slice(&nonce);
291 blob.extend_from_slice(&ct);
292 Ok(EncryptedField {
293 key_id: key.clone(),
294 key_version: dk.version,
295 blob,
296 })
297 }
298
299 pub fn decrypt(&self, field: &EncryptedField) -> Result<Secret<Vec<u8>>, CryptoError> {
302 let ring = self.ring.load();
303 let dk = ring
304 .key_version(&field.key_id, field.key_version)
305 .ok_or_else(|| CryptoError::Shredded(field.key_id.clone()))?;
306 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dk.material.expose_secret()));
307 let (nonce, ct) = field.blob.split_at(12);
308 let pt = cipher
309 .decrypt(Nonce::from_slice(nonce), ct)
310 .map_err(|_| CryptoError::Aead)?;
311 Ok(Secret::new(pt))
312 }
313
314 pub fn has_key(&self, key: &KeyId) -> bool {
316 self.ring.load().active_key(key).is_some()
317 }
318
319 pub async fn ensure_key(&self, key: &KeyId) -> Result<(), CryptoError> {
322 if self.has_key(key) {
323 return Ok(());
324 }
325 let _g = self.rebuild.lock().await;
326 if self.has_key(key) {
327 return Ok(()); }
329 let dk = self.source.provision(key).await?;
330 self.swap_ring(|keys| {
331 keys.entry(key.clone()).or_default().push(dk);
332 });
333 Ok(())
334 }
335
336 pub async fn rotate(&self, key: &KeyId) -> Result<u32, CryptoError> {
340 let _g = self.rebuild.lock().await;
341 let dk = self.source.provision(key).await?;
342 let v = dk.version;
343 self.swap_ring(|keys| {
344 let versions = keys.entry(key.clone()).or_default();
345 versions.push(dk);
346 versions.sort_by_key(|k| k.version);
347 });
348 Ok(v)
349 }
350
351 pub async fn shred(&self, key: &KeyId) -> Result<(), CryptoError> {
356 let _g = self.rebuild.lock().await;
357 self.source.destroy(key).await?;
358 self.swap_ring(|keys| {
359 keys.remove(key);
360 });
361 tracing::info!(key = %key, "data key shredded — subject data erased");
362 Ok(())
363 }
364
365 fn swap_ring(&self, mutate: impl FnOnce(&mut HashMap<KeyId, Vec<DataKey>>)) {
367 let cur = self.ring.load();
368 let mut keys: HashMap<KeyId, Vec<DataKey>> = cur
371 .keys
372 .iter()
373 .map(|(k, vs)| {
374 (
375 k.clone(),
376 vs.iter()
377 .map(|d| DataKey::new(d.version, *d.material.expose_secret()))
378 .collect(),
379 )
380 })
381 .collect();
382 mutate(&mut keys);
383 self.ring.store(Arc::new(KeyRingSnapshot {
384 keys,
385 epoch: cur.epoch + 1,
386 }));
387 }
388}
389
390#[derive(Clone, Debug)]
395enum Seg {
396 Key(&'static str),
397 Any,
398}
399
400fn compile(spec: &'static str) -> Vec<Seg> {
401 spec.split('.')
402 .map(|s| if s == "*" { Seg::Any } else { Seg::Key(s) })
403 .collect()
404}
405
406fn seal_at(
410 vault: &CryptoVault,
411 key: &KeyId,
412 v: &mut Value,
413 path: &[Seg],
414) -> Result<(), CryptoError> {
415 match path.split_first() {
416 None => {
417 let plain = serde_json::to_vec(v).map_err(CryptoError::Codec)?;
418 *v = Value::String(vault.encrypt(key, &plain)?.to_wire());
419 Ok(())
420 }
421 Some((Seg::Key(k), rest)) => match v.get_mut(*k) {
422 Some(child) => seal_at(vault, key, child, rest),
423 None => Ok(()), },
425 Some((Seg::Any, rest)) => {
426 match v {
427 Value::Array(items) => {
428 for item in items {
429 seal_at(vault, key, item, rest)?;
430 }
431 }
432 Value::Object(map) => {
433 for child in map.values_mut() {
434 seal_at(vault, key, child, rest)?;
435 }
436 }
437 _ => {}
438 }
439 Ok(())
440 }
441 }
442}
443
444fn unseal_at(vault: &CryptoVault, v: &mut Value, path: &[Seg]) -> Result<(), CryptoError> {
446 match path.split_first() {
447 None => {
448 let Value::String(s) = &*v else { return Ok(()) };
449 if !EncryptedField::is_wire(s) {
450 return Ok(()); }
452 let field = EncryptedField::from_wire(s)?;
453 let plain = vault.decrypt(&field)?;
454 *v = serde_json::from_slice(plain.expose_secret()).map_err(CryptoError::Codec)?;
455 Ok(())
456 }
457 Some((Seg::Key(k), rest)) => match v.get_mut(*k) {
458 Some(child) => unseal_at(vault, child, rest),
459 None => Ok(()),
460 },
461 Some((Seg::Any, rest)) => {
462 match v {
463 Value::Array(items) => {
464 for item in items {
465 unseal_at(vault, item, rest)?;
466 }
467 }
468 Value::Object(map) => {
469 for child in map.values_mut() {
470 unseal_at(vault, child, rest)?;
471 }
472 }
473 _ => {}
474 }
475 Ok(())
476 }
477 }
478}
479
480pub trait EncryptRecord: serde::Serialize + serde::de::DeserializeOwned {
492 const ENCRYPT_FIELDS: &'static [&'static str];
494 const KEY_ID: &'static str;
496
497 fn seal(&self, vault: &CryptoVault) -> Result<Value, CryptoError> {
498 self.seal_with_key(vault, &KeyId::new(Self::KEY_ID))
499 }
500
501 fn seal_with_key(&self, vault: &CryptoVault, key: &KeyId) -> Result<Value, CryptoError> {
502 let mut v = serde_json::to_value(self).map_err(CryptoError::Codec)?;
503 for spec in Self::ENCRYPT_FIELDS {
504 seal_at(vault, key, &mut v, &compile(spec))?;
505 }
506 Ok(v)
507 }
508
509 fn unseal(mut sealed: Value, vault: &CryptoVault) -> Result<Self, CryptoError> {
510 for spec in Self::ENCRYPT_FIELDS {
511 unseal_at(vault, &mut sealed, &compile(spec))?;
512 }
513 serde_json::from_value(sealed).map_err(CryptoError::Codec)
514 }
515}
516
517#[cfg(test)]
520mod tests {
521 use super::*;
522 use sha2::{Digest, Sha256};
523
524 struct TestKek {
526 shredded: std::sync::Mutex<std::collections::HashSet<KeyId>>,
527 versions: std::sync::Mutex<HashMap<KeyId, u32>>,
528 }
529
530 impl TestKek {
531 fn new() -> Self {
532 Self {
533 shredded: Default::default(),
534 versions: Default::default(),
535 }
536 }
537 fn derive(id: &KeyId, version: u32) -> [u8; 32] {
538 let mut h = Sha256::new();
539 h.update(b"test-master");
540 h.update(id.as_str().as_bytes());
541 h.update(version.to_be_bytes());
542 h.finalize().into()
543 }
544 }
545
546 impl KekSource for TestKek {
547 fn load_keyring(&self) -> BoxFuture<'_, Result<Vec<(KeyId, Vec<DataKey>)>, CryptoError>> {
548 Box::pin(async { Ok(Vec::new()) })
549 }
550 fn provision(&self, id: &KeyId) -> BoxFuture<'_, Result<DataKey, CryptoError>> {
551 let id = id.clone();
552 Box::pin(async move {
553 let mut versions = self.versions.lock().unwrap();
554 let v = versions.entry(id.clone()).or_insert(0);
555 *v += 1;
556 Ok(DataKey::new(*v, Self::derive(&id, *v)))
557 })
558 }
559 fn destroy(&self, id: &KeyId) -> BoxFuture<'_, Result<(), CryptoError>> {
560 let id = id.clone();
561 Box::pin(async move {
562 self.shredded.lock().unwrap().insert(id);
563 Ok(())
564 })
565 }
566 }
567
568 async fn vault() -> CryptoVault {
569 CryptoVault::bootstrap(Arc::new(TestKek::new()))
570 .await
571 .unwrap()
572 }
573
574 #[tokio::test]
575 async fn roundtrip_and_wire_format() {
576 let v = vault().await;
577 let key = KeyId::tenant("acme");
578 v.ensure_key(&key).await.unwrap();
579
580 let sealed = v.encrypt(&key, b"4242-4242").unwrap();
581 let wire = sealed.to_wire();
582 assert!(EncryptedField::is_wire(&wire));
583
584 let parsed = EncryptedField::from_wire(&wire).unwrap();
585 assert_eq!(parsed, sealed);
586 assert_eq!(parsed.key_id, key); let plain = v.decrypt(&parsed).unwrap();
589 assert_eq!(plain.expose_secret().as_slice(), b"4242-4242");
590 }
591
592 #[tokio::test]
593 async fn rotation_keeps_old_ciphertext_readable() {
594 let v = vault().await;
595 let key = KeyId::tenant("acme");
596 v.ensure_key(&key).await.unwrap();
597
598 let old = v.encrypt(&key, b"before-rotation").unwrap();
599 let new_version = v.rotate(&key).await.unwrap();
600 assert_eq!(new_version, 2);
601
602 assert_eq!(
604 v.decrypt(&old).unwrap().expose_secret().as_slice(),
605 b"before-rotation"
606 );
607 assert_eq!(v.encrypt(&key, b"x").unwrap().key_version, 2);
608 }
609
610 #[tokio::test]
611 async fn shred_makes_data_unrecoverable() {
612 let v = vault().await;
613 let key = KeyId::subject("user-42");
614 v.ensure_key(&key).await.unwrap();
615 let sealed = v.encrypt(&key, b"phi").unwrap();
616
617 v.shred(&key).await.unwrap();
618 assert!(matches!(v.decrypt(&sealed), Err(CryptoError::Shredded(_))));
619 assert!(matches!(
620 v.encrypt(&key, b"more"),
621 Err(CryptoError::UnknownKey(_))
622 ));
623 }
624
625 #[tokio::test]
626 async fn tampered_ciphertext_fails_aead() {
627 let v = vault().await;
628 let key = KeyId::tenant("acme");
629 v.ensure_key(&key).await.unwrap();
630 let mut sealed = v.encrypt(&key, b"secret").unwrap();
631 *sealed.blob.last_mut().unwrap() ^= 0xFF;
632 assert!(matches!(v.decrypt(&sealed), Err(CryptoError::Aead)));
633 }
634
635 #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
636 struct Patient {
637 name: String,
638 ssn: String,
639 visits: Vec<Visit>,
640 }
641 #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
642 struct Visit {
643 diagnosis: String,
644 year: u32,
645 }
646
647 impl EncryptRecord for Patient {
648 const ENCRYPT_FIELDS: &'static [&'static str] = &["ssn", "visits.*.diagnosis"];
649 const KEY_ID: &'static str = "tenant:clinic";
650 }
651
652 #[tokio::test]
653 async fn record_seal_unseal_with_wildcards() {
654 let v = vault().await;
655 v.ensure_key(&KeyId::new("tenant:clinic")).await.unwrap();
656
657 let p = Patient {
658 name: "Jane".into(),
659 ssn: "123-45-6789".into(),
660 visits: vec![
661 Visit {
662 diagnosis: "A".into(),
663 year: 2024,
664 },
665 Visit {
666 diagnosis: "B".into(),
667 year: 2025,
668 },
669 ],
670 };
671
672 let sealed = p.seal(&v).unwrap();
673 assert!(EncryptedField::is_wire(sealed["ssn"].as_str().unwrap()));
675 assert!(EncryptedField::is_wire(
676 sealed["visits"][0]["diagnosis"].as_str().unwrap()
677 ));
678 assert_eq!(sealed["name"], "Jane");
679 assert_eq!(sealed["visits"][1]["year"], 2025);
680
681 let back = Patient::unseal(sealed, &v).unwrap();
682 assert_eq!(back, p);
683 }
684
685 #[tokio::test]
686 async fn unseal_tolerates_pre_rollout_plaintext() {
687 let v = vault().await;
688 v.ensure_key(&KeyId::new("tenant:clinic")).await.unwrap();
689 let legacy = serde_json::json!({
691 "name": "Old", "ssn": "raw-ssn", "visits": []
692 });
693 let p = Patient::unseal(legacy, &v).unwrap();
694 assert_eq!(p.ssn, "raw-ssn");
695 }
696}