#![cfg(feature = "aws-kms")]
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use secure_data::kms::KeyProvider;
use secure_data::providers::aws_kms::AwsKmsKeyProvider;
use std::io::{Read, Write};
use std::net::TcpListener;
const TEST_DEK_BYTES: [u8; 32] = [0x99u8; 32];
const TEST_WRAPPED_BYTES: [u8; 32] = [0xABu8; 32];
fn test_dek_b64() -> String {
B64.encode(TEST_DEK_BYTES)
}
fn test_wrapped_b64() -> String {
B64.encode(TEST_WRAPPED_BYTES)
}
fn make_generate_response() -> String {
let dek_b64 = test_dek_b64();
let wrapped_b64 = test_wrapped_b64();
let body = format!(
r#"{{"Plaintext":"{dek_b64}","CiphertextBlob":"{wrapped_b64}","KeyId":"arn:aws:kms:us-east-1:123456789012:key/test-key-id"}}"#
);
format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/x-amz-json-1.1\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
)
}
fn make_decrypt_response() -> String {
let dek_b64 = test_dek_b64();
let body = format!(
r#"{{"Plaintext":"{dek_b64}","KeyId":"arn:aws:kms:us-east-1:123456789012:key/test-key-id"}}"#
);
format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/x-amz-json-1.1\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
)
}
fn start_mock_server(responses: Vec<String>) -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind mock server");
let port = listener.local_addr().unwrap().port();
std::thread::spawn(move || {
for response in responses {
match listener.accept() {
Ok((mut stream, _)) => {
let mut buf = [0u8; 8192];
let _ = stream.read(&mut buf);
let _ = stream.write_all(response.as_bytes());
}
Err(_) => break,
}
}
});
port
}
#[tokio::test]
async fn test_aws_kms_generate_data_key_happy_path() {
let port = start_mock_server(vec![make_generate_response()]);
let provider = AwsKmsKeyProvider::with_endpoint(format!("http://127.0.0.1:{port}")).await;
let result = provider.generate_data_key("alias/my-key").await;
let (dek, wrapped, version) = result.expect("generate_data_key must succeed");
assert_eq!(
dek.as_slice(),
&TEST_DEK_BYTES,
"DEK must match mock response"
);
assert!(!wrapped.is_empty(), "wrapped key must not be empty");
assert!(!version.is_empty(), "version must not be empty");
}
#[tokio::test]
async fn test_aws_kms_unwrap_data_key_happy_path() {
let port = start_mock_server(vec![make_decrypt_response()]);
let provider = AwsKmsKeyProvider::with_endpoint(format!("http://127.0.0.1:{port}")).await;
let wrapped = TEST_WRAPPED_BYTES.to_vec();
let result = provider
.unwrap_data_key(&wrapped, "alias/my-key", "1")
.await;
let dek = result.expect("unwrap_data_key must succeed");
assert_eq!(dek.as_slice(), &TEST_DEK_BYTES, "unwrapped DEK must match");
}
#[tokio::test]
async fn test_aws_kms_provider_unavailable() {
use secure_data::error::DataError;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let provider = AwsKmsKeyProvider::with_endpoint(format!("http://127.0.0.1:{port}")).await;
let result = provider.generate_data_key("alias/my-key").await;
assert!(
matches!(result, Err(DataError::ProviderUnavailable { .. })),
"expected ProviderUnavailable, got: {result:?}"
);
}