use std::sync::Arc;
use azure_core::http::RequestContent;
use azure_identity::DeveloperToolsCredential;
use azure_security_keyvault_keys::KeyClient;
use azure_security_keyvault_keys::models::{EncryptionAlgorithm, KeyOperationParameters};
use tracing::{debug, instrument};
use url::Url;
use crate::encryption::{EncryptionError, KeyStoreProvider};
const PROVIDER_NAME: &str = "AZURE_KEY_VAULT";
const DEFAULT_TRUSTED_KEY_VAULT_SUFFIXES: &[&str] = &[
".vault.azure.net", ".vaultcore.azure.net", ".managedhsm.azure.net", ".vault.azure.cn", ".managedhsm.azure.cn", ".vault.usgovcloudapi.net", ".managedhsm.usgovcloudapi.net", ".vault.microsoftazure.de", ];
pub struct AzureKeyVaultProvider {
credential: Arc<DeveloperToolsCredential>,
trusted_host_suffixes: Vec<String>,
}
fn default_trusted_suffixes() -> Vec<String> {
DEFAULT_TRUSTED_KEY_VAULT_SUFFIXES
.iter()
.map(|s| (*s).to_string())
.collect()
}
impl AzureKeyVaultProvider {
pub fn new() -> Result<Self, EncryptionError> {
let credential = DeveloperToolsCredential::new(None).map_err(|e| {
EncryptionError::ConfigurationError(format!("Failed to create Azure credential: {e}"))
})?;
Ok(Self {
credential,
trusted_host_suffixes: default_trusted_suffixes(),
})
}
#[must_use]
pub fn with_credential(credential: Arc<DeveloperToolsCredential>) -> Self {
Self {
credential,
trusted_host_suffixes: default_trusted_suffixes(),
}
}
#[must_use]
pub fn with_trusted_endpoints<I, S>(mut self, suffixes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.trusted_host_suffixes = suffixes.into_iter().map(Into::into).collect();
self
}
fn parse_cmk_path(
cmk_path: &str,
trusted_suffixes: &[String],
) -> Result<(String, String, Option<String>), EncryptionError> {
let url = Url::parse(cmk_path).map_err(|e| {
EncryptionError::CmkError(format!("Invalid CMK path '{cmk_path}': {e}"))
})?;
if url.scheme() != "https" {
return Err(EncryptionError::CmkError(format!(
"CMK path must use https, got scheme '{}' in '{cmk_path}'",
url.scheme()
)));
}
let host = url
.host_str()
.ok_or_else(|| EncryptionError::CmkError("CMK path missing host".into()))?;
let host_lc = host.to_ascii_lowercase();
let trusted = trusted_suffixes
.iter()
.any(|suffix| host_lc.ends_with(&suffix.to_ascii_lowercase()));
if !trusted {
return Err(EncryptionError::CmkError(format!(
"CMK host '{host}' is not a trusted Key Vault endpoint. The CMK path is \
supplied by the server; allowing an arbitrary host would let a malicious \
server redirect key operations and exfiltrate access tokens. Trusted \
suffixes: {trusted_suffixes:?}. For custom deployments use \
AzureKeyVaultProvider::with_trusted_endpoints."
)));
}
let vault_url = format!("{}://{host}", url.scheme());
let segments: Vec<&str> = url.path_segments().map(|s| s.collect()).unwrap_or_default();
if segments.len() < 2 || segments[0] != "keys" {
return Err(EncryptionError::CmkError(format!(
"Invalid CMK path format: expected /keys/<name>[/<version>], got '{}'",
url.path()
)));
}
let key_name = segments[1].to_string();
let key_version = if segments.len() >= 3 && !segments[2].is_empty() {
Some(segments[2].to_string())
} else {
None
};
Ok((vault_url, key_name, key_version))
}
fn create_client(&self, vault_url: &str) -> Result<KeyClient, EncryptionError> {
KeyClient::new(vault_url, self.credential.clone(), None).map_err(|e| {
EncryptionError::CmkError(format!("Failed to create Key Vault client: {e}"))
})
}
}
impl std::fmt::Debug for AzureKeyVaultProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureKeyVaultProvider")
.field("provider_name", &PROVIDER_NAME)
.finish_non_exhaustive()
}
}
#[async_trait::async_trait]
impl KeyStoreProvider for AzureKeyVaultProvider {
fn provider_name(&self) -> &str {
PROVIDER_NAME
}
#[instrument(skip(self, encrypted_cek), fields(cmk_path = %cmk_path, algorithm = %algorithm))]
async fn decrypt_cek(
&self,
cmk_path: &str,
algorithm: &str,
encrypted_cek: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
debug!("Decrypting CEK using Azure Key Vault");
let (vault_url, key_name, key_version) =
Self::parse_cmk_path(cmk_path, &self.trusted_host_suffixes)?;
let client = self.create_client(&vault_url)?;
let kv_algorithm = map_algorithm(algorithm)?;
let envelope = crate::cek_envelope::parse(encrypted_cek)?;
let digest = envelope.signed_digest();
let valid = self
.verify_signature(cmk_path, &digest, envelope.signature)
.await?;
if !valid {
return Err(EncryptionError::CekDecryptionFailed(
"CEK envelope signature verification failed".into(),
));
}
let parameters = KeyOperationParameters {
algorithm: Some(kv_algorithm),
value: Some(envelope.ciphertext.to_vec()),
..Default::default()
};
let version = key_version.ok_or_else(|| {
EncryptionError::CmkError(
"CMK path must include key version (e.g., /keys/<name>/<version>)".into(),
)
})?;
let request_content: RequestContent<KeyOperationParameters> =
parameters.try_into().map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("Failed to create request: {e}"))
})?;
let result = client
.unwrap_key(&key_name, &version, request_content, None)
.await
.map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("Key Vault unwrap failed: {e}"))
})?
.into_model()
.map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("Failed to parse response: {e}"))
})?;
let decrypted = result.result.ok_or_else(|| {
EncryptionError::CekDecryptionFailed("Key Vault unwrap returned no result".into())
})?;
debug!("Successfully decrypted CEK using Azure Key Vault");
Ok(decrypted)
}
#[instrument(skip(self, data), fields(cmk_path = %cmk_path))]
async fn sign_data(&self, cmk_path: &str, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
debug!("Signing data using Azure Key Vault");
let (vault_url, key_name, key_version) =
Self::parse_cmk_path(cmk_path, &self.trusted_host_suffixes)?;
let client = self.create_client(&vault_url)?;
use azure_security_keyvault_keys::models::{
KeyClientSignOptions, SignParameters, SignatureAlgorithm,
};
let parameters = SignParameters {
algorithm: Some(SignatureAlgorithm::Rs256),
value: Some(data.to_vec()),
};
let version = key_version.ok_or_else(|| {
EncryptionError::CmkError("CMK path must include key version for sign operation".into())
})?;
let request_content: RequestContent<SignParameters> = parameters
.try_into()
.map_err(|e| EncryptionError::CmkError(format!("Failed to create request: {e}")))?;
let sign_options = KeyClientSignOptions {
key_version: Some(version),
..Default::default()
};
let result = client
.sign(&key_name, request_content, Some(sign_options))
.await
.map_err(|e| EncryptionError::CmkError(format!("Key Vault sign failed: {e}")))?
.into_model()
.map_err(|e| EncryptionError::CmkError(format!("Failed to parse response: {e}")))?;
let signature = result
.result
.ok_or_else(|| EncryptionError::CmkError("Key Vault sign returned no result".into()))?;
debug!("Successfully signed data using Azure Key Vault");
Ok(signature)
}
#[instrument(skip(self, data, signature), fields(cmk_path = %cmk_path))]
async fn verify_signature(
&self,
cmk_path: &str,
data: &[u8],
signature: &[u8],
) -> Result<bool, EncryptionError> {
debug!("Verifying signature using Azure Key Vault");
let (vault_url, key_name, key_version) =
Self::parse_cmk_path(cmk_path, &self.trusted_host_suffixes)?;
let client = self.create_client(&vault_url)?;
use azure_security_keyvault_keys::models::{SignatureAlgorithm, VerifyParameters};
let parameters = VerifyParameters {
algorithm: Some(SignatureAlgorithm::Rs256),
digest: Some(data.to_vec()),
signature: Some(signature.to_vec()),
};
let version = key_version.ok_or_else(|| {
EncryptionError::CmkError(
"CMK path must include key version for verify operation".into(),
)
})?;
let request_content: RequestContent<VerifyParameters> = parameters
.try_into()
.map_err(|e| EncryptionError::CmkError(format!("Failed to create request: {e}")))?;
let result = client
.verify(&key_name, &version, request_content, None)
.await
.map_err(|e| EncryptionError::CmkError(format!("Key Vault verify failed: {e}")))?
.into_model()
.map_err(|e| EncryptionError::CmkError(format!("Failed to parse response: {e}")))?;
let is_valid = result.value.unwrap_or(false);
debug!("Signature verification result: {}", is_valid);
Ok(is_valid)
}
}
fn map_algorithm(algorithm: &str) -> Result<EncryptionAlgorithm, EncryptionError> {
match algorithm.to_uppercase().as_str() {
"RSA_OAEP" | "RSA-OAEP" => Ok(EncryptionAlgorithm::RsaOaep),
"RSA_OAEP_256" | "RSA-OAEP-256" => Ok(EncryptionAlgorithm::RsaOaep256),
"RSA1_5" | "RSA-1_5" => Ok(EncryptionAlgorithm::Rsa1_5),
_ => Err(EncryptionError::ConfigurationError(format!(
"Unsupported key encryption algorithm: {algorithm}. Expected RSA_OAEP, RSA_OAEP_256, or RSA1_5"
))),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn trusted() -> Vec<String> {
default_trusted_suffixes()
}
#[test]
fn test_parse_cmk_path() {
let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/keys/mykey/abc123",
&trusted(),
)
.expect("valid CMK path with version should parse");
assert_eq!(vault, "https://myvault.vault.azure.net");
assert_eq!(name, "mykey");
assert_eq!(version, Some("abc123".to_string()));
let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/keys/mykey",
&trusted(),
)
.expect("valid CMK path without version should parse");
assert_eq!(vault, "https://myvault.vault.azure.net");
assert_eq!(name, "mykey");
assert_eq!(version, None);
let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/keys/mykey/",
&trusted(),
)
.expect("valid CMK path with trailing slash should parse");
assert_eq!(vault, "https://myvault.vault.azure.net");
assert_eq!(name, "mykey");
assert_eq!(version, None);
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://myhsm.managedhsm.azure.net/keys/mykey",
&trusted(),
)
.is_ok()
);
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.usgovcloudapi.net/keys/mykey",
&trusted(),
)
.is_ok()
);
}
#[test]
fn test_parse_cmk_path_invalid() {
assert!(AzureKeyVaultProvider::parse_cmk_path("not-a-url", &trusted()).is_err());
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/secrets/mysecret",
&trusted(),
)
.is_err()
);
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/keys",
&trusted(),
)
.is_err()
);
}
#[test]
fn test_parse_cmk_path_rejects_untrusted_host() {
let err = AzureKeyVaultProvider::parse_cmk_path(
"https://attacker.example.com/keys/mykey",
&trusted(),
)
.expect_err("untrusted host must be rejected");
assert!(err.to_string().contains("not a trusted Key Vault endpoint"));
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://vault.azure.net.attacker.com/keys/mykey",
&trusted(),
)
.is_err()
);
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"http://myvault.vault.azure.net/keys/mykey",
&trusted(),
)
.is_err()
);
}
#[test]
fn test_with_trusted_endpoints_override() {
let custom = vec![".vault.contoso.example".to_string()];
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://kv1.vault.contoso.example/keys/mykey",
&custom,
)
.is_ok()
);
assert!(
AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/keys/mykey",
&custom,
)
.is_err()
);
}
#[test]
fn test_map_algorithm() {
assert!(matches!(
map_algorithm("RSA_OAEP").expect("RSA_OAEP should be a valid algorithm"),
EncryptionAlgorithm::RsaOaep
));
assert!(matches!(
map_algorithm("RSA-OAEP").expect("RSA-OAEP should be a valid algorithm"),
EncryptionAlgorithm::RsaOaep
));
assert!(matches!(
map_algorithm("RSA_OAEP_256").expect("RSA_OAEP_256 should be a valid algorithm"),
EncryptionAlgorithm::RsaOaep256
));
assert!(matches!(
map_algorithm("rsa_oaep").expect("lowercase rsa_oaep should be valid"),
EncryptionAlgorithm::RsaOaep
));
assert!(map_algorithm("UNKNOWN").is_err());
}
}