Skip to main content

mx_keyvault/
lib.rs

1//! Azure `KeyVault` client for secure secret resolution.
2//!
3//! This crate provides a reusable client for fetching secrets from Azure `KeyVault`.
4//! It supports both Managed Identity (for Azure-hosted services) and Client Secret
5//! credentials (for local development).
6
7#![cfg_attr(not(feature = "keyvault"), allow(dead_code))]
8
9#[cfg(feature = "keyvault")]
10pub use client::KeyVaultClient;
11#[cfg(feature = "keyvault")]
12pub use client::{CredentialConfig, CredentialType};
13
14#[cfg(feature = "keyvault")]
15mod client {
16    use anyhow::{Context, Result};
17    use azure_core::credentials::{Secret as CoreSecret, TokenCredential};
18    use azure_identity::{ClientSecretCredential, DeveloperToolsCredential};
19    use azure_security_keyvault_secrets::{SecretClient, models::Secret as KeyVaultSecret};
20    use std::env;
21    use std::sync::Arc;
22    use tracing::{debug, warn};
23
24    /// Represents the type of credential being used.
25    #[derive(Debug, Clone, PartialEq, Eq)]
26    pub enum CredentialType {
27        /// Client secret credentials from environment variables.
28        ClientSecret,
29        /// Developer tools credential (Azure CLI, VS Code, etc.).
30        DeveloperTools,
31    }
32
33    /// Configuration for credential resolution.
34    #[derive(Debug, Clone)]
35    pub struct CredentialConfig {
36        pub tenant_id: Option<String>,
37        pub client_id: Option<String>,
38        pub client_secret: Option<String>,
39        pub disable_managed_identity: bool,
40    }
41
42    impl CredentialConfig {
43        /// Creates a new credential config from environment variables.
44        pub fn from_env() -> Self {
45            let tenant_id = env::var("AZURE_TENANT_ID").ok();
46            let client_id = env::var("AZURE_CLIENT_ID").ok();
47            let client_secret = env::var("AZURE_CLIENT_SECRET").ok();
48            let disable_managed_identity =
49                env::var("AZURE_IDENTITY_DISABLE_MANAGED_IDENTITY_CREDENTIAL")
50                    .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
51                    .unwrap_or(false);
52
53            Self {
54                tenant_id,
55                client_id,
56                client_secret,
57                disable_managed_identity,
58            }
59        }
60
61        /// Determines which credential type should be used based on the configuration.
62        pub fn resolve_credential_type(&self) -> Result<CredentialType> {
63            if self.tenant_id.is_some() && self.client_id.is_some() && self.client_secret.is_some()
64            {
65                Ok(CredentialType::ClientSecret)
66            } else if self.disable_managed_identity {
67                Err(anyhow::anyhow!(
68                    "KeyVault enabled but no env credentials provided and managed identity disabled"
69                ))
70            } else {
71                Ok(CredentialType::DeveloperTools)
72            }
73        }
74
75        /// Checks if all client secret credentials are present.
76        pub fn has_client_secret_credentials(&self) -> bool {
77            self.tenant_id.is_some() && self.client_id.is_some() && self.client_secret.is_some()
78        }
79
80        /// Checks if any partial client secret credentials are present.
81        pub fn has_partial_credentials(&self) -> bool {
82            let count = [
83                self.tenant_id.is_some(),
84                self.client_id.is_some(),
85                self.client_secret.is_some(),
86            ]
87            .iter()
88            .filter(|&&v| v)
89            .count();
90            count > 0 && count < 3
91        }
92    }
93
94    /// Client for fetching secrets from Azure `KeyVault`.
95    pub struct KeyVaultClient {
96        client: SecretClient,
97    }
98
99    impl std::fmt::Debug for KeyVaultClient {
100        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101            f.debug_struct("KeyVaultClient")
102                .field("client", &"<SecretClient>")
103                .finish()
104        }
105    }
106
107    impl KeyVaultClient {
108        /// Creates a new `KeyVault` client for the specified vault URL.
109        ///
110        /// Credentials are determined in the following order:
111        /// 1. Environment variables: `AZURE_TENANT_ID`, `AZURE_CLIENT_ID`, `AZURE_CLIENT_SECRET`
112        /// 2. Developer Tools credential (Azure CLI, VS Code, etc.) unless disabled
113        ///    via `AZURE_IDENTITY_DISABLE_MANAGED_IDENTITY_CREDENTIAL=true`
114        pub fn new(vault_url: &str) -> Result<Self> {
115            Self::with_config(vault_url, CredentialConfig::from_env())
116        }
117
118        /// Creates a new `KeyVault` client with explicit credential configuration.
119        pub fn with_config(vault_url: &str, config: CredentialConfig) -> Result<Self> {
120            if vault_url.trim().is_empty() {
121                return Err(anyhow::anyhow!("KeyVault URL cannot be empty"));
122            }
123
124            let credential = Self::create_credential(&config)?;
125            let client = SecretClient::new(vault_url, credential, None)?;
126
127            Ok(Self { client })
128        }
129
130        /// Creates Azure credential based on configuration.
131        fn create_credential(config: &CredentialConfig) -> Result<Arc<dyn TokenCredential>> {
132            match config.resolve_credential_type()? {
133                CredentialType::ClientSecret => {
134                    let tenant_id = config.tenant_id.as_ref().unwrap();
135                    let client_id = config.client_id.clone().unwrap();
136                    let client_secret = config.client_secret.clone().unwrap();
137
138                    debug!(target: "keyvault", "using ClientSecretCredential from env");
139                    Ok(ClientSecretCredential::new(
140                        tenant_id,
141                        client_id,
142                        CoreSecret::new(client_secret),
143                        None,
144                    )?)
145                }
146                CredentialType::DeveloperTools => {
147                    debug!(
148                        target: "keyvault",
149                        "using DeveloperToolsCredential (env creds not present)"
150                    );
151                    Ok(DeveloperToolsCredential::new(None)?)
152                }
153            }
154        }
155
156        /// Fetches a secret from `KeyVault` by name.
157        ///
158        /// Returns `Ok(None)` if the secret name is empty.
159        /// Returns `Err` if the secret fetch fails.
160        pub async fn fetch_secret(&self, name: &str) -> Result<Option<String>> {
161            if name.trim().is_empty() {
162                return Ok(None);
163            }
164
165            let resp = self
166                .client
167                .get_secret(name, None)
168                .await
169                .with_context(|| format!("failed to fetch secret '{name}' from KeyVault"))?;
170
171            let body = resp.into_body();
172            let secret: KeyVaultSecret = body
173                .json()
174                .with_context(|| format!("invalid secret payload for '{name}'"))?;
175
176            if let Some(value) = secret.value {
177                debug!(target: "keyvault", secret_name = name, "fetched secret from KeyVault");
178                Ok(Some(value))
179            } else {
180                warn!(target: "keyvault", secret_name = name, "secret exists but has no value");
181                Ok(None)
182            }
183        }
184
185        /// Fetches multiple secrets from `KeyVault` in parallel.
186        ///
187        /// Returns a vector of `Option<String>` in the same order as the input names.
188        /// Empty names will result in `None` values.
189        pub async fn fetch_secrets(&self, names: &[&str]) -> Result<Vec<Option<String>>> {
190            use futures::future::try_join_all;
191
192            let futures = names.iter().map(|name| self.fetch_secret(name));
193            try_join_all(futures).await
194        }
195    }
196}
197
198#[cfg(all(test, feature = "keyvault"))]
199mod tests;
200
201#[cfg(not(feature = "keyvault"))]
202pub struct KeyVaultClient;
203
204#[cfg(not(feature = "keyvault"))]
205impl KeyVaultClient {
206    pub fn new(_vault_url: &str) -> anyhow::Result<Self> {
207        Err(anyhow::anyhow!(
208            "KeyVault feature not enabled. Rebuild with --features keyvault"
209        ))
210    }
211}