use std::sync::Arc;
use azure_core::http::RequestContent;
use azure_identity::DeveloperToolsCredential;
use azure_security_keyvault_keys::KeyClient;
use azure_security_keyvault_keys::models::{EncryptionAlgorithm, KeyOperationParameters};
use tracing::{debug, instrument};
use url::Url;
use crate::encryption::{EncryptionError, KeyStoreProvider};
const PROVIDER_NAME: &str = "AZURE_KEY_VAULT";
pub struct AzureKeyVaultProvider {
credential: Arc<DeveloperToolsCredential>,
}
impl AzureKeyVaultProvider {
pub fn new() -> Result<Self, EncryptionError> {
let credential = DeveloperToolsCredential::new(None).map_err(|e| {
EncryptionError::ConfigurationError(format!("Failed to create Azure credential: {e}"))
})?;
Ok(Self { credential })
}
#[must_use]
pub fn with_credential(credential: Arc<DeveloperToolsCredential>) -> Self {
Self { credential }
}
fn parse_cmk_path(cmk_path: &str) -> Result<(String, String, Option<String>), EncryptionError> {
let url = Url::parse(cmk_path).map_err(|e| {
EncryptionError::CmkError(format!("Invalid CMK path '{cmk_path}': {e}"))
})?;
let vault_url = format!(
"{}://{}",
url.scheme(),
url.host_str()
.ok_or_else(|| EncryptionError::CmkError("CMK path missing host".into()))?
);
let segments: Vec<&str> = url.path_segments().map(|s| s.collect()).unwrap_or_default();
if segments.len() < 2 || segments[0] != "keys" {
return Err(EncryptionError::CmkError(format!(
"Invalid CMK path format: expected /keys/<name>[/<version>], got '{}'",
url.path()
)));
}
let key_name = segments[1].to_string();
let key_version = if segments.len() >= 3 && !segments[2].is_empty() {
Some(segments[2].to_string())
} else {
None
};
Ok((vault_url, key_name, key_version))
}
fn create_client(&self, vault_url: &str) -> Result<KeyClient, EncryptionError> {
KeyClient::new(vault_url, self.credential.clone(), None).map_err(|e| {
EncryptionError::CmkError(format!("Failed to create Key Vault client: {e}"))
})
}
}
impl std::fmt::Debug for AzureKeyVaultProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureKeyVaultProvider")
.field("provider_name", &PROVIDER_NAME)
.finish_non_exhaustive()
}
}
#[async_trait::async_trait]
impl KeyStoreProvider for AzureKeyVaultProvider {
fn provider_name(&self) -> &str {
PROVIDER_NAME
}
#[instrument(skip(self, encrypted_cek), fields(cmk_path = %cmk_path, algorithm = %algorithm))]
async fn decrypt_cek(
&self,
cmk_path: &str,
algorithm: &str,
encrypted_cek: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
debug!("Decrypting CEK using Azure Key Vault");
let (vault_url, key_name, key_version) = Self::parse_cmk_path(cmk_path)?;
let client = self.create_client(&vault_url)?;
let kv_algorithm = map_algorithm(algorithm)?;
let ciphertext = parse_sql_server_encrypted_cek(encrypted_cek)?;
let parameters = KeyOperationParameters {
algorithm: Some(kv_algorithm),
value: Some(ciphertext.to_vec()),
..Default::default()
};
let version = key_version.ok_or_else(|| {
EncryptionError::CmkError(
"CMK path must include key version (e.g., /keys/<name>/<version>)".into(),
)
})?;
let request_content: RequestContent<KeyOperationParameters> =
parameters.try_into().map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("Failed to create request: {e}"))
})?;
let result = client
.unwrap_key(&key_name, &version, request_content, None)
.await
.map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("Key Vault unwrap failed: {e}"))
})?
.into_model()
.map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("Failed to parse response: {e}"))
})?;
let decrypted = result.result.ok_or_else(|| {
EncryptionError::CekDecryptionFailed("Key Vault unwrap returned no result".into())
})?;
debug!("Successfully decrypted CEK using Azure Key Vault");
Ok(decrypted)
}
#[instrument(skip(self, data), fields(cmk_path = %cmk_path))]
async fn sign_data(&self, cmk_path: &str, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
debug!("Signing data using Azure Key Vault");
let (vault_url, key_name, key_version) = Self::parse_cmk_path(cmk_path)?;
let client = self.create_client(&vault_url)?;
use azure_security_keyvault_keys::models::{SignParameters, SignatureAlgorithm};
let parameters = SignParameters {
algorithm: Some(SignatureAlgorithm::Rs256),
value: Some(data.to_vec()),
};
let version = key_version.ok_or_else(|| {
EncryptionError::CmkError("CMK path must include key version for sign operation".into())
})?;
let request_content: RequestContent<SignParameters> = parameters
.try_into()
.map_err(|e| EncryptionError::CmkError(format!("Failed to create request: {e}")))?;
let result = client
.sign(&key_name, &version, request_content, None)
.await
.map_err(|e| EncryptionError::CmkError(format!("Key Vault sign failed: {e}")))?
.into_model()
.map_err(|e| EncryptionError::CmkError(format!("Failed to parse response: {e}")))?;
let signature = result
.result
.ok_or_else(|| EncryptionError::CmkError("Key Vault sign returned no result".into()))?;
debug!("Successfully signed data using Azure Key Vault");
Ok(signature)
}
#[instrument(skip(self, data, signature), fields(cmk_path = %cmk_path))]
async fn verify_signature(
&self,
cmk_path: &str,
data: &[u8],
signature: &[u8],
) -> Result<bool, EncryptionError> {
debug!("Verifying signature using Azure Key Vault");
let (vault_url, key_name, key_version) = Self::parse_cmk_path(cmk_path)?;
let client = self.create_client(&vault_url)?;
use azure_security_keyvault_keys::models::{SignatureAlgorithm, VerifyParameters};
let parameters = VerifyParameters {
algorithm: Some(SignatureAlgorithm::Rs256),
digest: Some(data.to_vec()),
signature: Some(signature.to_vec()),
};
let version = key_version.ok_or_else(|| {
EncryptionError::CmkError(
"CMK path must include key version for verify operation".into(),
)
})?;
let request_content: RequestContent<VerifyParameters> = parameters
.try_into()
.map_err(|e| EncryptionError::CmkError(format!("Failed to create request: {e}")))?;
let result = client
.verify(&key_name, &version, request_content, None)
.await
.map_err(|e| EncryptionError::CmkError(format!("Key Vault verify failed: {e}")))?
.into_model()
.map_err(|e| EncryptionError::CmkError(format!("Failed to parse response: {e}")))?;
let is_valid = result.value.unwrap_or(false);
debug!("Signature verification result: {}", is_valid);
Ok(is_valid)
}
}
fn map_algorithm(algorithm: &str) -> Result<EncryptionAlgorithm, EncryptionError> {
match algorithm.to_uppercase().as_str() {
"RSA_OAEP" | "RSA-OAEP" => Ok(EncryptionAlgorithm::RsaOaep),
"RSA_OAEP_256" | "RSA-OAEP-256" => Ok(EncryptionAlgorithm::RsaOaep256),
"RSA1_5" | "RSA-1_5" => Ok(EncryptionAlgorithm::Rsa1_5),
_ => Err(EncryptionError::ConfigurationError(format!(
"Unsupported key encryption algorithm: {algorithm}. Expected RSA_OAEP, RSA_OAEP_256, or RSA1_5"
))),
}
}
fn parse_sql_server_encrypted_cek(data: &[u8]) -> Result<&[u8], EncryptionError> {
if data.len() < 5 {
return Err(EncryptionError::CekDecryptionFailed(
"Encrypted CEK too short".into(),
));
}
if data[0] != 0x01 {
return Err(EncryptionError::CekDecryptionFailed(format!(
"Invalid CEK version: expected 0x01, got {:#04x}",
data[0]
)));
}
let key_path_len = u16::from_le_bytes([data[1], data[2]]) as usize;
let ciphertext_len_offset = 3 + key_path_len;
if data.len() < ciphertext_len_offset + 2 {
return Err(EncryptionError::CekDecryptionFailed(
"Encrypted CEK truncated: missing ciphertext length".into(),
));
}
let ciphertext_len =
u16::from_le_bytes([data[ciphertext_len_offset], data[ciphertext_len_offset + 1]]) as usize;
let ciphertext_offset = ciphertext_len_offset + 2;
if data.len() < ciphertext_offset + ciphertext_len {
return Err(EncryptionError::CekDecryptionFailed(format!(
"Encrypted CEK truncated: expected {} bytes of ciphertext, got {}",
ciphertext_len,
data.len() - ciphertext_offset
)));
}
Ok(&data[ciphertext_offset..ciphertext_offset + ciphertext_len])
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_parse_cmk_path() {
let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
"https://myvault.vault.azure.net/keys/mykey/abc123",
)
.expect("valid CMK path with version should parse");
assert_eq!(vault, "https://myvault.vault.azure.net");
assert_eq!(name, "mykey");
assert_eq!(version, Some("abc123".to_string()));
let (vault, name, version) =
AzureKeyVaultProvider::parse_cmk_path("https://myvault.vault.azure.net/keys/mykey")
.expect("valid CMK path without version should parse");
assert_eq!(vault, "https://myvault.vault.azure.net");
assert_eq!(name, "mykey");
assert_eq!(version, None);
let (vault, name, version) =
AzureKeyVaultProvider::parse_cmk_path("https://myvault.vault.azure.net/keys/mykey/")
.expect("valid CMK path with trailing slash should parse");
assert_eq!(vault, "https://myvault.vault.azure.net");
assert_eq!(name, "mykey");
assert_eq!(version, None);
}
#[test]
fn test_parse_cmk_path_invalid() {
assert!(AzureKeyVaultProvider::parse_cmk_path("not-a-url").is_err());
assert!(
AzureKeyVaultProvider::parse_cmk_path("https://vault.azure.net/secrets/mysecret")
.is_err()
);
assert!(AzureKeyVaultProvider::parse_cmk_path("https://vault.azure.net/keys").is_err());
}
#[test]
fn test_map_algorithm() {
assert!(matches!(
map_algorithm("RSA_OAEP").expect("RSA_OAEP should be a valid algorithm"),
EncryptionAlgorithm::RsaOaep
));
assert!(matches!(
map_algorithm("RSA-OAEP").expect("RSA-OAEP should be a valid algorithm"),
EncryptionAlgorithm::RsaOaep
));
assert!(matches!(
map_algorithm("RSA_OAEP_256").expect("RSA_OAEP_256 should be a valid algorithm"),
EncryptionAlgorithm::RsaOaep256
));
assert!(matches!(
map_algorithm("rsa_oaep").expect("lowercase rsa_oaep should be valid"),
EncryptionAlgorithm::RsaOaep
));
assert!(map_algorithm("UNKNOWN").is_err());
}
#[test]
fn test_parse_sql_server_encrypted_cek() {
let key_path = "test";
let key_path_utf16: Vec<u8> = key_path
.encode_utf16()
.flat_map(|c| c.to_le_bytes())
.collect();
let ciphertext = vec![0xAB, 0xCD, 0xEF];
let mut data = Vec::new();
data.push(0x01); data.extend_from_slice(&(key_path_utf16.len() as u16).to_le_bytes()); data.extend_from_slice(&key_path_utf16); data.extend_from_slice(&(ciphertext.len() as u16).to_le_bytes()); data.extend_from_slice(&ciphertext);
let parsed =
parse_sql_server_encrypted_cek(&data).expect("valid encrypted CEK should parse");
assert_eq!(parsed, &ciphertext[..]);
}
#[test]
fn test_parse_sql_server_encrypted_cek_invalid() {
assert!(parse_sql_server_encrypted_cek(&[0x01, 0x00]).is_err());
assert!(parse_sql_server_encrypted_cek(&[0x02, 0x00, 0x00, 0x00, 0x00]).is_err());
}
}