#![cfg(feature = "aws-kms")]
use std::time::Duration;
use testcontainers::runners::AsyncRunner;
use testcontainers::ContainerAsync;
use testcontainers_modules::localstack::LocalStack;
use tokio::time::sleep;
#[tokio::test]
async fn test_aws_kms_provider_basic_operations() {
use rustberg::crypto::kms::{AwsKmsProvider, KeyManagementService};
let localstack: ContainerAsync<LocalStack> = LocalStack::default()
.start()
.await
.expect("Failed to start LocalStack container");
let host_port = localstack.get_host_port_ipv4(4566).await.unwrap();
let endpoint_url = format!("http://127.0.0.1:{}", host_port);
sleep(Duration::from_secs(3)).await;
let client = reqwest::Client::new();
let create_key_response = client
.post(format!("{}/", endpoint_url))
.header("Content-Type", "application/x-amz-json-1.1")
.header("X-Amz-Target", "TrentService.CreateKey")
.body(r#"{"Description": "Test key for Rustberg", "KeyUsage": "ENCRYPT_DECRYPT", "KeySpec": "SYMMETRIC_DEFAULT"}"#)
.send()
.await
.expect("Failed to create KMS key");
assert!(
create_key_response.status().is_success(),
"Failed to create KMS key: {:?}",
create_key_response.text().await
);
let key_response: serde_json::Value = client
.post(format!("{}/", endpoint_url))
.header("Content-Type", "application/x-amz-json-1.1")
.header("X-Amz-Target", "TrentService.CreateKey")
.body(r#"{"Description": "Test key for Rustberg", "KeyUsage": "ENCRYPT_DECRYPT", "KeySpec": "SYMMETRIC_DEFAULT"}"#)
.send()
.await
.expect("Failed to create KMS key")
.json()
.await
.expect("Failed to parse key response");
let key_id = key_response["KeyMetadata"]["KeyId"]
.as_str()
.expect("No KeyId in response");
println!("Created KMS key: {}", key_id);
std::env::set_var("AWS_ACCESS_KEY_ID", "test");
std::env::set_var("AWS_SECRET_ACCESS_KEY", "test");
std::env::set_var("AWS_DEFAULT_REGION", "us-east-1");
let provider = AwsKmsProvider::new("us-east-1", key_id, Some(endpoint_url.clone()))
.await
.expect("Failed to create AwsKmsProvider");
provider.health_check().await.expect("Health check failed");
let dek = provider
.generate_data_key(key_id)
.await
.expect("Failed to generate data key");
assert_eq!(
dek.plaintext.expose().len(),
32,
"DEK should be 32 bytes (256 bits)"
);
assert!(
!dek.ciphertext.is_empty(),
"Encrypted DEK should not be empty"
);
assert!(dek.version >= 1, "Version should be at least 1");
let decrypted = provider
.decrypt_data_key(key_id, &dek.ciphertext)
.await
.expect("Failed to decrypt data key");
assert_eq!(
decrypted.expose(),
dek.plaintext.expose(),
"Decrypted key should match original"
);
let test_plaintext = b"test-data-for-encryption-aws";
let encrypted = provider
.encrypt_data_key(key_id, test_plaintext)
.await
.expect("Failed to encrypt data");
assert!(!encrypted.is_empty(), "Encrypted data should not be empty");
let decrypted_data = provider
.decrypt_data_key(key_id, &encrypted)
.await
.expect("Failed to decrypt data");
assert_eq!(
decrypted_data.expose(),
test_plaintext,
"Decrypted data should match original"
);
std::env::remove_var("AWS_ACCESS_KEY_ID");
std::env::remove_var("AWS_SECRET_ACCESS_KEY");
std::env::remove_var("AWS_DEFAULT_REGION");
println!("✅ All AWS KMS provider tests passed!");
}
#[tokio::test]
async fn test_aws_kms_envelope_encryption() {
use rustberg::crypto::kms::{AwsKmsProvider, EncryptedEnvelope};
let localstack: ContainerAsync<LocalStack> = LocalStack::default()
.start()
.await
.expect("Failed to start LocalStack container");
let host_port = localstack.get_host_port_ipv4(4566).await.unwrap();
let endpoint_url = format!("http://127.0.0.1:{}", host_port);
sleep(Duration::from_secs(3)).await;
let client = reqwest::Client::new();
let key_response: serde_json::Value = client
.post(format!("{}/", endpoint_url))
.header("Content-Type", "application/x-amz-json-1.1")
.header("X-Amz-Target", "TrentService.CreateKey")
.body(r#"{"Description": "Envelope test key", "KeyUsage": "ENCRYPT_DECRYPT"}"#)
.send()
.await
.expect("Failed to create KMS key")
.json()
.await
.expect("Failed to parse key response");
let key_id = key_response["KeyMetadata"]["KeyId"]
.as_str()
.expect("No KeyId in response");
std::env::set_var("AWS_ACCESS_KEY_ID", "test");
std::env::set_var("AWS_SECRET_ACCESS_KEY", "test");
std::env::set_var("AWS_DEFAULT_REGION", "us-east-1");
let provider = AwsKmsProvider::new("us-east-1", key_id, Some(endpoint_url))
.await
.expect("Failed to create AwsKmsProvider");
let plaintext = b"Sensitive table metadata that needs encryption";
let envelope = EncryptedEnvelope::encrypt(&provider, key_id, plaintext)
.await
.expect("Failed to encrypt envelope");
assert!(!envelope.ciphertext.is_empty());
assert!(!envelope.wrapped_dek.is_empty());
let decrypted = envelope
.decrypt(&provider)
.await
.expect("Failed to decrypt envelope");
assert_eq!(decrypted, plaintext);
std::env::remove_var("AWS_ACCESS_KEY_ID");
std::env::remove_var("AWS_SECRET_ACCESS_KEY");
std::env::remove_var("AWS_DEFAULT_REGION");
println!("✅ AWS KMS envelope encryption test passed!");
}
#[tokio::test]
async fn test_aws_kms_key_not_found() {
use rustberg::crypto::kms::{AwsKmsProvider, KeyManagementService};
let localstack: ContainerAsync<LocalStack> = LocalStack::default()
.start()
.await
.expect("Failed to start LocalStack container");
let host_port = localstack
.get_host_port_ipv4(4566)
.await
.expect("Failed to get LocalStack port");
let endpoint_url = format!("http://127.0.0.1:{}", host_port);
sleep(Duration::from_secs(3)).await;
std::env::set_var("AWS_ACCESS_KEY_ID", "test");
std::env::set_var("AWS_SECRET_ACCESS_KEY", "test");
std::env::set_var("AWS_DEFAULT_REGION", "us-east-1");
let result =
AwsKmsProvider::new("us-east-1", "non-existent-key-12345", Some(endpoint_url)).await;
if let Ok(provider) = result {
let gen_result = provider.generate_data_key("non-existent-key-12345").await;
assert!(gen_result.is_err(), "Should fail with non-existent key");
}
std::env::remove_var("AWS_ACCESS_KEY_ID");
std::env::remove_var("AWS_SECRET_ACCESS_KEY");
std::env::remove_var("AWS_DEFAULT_REGION");
println!("✅ AWS KMS error handling test passed!");
}
#[tokio::test]
async fn test_aws_kms_concurrent_operations() {
use rustberg::crypto::kms::{AwsKmsProvider, KeyManagementService};
use std::sync::Arc;
let localstack: ContainerAsync<LocalStack> = LocalStack::default()
.start()
.await
.expect("Failed to start LocalStack container");
let host_port = localstack.get_host_port_ipv4(4566).await.unwrap();
let endpoint_url = format!("http://127.0.0.1:{}", host_port);
sleep(Duration::from_secs(3)).await;
let client = reqwest::Client::new();
let key_response: serde_json::Value = client
.post(format!("{}/", endpoint_url))
.header("Content-Type", "application/x-amz-json-1.1")
.header("X-Amz-Target", "TrentService.CreateKey")
.body(r#"{"Description": "Concurrent test key", "KeyUsage": "ENCRYPT_DECRYPT"}"#)
.send()
.await
.expect("Failed to create KMS key")
.json()
.await
.expect("Failed to parse key response");
let key_id = key_response["KeyMetadata"]["KeyId"]
.as_str()
.expect("No KeyId in response")
.to_string();
std::env::set_var("AWS_ACCESS_KEY_ID", "test");
std::env::set_var("AWS_SECRET_ACCESS_KEY", "test");
std::env::set_var("AWS_DEFAULT_REGION", "us-east-1");
let provider = Arc::new(
AwsKmsProvider::new("us-east-1", &key_id, Some(endpoint_url))
.await
.expect("Failed to create AwsKmsProvider"),
);
let mut handles = vec![];
for i in 0..10 {
let provider = provider.clone();
let key_id = key_id.clone();
handles.push(tokio::spawn(async move {
let dek = provider
.generate_data_key(&key_id)
.await
.unwrap_or_else(|_| panic!("Failed to generate DEK {}", i));
let decrypted = provider
.decrypt_data_key(&key_id, &dek.ciphertext)
.await
.unwrap_or_else(|_| panic!("Failed to decrypt DEK {}", i));
assert_eq!(decrypted.expose(), dek.plaintext.expose());
i
}));
}
for handle in handles {
handle.await.expect("Task panicked");
}
std::env::remove_var("AWS_ACCESS_KEY_ID");
std::env::remove_var("AWS_SECRET_ACCESS_KEY");
std::env::remove_var("AWS_DEFAULT_REGION");
println!("✅ AWS KMS concurrent operations test passed!");
}