use base64::{engine::general_purpose::STANDARD as B64, Engine};
use zeroize::Zeroizing;
use crate::error::DataError;
use crate::kms::{DataKey, KeyAlias, WrappedDataKey};
pub struct AwsKmsKeyProvider {
client: aws_sdk_kms::Client,
}
impl AwsKmsKeyProvider {
pub async fn new() -> Self {
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.load()
.await;
let client = aws_sdk_kms::Client::new(&config);
Self { client }
}
pub async fn with_endpoint(endpoint_url: impl Into<String>) -> Self {
let endpoint = endpoint_url.into();
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.endpoint_url(&endpoint)
.region(aws_config::Region::new("us-east-1"))
.credentials_provider(aws_sdk_kms::config::Credentials::new(
"test-key-id",
"test-secret-key",
None,
None,
"test",
))
.load()
.await;
let client = aws_sdk_kms::Client::new(&config);
Self { client }
}
#[must_use]
pub fn with_client(client: aws_sdk_kms::Client) -> Self {
Self { client }
}
}
impl crate::kms::private::Sealed for AwsKmsKeyProvider {}
impl crate::kms::KeyProvider for AwsKmsKeyProvider {
fn generate_data_key(
&self,
alias: &KeyAlias,
) -> impl std::future::Future<Output = Result<(DataKey, WrappedDataKey, String), DataError>> + Send
{
let client = self.client.clone();
let key_id = alias.to_string();
async move {
let resp = client
.generate_data_key()
.key_id(&key_id)
.key_spec(aws_sdk_kms::types::DataKeySpec::Aes256)
.send()
.await
.map_err(|e| map_kms_error(e.into()))?;
let plaintext_blob =
resp.plaintext()
.ok_or_else(|| DataError::ProviderUnavailable {
provider: "aws-kms".to_string(),
reason: "GenerateDataKey response missing Plaintext".to_string(),
})?;
let ciphertext_blob =
resp.ciphertext_blob()
.ok_or_else(|| DataError::ProviderUnavailable {
provider: "aws-kms".to_string(),
reason: "GenerateDataKey response missing CiphertextBlob".to_string(),
})?;
let dek = Zeroizing::new(plaintext_blob.as_ref().to_vec());
let wrapped = ciphertext_blob.as_ref().to_vec();
let version = "1".to_string();
Ok((dek, wrapped, version))
}
}
fn unwrap_data_key(
&self,
wrapped: &WrappedDataKey,
_alias: &KeyAlias,
_version: &str,
) -> impl std::future::Future<Output = Result<DataKey, DataError>> + Send {
let client = self.client.clone();
let ciphertext_blob = aws_sdk_kms::primitives::Blob::new(wrapped.clone());
async move {
let resp = client
.decrypt()
.ciphertext_blob(ciphertext_blob)
.send()
.await
.map_err(|e| map_kms_error(e.into()))?;
let plaintext_blob =
resp.plaintext()
.ok_or_else(|| DataError::ProviderUnavailable {
provider: "aws-kms".to_string(),
reason: "Decrypt response missing Plaintext".to_string(),
})?;
Ok(Zeroizing::new(plaintext_blob.as_ref().to_vec()))
}
}
}
fn map_kms_error(e: Box<dyn std::error::Error + Send + Sync>) -> DataError {
let msg = e.to_string();
if msg.contains("AccessDeniedException")
|| msg.contains("InvalidSignatureException")
|| msg.contains("AuthFailure")
|| msg.contains("UnauthorizedOperation")
{
return DataError::ProviderAuthError {
provider: "aws-kms".to_string(),
reason: msg,
};
}
DataError::ProviderUnavailable {
provider: "aws-kms".to_string(),
reason: msg,
}
}
#[allow(dead_code)]
fn decode_b64(value: &str) -> Result<Vec<u8>, DataError> {
B64.decode(value)
.map_err(|e| DataError::ProviderUnavailable {
provider: "aws-kms".to_string(),
reason: format!("base64 decode error: {e}"),
})
}