use std::collections::HashMap;
use std::path::PathBuf;
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use async_trait::async_trait;
use rand::RngCore;
use zeroize::Zeroizing;
const KEK_LEN: usize = 32;
const DEK_LEN: usize = 32;
const WRAP_NONCE_LEN: usize = 12;
const WRAP_TAG_LEN: usize = 16;
const LOCAL_WRAP_MIN_LEN: usize = WRAP_NONCE_LEN + WRAP_TAG_LEN;
#[derive(Debug, thiserror::Error)]
pub enum KmsError {
#[error("KMS key id {key_id:?} not found in backend")]
KeyNotFound { key_id: String },
#[error("KMS KEK file {path:?}: {source}")]
KekFileIo {
path: PathBuf,
source: std::io::Error,
},
#[error("KMS KEK file {path:?} must be exactly {expected} raw bytes; got {got}")]
KekBadLength {
path: PathBuf,
expected: usize,
got: usize,
},
#[error("KMS KEK directory {path:?}: {source}")]
KekDirIo {
path: PathBuf,
source: std::io::Error,
},
#[error("KMS wrapped DEK too short ({got} bytes; need at least {min})")]
WrappedDekTooShort { got: usize, min: usize },
#[error("KMS unwrap failed (wrapped DEK auth tag mismatch for key_id {key_id:?})")]
UnwrapFailed { key_id: String },
#[error("KMS backend unavailable: {message}")]
BackendUnavailable { message: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WrappedDek {
pub key_id: String,
pub ciphertext: Vec<u8>,
}
#[async_trait]
pub trait KmsBackend: Send + Sync + std::fmt::Debug {
async fn generate_dek(
&self,
key_id: &str,
) -> Result<(Zeroizing<Vec<u8>>, WrappedDek), KmsError>;
async fn decrypt_dek(&self, wrapped: &WrappedDek) -> Result<Zeroizing<Vec<u8>>, KmsError>;
}
pub struct LocalKms {
dir: PathBuf,
keks: HashMap<String, [u8; KEK_LEN]>,
}
impl std::fmt::Debug for LocalKms {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalKms")
.field("dir", &self.dir)
.field("key_count", &self.keks.len())
.field("key_ids", &self.keks.keys().collect::<Vec<_>>())
.finish()
}
}
impl LocalKms {
pub fn open(dir: PathBuf) -> Result<Self, KmsError> {
let read_dir = std::fs::read_dir(&dir).map_err(|source| KmsError::KekDirIo {
path: dir.clone(),
source,
})?;
let mut keks = HashMap::new();
for entry in read_dir {
let entry = entry.map_err(|source| KmsError::KekDirIo {
path: dir.clone(),
source,
})?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("kek") {
continue;
}
let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
continue;
};
let key_id = stem.to_string();
let bytes = std::fs::read(&path).map_err(|source| KmsError::KekFileIo {
path: path.clone(),
source,
})?;
if bytes.len() != KEK_LEN {
return Err(KmsError::KekBadLength {
path: path.clone(),
expected: KEK_LEN,
got: bytes.len(),
});
}
let mut k = [0u8; KEK_LEN];
k.copy_from_slice(&bytes);
keks.insert(key_id, k);
}
Ok(Self { dir, keks })
}
pub fn from_keks(dir: PathBuf, keks: HashMap<String, [u8; KEK_LEN]>) -> Self {
Self { dir, keks }
}
pub fn key_ids(&self) -> Vec<String> {
let mut ids: Vec<String> = self.keks.keys().cloned().collect();
ids.sort();
ids
}
fn kek(&self, key_id: &str) -> Result<&[u8; KEK_LEN], KmsError> {
self.keks.get(key_id).ok_or_else(|| KmsError::KeyNotFound {
key_id: key_id.to_string(),
})
}
}
#[async_trait]
impl KmsBackend for LocalKms {
async fn generate_dek(
&self,
key_id: &str,
) -> Result<(Zeroizing<Vec<u8>>, WrappedDek), KmsError> {
let kek = self.kek(key_id)?;
let mut dek: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0u8; DEK_LEN]);
rand::rngs::OsRng.fill_bytes(&mut dek);
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(kek));
let mut nonce_bytes = [0u8; WRAP_NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let aad = key_id.as_bytes();
let ct_with_tag = cipher
.encrypt(nonce, Payload { msg: &dek, aad })
.expect("aes-gcm encrypt cannot fail with a 32-byte key");
let mut wrapped = Vec::with_capacity(WRAP_NONCE_LEN + ct_with_tag.len());
wrapped.extend_from_slice(&nonce_bytes);
wrapped.extend_from_slice(&ct_with_tag);
Ok((
dek,
WrappedDek {
key_id: key_id.to_string(),
ciphertext: wrapped,
},
))
}
async fn decrypt_dek(&self, wrapped: &WrappedDek) -> Result<Zeroizing<Vec<u8>>, KmsError> {
let kek = self.kek(&wrapped.key_id)?;
if wrapped.ciphertext.len() < LOCAL_WRAP_MIN_LEN {
return Err(KmsError::WrappedDekTooShort {
got: wrapped.ciphertext.len(),
min: LOCAL_WRAP_MIN_LEN,
});
}
let (nonce_bytes, ct_with_tag) = wrapped.ciphertext.split_at(WRAP_NONCE_LEN);
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(kek));
let nonce = Nonce::from_slice(nonce_bytes);
let aad = wrapped.key_id.as_bytes();
let dek = cipher
.decrypt(
nonce,
Payload {
msg: ct_with_tag,
aad,
},
)
.map_err(|_| KmsError::UnwrapFailed {
key_id: wrapped.key_id.clone(),
})?;
Ok(Zeroizing::new(dek))
}
}
#[cfg(feature = "aws-kms")]
pub mod aws {
use super::{KmsBackend, KmsError, WrappedDek};
use async_trait::async_trait;
use zeroize::Zeroizing;
pub struct AwsKms {
client: aws_sdk_kms::Client,
}
impl std::fmt::Debug for AwsKms {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsKms").finish()
}
}
impl AwsKms {
pub fn new(client: aws_sdk_kms::Client) -> Self {
Self { client }
}
pub async fn from_default_env() -> Self {
let cfg = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let client = aws_sdk_kms::Client::new(&cfg);
Self { client }
}
}
#[async_trait]
impl KmsBackend for AwsKms {
async fn generate_dek(
&self,
key_id: &str,
) -> Result<(Zeroizing<Vec<u8>>, WrappedDek), KmsError> {
let resp = self
.client
.generate_data_key()
.key_id(key_id)
.key_spec(aws_sdk_kms::types::DataKeySpec::Aes256)
.send()
.await
.map_err(|e| KmsError::BackendUnavailable {
message: format!("GenerateDataKey({key_id}): {e}"),
})?;
let dek_vec = resp
.plaintext
.ok_or_else(|| KmsError::BackendUnavailable {
message: format!("GenerateDataKey({key_id}): missing Plaintext in response"),
})?
.into_inner();
let dek = Zeroizing::new(dek_vec);
let ciphertext = resp
.ciphertext_blob
.ok_or_else(|| KmsError::BackendUnavailable {
message: format!(
"GenerateDataKey({key_id}): missing CiphertextBlob in response"
),
})?
.into_inner();
let stored_id = resp.key_id.unwrap_or_else(|| key_id.to_string());
Ok((
dek,
WrappedDek {
key_id: stored_id,
ciphertext,
},
))
}
async fn decrypt_dek(&self, wrapped: &WrappedDek) -> Result<Zeroizing<Vec<u8>>, KmsError> {
let resp = self
.client
.decrypt()
.ciphertext_blob(aws_sdk_kms::primitives::Blob::new(
wrapped.ciphertext.clone(),
))
.key_id(&wrapped.key_id)
.send()
.await
.map_err(|e| KmsError::BackendUnavailable {
message: format!("Decrypt({}): {e}", wrapped.key_id),
})?;
let dek_vec = resp
.plaintext
.ok_or_else(|| KmsError::BackendUnavailable {
message: format!("Decrypt({}): missing Plaintext in response", wrapped.key_id),
})?
.into_inner();
Ok(Zeroizing::new(dek_vec))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::path::Path;
use tempfile::TempDir;
fn write_kek(dir: &Path, name: &str, bytes: &[u8]) {
std::fs::write(dir.join(format!("{name}.kek")), bytes).unwrap();
}
#[tokio::test]
async fn open_empty_dir_is_ok() {
let tmp = TempDir::new().unwrap();
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
assert!(kms.key_ids().is_empty());
let err = kms.generate_dek("missing").await.unwrap_err();
assert!(
matches!(err, KmsError::KeyNotFound { ref key_id } if key_id == "missing"),
"got {err:?}"
);
}
#[tokio::test]
async fn open_loads_kek_files_and_skips_others() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "alpha", &[1u8; KEK_LEN]);
write_kek(tmp.path(), "beta", &[2u8; KEK_LEN]);
std::fs::write(tmp.path().join("README"), b"hello").unwrap();
std::fs::write(tmp.path().join("alpha.kek.bak"), [9u8; 99]).unwrap();
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let ids = kms.key_ids();
assert_eq!(ids, vec!["alpha".to_string(), "beta".to_string()]);
}
#[tokio::test]
async fn open_rejects_truncated_kek_file() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "short", &[7u8; KEK_LEN - 1]);
let err = LocalKms::open(tmp.path().to_path_buf()).unwrap_err();
assert!(
matches!(
err,
KmsError::KekBadLength { expected, got, .. } if expected == KEK_LEN && got == KEK_LEN - 1
),
"got {err:?}"
);
}
#[tokio::test]
async fn generate_then_decrypt_roundtrip() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "main", &[42u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let (dek, wrapped) = kms.generate_dek("main").await.unwrap();
assert_eq!(dek.len(), DEK_LEN);
assert_eq!(wrapped.key_id, "main");
assert_eq!(
wrapped.ciphertext.len(),
WRAP_NONCE_LEN + DEK_LEN + WRAP_TAG_LEN
);
let unwrapped = kms.decrypt_dek(&wrapped).await.unwrap();
assert_eq!(unwrapped, dek);
}
#[tokio::test]
async fn generate_uses_random_dek_and_nonce() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "k", &[5u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let (dek1, w1) = kms.generate_dek("k").await.unwrap();
let (dek2, w2) = kms.generate_dek("k").await.unwrap();
assert_ne!(dek1, dek2, "DEK must be random per call");
assert_ne!(
w1.ciphertext, w2.ciphertext,
"wrap nonce must be random per call"
);
}
#[tokio::test]
async fn decrypt_unknown_key_id_errors() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "real", &[1u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let bogus = WrappedDek {
key_id: "phantom".to_string(),
ciphertext: vec![0u8; LOCAL_WRAP_MIN_LEN + DEK_LEN],
};
let err = kms.decrypt_dek(&bogus).await.unwrap_err();
assert!(
matches!(err, KmsError::KeyNotFound { ref key_id } if key_id == "phantom"),
"got {err:?}"
);
}
#[tokio::test]
async fn decrypt_tampered_ciphertext_fails_unwrap() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "k", &[3u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let (_dek, mut wrapped) = kms.generate_dek("k").await.unwrap();
let mid = wrapped.ciphertext.len() / 2;
wrapped.ciphertext[mid] ^= 0xFF;
let err = kms.decrypt_dek(&wrapped).await.unwrap_err();
assert!(
matches!(err, KmsError::UnwrapFailed { ref key_id } if key_id == "k"),
"got {err:?}"
);
}
#[tokio::test]
async fn decrypt_short_ciphertext_errors() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "k", &[8u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let bogus = WrappedDek {
key_id: "k".to_string(),
ciphertext: vec![0u8; 5], };
let err = kms.decrypt_dek(&bogus).await.unwrap_err();
assert!(
matches!(err, KmsError::WrappedDekTooShort { got: 5, .. }),
"got {err:?}"
);
}
#[tokio::test]
async fn decrypt_wrong_key_id_aad_fails_unwrap() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "alpha", &[1u8; KEK_LEN]);
write_kek(tmp.path(), "beta", &[2u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let (_dek, wrapped) = kms.generate_dek("alpha").await.unwrap();
let forged = WrappedDek {
key_id: "beta".to_string(),
ciphertext: wrapped.ciphertext.clone(),
};
let err = kms.decrypt_dek(&forged).await.unwrap_err();
assert!(
matches!(err, KmsError::UnwrapFailed { ref key_id } if key_id == "beta"),
"got {err:?}"
);
}
#[tokio::test]
async fn from_keks_constructor_works() {
let mut keks = HashMap::new();
keks.insert("inline".to_string(), [9u8; KEK_LEN]);
let kms = LocalKms::from_keks(PathBuf::from("/tmp/none"), keks);
let (_dek, wrapped) = kms.generate_dek("inline").await.unwrap();
assert_eq!(wrapped.key_id, "inline");
let _back = kms.decrypt_dek(&wrapped).await.unwrap();
}
#[cfg(feature = "aws-kms")]
#[tokio::test]
#[ignore = "requires AWS credentials and a real KMS key (set S4_KMS_TEST_KEY_ID)"]
async fn aws_kms_roundtrip() {
let key_id = std::env::var("S4_KMS_TEST_KEY_ID")
.expect("S4_KMS_TEST_KEY_ID env var required (real AWS KMS key ARN or alias)");
let kms = super::aws::AwsKms::from_default_env().await;
let (plaintext_dek, wrapped) = kms
.generate_dek(&key_id)
.await
.expect("generate_dek should succeed against real KMS");
assert_eq!(
plaintext_dek.len(),
DEK_LEN,
"DEK should be 32 bytes (AES-256)"
);
assert_ne!(
wrapped.ciphertext, *plaintext_dek,
"wrapped DEK must NOT equal plaintext DEK"
);
let unwrapped = kms
.decrypt_dek(&wrapped)
.await
.expect("decrypt_dek should succeed");
assert_eq!(*unwrapped, *plaintext_dek, "round-trip DEK must byte-equal");
assert!(
wrapped.key_id.starts_with("arn:aws:kms:") || wrapped.key_id == key_id,
"wrapped key_id should be canonical ARN or original input: {}",
wrapped.key_id
);
}
#[cfg(feature = "aws-kms")]
#[tokio::test]
#[ignore = "requires AWS credentials (no specific key needed; uses a synthetic bogus ARN)"]
async fn aws_kms_unwrap_unknown_arn_fails() {
let kms = super::aws::AwsKms::from_default_env().await;
let bogus = WrappedDek {
key_id: "arn:aws:kms:us-east-1:000000000000:key/00000000-0000-0000-0000-000000000000"
.to_string(),
ciphertext: vec![0u8; 100],
};
let err = kms
.decrypt_dek(&bogus)
.await
.expect_err("decrypt with bogus ciphertext must fail");
assert!(
matches!(
err,
KmsError::BackendUnavailable { .. } | KmsError::UnwrapFailed { .. }
),
"expected BackendUnavailable or UnwrapFailed, got {err:?}"
);
}
#[tokio::test]
async fn local_kms_generate_dek_returns_zeroizing() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "z", &[7u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let (dek, _wrapped): (Zeroizing<Vec<u8>>, WrappedDek) =
kms.generate_dek("z").await.unwrap();
assert_eq!(dek.len(), DEK_LEN);
let _slice: &[u8] = &dek;
}
#[tokio::test]
async fn local_kms_decrypt_dek_returns_zeroizing() {
let tmp = TempDir::new().unwrap();
write_kek(tmp.path(), "z", &[11u8; KEK_LEN]);
let kms = LocalKms::open(tmp.path().to_path_buf()).unwrap();
let (dek_in, wrapped) = kms.generate_dek("z").await.unwrap();
let dek_out: Zeroizing<Vec<u8>> = kms.decrypt_dek(&wrapped).await.unwrap();
assert_eq!(dek_out.len(), DEK_LEN);
assert_eq!(&*dek_out, &*dek_in);
}
#[tokio::test]
async fn dek_zeroized_on_drop_smoke() {
let mut z: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0u8; DEK_LEN]);
for (i, b) in z.iter_mut().enumerate() {
*b = (i as u8).wrapping_add(1);
}
assert_eq!(z[0], 1);
assert_eq!(z[DEK_LEN - 1], DEK_LEN as u8);
drop(z);
}
}