use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use aes_gcm::aead::{Aead, KeyInit, OsRng};
use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
use arc_swap::ArcSwap;
use base64::Engine;
use futures::future::BoxFuture;
use secrecy::{ExposeSecret, Secret};
use serde_json::Value;
use smol_str::SmolStr;
const WIRE_PREFIX: &str = "enc:v1:";
const B64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD_NO_PAD;
#[derive(Debug)]
pub enum CryptoError {
UnknownKey(KeyId),
Shredded(KeyId),
Aead,
WireFormat,
Kek(Box<dyn std::error::Error + Send + Sync>),
Codec(serde_json::Error),
}
impl fmt::Display for CryptoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnknownKey(k) => write!(f, "no data key provisioned for `{k}`"),
Self::Shredded(k) => write!(f, "data key `{k}` has been shredded — data is erased"),
Self::Aead => write!(f, "AEAD failure: wrong key or tampered ciphertext"),
Self::WireFormat => write!(f, "malformed encrypted-field wire format"),
Self::Kek(e) => write!(f, "KEK source error: {e}"),
Self::Codec(e) => write!(f, "field codec error: {e}"),
}
}
}
impl std::error::Error for CryptoError {}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct KeyId(pub SmolStr);
impl KeyId {
pub fn new(id: impl AsRef<str>) -> Self {
Self(SmolStr::new(id.as_ref()))
}
pub fn tenant(id: impl AsRef<str>) -> Self {
Self(SmolStr::new(format!("tenant:{}", id.as_ref())))
}
pub fn subject(id: impl AsRef<str>) -> Self {
Self(SmolStr::new(format!("subject:{}", id.as_ref())))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for KeyId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
pub struct DataKey {
version: u32,
material: Secret<[u8; 32]>,
}
impl DataKey {
pub fn new(version: u32, material: [u8; 32]) -> Self {
Self {
version,
material: Secret::new(material),
}
}
pub fn version(&self) -> u32 {
self.version
}
}
pub type LoadedKeyring = Vec<(KeyId, Vec<DataKey>)>;
pub trait KekSource: Send + Sync + 'static {
fn load_keyring(&self) -> BoxFuture<'_, Result<LoadedKeyring, CryptoError>>;
fn provision(&self, id: &KeyId) -> BoxFuture<'_, Result<DataKey, CryptoError>>;
fn destroy(&self, id: &KeyId) -> BoxFuture<'_, Result<(), CryptoError>>;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct EncryptedField {
pub key_id: KeyId,
pub key_version: u32,
pub blob: Vec<u8>,
}
impl EncryptedField {
pub fn to_wire(&self) -> String {
format!(
"{WIRE_PREFIX}{}:{}:{}",
self.key_id,
self.key_version,
B64.encode(&self.blob)
)
}
pub fn from_wire(s: &str) -> Result<Self, CryptoError> {
let rest = s.strip_prefix(WIRE_PREFIX).ok_or(CryptoError::WireFormat)?;
let mut it = rest.rsplitn(3, ':');
let blob_b64 = it.next().ok_or(CryptoError::WireFormat)?;
let version = it
.next()
.and_then(|v| v.parse::<u32>().ok())
.ok_or(CryptoError::WireFormat)?;
let key_id = it
.next()
.filter(|k| !k.is_empty())
.ok_or(CryptoError::WireFormat)?;
let blob = B64.decode(blob_b64).map_err(|_| CryptoError::WireFormat)?;
if blob.len() < 12 + 16 {
return Err(CryptoError::WireFormat);
}
Ok(Self {
key_id: KeyId::new(key_id),
key_version: version,
blob,
})
}
pub fn is_wire(s: &str) -> bool {
s.starts_with(WIRE_PREFIX)
}
}
struct KeyRingSnapshot {
keys: HashMap<KeyId, Vec<DataKey>>,
epoch: u64,
}
impl KeyRingSnapshot {
fn active_key(&self, id: &KeyId) -> Option<&DataKey> {
self.keys.get(id).and_then(|v| v.last())
}
fn key_version(&self, id: &KeyId, version: u32) -> Option<&DataKey> {
self.keys
.get(id)
.and_then(|v| v.iter().find(|k| k.version == version))
}
}
pub struct CryptoVault {
ring: ArcSwap<KeyRingSnapshot>,
source: Arc<dyn KekSource>,
rebuild: tokio::sync::Mutex<()>,
}
impl CryptoVault {
pub async fn bootstrap(source: Arc<dyn KekSource>) -> Result<Self, CryptoError> {
let mut keys: HashMap<KeyId, Vec<DataKey>> = HashMap::new();
for (id, mut versions) in source.load_keyring().await? {
versions.sort_by_key(|k| k.version);
keys.insert(id, versions);
}
Ok(Self {
ring: ArcSwap::from_pointee(KeyRingSnapshot { keys, epoch: 0 }),
source,
rebuild: tokio::sync::Mutex::new(()),
})
}
pub fn encrypt(&self, key: &KeyId, plaintext: &[u8]) -> Result<EncryptedField, CryptoError> {
let ring = self.ring.load();
let dk = ring
.active_key(key)
.ok_or_else(|| CryptoError::UnknownKey(key.clone()))?;
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dk.material.expose_secret()));
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ct = cipher
.encrypt(&nonce, plaintext)
.map_err(|_| CryptoError::Aead)?;
let mut blob = Vec::with_capacity(12 + ct.len());
blob.extend_from_slice(&nonce);
blob.extend_from_slice(&ct);
Ok(EncryptedField {
key_id: key.clone(),
key_version: dk.version,
blob,
})
}
pub fn decrypt(&self, field: &EncryptedField) -> Result<Secret<Vec<u8>>, CryptoError> {
let ring = self.ring.load();
let dk = ring
.key_version(&field.key_id, field.key_version)
.ok_or_else(|| CryptoError::Shredded(field.key_id.clone()))?;
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dk.material.expose_secret()));
let (nonce, ct) = field.blob.split_at(12);
let pt = cipher
.decrypt(Nonce::from_slice(nonce), ct)
.map_err(|_| CryptoError::Aead)?;
Ok(Secret::new(pt))
}
pub fn has_key(&self, key: &KeyId) -> bool {
self.ring.load().active_key(key).is_some()
}
pub async fn ensure_key(&self, key: &KeyId) -> Result<(), CryptoError> {
if self.has_key(key) {
return Ok(());
}
let _g = self.rebuild.lock().await;
if self.has_key(key) {
return Ok(()); }
let dk = self.source.provision(key).await?;
self.swap_ring(|keys| {
keys.entry(key.clone()).or_default().push(dk);
});
Ok(())
}
pub async fn rotate(&self, key: &KeyId) -> Result<u32, CryptoError> {
let _g = self.rebuild.lock().await;
let dk = self.source.provision(key).await?;
let v = dk.version;
self.swap_ring(|keys| {
let versions = keys.entry(key.clone()).or_default();
versions.push(dk);
versions.sort_by_key(|k| k.version);
});
Ok(v)
}
pub async fn shred(&self, key: &KeyId) -> Result<(), CryptoError> {
let _g = self.rebuild.lock().await;
self.source.destroy(key).await?;
self.swap_ring(|keys| {
keys.remove(key);
});
tracing::info!(key = %key, "data key shredded — subject data erased");
Ok(())
}
fn swap_ring(&self, mutate: impl FnOnce(&mut HashMap<KeyId, Vec<DataKey>>)) {
let cur = self.ring.load();
let mut keys: HashMap<KeyId, Vec<DataKey>> = cur
.keys
.iter()
.map(|(k, vs)| {
(
k.clone(),
vs.iter()
.map(|d| DataKey::new(d.version, *d.material.expose_secret()))
.collect(),
)
})
.collect();
mutate(&mut keys);
self.ring.store(Arc::new(KeyRingSnapshot {
keys,
epoch: cur.epoch + 1,
}));
}
}
#[derive(Clone, Debug)]
enum Seg {
Key(&'static str),
Any,
}
fn compile(spec: &'static str) -> Vec<Seg> {
spec.split('.')
.map(|s| if s == "*" { Seg::Any } else { Seg::Key(s) })
.collect()
}
fn seal_at(
vault: &CryptoVault,
key: &KeyId,
v: &mut Value,
path: &[Seg],
) -> Result<(), CryptoError> {
match path.split_first() {
None => {
let plain = serde_json::to_vec(v).map_err(CryptoError::Codec)?;
*v = Value::String(vault.encrypt(key, &plain)?.to_wire());
Ok(())
}
Some((Seg::Key(k), rest)) => match v.get_mut(*k) {
Some(child) => seal_at(vault, key, child, rest),
None => Ok(()), },
Some((Seg::Any, rest)) => {
match v {
Value::Array(items) => {
for item in items {
seal_at(vault, key, item, rest)?;
}
}
Value::Object(map) => {
for child in map.values_mut() {
seal_at(vault, key, child, rest)?;
}
}
_ => {}
}
Ok(())
}
}
}
fn unseal_at(vault: &CryptoVault, v: &mut Value, path: &[Seg]) -> Result<(), CryptoError> {
match path.split_first() {
None => {
let Value::String(s) = &*v else { return Ok(()) };
if !EncryptedField::is_wire(s) {
return Ok(()); }
let field = EncryptedField::from_wire(s)?;
let plain = vault.decrypt(&field)?;
*v = serde_json::from_slice(plain.expose_secret()).map_err(CryptoError::Codec)?;
Ok(())
}
Some((Seg::Key(k), rest)) => match v.get_mut(*k) {
Some(child) => unseal_at(vault, child, rest),
None => Ok(()),
},
Some((Seg::Any, rest)) => {
match v {
Value::Array(items) => {
for item in items {
unseal_at(vault, item, rest)?;
}
}
Value::Object(map) => {
for child in map.values_mut() {
unseal_at(vault, child, rest)?;
}
}
_ => {}
}
Ok(())
}
}
}
pub trait EncryptRecord: serde::Serialize + serde::de::DeserializeOwned {
const ENCRYPT_FIELDS: &'static [&'static str];
const KEY_ID: &'static str;
fn seal(&self, vault: &CryptoVault) -> Result<Value, CryptoError> {
self.seal_with_key(vault, &KeyId::new(Self::KEY_ID))
}
fn seal_with_key(&self, vault: &CryptoVault, key: &KeyId) -> Result<Value, CryptoError> {
let mut v = serde_json::to_value(self).map_err(CryptoError::Codec)?;
for spec in Self::ENCRYPT_FIELDS {
seal_at(vault, key, &mut v, &compile(spec))?;
}
Ok(v)
}
fn unseal(mut sealed: Value, vault: &CryptoVault) -> Result<Self, CryptoError> {
for spec in Self::ENCRYPT_FIELDS {
unseal_at(vault, &mut sealed, &compile(spec))?;
}
serde_json::from_value(sealed).map_err(CryptoError::Codec)
}
}
#[cfg(test)]
mod tests {
use super::*;
use sha2::{Digest, Sha256};
struct TestKek {
shredded: std::sync::Mutex<std::collections::HashSet<KeyId>>,
versions: std::sync::Mutex<HashMap<KeyId, u32>>,
}
impl TestKek {
fn new() -> Self {
Self {
shredded: Default::default(),
versions: Default::default(),
}
}
fn derive(id: &KeyId, version: u32) -> [u8; 32] {
let mut h = Sha256::new();
h.update(b"test-master");
h.update(id.as_str().as_bytes());
h.update(version.to_be_bytes());
h.finalize().into()
}
}
impl KekSource for TestKek {
fn load_keyring(&self) -> BoxFuture<'_, Result<Vec<(KeyId, Vec<DataKey>)>, CryptoError>> {
Box::pin(async { Ok(Vec::new()) })
}
fn provision(&self, id: &KeyId) -> BoxFuture<'_, Result<DataKey, CryptoError>> {
let id = id.clone();
Box::pin(async move {
let mut versions = self.versions.lock().unwrap();
let v = versions.entry(id.clone()).or_insert(0);
*v += 1;
Ok(DataKey::new(*v, Self::derive(&id, *v)))
})
}
fn destroy(&self, id: &KeyId) -> BoxFuture<'_, Result<(), CryptoError>> {
let id = id.clone();
Box::pin(async move {
self.shredded.lock().unwrap().insert(id);
Ok(())
})
}
}
async fn vault() -> CryptoVault {
CryptoVault::bootstrap(Arc::new(TestKek::new()))
.await
.unwrap()
}
#[tokio::test]
async fn roundtrip_and_wire_format() {
let v = vault().await;
let key = KeyId::tenant("acme");
v.ensure_key(&key).await.unwrap();
let sealed = v.encrypt(&key, b"4242-4242").unwrap();
let wire = sealed.to_wire();
assert!(EncryptedField::is_wire(&wire));
let parsed = EncryptedField::from_wire(&wire).unwrap();
assert_eq!(parsed, sealed);
assert_eq!(parsed.key_id, key);
let plain = v.decrypt(&parsed).unwrap();
assert_eq!(plain.expose_secret().as_slice(), b"4242-4242");
}
#[tokio::test]
async fn rotation_keeps_old_ciphertext_readable() {
let v = vault().await;
let key = KeyId::tenant("acme");
v.ensure_key(&key).await.unwrap();
let old = v.encrypt(&key, b"before-rotation").unwrap();
let new_version = v.rotate(&key).await.unwrap();
assert_eq!(new_version, 2);
assert_eq!(
v.decrypt(&old).unwrap().expose_secret().as_slice(),
b"before-rotation"
);
assert_eq!(v.encrypt(&key, b"x").unwrap().key_version, 2);
}
#[tokio::test]
async fn shred_makes_data_unrecoverable() {
let v = vault().await;
let key = KeyId::subject("user-42");
v.ensure_key(&key).await.unwrap();
let sealed = v.encrypt(&key, b"phi").unwrap();
v.shred(&key).await.unwrap();
assert!(matches!(v.decrypt(&sealed), Err(CryptoError::Shredded(_))));
assert!(matches!(
v.encrypt(&key, b"more"),
Err(CryptoError::UnknownKey(_))
));
}
#[tokio::test]
async fn tampered_ciphertext_fails_aead() {
let v = vault().await;
let key = KeyId::tenant("acme");
v.ensure_key(&key).await.unwrap();
let mut sealed = v.encrypt(&key, b"secret").unwrap();
*sealed.blob.last_mut().unwrap() ^= 0xFF;
assert!(matches!(v.decrypt(&sealed), Err(CryptoError::Aead)));
}
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct Patient {
name: String,
ssn: String,
visits: Vec<Visit>,
}
#[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
struct Visit {
diagnosis: String,
year: u32,
}
impl EncryptRecord for Patient {
const ENCRYPT_FIELDS: &'static [&'static str] = &["ssn", "visits.*.diagnosis"];
const KEY_ID: &'static str = "tenant:clinic";
}
#[tokio::test]
async fn record_seal_unseal_with_wildcards() {
let v = vault().await;
v.ensure_key(&KeyId::new("tenant:clinic")).await.unwrap();
let p = Patient {
name: "Jane".into(),
ssn: "123-45-6789".into(),
visits: vec![
Visit {
diagnosis: "A".into(),
year: 2024,
},
Visit {
diagnosis: "B".into(),
year: 2025,
},
],
};
let sealed = p.seal(&v).unwrap();
assert!(EncryptedField::is_wire(sealed["ssn"].as_str().unwrap()));
assert!(EncryptedField::is_wire(
sealed["visits"][0]["diagnosis"].as_str().unwrap()
));
assert_eq!(sealed["name"], "Jane");
assert_eq!(sealed["visits"][1]["year"], 2025);
let back = Patient::unseal(sealed, &v).unwrap();
assert_eq!(back, p);
}
#[tokio::test]
async fn unseal_tolerates_pre_rollout_plaintext() {
let v = vault().await;
v.ensure_key(&KeyId::new("tenant:clinic")).await.unwrap();
let legacy = serde_json::json!({
"name": "Old", "ssn": "raw-ssn", "visits": []
});
let p = Patient::unseal(legacy, &v).unwrap();
assert_eq!(p.ssn, "raw-ssn");
}
}