use tracing::trace;
use crate::core::cipher::Cipher;
use crate::error::{CipherError, Result};
#[cfg(feature = "aws")]
#[allow(dead_code)]
pub struct AwsKms {
key_id: String,
}
#[cfg(feature = "aws")]
impl AwsKms {
#[allow(dead_code)]
pub fn new(key_id: String) -> Self {
Self { key_id }
}
}
#[cfg(feature = "aws")]
impl Cipher for AwsKms {
type Recipient = String;
type Identity = ();
fn name(&self) -> &'static str {
"aws-kms"
}
fn encrypt(&self, plaintext: &str, _recipients: &[String]) -> Result<String> {
use ::base64::Engine;
trace!(
key_id = %self.key_id,
plaintext_len = plaintext.len(),
"encrypting with AWS KMS"
);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| {
CipherError::EncryptionFailed(format!("failed to create runtime: {}", e))
})?;
rt.block_on(async {
let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let client = aws_sdk_kms::Client::new(&config);
let result = client
.encrypt()
.key_id(&self.key_id)
.plaintext(aws_sdk_kms::primitives::Blob::new(plaintext.as_bytes()))
.send()
.await
.map_err(|e| CipherError::EncryptionFailed(format!("KMS encrypt failed: {}", e)))?;
let blob = result
.ciphertext_blob()
.ok_or_else(|| CipherError::EncryptionFailed("no ciphertext returned".into()))?;
let encoded = ::base64::engine::general_purpose::STANDARD.encode(blob.as_ref());
trace!(ciphertext_len = encoded.len(), "encrypted with AWS KMS");
Ok(encoded)
})
}
fn decrypt(&self, ciphertext: &str, _identity: &()) -> Result<String> {
trace!(ciphertext_len = ciphertext.len(), "decrypting with AWS KMS");
use ::base64::Engine;
let blob = ::base64::engine::general_purpose::STANDARD
.decode(ciphertext)
.map_err(|e| CipherError::DecryptionFailed(format!("invalid base64: {}", e)))?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| {
CipherError::DecryptionFailed(format!("failed to create runtime: {}", e))
})?;
rt.block_on(async {
let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
let client = aws_sdk_kms::Client::new(&config);
let result = client
.decrypt()
.ciphertext_blob(aws_sdk_kms::primitives::Blob::new(blob))
.send()
.await
.map_err(|e| CipherError::DecryptionFailed(format!("KMS decrypt failed: {}", e)))?;
let plaintext_blob = result
.plaintext()
.ok_or_else(|| CipherError::DecryptionFailed("no plaintext returned".into()))?;
let plaintext = String::from_utf8(plaintext_blob.as_ref().to_vec())
.map_err(|e| CipherError::DecryptionFailed(format!("UTF-8 error: {}", e)))?;
trace!(plaintext_len = plaintext.len(), "decrypted with AWS KMS");
Ok(plaintext)
})
}
}