Skip to main content

mssql_auth/
azure_identity_auth.rs

1//! Azure Identity authentication providers.
2//!
3//! This module provides Azure authentication using the `azure_identity` crate
4//! for token acquisition. It supports:
5//!
6//! - **Managed Identity**: For Azure VMs, App Service, Container Instances, and AKS
7//! - **Service Principal**: For application-based authentication with client credentials
8//!
9//! ## Example: Managed Identity (System-Assigned)
10//!
11//! ```rust,ignore
12//! use mssql_auth::ManagedIdentityAuth;
13//!
14//! // System-assigned managed identity (default)
15//! let auth = ManagedIdentityAuth::system_assigned();
16//! let token = auth.get_token().await?;
17//! ```
18//!
19//! ## Example: Managed Identity (User-Assigned)
20//!
21//! ```rust,ignore
22//! use mssql_auth::ManagedIdentityAuth;
23//!
24//! // User-assigned managed identity by client ID
25//! let auth = ManagedIdentityAuth::user_assigned_client_id("your-client-id");
26//! let token = auth.get_token().await?;
27//! ```
28//!
29//! ## Example: Service Principal
30//!
31//! ```rust,ignore
32//! use mssql_auth::ServicePrincipalAuth;
33//!
34//! let auth = ServicePrincipalAuth::new(
35//!     "your-tenant-id",
36//!     "your-client-id",
37//!     "your-client-secret",
38//! );
39//! let token = auth.get_token().await?;
40//! ```
41
42use std::sync::Arc;
43use std::time::Duration;
44
45use azure_core::credentials::TokenCredential;
46use azure_identity::{
47    ClientSecretCredential, ManagedIdentityCredential, ManagedIdentityCredentialOptions,
48    UserAssignedId,
49};
50
51use crate::AzureAdAuth;
52use crate::error::AuthError;
53use crate::provider::{AuthData, AuthMethod};
54
55/// The Azure SQL Database scope for token requests.
56const AZURE_SQL_SCOPE: &str = "https://database.windows.net/.default";
57
58/// Managed Identity authentication provider.
59///
60/// Uses Azure Managed Identity to acquire access tokens for Azure SQL Database.
61/// This works on Azure VMs, App Service, Container Instances, and AKS.
62#[derive(Clone)]
63pub struct ManagedIdentityAuth {
64    credential: Arc<ManagedIdentityCredential>,
65}
66
67impl ManagedIdentityAuth {
68    /// Create authentication using system-assigned managed identity.
69    ///
70    /// This is the simplest form - uses the identity assigned to the Azure resource
71    /// (VM, App Service, etc.) that the code is running on.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the managed identity credential cannot be created.
76    pub fn system_assigned() -> Result<Self, AuthError> {
77        let credential = ManagedIdentityCredential::new(None)
78            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
79        Ok(Self { credential })
80    }
81
82    /// Create authentication using a user-assigned managed identity by client ID.
83    ///
84    /// Use this when you have multiple managed identities and need to specify which one to use.
85    ///
86    /// # Arguments
87    ///
88    /// * `client_id` - The client ID of the user-assigned managed identity
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if the managed identity credential cannot be created.
93    pub fn user_assigned_client_id(client_id: impl Into<String>) -> Result<Self, AuthError> {
94        let options = ManagedIdentityCredentialOptions {
95            user_assigned_id: Some(UserAssignedId::ClientId(client_id.into())),
96            ..Default::default()
97        };
98        let credential = ManagedIdentityCredential::new(Some(options))
99            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
100        Ok(Self { credential })
101    }
102
103    /// Create authentication using a user-assigned managed identity by resource ID.
104    ///
105    /// # Arguments
106    ///
107    /// * `resource_id` - The Azure resource ID of the user-assigned managed identity
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if the managed identity credential cannot be created.
112    pub fn user_assigned_resource_id(resource_id: impl Into<String>) -> Result<Self, AuthError> {
113        let options = ManagedIdentityCredentialOptions {
114            user_assigned_id: Some(UserAssignedId::ResourceId(resource_id.into())),
115            ..Default::default()
116        };
117        let credential = ManagedIdentityCredential::new(Some(options))
118            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
119        Ok(Self { credential })
120    }
121
122    /// Create authentication using a user-assigned managed identity by object ID.
123    ///
124    /// # Arguments
125    ///
126    /// * `object_id` - The object ID of the user-assigned managed identity
127    ///
128    /// # Errors
129    ///
130    /// Returns an error if the managed identity credential cannot be created.
131    pub fn user_assigned_object_id(object_id: impl Into<String>) -> Result<Self, AuthError> {
132        let options = ManagedIdentityCredentialOptions {
133            user_assigned_id: Some(UserAssignedId::ObjectId(object_id.into())),
134            ..Default::default()
135        };
136        let credential = ManagedIdentityCredential::new(Some(options))
137            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
138        Ok(Self { credential })
139    }
140
141    /// Get an access token for Azure SQL Database.
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if token acquisition fails.
146    pub async fn get_token(&self) -> Result<String, AuthError> {
147        let token = self
148            .credential
149            .get_token(&[AZURE_SQL_SCOPE], None)
150            .await
151            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
152        Ok(token.token.secret().to_string())
153    }
154
155    /// Get an access token with expiration information.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if token acquisition fails.
160    pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
161        let token = self
162            .credential
163            .get_token(&[AZURE_SQL_SCOPE], None)
164            .await
165            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
166
167        // Calculate time until expiration
168        let now = time::OffsetDateTime::now_utc();
169        let expires_in = if token.expires_on > now {
170            let diff = token.expires_on - now;
171            Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
172        } else {
173            None
174        };
175
176        Ok((token.token.secret().to_string(), expires_in))
177    }
178
179    /// Convert to an `AzureAdAuth` provider with an acquired token.
180    ///
181    /// # Errors
182    ///
183    /// Returns an error if token acquisition fails.
184    pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
185        let (token, expires_in) = self.get_token_with_expiry().await?;
186        match expires_in {
187            Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
188            None => Ok(AzureAdAuth::with_token(token)),
189        }
190    }
191}
192
193impl std::fmt::Debug for ManagedIdentityAuth {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("ManagedIdentityAuth")
196            .finish_non_exhaustive()
197    }
198}
199
200impl crate::provider::AsyncAuthProvider for ManagedIdentityAuth {
201    fn method(&self) -> AuthMethod {
202        AuthMethod::AzureAd
203    }
204
205    async fn authenticate_async(&self) -> Result<AuthData, AuthError> {
206        let token = self.get_token().await?;
207        Ok(AuthData::FedAuth { token, nonce: None })
208    }
209
210    fn needs_refresh(&self) -> bool {
211        // Managed identity tokens are acquired fresh each time
212        false
213    }
214}
215
216/// Service Principal authentication provider.
217///
218/// Uses Azure Service Principal (application credentials) to acquire access tokens.
219/// This is suitable for server-to-server authentication where no user is present.
220pub struct ServicePrincipalAuth {
221    credential: Arc<ClientSecretCredential>,
222}
223
224impl ServicePrincipalAuth {
225    /// Create a new Service Principal authenticator.
226    ///
227    /// # Arguments
228    ///
229    /// * `tenant_id` - The Azure AD tenant ID
230    /// * `client_id` - The application (client) ID
231    /// * `client_secret` - The client secret
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if the credential cannot be created.
236    pub fn new(
237        tenant_id: impl AsRef<str>,
238        client_id: impl Into<String>,
239        client_secret: impl Into<String>,
240    ) -> Result<Self, AuthError> {
241        use azure_core::credentials::Secret;
242
243        let secret = Secret::new(client_secret.into());
244        let credential =
245            ClientSecretCredential::new(tenant_id.as_ref(), client_id.into(), secret, None)
246                .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
247        Ok(Self { credential })
248    }
249
250    /// Get an access token for Azure SQL Database.
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if token acquisition fails.
255    pub async fn get_token(&self) -> Result<String, AuthError> {
256        let token = self
257            .credential
258            .get_token(&[AZURE_SQL_SCOPE], None)
259            .await
260            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
261        Ok(token.token.secret().to_string())
262    }
263
264    /// Get an access token with expiration information.
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if token acquisition fails.
269    pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
270        let token = self
271            .credential
272            .get_token(&[AZURE_SQL_SCOPE], None)
273            .await
274            .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
275
276        // Calculate time until expiration
277        let now = time::OffsetDateTime::now_utc();
278        let expires_in = if token.expires_on > now {
279            let diff = token.expires_on - now;
280            Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
281        } else {
282            None
283        };
284
285        Ok((token.token.secret().to_string(), expires_in))
286    }
287
288    /// Convert to an `AzureAdAuth` provider with an acquired token.
289    ///
290    /// # Errors
291    ///
292    /// Returns an error if token acquisition fails.
293    pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
294        let (token, expires_in) = self.get_token_with_expiry().await?;
295        match expires_in {
296            Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
297            None => Ok(AzureAdAuth::with_token(token)),
298        }
299    }
300}
301
302impl Clone for ServicePrincipalAuth {
303    fn clone(&self) -> Self {
304        Self {
305            credential: Arc::clone(&self.credential),
306        }
307    }
308}
309
310impl std::fmt::Debug for ServicePrincipalAuth {
311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        f.debug_struct("ServicePrincipalAuth")
313            .field("credential", &"[REDACTED]")
314            .finish()
315    }
316}
317
318impl crate::provider::AsyncAuthProvider for ServicePrincipalAuth {
319    fn method(&self) -> AuthMethod {
320        AuthMethod::AzureAd
321    }
322
323    async fn authenticate_async(&self) -> Result<AuthData, AuthError> {
324        let token = self.get_token().await?;
325        Ok(AuthData::FedAuth { token, nonce: None })
326    }
327
328    fn needs_refresh(&self) -> bool {
329        // Service principal tokens are acquired fresh each time
330        false
331    }
332}
333
334#[cfg(test)]
335#[allow(clippy::unwrap_used, clippy::expect_used)]
336mod tests {
337    use super::*;
338
339    // Note: These tests require Azure credentials to be configured in the environment.
340    // They are marked as ignored by default and can be run manually with:
341    // cargo test --features azure-identity -- --ignored
342
343    #[tokio::test]
344    #[ignore = "Requires Azure Managed Identity environment"]
345    async fn test_managed_identity_system_assigned() {
346        let auth = ManagedIdentityAuth::system_assigned().expect("Failed to create credential");
347        let token = auth.get_token().await.expect("Failed to get token");
348        assert!(!token.is_empty());
349    }
350
351    #[tokio::test]
352    #[ignore = "Requires Azure Service Principal credentials"]
353    async fn test_service_principal() {
354        let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
355        let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
356        let client_secret =
357            std::env::var("AZURE_CLIENT_SECRET").expect("AZURE_CLIENT_SECRET not set");
358
359        let auth = ServicePrincipalAuth::new(tenant_id, client_id, client_secret)
360            .expect("Failed to create credential");
361        let token = auth.get_token().await.expect("Failed to get token");
362        assert!(!token.is_empty());
363    }
364
365    #[test]
366    fn test_debug_redacts_credentials() {
367        if let Ok(auth) = ManagedIdentityAuth::system_assigned() {
368            let debug = format!("{auth:?}");
369            assert!(debug.contains("ManagedIdentityAuth"));
370        }
371
372        // The client secret must never appear in the Debug output.
373        let secret = "client-secret-must-not-appear-in-debug";
374        let auth = ServicePrincipalAuth::new(
375            "00000000-0000-0000-0000-000000000000",
376            "11111111-1111-1111-1111-111111111111",
377            secret,
378        )
379        .expect("constructing a service principal credential should not fail offline");
380        let debug = format!("{auth:?}");
381        assert!(
382            !debug.contains(secret),
383            "Debug output must not expose the client secret"
384        );
385        assert!(debug.contains("[REDACTED]"));
386    }
387}