1#![cfg_attr(not(feature = "keyvault"), allow(dead_code))]
8
9#[cfg(feature = "keyvault")]
10pub use client::KeyVaultClient;
11#[cfg(feature = "keyvault")]
12pub use client::{CredentialConfig, CredentialType};
13
14#[cfg(feature = "keyvault")]
15mod client {
16 use anyhow::{Context, Result};
17 use azure_core::credentials::{Secret as CoreSecret, TokenCredential};
18 use azure_identity::{ClientSecretCredential, DeveloperToolsCredential};
19 use azure_security_keyvault_secrets::{SecretClient, models::Secret as KeyVaultSecret};
20 use std::env;
21 use std::sync::Arc;
22 use tracing::{debug, warn};
23
24 #[derive(Debug, Clone, PartialEq, Eq)]
26 pub enum CredentialType {
27 ClientSecret,
29 DeveloperTools,
31 }
32
33 #[derive(Debug, Clone)]
35 pub struct CredentialConfig {
36 pub tenant_id: Option<String>,
37 pub client_id: Option<String>,
38 pub client_secret: Option<String>,
39 pub disable_managed_identity: bool,
40 }
41
42 impl CredentialConfig {
43 pub fn from_env() -> Self {
45 let tenant_id = env::var("AZURE_TENANT_ID").ok();
46 let client_id = env::var("AZURE_CLIENT_ID").ok();
47 let client_secret = env::var("AZURE_CLIENT_SECRET").ok();
48 let disable_managed_identity =
49 env::var("AZURE_IDENTITY_DISABLE_MANAGED_IDENTITY_CREDENTIAL")
50 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
51 .unwrap_or(false);
52
53 Self {
54 tenant_id,
55 client_id,
56 client_secret,
57 disable_managed_identity,
58 }
59 }
60
61 pub fn resolve_credential_type(&self) -> Result<CredentialType> {
63 if self.tenant_id.is_some() && self.client_id.is_some() && self.client_secret.is_some()
64 {
65 Ok(CredentialType::ClientSecret)
66 } else if self.disable_managed_identity {
67 Err(anyhow::anyhow!(
68 "KeyVault enabled but no env credentials provided and managed identity disabled"
69 ))
70 } else {
71 Ok(CredentialType::DeveloperTools)
72 }
73 }
74
75 pub fn has_client_secret_credentials(&self) -> bool {
77 self.tenant_id.is_some() && self.client_id.is_some() && self.client_secret.is_some()
78 }
79
80 pub fn has_partial_credentials(&self) -> bool {
82 let count = [
83 self.tenant_id.is_some(),
84 self.client_id.is_some(),
85 self.client_secret.is_some(),
86 ]
87 .iter()
88 .filter(|&&v| v)
89 .count();
90 count > 0 && count < 3
91 }
92 }
93
94 pub struct KeyVaultClient {
96 client: SecretClient,
97 }
98
99 impl std::fmt::Debug for KeyVaultClient {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("KeyVaultClient")
102 .field("client", &"<SecretClient>")
103 .finish()
104 }
105 }
106
107 impl KeyVaultClient {
108 pub fn new(vault_url: &str) -> Result<Self> {
115 Self::with_config(vault_url, CredentialConfig::from_env())
116 }
117
118 pub fn with_config(vault_url: &str, config: CredentialConfig) -> Result<Self> {
120 if vault_url.trim().is_empty() {
121 return Err(anyhow::anyhow!("KeyVault URL cannot be empty"));
122 }
123
124 let credential = Self::create_credential(&config)?;
125 let client = SecretClient::new(vault_url, credential, None)?;
126
127 Ok(Self { client })
128 }
129
130 fn create_credential(config: &CredentialConfig) -> Result<Arc<dyn TokenCredential>> {
132 match config.resolve_credential_type()? {
133 CredentialType::ClientSecret => {
134 let tenant_id = config.tenant_id.as_ref().unwrap();
135 let client_id = config.client_id.clone().unwrap();
136 let client_secret = config.client_secret.clone().unwrap();
137
138 debug!(target: "keyvault", "using ClientSecretCredential from env");
139 Ok(ClientSecretCredential::new(
140 tenant_id,
141 client_id,
142 CoreSecret::new(client_secret),
143 None,
144 )?)
145 }
146 CredentialType::DeveloperTools => {
147 debug!(
148 target: "keyvault",
149 "using DeveloperToolsCredential (env creds not present)"
150 );
151 Ok(DeveloperToolsCredential::new(None)?)
152 }
153 }
154 }
155
156 pub async fn fetch_secret(&self, name: &str) -> Result<Option<String>> {
161 if name.trim().is_empty() {
162 return Ok(None);
163 }
164
165 let resp = self
166 .client
167 .get_secret(name, None)
168 .await
169 .with_context(|| format!("failed to fetch secret '{name}' from KeyVault"))?;
170
171 let body = resp.into_body();
172 let secret: KeyVaultSecret = body
173 .json()
174 .with_context(|| format!("invalid secret payload for '{name}'"))?;
175
176 if let Some(value) = secret.value {
177 debug!(target: "keyvault", secret_name = name, "fetched secret from KeyVault");
178 Ok(Some(value))
179 } else {
180 warn!(target: "keyvault", secret_name = name, "secret exists but has no value");
181 Ok(None)
182 }
183 }
184
185 pub async fn fetch_secrets(&self, names: &[&str]) -> Result<Vec<Option<String>>> {
190 use futures::future::try_join_all;
191
192 let futures = names.iter().map(|name| self.fetch_secret(name));
193 try_join_all(futures).await
194 }
195 }
196}
197
198#[cfg(all(test, feature = "keyvault"))]
199mod tests;
200
201#[cfg(not(feature = "keyvault"))]
202pub struct KeyVaultClient;
203
204#[cfg(not(feature = "keyvault"))]
205impl KeyVaultClient {
206 pub fn new(_vault_url: &str) -> anyhow::Result<Self> {
207 Err(anyhow::anyhow!(
208 "KeyVault feature not enabled. Rebuild with --features keyvault"
209 ))
210 }
211}