use crate::CryptoError;
use crate::encryption::{EncryptionKey, KEY_LENGTH, KeyEncryptionKey, WrappedKey};
use std::collections::HashMap;
use std::sync::Mutex;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct KmsKeyRef(String);
impl KmsKeyRef {
pub fn new(s: impl Into<String>) -> Self {
let s: String = s.into();
assert!(!s.is_empty(), "KmsKeyRef cannot be empty");
Self(s)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SealedKey(Vec<u8>);
impl SealedKey {
pub fn new(bytes: Vec<u8>) -> Self {
assert!(!bytes.is_empty(), "SealedKey cannot be empty");
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn into_bytes(self) -> Vec<u8> {
self.0
}
}
#[derive(Debug, Error)]
pub enum KmsError {
#[error("KMS key not found: {0}")]
KeyNotFound(String),
#[error("KMS authentication failure: {0}")]
AuthError(String),
#[error("KMS network / IO error: {0}")]
Transport(String),
#[error("KMS rejected the operation: {0}")]
OperationDenied(String),
#[error("KMS returned malformed ciphertext")]
Malformed,
#[error("crypto error: {0}")]
Crypto(#[from] CryptoError),
}
pub trait KmsProvider: Send + Sync {
fn seal(&self, key_ref: &KmsKeyRef, plaintext: &[u8]) -> Result<SealedKey, KmsError>;
fn open(&self, key_ref: &KmsKeyRef, sealed: &SealedKey) -> Result<Vec<u8>, KmsError>;
fn provider_name(&self) -> &'static str;
}
pub struct KmsMasterKey<P: KmsProvider> {
provider: P,
key_ref: KmsKeyRef,
}
impl<P: KmsProvider> KmsMasterKey<P> {
pub fn new(provider: P, key_ref: KmsKeyRef) -> Self {
Self { provider, key_ref }
}
pub fn provider_name(&self) -> &'static str {
self.provider.provider_name()
}
pub fn key_ref(&self) -> &KmsKeyRef {
&self.key_ref
}
pub fn seal_raw(&self, plaintext: &[u8]) -> Result<SealedKey, KmsError> {
self.provider.seal(&self.key_ref, plaintext)
}
pub fn open_raw(&self, sealed: &SealedKey) -> Result<Vec<u8>, KmsError> {
self.provider.open(&self.key_ref, sealed)
}
pub fn generate_sealed_kek(&self) -> Result<(KeyEncryptionKey, SealedKey), KmsError> {
let mut bytes = [0u8; KEY_LENGTH];
crate::encryption::fill_random(&mut bytes);
let sealed = self.provider.seal(&self.key_ref, &bytes)?;
let kek = KeyEncryptionKey::from_bytes(&bytes);
bytes.fill(0);
Ok((kek, sealed))
}
pub fn restore_sealed_kek(&self, sealed: &SealedKey) -> Result<KeyEncryptionKey, KmsError> {
let plaintext = self.provider.open(&self.key_ref, sealed)?;
let bytes: [u8; KEY_LENGTH] = plaintext
.as_slice()
.try_into()
.map_err(|_| KmsError::Malformed)?;
Ok(KeyEncryptionKey::from_bytes(&bytes))
}
pub fn rotate_kek(
&self,
sealed_under_old: &SealedKey,
new_root: &KmsMasterKey<P>,
) -> Result<SealedKey, KmsError> {
let kek_bytes = self.provider.open(&self.key_ref, sealed_under_old)?;
new_root.provider.seal(&new_root.key_ref, &kek_bytes)
}
}
pub struct InMemoryKms {
keys: Mutex<HashMap<KmsKeyRef, EncryptionKey>>,
}
impl Default for InMemoryKms {
fn default() -> Self {
Self::new()
}
}
impl InMemoryKms {
pub fn new() -> Self {
Self {
keys: Mutex::new(HashMap::new()),
}
}
pub fn create_key(&self, key_ref: KmsKeyRef) -> Result<(), KmsError> {
let mut keys = self.keys.lock().expect("InMemoryKms mutex poisoned");
if keys.contains_key(&key_ref) {
return Err(KmsError::OperationDenied(format!(
"key already exists: {}",
key_ref.as_str()
)));
}
keys.insert(key_ref, EncryptionKey::generate());
Ok(())
}
}
impl KmsProvider for InMemoryKms {
fn seal(&self, key_ref: &KmsKeyRef, plaintext: &[u8]) -> Result<SealedKey, KmsError> {
let keys = self.keys.lock().expect("InMemoryKms mutex poisoned");
let key = keys
.get(key_ref)
.ok_or_else(|| KmsError::KeyNotFound(key_ref.as_str().to_string()))?;
let bytes: [u8; KEY_LENGTH] = plaintext.try_into().map_err(|_| KmsError::Malformed)?;
let wrapped = WrappedKey::new(key, &bytes);
Ok(SealedKey::new(wrapped.to_bytes().to_vec()))
}
fn open(&self, key_ref: &KmsKeyRef, sealed: &SealedKey) -> Result<Vec<u8>, KmsError> {
let keys = self.keys.lock().expect("InMemoryKms mutex poisoned");
let key = keys
.get(key_ref)
.ok_or_else(|| KmsError::KeyNotFound(key_ref.as_str().to_string()))?;
let bytes: [u8; crate::encryption::WRAPPED_KEY_LENGTH] = sealed
.as_bytes()
.try_into()
.map_err(|_| KmsError::Malformed)?;
let wrapped = WrappedKey::from_bytes(&bytes);
let plaintext = wrapped.unwrap_key(key)?;
Ok(plaintext.to_vec())
}
fn provider_name(&self) -> &'static str {
"in-memory-kms"
}
}
pub mod aws_kms_integration {}
pub mod gcp_kms_integration {}
pub mod azure_key_vault_integration {}
#[cfg(test)]
mod tests {
use super::*;
fn fresh_kms() -> (InMemoryKms, KmsKeyRef) {
let kms = InMemoryKms::new();
let key_ref = KmsKeyRef::new("test-root-key");
kms.create_key(key_ref.clone()).unwrap();
(kms, key_ref)
}
#[test]
fn seal_then_open_round_trips() {
let (kms, key_ref) = fresh_kms();
let plaintext = [0x42u8; KEY_LENGTH];
let sealed = kms.seal(&key_ref, &plaintext).unwrap();
let recovered = kms.open(&key_ref, &sealed).unwrap();
assert_eq!(recovered.as_slice(), plaintext.as_slice());
}
#[test]
fn open_with_unknown_key_ref_fails() {
let (kms, _) = fresh_kms();
let sealed = SealedKey::new(vec![0u8; crate::encryption::WRAPPED_KEY_LENGTH]);
let other = KmsKeyRef::new("unprovisioned");
let err = kms.open(&other, &sealed).unwrap_err();
assert!(matches!(err, KmsError::KeyNotFound(_)));
}
#[test]
fn create_key_twice_rejects_duplicate() {
let (kms, key_ref) = fresh_kms();
let err = kms.create_key(key_ref).unwrap_err();
assert!(matches!(err, KmsError::OperationDenied(_)));
}
#[test]
fn kms_master_key_seals_and_restores_tenant_kek() {
let kms = InMemoryKms::new();
let key_ref = KmsKeyRef::new("tenant-root");
kms.create_key(key_ref.clone()).unwrap();
let master = KmsMasterKey::new(kms, key_ref);
let (kek, sealed) = master.generate_sealed_kek().unwrap();
let kek_bytes = kek.to_bytes();
drop(kek);
let restored = master.restore_sealed_kek(&sealed).unwrap();
assert_eq!(restored.to_bytes(), kek_bytes);
assert_eq!(master.provider_name(), "in-memory-kms");
}
struct ArcKms(std::sync::Arc<InMemoryKms>);
impl KmsProvider for ArcKms {
fn seal(&self, key_ref: &KmsKeyRef, plaintext: &[u8]) -> Result<SealedKey, KmsError> {
self.0.seal(key_ref, plaintext)
}
fn open(&self, key_ref: &KmsKeyRef, sealed: &SealedKey) -> Result<Vec<u8>, KmsError> {
self.0.open(key_ref, sealed)
}
fn provider_name(&self) -> &'static str {
"arc-in-memory-kms"
}
}
#[test]
fn kek_rotation_to_new_root_re_seals() {
let kms = std::sync::Arc::new(InMemoryKms::new());
let old_ref = KmsKeyRef::new("root-old");
let new_ref = KmsKeyRef::new("root-new");
kms.create_key(old_ref.clone()).unwrap();
kms.create_key(new_ref.clone()).unwrap();
let old_master = KmsMasterKey::new(ArcKms(kms.clone()), old_ref);
let new_master = KmsMasterKey::new(ArcKms(kms.clone()), new_ref);
let (kek, sealed_old) = old_master.generate_sealed_kek().unwrap();
let kek_bytes = kek.to_bytes();
drop(kek);
let sealed_new = old_master.rotate_kek(&sealed_old, &new_master).unwrap();
let restored = new_master.restore_sealed_kek(&sealed_new).unwrap();
assert_eq!(restored.to_bytes(), kek_bytes);
}
#[test]
#[should_panic(expected = "KmsKeyRef cannot be empty")]
fn empty_key_ref_panics() {
let _ = KmsKeyRef::new("");
}
#[test]
#[should_panic(expected = "SealedKey cannot be empty")]
fn empty_sealed_key_panics() {
let _ = SealedKey::new(vec![]);
}
}