use std::sync::Arc;
use aws_config::{BehaviorVersion, Region};
use aws_sdk_kms::{
primitives::Blob,
types::{DataKeySpec, KeySpec},
Client as KmsClient,
};
use bytes::Bytes;
use tokio::runtime::{Handle, Runtime};
use arkhe_forge_core::event::RuntimeSignatureClass;
use arkhe_forge_core::pii::DekId;
use crate::crypto::Dek;
use crate::crypto_erasure::DekShredAttestation;
use super::kms_backend::{KekRef, KeyDeletionAttestation, KmsBackend, KmsError};
pub struct AwsKmsBackend {
client: KmsClient,
executor: Executor,
}
enum Executor {
Owned(Arc<Runtime>),
Borrowed(Handle),
}
impl std::fmt::Debug for AwsKmsBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsKmsBackend").finish_non_exhaustive()
}
}
impl AwsKmsBackend {
pub fn new(region: impl Into<String>) -> Result<Self, KmsError> {
let runtime = Runtime::new()
.map_err(|e| KmsError::Backend(format!("tokio runtime init failed: {e}")))?;
let region = Region::new(region.into());
let client = runtime.block_on(async move {
let conf = aws_config::defaults(BehaviorVersion::latest())
.region(region)
.load()
.await;
KmsClient::new(&conf)
});
Ok(Self {
client,
executor: Executor::Owned(Arc::new(runtime)),
})
}
pub fn new_in_runtime(handle: Handle, client: KmsClient) -> Self {
Self {
client,
executor: Executor::Borrowed(handle),
}
}
fn block_on<F, T>(&self, fut: F) -> T
where
F: std::future::Future<Output = T>,
{
match &self.executor {
Executor::Owned(rt) => rt.block_on(fut),
Executor::Borrowed(handle) => {
if Handle::try_current().is_ok() {
tokio::task::block_in_place(|| handle.block_on(fut))
} else {
handle.block_on(fut)
}
}
}
}
}
fn map_sdk_error<E: std::fmt::Display>(e: E) -> KmsError {
KmsError::Backend(format!("aws-sdk-kms: {e}"))
}
impl KmsBackend for AwsKmsBackend {
fn generate_dek(&self, kek_ref: &KekRef) -> Result<(DekId, Dek), KmsError> {
self.block_on(async {
let out = self
.client
.generate_data_key()
.key_id(kek_ref.as_str())
.key_spec(DataKeySpec::Aes256)
.send()
.await
.map_err(map_sdk_error)?;
let plaintext = out.plaintext.ok_or(KmsError::UnwrapFailed)?;
let plaintext_bytes = plaintext.into_inner();
if plaintext_bytes.len() != 32 {
return Err(KmsError::Backend(format!(
"unexpected DEK length: {}",
plaintext_bytes.len()
)));
}
let mut material = [0u8; 32];
material.copy_from_slice(&plaintext_bytes);
let ciphertext = out.ciphertext_blob.ok_or(KmsError::UnwrapFailed)?;
let id_digest = blake3::keyed_hash(
blake3::hash(b"arkhe-forge-aws-kms-dek-id").as_bytes(),
ciphertext.as_ref(),
);
let mut id_bytes = [0u8; 16];
id_bytes.copy_from_slice(&id_digest.as_bytes()[..16]);
Ok((DekId(id_bytes), Dek::from_bytes(material)))
})
}
fn wrap_dek(&self, dek: &Dek, kek_ref: &KekRef) -> Result<Bytes, KmsError> {
self.block_on(async {
let out = self
.client
.encrypt()
.key_id(kek_ref.as_str())
.plaintext(Blob::new(dek.as_bytes().to_vec()))
.send()
.await
.map_err(map_sdk_error)?;
let blob = out.ciphertext_blob.ok_or(KmsError::UnwrapFailed)?;
Ok(Bytes::copy_from_slice(blob.as_ref()))
})
}
fn unwrap_dek(&self, wrapped: &[u8], kek_ref: &KekRef) -> Result<Dek, KmsError> {
self.block_on(async {
let out = self
.client
.decrypt()
.key_id(kek_ref.as_str())
.ciphertext_blob(Blob::new(wrapped.to_vec()))
.send()
.await
.map_err(map_sdk_error)?;
let plaintext = out.plaintext.ok_or(KmsError::UnwrapFailed)?;
let bytes = plaintext.into_inner();
if bytes.len() != 32 {
return Err(KmsError::UnwrapFailed);
}
let mut material = [0u8; 32];
material.copy_from_slice(&bytes);
Ok(Dek::from_bytes(material))
})
}
fn delete_key(&self, kek_ref: &KekRef) -> Result<KeyDeletionAttestation, KmsError> {
self.block_on(async {
let out = self
.client
.schedule_key_deletion()
.key_id(kek_ref.as_str())
.pending_window_in_days(7)
.send()
.await
.map_err(map_sdk_error)?;
let key_id = out.key_id.unwrap_or_default();
let deletion_ts = out.deletion_date.map(|d| d.secs()).unwrap_or_default();
let mut h = blake3::Hasher::new();
h.update(b"arkhe-forge-aws-kms-delete-attestation");
h.update(key_id.as_bytes());
h.update(&deletion_ts.to_le_bytes());
let payload: [u8; 32] = *h.finalize().as_bytes();
Ok(DekShredAttestation {
attestation_class: RuntimeSignatureClass::Ed25519,
attestation_bytes: Bytes::copy_from_slice(&payload),
log_index: Some(deletion_ts as u64),
})
})
}
fn rotate_kek(&self, old: &KekRef, new: &KekRef) -> Result<(), KmsError> {
self.block_on(async {
for k in [old, new] {
self.client
.describe_key()
.key_id(k.as_str())
.send()
.await
.map_err(|e| KmsError::KekNotFound(format!("{}: {e}", k.as_str())))?;
}
Ok(())
})
}
}
#[allow(dead_code)]
const _KEY_SPEC: KeySpec = KeySpec::SymmetricDefault;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn constructor_owns_runtime() {
let backend = AwsKmsBackend::new("us-east-1").expect("runtime built");
assert!(matches!(backend.executor, Executor::Owned(_)));
}
#[test]
fn constructor_accepts_explicit_handle() {
let rt = Runtime::new().unwrap();
let handle = rt.handle().clone();
let client = rt.block_on(async {
let conf = aws_config::defaults(BehaviorVersion::latest())
.region(Region::new("us-east-1"))
.load()
.await;
KmsClient::new(&conf)
});
let backend = AwsKmsBackend::new_in_runtime(handle, client);
assert!(matches!(backend.executor, Executor::Borrowed(_)));
drop(backend);
drop(rt);
}
#[test]
fn debug_does_not_leak_internals() {
let backend = AwsKmsBackend::new("us-east-1").expect("runtime built");
let s = format!("{:?}", backend);
assert!(s.contains("AwsKmsBackend"));
assert!(!s.contains("Owned"));
assert!(!s.contains("Borrowed"));
}
#[test]
fn map_sdk_error_wraps_display_in_backend_variant() {
let err = map_sdk_error("boom");
assert!(matches!(err, KmsError::Backend(msg) if msg.contains("boom")));
}
}