use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Mutex;
use arkhe_forge_core::event::RuntimeSignatureClass;
use arkhe_forge_core::pii::DekId;
use crate::crypto::Dek;
use crate::crypto_erasure::DekShredAttestation;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct KekRef {
id: String,
}
impl KekRef {
#[inline]
#[must_use]
pub fn new<S: Into<String>>(id: S) -> Self {
Self { id: id.into() }
}
#[inline]
#[must_use]
pub fn as_str(&self) -> &str {
&self.id
}
}
pub type KeyDeletionAttestation = DekShredAttestation;
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum KmsError {
#[error("KMS transport error: {0}")]
Transport(String),
#[error("KMS auth error: {0}")]
Auth(String),
#[error("KEK not found: {0}")]
KekNotFound(String),
#[error("wrapped DEK unwrap failed")]
UnwrapFailed,
#[error("KEK deletion already scheduled: {0}")]
KekDeleting(String),
#[error("KMS quota exceeded")]
QuotaExceeded,
#[error("KMS backend error: {0}")]
Backend(String),
}
pub trait KmsBackend: Send + Sync {
fn generate_dek(&self, kek_ref: &KekRef) -> Result<(DekId, Dek), KmsError>;
fn wrap_dek(&self, dek: &Dek, kek_ref: &KekRef) -> Result<Bytes, KmsError>;
fn unwrap_dek(&self, wrapped: &[u8], kek_ref: &KekRef) -> Result<Dek, KmsError>;
fn delete_key(&self, kek_ref: &KekRef) -> Result<KeyDeletionAttestation, KmsError>;
fn rotate_kek(&self, old: &KekRef, new: &KekRef) -> Result<(), KmsError>;
}
#[derive(Debug, Default)]
pub struct MockKmsBackend {
inner: Mutex<MockState>,
}
#[derive(Debug, Default)]
struct MockState {
keks: HashMap<KekRef, MockKek>,
destroyed: HashMap<KekRef, KeyDeletionAttestation>,
counter: u64,
rotations: Vec<(KekRef, KekRef)>,
destruction_log_index: u64,
}
#[derive(Debug, Clone)]
struct MockKek {
pad: [u8; 32],
}
impl MockKmsBackend {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_kek(&self, kek_ref: &KekRef) {
let mut state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if state.keks.contains_key(kek_ref) {
return;
}
let digest = blake3::keyed_hash(
blake3::hash(b"arkhe-forge-mock-kms-kek-pad").as_bytes(),
kek_ref.as_str().as_bytes(),
);
state.keks.insert(
kek_ref.clone(),
MockKek {
pad: *digest.as_bytes(),
},
);
}
#[must_use]
pub fn destroyed_count(&self) -> usize {
let state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
state.destroyed.len()
}
#[must_use]
pub fn rotation_log_len(&self) -> usize {
let state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
state.rotations.len()
}
fn xor_pad(dst: &mut [u8], pad: &[u8; 32]) {
for (i, b) in dst.iter_mut().enumerate() {
*b ^= pad[i % 32];
}
}
}
impl KmsBackend for MockKmsBackend {
fn generate_dek(&self, kek_ref: &KekRef) -> Result<(DekId, Dek), KmsError> {
let mut state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if !state.keks.contains_key(kek_ref) {
return Err(KmsError::KekNotFound(kek_ref.as_str().to_string()));
}
state.counter = state.counter.saturating_add(1);
let mut h = blake3::Hasher::new();
h.update(b"arkhe-forge-mock-kms-dek-material");
h.update(kek_ref.as_str().as_bytes());
h.update(&state.counter.to_le_bytes());
let dek_material: [u8; 32] = *h.finalize().as_bytes();
let mut id_hasher = blake3::Hasher::new();
id_hasher.update(b"arkhe-forge-mock-kms-dek-id");
id_hasher.update(kek_ref.as_str().as_bytes());
id_hasher.update(&state.counter.to_le_bytes());
let dek_id_full: [u8; 32] = *id_hasher.finalize().as_bytes();
let mut dek_id_bytes = [0u8; 16];
dek_id_bytes.copy_from_slice(&dek_id_full[..16]);
Ok((DekId(dek_id_bytes), Dek::from_bytes(dek_material)))
}
fn wrap_dek(&self, dek: &Dek, kek_ref: &KekRef) -> Result<Bytes, KmsError> {
let state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let kek = state
.keks
.get(kek_ref)
.ok_or_else(|| KmsError::KekNotFound(kek_ref.as_str().to_string()))?;
let mut buf = Vec::with_capacity(32 + 32);
buf.extend_from_slice(&kek.pad);
buf.extend_from_slice(dek.as_bytes());
let (marker, body) = buf.split_at_mut(32);
let _ = marker; Self::xor_pad(body, &kek.pad);
Ok(Bytes::from(buf))
}
fn unwrap_dek(&self, wrapped: &[u8], kek_ref: &KekRef) -> Result<Dek, KmsError> {
if wrapped.len() != 64 {
return Err(KmsError::UnwrapFailed);
}
let state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let kek = state
.keks
.get(kek_ref)
.ok_or_else(|| KmsError::KekNotFound(kek_ref.as_str().to_string()))?;
let marker = &wrapped[..32];
if marker != kek.pad {
return Err(KmsError::UnwrapFailed);
}
let mut body = [0u8; 32];
body.copy_from_slice(&wrapped[32..]);
Self::xor_pad(&mut body, &kek.pad);
Ok(Dek::from_bytes(body))
}
fn delete_key(&self, kek_ref: &KekRef) -> Result<KeyDeletionAttestation, KmsError> {
let mut state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(cached) = state.destroyed.get(kek_ref) {
return Ok(cached.clone());
}
if !state.keks.contains_key(kek_ref) {
return Err(KmsError::KekNotFound(kek_ref.as_str().to_string()));
}
let log_index = state.destruction_log_index;
state.destruction_log_index = state.destruction_log_index.saturating_add(1);
let mut h = blake3::Hasher::new();
h.update(b"arkhe-forge-mock-kms-delete-attestation");
h.update(kek_ref.as_str().as_bytes());
h.update(&log_index.to_le_bytes());
let payload: [u8; 32] = *h.finalize().as_bytes();
let attestation = DekShredAttestation {
attestation_class: RuntimeSignatureClass::Ed25519,
attestation_bytes: Bytes::copy_from_slice(&payload),
log_index: Some(log_index),
};
state.keks.remove(kek_ref);
state.destroyed.insert(kek_ref.clone(), attestation.clone());
Ok(attestation)
}
fn rotate_kek(&self, old: &KekRef, new: &KekRef) -> Result<(), KmsError> {
let mut state = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if !state.keks.contains_key(old) {
return Err(KmsError::KekNotFound(old.as_str().to_string()));
}
if !state.keks.contains_key(new) {
return Err(KmsError::KekNotFound(new.as_str().to_string()));
}
state.rotations.push((old.clone(), new.clone()));
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn mock_with_kek(id: &str) -> (MockKmsBackend, KekRef) {
let b = MockKmsBackend::new();
let k = KekRef::new(id);
b.register_kek(&k);
(b, k)
}
#[test]
fn generate_dek_without_kek_errors() {
let b = MockKmsBackend::new();
let k = KekRef::new("ghost");
let err = b.generate_dek(&k).unwrap_err();
assert!(matches!(err, KmsError::KekNotFound(_)));
}
#[test]
fn generate_dek_roundtrips_wrap_unwrap() {
let (b, k) = mock_with_kek("kek-1");
let (dek_id, dek) = b.generate_dek(&k).unwrap();
let wrapped = b.wrap_dek(&dek, &k).unwrap();
let unwrapped = b.unwrap_dek(&wrapped, &k).unwrap();
assert_eq!(dek.as_bytes(), unwrapped.as_bytes());
assert_ne!(dek_id.0, [0u8; 16]);
}
#[test]
fn unwrap_under_wrong_kek_fails() {
let (b, k1) = mock_with_kek("kek-1");
let k2 = KekRef::new("kek-2");
b.register_kek(&k2);
let (_id, dek) = b.generate_dek(&k1).unwrap();
let wrapped = b.wrap_dek(&dek, &k1).unwrap();
let err = b.unwrap_dek(&wrapped, &k2).unwrap_err();
assert!(matches!(err, KmsError::UnwrapFailed));
}
#[test]
fn delete_key_idempotent_across_retries() {
let (b, k) = mock_with_kek("kek-delete");
let first = b.delete_key(&k).unwrap();
let second = b.delete_key(&k).unwrap();
assert_eq!(first, second);
assert_eq!(b.destroyed_count(), 1);
}
#[test]
fn delete_key_unknown_errors() {
let b = MockKmsBackend::new();
let k = KekRef::new("never-existed");
let err = b.delete_key(&k).unwrap_err();
assert!(matches!(err, KmsError::KekNotFound(_)));
}
#[test]
fn rotate_kek_records_log_entry() {
let (b, k1) = mock_with_kek("kek-old");
let k2 = KekRef::new("kek-new");
b.register_kek(&k2);
assert!(b.rotate_kek(&k1, &k2).is_ok());
assert_eq!(b.rotation_log_len(), 1);
}
#[test]
fn rotate_kek_unknown_old_errors() {
let b = MockKmsBackend::new();
let k1 = KekRef::new("never-1");
let k2 = KekRef::new("never-2");
let err = b.rotate_kek(&k1, &k2).unwrap_err();
assert!(matches!(err, KmsError::KekNotFound(_)));
}
#[test]
fn generate_dek_is_deterministic_under_fixed_kek_seed() {
let b1 = MockKmsBackend::new();
let b2 = MockKmsBackend::new();
let k = KekRef::new("det-kek");
b1.register_kek(&k);
b2.register_kek(&k);
let (id1, dek1) = b1.generate_dek(&k).unwrap();
let (id2, dek2) = b2.generate_dek(&k).unwrap();
assert_eq!(id1, id2);
assert_eq!(dek1.as_bytes(), dek2.as_bytes());
}
#[test]
fn kek_ref_preserves_identity() {
let k = KekRef::new("arn:aws:kms:eu-central-1:123:key/abc");
assert_eq!(k.as_str(), "arn:aws:kms:eu-central-1:123:key/abc");
let cloned = k.clone();
assert_eq!(k, cloned);
}
#[test]
fn duplicate_register_kek_is_noop() {
let (b, k) = mock_with_kek("kek-dup");
b.register_kek(&k);
b.register_kek(&k);
assert_eq!(b.destroyed_count(), 0);
}
}