#![cfg(feature = "vault-kms")]
use std::time::Duration;
use testcontainers::runners::AsyncRunner;
use testcontainers::ContainerAsync;
use testcontainers_modules::hashicorp_vault::HashicorpVault;
use tokio::time::sleep;
#[tokio::test]
async fn test_vault_kms_provider_basic_operations() {
use rustberg::crypto::kms::{KeyManagementService, VaultKmsProvider};
let vault: ContainerAsync<HashicorpVault> = HashicorpVault::default()
.start()
.await
.expect("Failed to start Vault container");
let host_port = vault.get_host_port_ipv4(8200).await.unwrap();
let vault_addr = format!("http://127.0.0.1:{}", host_port);
let root_token = "myroot";
sleep(Duration::from_secs(2)).await;
let client = reqwest::Client::new();
let enable_response = client
.post(format!("{}/v1/sys/mounts/transit", vault_addr))
.header("X-Vault-Token", root_token)
.json(&serde_json::json!({
"type": "transit"
}))
.send()
.await
.expect("Failed to enable transit engine");
assert!(
enable_response.status().is_success() || enable_response.status().as_u16() == 400,
"Failed to enable transit: {:?}",
enable_response.text().await
);
let key_response = client
.post(format!("{}/v1/transit/keys/rustberg-test-key", vault_addr))
.header("X-Vault-Token", root_token)
.json(&serde_json::json!({
"type": "aes256-gcm96"
}))
.send()
.await
.expect("Failed to create transit key");
assert!(
key_response.status().is_success() || key_response.status().as_u16() == 204,
"Failed to create key: {:?}",
key_response.text().await
);
let provider = VaultKmsProvider::new(
&vault_addr,
"transit",
"rustberg-test-key",
root_token.to_string(),
)
.await
.expect("Failed to create VaultKmsProvider");
provider.health_check().await.expect("Health check failed");
let dek = provider
.generate_data_key("rustberg-test-key")
.await
.expect("Failed to generate data key");
assert_eq!(
dek.plaintext.expose().len(),
32,
"DEK should be 32 bytes (256 bits)"
);
assert!(
!dek.ciphertext.is_empty(),
"Encrypted DEK should not be empty"
);
assert!(dek.version >= 1, "Version should be at least 1");
let decrypted = provider
.decrypt_data_key("rustberg-test-key", &dek.ciphertext)
.await
.expect("Failed to decrypt data key");
assert_eq!(
decrypted.expose(),
dek.plaintext.expose(),
"Decrypted key should match original"
);
let test_plaintext = b"test-data-for-encryption";
let encrypted = provider
.encrypt_data_key("rustberg-test-key", test_plaintext)
.await
.expect("Failed to encrypt data");
assert!(!encrypted.is_empty(), "Encrypted data should not be empty");
let decrypted_data = provider
.decrypt_data_key("rustberg-test-key", &encrypted)
.await
.expect("Failed to decrypt data");
assert_eq!(
decrypted_data.expose(),
test_plaintext,
"Decrypted data should match original"
);
let version = provider
.current_version("rustberg-test-key")
.await
.expect("Failed to get current version");
assert!(version >= 1, "Version should be at least 1");
println!("✅ All Vault KMS provider tests passed!");
}
#[tokio::test]
async fn test_vault_kms_key_rotation() {
use rustberg::crypto::kms::{KeyManagementService, VaultKmsProvider};
let vault: ContainerAsync<HashicorpVault> = HashicorpVault::default()
.start()
.await
.expect("Failed to start Vault container");
let host_port = vault.get_host_port_ipv4(8200).await.unwrap();
let vault_addr = format!("http://127.0.0.1:{}", host_port);
let root_token = "myroot";
sleep(Duration::from_secs(2)).await;
let client = reqwest::Client::new();
client
.post(format!("{}/v1/sys/mounts/transit", vault_addr))
.header("X-Vault-Token", root_token)
.json(&serde_json::json!({"type": "transit"}))
.send()
.await
.expect("Failed to enable transit");
client
.post(format!("{}/v1/transit/keys/rotation-test-key", vault_addr))
.header("X-Vault-Token", root_token)
.json(&serde_json::json!({"type": "aes256-gcm96"}))
.send()
.await
.expect("Failed to create key");
let provider = VaultKmsProvider::new(
&vault_addr,
"transit",
"rotation-test-key",
root_token.to_string(),
)
.await
.expect("Failed to create provider");
let initial_version = provider
.current_version("rotation-test-key")
.await
.expect("Failed to get version");
let test_data = b"data-before-rotation";
let encrypted_before = provider
.encrypt_data_key("rotation-test-key", test_data)
.await
.expect("Failed to encrypt");
let new_version = provider
.rotate_master_key("rotation-test-key")
.await
.expect("Failed to rotate key");
assert!(
new_version > initial_version,
"New version should be greater than initial"
);
let decrypted = provider
.decrypt_data_key("rotation-test-key", &encrypted_before)
.await
.expect("Failed to decrypt after rotation");
assert_eq!(
decrypted.expose(),
test_data,
"Should still decrypt data from old key version"
);
let encrypted_after = provider
.encrypt_data_key("rotation-test-key", test_data)
.await
.expect("Failed to encrypt after rotation");
let ciphertext_str = String::from_utf8_lossy(&encrypted_after);
assert!(
ciphertext_str.contains(&format!(":v{}:", new_version)),
"New encryption should use new key version"
);
println!("✅ Vault KMS key rotation tests passed!");
}
#[tokio::test]
async fn test_vault_kms_connection_error() {
use rustberg::crypto::kms::VaultKmsProvider;
let result = VaultKmsProvider::new(
"http://127.0.0.1:19999", "transit",
"test-key",
"fake-token".to_string(),
)
.await;
assert!(
result.is_err(),
"Should fail to connect to non-existent Vault"
);
let err = result.unwrap_err();
println!("Expected error: {}", err);
}
#[tokio::test]
async fn test_vault_kms_concurrent_operations() {
use rustberg::crypto::kms::{KeyManagementService, VaultKmsProvider};
use std::sync::Arc;
let vault: ContainerAsync<HashicorpVault> = HashicorpVault::default()
.start()
.await
.expect("Failed to start Vault container");
let host_port = vault.get_host_port_ipv4(8200).await.unwrap();
let vault_addr = format!("http://127.0.0.1:{}", host_port);
let root_token = "myroot";
sleep(Duration::from_secs(2)).await;
let client = reqwest::Client::new();
client
.post(format!("{}/v1/sys/mounts/transit", vault_addr))
.header("X-Vault-Token", root_token)
.json(&serde_json::json!({"type": "transit"}))
.send()
.await
.expect("Failed to enable transit");
client
.post(format!(
"{}/v1/transit/keys/concurrent-test-key",
vault_addr
))
.header("X-Vault-Token", root_token)
.json(&serde_json::json!({"type": "aes256-gcm96"}))
.send()
.await
.expect("Failed to create key");
let provider = Arc::new(
VaultKmsProvider::new(
&vault_addr,
"transit",
"concurrent-test-key",
root_token.to_string(),
)
.await
.expect("Failed to create provider"),
);
let mut handles = Vec::new();
for i in 0..10 {
let provider = Arc::clone(&provider);
let handle = tokio::spawn(async move {
let data = format!("concurrent-data-{}", i);
let dek = provider
.generate_data_key("concurrent-test-key")
.await
.expect("Failed to generate data key");
let encrypted = provider
.encrypt_data_key("concurrent-test-key", data.as_bytes())
.await
.expect("Failed to encrypt");
let decrypted = provider
.decrypt_data_key("concurrent-test-key", &encrypted)
.await
.expect("Failed to decrypt");
assert_eq!(decrypted.expose(), data.as_bytes());
assert_eq!(dek.plaintext.expose().len(), 32);
i
});
handles.push(handle);
}
let results: Vec<_> = futures::future::join_all(handles).await;
for result in results {
result.expect("Task failed");
}
println!("✅ Vault KMS concurrent operations tests passed!");
}