1use std::sync::Arc;
58
59use azure_core::http::RequestContent;
60use azure_identity::DeveloperToolsCredential;
61use azure_security_keyvault_keys::KeyClient;
62use azure_security_keyvault_keys::models::{EncryptionAlgorithm, KeyOperationParameters};
63use tracing::{debug, instrument};
64use url::Url;
65
66use crate::encryption::{EncryptionError, KeyStoreProvider};
67
68const PROVIDER_NAME: &str = "AZURE_KEY_VAULT";
70
71const DEFAULT_TRUSTED_KEY_VAULT_SUFFIXES: &[&str] = &[
81 ".vault.azure.net", ".vaultcore.azure.net", ".managedhsm.azure.net", ".vault.azure.cn", ".managedhsm.azure.cn", ".vault.usgovcloudapi.net", ".managedhsm.usgovcloudapi.net", ".vault.microsoftazure.de", ];
90
91pub struct AzureKeyVaultProvider {
100 credential: Arc<DeveloperToolsCredential>,
102 trusted_host_suffixes: Vec<String>,
105}
106
107fn default_trusted_suffixes() -> Vec<String> {
108 DEFAULT_TRUSTED_KEY_VAULT_SUFFIXES
109 .iter()
110 .map(|s| (*s).to_string())
111 .collect()
112}
113
114impl AzureKeyVaultProvider {
115 pub fn new() -> Result<Self, EncryptionError> {
136 let credential = DeveloperToolsCredential::new(None).map_err(|e| {
137 EncryptionError::ConfigurationError(format!("Failed to create Azure credential: {e}"))
138 })?;
139 Ok(Self {
140 credential,
141 trusted_host_suffixes: default_trusted_suffixes(),
142 })
143 }
144
145 #[must_use]
158 pub fn with_credential(credential: Arc<DeveloperToolsCredential>) -> Self {
159 Self {
160 credential,
161 trusted_host_suffixes: default_trusted_suffixes(),
162 }
163 }
164
165 #[must_use]
176 pub fn with_trusted_endpoints<I, S>(mut self, suffixes: I) -> Self
177 where
178 I: IntoIterator<Item = S>,
179 S: Into<String>,
180 {
181 self.trusted_host_suffixes = suffixes.into_iter().map(Into::into).collect();
182 self
183 }
184
185 fn parse_cmk_path(
193 cmk_path: &str,
194 trusted_suffixes: &[String],
195 ) -> Result<(String, String, Option<String>), EncryptionError> {
196 let url = Url::parse(cmk_path).map_err(|e| {
197 EncryptionError::CmkError(format!("Invalid CMK path '{cmk_path}': {e}"))
198 })?;
199
200 if url.scheme() != "https" {
201 return Err(EncryptionError::CmkError(format!(
202 "CMK path must use https, got scheme '{}' in '{cmk_path}'",
203 url.scheme()
204 )));
205 }
206
207 let host = url
208 .host_str()
209 .ok_or_else(|| EncryptionError::CmkError("CMK path missing host".into()))?;
210 let host_lc = host.to_ascii_lowercase();
211 let trusted = trusted_suffixes
212 .iter()
213 .any(|suffix| host_lc.ends_with(&suffix.to_ascii_lowercase()));
214 if !trusted {
215 return Err(EncryptionError::CmkError(format!(
216 "CMK host '{host}' is not a trusted Key Vault endpoint. The CMK path is \
217 supplied by the server; allowing an arbitrary host would let a malicious \
218 server redirect key operations and exfiltrate access tokens. Trusted \
219 suffixes: {trusted_suffixes:?}. For custom deployments use \
220 AzureKeyVaultProvider::with_trusted_endpoints."
221 )));
222 }
223
224 let vault_url = format!("{}://{host}", url.scheme());
226
227 let segments: Vec<&str> = url.path_segments().map(|s| s.collect()).unwrap_or_default();
229
230 if segments.len() < 2 || segments[0] != "keys" {
231 return Err(EncryptionError::CmkError(format!(
232 "Invalid CMK path format: expected /keys/<name>[/<version>], got '{}'",
233 url.path()
234 )));
235 }
236
237 let key_name = segments[1].to_string();
238 let key_version = if segments.len() >= 3 && !segments[2].is_empty() {
239 Some(segments[2].to_string())
240 } else {
241 None
242 };
243
244 Ok((vault_url, key_name, key_version))
245 }
246
247 fn create_client(&self, vault_url: &str) -> Result<KeyClient, EncryptionError> {
249 KeyClient::new(vault_url, self.credential.clone(), None).map_err(|e| {
250 EncryptionError::CmkError(format!("Failed to create Key Vault client: {e}"))
251 })
252 }
253}
254
255impl std::fmt::Debug for AzureKeyVaultProvider {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.debug_struct("AzureKeyVaultProvider")
258 .field("provider_name", &PROVIDER_NAME)
259 .finish_non_exhaustive()
260 }
261}
262
263#[async_trait::async_trait]
264impl KeyStoreProvider for AzureKeyVaultProvider {
265 fn provider_name(&self) -> &str {
266 PROVIDER_NAME
267 }
268
269 #[instrument(skip(self, encrypted_cek), fields(cmk_path = %cmk_path, algorithm = %algorithm))]
270 async fn decrypt_cek(
271 &self,
272 cmk_path: &str,
273 algorithm: &str,
274 encrypted_cek: &[u8],
275 ) -> Result<Vec<u8>, EncryptionError> {
276 debug!("Decrypting CEK using Azure Key Vault");
277
278 let (vault_url, key_name, key_version) =
280 Self::parse_cmk_path(cmk_path, &self.trusted_host_suffixes)?;
281
282 let client = self.create_client(&vault_url)?;
284
285 let kv_algorithm = map_algorithm(algorithm)?;
287
288 let envelope = crate::cek_envelope::parse(encrypted_cek)?;
290
291 let digest = envelope.signed_digest();
294 let valid = self
295 .verify_signature(cmk_path, &digest, envelope.signature)
296 .await?;
297 if !valid {
298 return Err(EncryptionError::CekDecryptionFailed(
299 "CEK envelope signature verification failed".into(),
300 ));
301 }
302
303 let parameters = KeyOperationParameters {
305 algorithm: Some(kv_algorithm),
306 value: Some(envelope.ciphertext.to_vec()),
307 ..Default::default()
308 };
309
310 let version = key_version.ok_or_else(|| {
312 EncryptionError::CmkError(
313 "CMK path must include key version (e.g., /keys/<name>/<version>)".into(),
314 )
315 })?;
316
317 let request_content: RequestContent<KeyOperationParameters> =
319 parameters.try_into().map_err(|e| {
320 EncryptionError::CekDecryptionFailed(format!("Failed to create request: {e}"))
321 })?;
322
323 let result = client
325 .unwrap_key(&key_name, &version, request_content, None)
326 .await
327 .map_err(|e| {
328 EncryptionError::CekDecryptionFailed(format!("Key Vault unwrap failed: {e}"))
329 })?
330 .into_model()
331 .map_err(|e| {
332 EncryptionError::CekDecryptionFailed(format!("Failed to parse response: {e}"))
333 })?;
334
335 let decrypted = result.result.ok_or_else(|| {
337 EncryptionError::CekDecryptionFailed("Key Vault unwrap returned no result".into())
338 })?;
339
340 debug!("Successfully decrypted CEK using Azure Key Vault");
341 Ok(decrypted)
342 }
343
344 #[instrument(skip(self, data), fields(cmk_path = %cmk_path))]
345 async fn sign_data(&self, cmk_path: &str, data: &[u8]) -> Result<Vec<u8>, EncryptionError> {
346 debug!("Signing data using Azure Key Vault");
347
348 let (vault_url, key_name, key_version) =
350 Self::parse_cmk_path(cmk_path, &self.trusted_host_suffixes)?;
351
352 let client = self.create_client(&vault_url)?;
354
355 use azure_security_keyvault_keys::models::{
357 KeyClientSignOptions, SignParameters, SignatureAlgorithm,
358 };
359
360 let parameters = SignParameters {
361 algorithm: Some(SignatureAlgorithm::Rs256),
362 value: Some(data.to_vec()),
363 };
364
365 let version = key_version.ok_or_else(|| {
369 EncryptionError::CmkError("CMK path must include key version for sign operation".into())
370 })?;
371
372 let request_content: RequestContent<SignParameters> = parameters
373 .try_into()
374 .map_err(|e| EncryptionError::CmkError(format!("Failed to create request: {e}")))?;
375
376 let sign_options = KeyClientSignOptions {
377 key_version: Some(version),
378 ..Default::default()
379 };
380
381 let result = client
383 .sign(&key_name, request_content, Some(sign_options))
384 .await
385 .map_err(|e| EncryptionError::CmkError(format!("Key Vault sign failed: {e}")))?
386 .into_model()
387 .map_err(|e| EncryptionError::CmkError(format!("Failed to parse response: {e}")))?;
388
389 let signature = result
391 .result
392 .ok_or_else(|| EncryptionError::CmkError("Key Vault sign returned no result".into()))?;
393
394 debug!("Successfully signed data using Azure Key Vault");
395 Ok(signature)
396 }
397
398 #[instrument(skip(self, data, signature), fields(cmk_path = %cmk_path))]
399 async fn verify_signature(
400 &self,
401 cmk_path: &str,
402 data: &[u8],
403 signature: &[u8],
404 ) -> Result<bool, EncryptionError> {
405 debug!("Verifying signature using Azure Key Vault");
406
407 let (vault_url, key_name, key_version) =
409 Self::parse_cmk_path(cmk_path, &self.trusted_host_suffixes)?;
410
411 let client = self.create_client(&vault_url)?;
413
414 use azure_security_keyvault_keys::models::{SignatureAlgorithm, VerifyParameters};
416
417 let parameters = VerifyParameters {
418 algorithm: Some(SignatureAlgorithm::Rs256),
419 digest: Some(data.to_vec()),
420 signature: Some(signature.to_vec()),
421 };
422
423 let version = key_version.ok_or_else(|| {
425 EncryptionError::CmkError(
426 "CMK path must include key version for verify operation".into(),
427 )
428 })?;
429
430 let request_content: RequestContent<VerifyParameters> = parameters
431 .try_into()
432 .map_err(|e| EncryptionError::CmkError(format!("Failed to create request: {e}")))?;
433
434 let result = client
436 .verify(&key_name, &version, request_content, None)
437 .await
438 .map_err(|e| EncryptionError::CmkError(format!("Key Vault verify failed: {e}")))?
439 .into_model()
440 .map_err(|e| EncryptionError::CmkError(format!("Failed to parse response: {e}")))?;
441
442 let is_valid = result.value.unwrap_or(false);
445
446 debug!("Signature verification result: {}", is_valid);
447 Ok(is_valid)
448 }
449}
450
451fn map_algorithm(algorithm: &str) -> Result<EncryptionAlgorithm, EncryptionError> {
453 match algorithm.to_uppercase().as_str() {
454 "RSA_OAEP" | "RSA-OAEP" => Ok(EncryptionAlgorithm::RsaOaep),
455 "RSA_OAEP_256" | "RSA-OAEP-256" => Ok(EncryptionAlgorithm::RsaOaep256),
456 "RSA1_5" | "RSA-1_5" => Ok(EncryptionAlgorithm::Rsa1_5),
457 _ => Err(EncryptionError::ConfigurationError(format!(
458 "Unsupported key encryption algorithm: {algorithm}. Expected RSA_OAEP, RSA_OAEP_256, or RSA1_5"
459 ))),
460 }
461}
462
463#[cfg(test)]
464#[allow(clippy::unwrap_used, clippy::expect_used)]
465mod tests {
466 use super::*;
467
468 fn trusted() -> Vec<String> {
469 default_trusted_suffixes()
470 }
471
472 #[test]
473 fn test_parse_cmk_path() {
474 let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
476 "https://myvault.vault.azure.net/keys/mykey/abc123",
477 &trusted(),
478 )
479 .expect("valid CMK path with version should parse");
480 assert_eq!(vault, "https://myvault.vault.azure.net");
481 assert_eq!(name, "mykey");
482 assert_eq!(version, Some("abc123".to_string()));
483
484 let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
486 "https://myvault.vault.azure.net/keys/mykey",
487 &trusted(),
488 )
489 .expect("valid CMK path without version should parse");
490 assert_eq!(vault, "https://myvault.vault.azure.net");
491 assert_eq!(name, "mykey");
492 assert_eq!(version, None);
493
494 let (vault, name, version) = AzureKeyVaultProvider::parse_cmk_path(
496 "https://myvault.vault.azure.net/keys/mykey/",
497 &trusted(),
498 )
499 .expect("valid CMK path with trailing slash should parse");
500 assert_eq!(vault, "https://myvault.vault.azure.net");
501 assert_eq!(name, "mykey");
502 assert_eq!(version, None);
503
504 assert!(
506 AzureKeyVaultProvider::parse_cmk_path(
507 "https://myhsm.managedhsm.azure.net/keys/mykey",
508 &trusted(),
509 )
510 .is_ok()
511 );
512 assert!(
513 AzureKeyVaultProvider::parse_cmk_path(
514 "https://myvault.vault.usgovcloudapi.net/keys/mykey",
515 &trusted(),
516 )
517 .is_ok()
518 );
519 }
520
521 #[test]
522 fn test_parse_cmk_path_invalid() {
523 assert!(AzureKeyVaultProvider::parse_cmk_path("not-a-url", &trusted()).is_err());
525
526 assert!(
528 AzureKeyVaultProvider::parse_cmk_path(
529 "https://myvault.vault.azure.net/secrets/mysecret",
530 &trusted(),
531 )
532 .is_err()
533 );
534
535 assert!(
537 AzureKeyVaultProvider::parse_cmk_path(
538 "https://myvault.vault.azure.net/keys",
539 &trusted(),
540 )
541 .is_err()
542 );
543 }
544
545 #[test]
549 fn test_parse_cmk_path_rejects_untrusted_host() {
550 let err = AzureKeyVaultProvider::parse_cmk_path(
552 "https://attacker.example.com/keys/mykey",
553 &trusted(),
554 )
555 .expect_err("untrusted host must be rejected");
556 assert!(err.to_string().contains("not a trusted Key Vault endpoint"));
557
558 assert!(
560 AzureKeyVaultProvider::parse_cmk_path(
561 "https://vault.azure.net.attacker.com/keys/mykey",
562 &trusted(),
563 )
564 .is_err()
565 );
566
567 assert!(
569 AzureKeyVaultProvider::parse_cmk_path(
570 "http://myvault.vault.azure.net/keys/mykey",
571 &trusted(),
572 )
573 .is_err()
574 );
575 }
576
577 #[test]
580 fn test_with_trusted_endpoints_override() {
581 let custom = vec![".vault.contoso.example".to_string()];
582 assert!(
583 AzureKeyVaultProvider::parse_cmk_path(
584 "https://kv1.vault.contoso.example/keys/mykey",
585 &custom,
586 )
587 .is_ok()
588 );
589 assert!(
591 AzureKeyVaultProvider::parse_cmk_path(
592 "https://myvault.vault.azure.net/keys/mykey",
593 &custom,
594 )
595 .is_err()
596 );
597 }
598
599 #[test]
600 fn test_map_algorithm() {
601 assert!(matches!(
602 map_algorithm("RSA_OAEP").expect("RSA_OAEP should be a valid algorithm"),
603 EncryptionAlgorithm::RsaOaep
604 ));
605 assert!(matches!(
606 map_algorithm("RSA-OAEP").expect("RSA-OAEP should be a valid algorithm"),
607 EncryptionAlgorithm::RsaOaep
608 ));
609 assert!(matches!(
610 map_algorithm("RSA_OAEP_256").expect("RSA_OAEP_256 should be a valid algorithm"),
611 EncryptionAlgorithm::RsaOaep256
612 ));
613 assert!(matches!(
615 map_algorithm("rsa_oaep").expect("lowercase rsa_oaep should be valid"),
616 EncryptionAlgorithm::RsaOaep
617 ));
618 assert!(map_algorithm("UNKNOWN").is_err());
619 }
620
621 #[tokio::test]
635 #[ignore = "Requires a live Azure Key Vault + az session (see env vars)"]
636 async fn decrypt_cek_round_trips_through_live_key_vault() {
637 use sha2::Digest;
638
639 fn from_hex(s: &str) -> Vec<u8> {
640 (0..s.len())
641 .step_by(2)
642 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).expect("valid hex"))
643 .collect()
644 }
645
646 let (cmk_path, cek, ciphertext) = match (
647 std::env::var("AZURE_KEYVAULT_CMK_PATH").ok(),
648 std::env::var("AEKV_PLAIN_CEK_HEX").ok(),
649 std::env::var("AEKV_WRAPPED_CEK_HEX").ok(),
650 ) {
651 (Some(p), Some(plain), Some(wrapped)) => (p, from_hex(&plain), from_hex(&wrapped)),
652 _ => return, };
654
655 let provider = AzureKeyVaultProvider::new().expect("provider");
656
657 let signed_portion = crate::cek_envelope::build_signed_portion(&cmk_path, &ciphertext);
661 let digest: [u8; 32] = sha2::Sha256::digest(&signed_portion).into();
662 let signature = provider
663 .sign_data(&cmk_path, &digest)
664 .await
665 .expect("Key Vault RS256 sign");
666 let mut envelope = signed_portion;
667 envelope.extend_from_slice(&signature);
668
669 let decrypted = provider
670 .decrypt_cek(&cmk_path, "RSA_OAEP", &envelope)
671 .await
672 .expect("decrypt_cek must verify + unwrap via Key Vault");
673
674 assert_eq!(
675 decrypted, cek,
676 "Key-Vault-unwrapped CEK must equal the original plaintext CEK"
677 );
678 }
679}