pub mod types;
pub use types::{Ciphertext, CiphertextV1};
use aes_gcm::{
Aes256Gcm,
aead::{Aead, KeyInit, Payload},
};
use async_trait::async_trait;
use dashmap::DashMap;
use rand::{RngCore, SeedableRng, rngs::StdRng};
use sha2::{Digest, Sha512};
use std::sync::Arc;
use tracing::debug;
use uuid::Uuid;
use zeroize::Zeroizing;
use crate::common::error::{Error, Result, StringError};
#[cfg_attr(test, mockall::automock)]
#[async_trait]
pub trait Kms: Send + Sync {
async fn encrypt_data_key(
&self,
kms_key_id: &str,
data_key: &[u8],
aad: &[u8],
) -> Result<Vec<u8>>;
async fn decrypt_data_key(
&self,
kms_key_id: &str,
encrypted_data_key: &[u8],
aad: &[u8],
) -> Result<Vec<u8>>;
}
#[derive(Debug, Default, Clone)]
pub struct NoopKms;
#[async_trait]
impl Kms for NoopKms {
async fn encrypt_data_key(
&self,
_kms_key_id: &str,
data_key: &[u8],
_aad: &[u8],
) -> Result<Vec<u8>> {
Ok(data_key.to_vec())
}
async fn decrypt_data_key(
&self,
_kms_key_id: &str,
encrypted_data_key: &[u8],
_aad: &[u8],
) -> Result<Vec<u8>> {
Ok(encrypted_data_key.to_vec())
}
}
pub struct SecretsManager {
kms: Arc<dyn Kms>,
master_key_id: String,
data_keys_cache: DashMap<Uuid, Zeroizing<[u8; 32]>>,
}
impl std::fmt::Debug for SecretsManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SecretsManager")
.field("master_key_id", &self.master_key_id)
.finish_non_exhaustive()
}
}
impl SecretsManager {
pub fn new(kms: Arc<dyn Kms>, master_key_id: String) -> Self {
Self {
kms,
master_key_id,
data_keys_cache: DashMap::new(),
}
}
pub fn master_key_id(&self) -> &str {
&self.master_key_id
}
pub async fn encrypt(&self, plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
let data_key_id = Uuid::new_v4();
let mut data_key = Zeroizing::new([0u8; 32]);
let mut nonce = [0u8; 12];
{
let mut r = StdRng::from_os_rng();
r.fill_bytes(data_key.as_mut());
r.fill_bytes(&mut nonce);
}
let encrypted_data_key = self
.kms
.encrypt_data_key(
&self.master_key_id,
data_key.as_ref(),
data_key_id.as_bytes(),
)
.await
.map_err(|e| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError(format!("KMS encrypt_data_key failed: {e}"))),
})?;
let aad_hash = build_aad_hash(aad, &encrypted_data_key, &self.master_key_id, data_key_id);
let cipher = Aes256Gcm::new(data_key.as_ref().into());
let encrypted_data = cipher
.encrypt(
&nonce.into(),
Payload {
msg: plaintext,
aad: &aad_hash,
},
)
.map_err(|e| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError(format!("AES-256-GCM encrypt failed: {e}"))),
})?;
let estimated = encrypted_data.len()
+ self.master_key_id.len()
+ 16
+ encrypted_data_key.len()
+ 12
+ 32;
let ct = Ciphertext::V1(CiphertextV1 {
kms_key_id: self.master_key_id.clone(),
data_key_id,
encrypted_data_key,
nonce,
encrypted_data,
});
let mut buf = Vec::with_capacity(estimated);
ciborium::into_writer(&ct, &mut buf).map_err(|e| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError(format!("CBOR encode failed: {e}"))),
})?;
Ok(buf)
}
pub async fn decrypt(&self, ciphertext_bytes: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
let ct: Ciphertext =
ciborium::from_reader(ciphertext_bytes).map_err(|e| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError(format!("CBOR decode failed: {e}"))),
})?;
let Ciphertext::V1(ct) = ct;
let data_key: Zeroizing<[u8; 32]> =
if let Some(cached) = self.data_keys_cache.get(&ct.data_key_id) {
debug!(data_key_id = %ct.data_key_id, "data key cache hit");
cached.clone()
} else {
debug!(data_key_id = %ct.data_key_id, "decrypting data key from KMS");
let raw = self
.kms
.decrypt_data_key(
&ct.kms_key_id,
&ct.encrypted_data_key,
ct.data_key_id.as_bytes(),
)
.await
.map_err(|e| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError(format!("KMS decrypt_data_key failed: {e}"))),
})?;
let arr: [u8; 32] = raw.try_into().map_err(|_| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError("decrypted data key is not 32 bytes".to_owned())),
})?;
let key = Zeroizing::new(arr);
self.data_keys_cache.insert(ct.data_key_id, key.clone());
key
};
let aad_hash = build_aad_hash(aad, &ct.encrypted_data_key, &ct.kms_key_id, ct.data_key_id);
let cipher = Aes256Gcm::new(data_key.as_ref().into());
let plaintext = cipher
.decrypt(
&ct.nonce.into(),
Payload {
msg: &ct.encrypted_data,
aad: &aad_hash,
},
)
.map_err(|e| Error::Generic {
store: "SecretsManager",
source: Box::new(StringError(format!("AES-256-GCM decrypt failed: {e}"))),
})?;
Ok(plaintext)
}
}
fn build_aad_hash(
caller_aad: &[u8],
encrypted_data_key: &[u8],
master_key_id: &str,
data_key_id: Uuid,
) -> Vec<u8> {
let mut hasher = Sha512::new();
hasher.update(caller_aad);
hasher.update(encrypted_data_key);
hasher.update(master_key_id.as_bytes());
hasher.update(data_key_id.as_bytes());
hasher.finalize().to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
fn manager_with_noop() -> SecretsManager {
SecretsManager::new(Arc::new(NoopKms), "test-master-key".to_owned())
}
#[tokio::test]
async fn noop_kms_encrypt_returns_key_unchanged() {
let kms = NoopKms;
let key = b"0123456789abcdef0123456789abcdef";
let result = kms.encrypt_data_key("key-id", key, b"aad").await.unwrap();
assert_eq!(result, key);
}
#[tokio::test]
async fn noop_kms_decrypt_returns_ciphertext_unchanged() {
let kms = NoopKms;
let enc = vec![1u8, 2, 3];
let result = kms.decrypt_data_key("key-id", &enc, b"aad").await.unwrap();
assert_eq!(result, enc);
}
#[tokio::test]
async fn encrypt_decrypt_roundtrip() {
let mgr = manager_with_noop();
let plaintext = b"super-sensitive-api-key";
let aad = b"user-id-42";
let ct = mgr.encrypt(plaintext, aad).await.unwrap();
let pt = mgr.decrypt(&ct, aad).await.unwrap();
assert_eq!(pt, plaintext);
}
#[tokio::test]
async fn encrypt_produces_non_empty_ciphertext() {
let mgr = manager_with_noop();
let ct = mgr.encrypt(b"secret", b"aad").await.unwrap();
assert!(!ct.is_empty());
}
#[tokio::test]
async fn two_encryptions_of_same_plaintext_differ() {
let mgr = manager_with_noop();
let aad = b"aad";
let ct1 = mgr.encrypt(b"same-value", aad).await.unwrap();
let ct2 = mgr.encrypt(b"same-value", aad).await.unwrap();
assert_ne!(ct1, ct2);
}
#[tokio::test]
async fn encrypt_empty_plaintext_decrypts_correctly() {
let mgr = manager_with_noop();
let ct = mgr.encrypt(b"", b"aad").await.unwrap();
let pt = mgr.decrypt(&ct, b"aad").await.unwrap();
assert_eq!(pt, b"");
}
#[tokio::test]
async fn data_key_cache_is_populated_after_first_decrypt() {
let mgr = manager_with_noop();
let ct = mgr.encrypt(b"value", b"aad").await.unwrap();
assert!(mgr.data_keys_cache.is_empty());
mgr.decrypt(&ct, b"aad").await.unwrap();
assert_eq!(mgr.data_keys_cache.len(), 1);
mgr.decrypt(&ct, b"aad").await.unwrap();
assert_eq!(mgr.data_keys_cache.len(), 1);
}
#[tokio::test]
async fn decrypt_with_wrong_aad_fails() {
let mgr = manager_with_noop();
let ct = mgr.encrypt(b"secret", b"correct-aad").await.unwrap();
let result = mgr.decrypt(&ct, b"wrong-aad").await;
assert!(
result.is_err(),
"expected decryption to fail with wrong AAD"
);
}
#[tokio::test]
async fn decrypt_corrupted_ciphertext_fails() {
let mgr = manager_with_noop();
let mut ct = mgr.encrypt(b"secret", b"aad").await.unwrap();
if let Some(last) = ct.last_mut() {
*last ^= 0xFF;
}
let result = mgr.decrypt(&ct, b"aad").await;
assert!(result.is_err(), "expected failure on corrupted ciphertext");
}
#[tokio::test]
async fn decrypt_empty_bytes_fails() {
let mgr = manager_with_noop();
let result = mgr.decrypt(&[], b"aad").await;
assert!(result.is_err());
}
#[tokio::test]
async fn kms_encrypt_failure_propagates() {
let mut mock_kms = MockKms::new();
mock_kms.expect_encrypt_data_key().returning(|_, _, _| {
Err(Error::Generic {
store: "MockKms",
source: Box::new(StringError("KMS unavailable".to_owned())),
})
});
let mgr = SecretsManager::new(Arc::new(mock_kms), "key-id".to_owned());
let result = mgr.encrypt(b"secret", b"aad").await;
assert!(result.is_err());
}
#[tokio::test]
async fn kms_decrypt_failure_propagates() {
let mut mock_kms = MockKms::new();
mock_kms.expect_decrypt_data_key().returning(|_, _, _| {
Err(Error::Generic {
store: "MockKms",
source: Box::new(StringError("KMS unavailable".to_owned())),
})
});
mock_kms
.expect_encrypt_data_key()
.returning(|_, data_key, _| Ok(data_key.to_vec()));
let mgr = SecretsManager::new(Arc::new(mock_kms), "key-id".to_owned());
let ct = mgr.encrypt(b"secret", b"aad").await.unwrap();
let result = mgr.decrypt(&ct, b"aad").await;
assert!(result.is_err());
}
#[tokio::test]
async fn kms_decrypt_returns_wrong_length_fails() {
let mut mock_kms = MockKms::new();
mock_kms
.expect_encrypt_data_key()
.returning(|_, data_key, _| Ok(data_key.to_vec()));
mock_kms
.expect_decrypt_data_key()
.returning(|_, _, _| Ok(vec![1u8, 2, 3]));
let mgr = SecretsManager::new(Arc::new(mock_kms), "key-id".to_owned());
let ct = mgr.encrypt(b"secret", b"aad").await.unwrap();
let result = mgr.decrypt(&ct, b"aad").await;
assert!(result.is_err());
}
}