mssql_auth/
azure_identity_auth.rs1use std::sync::Arc;
43use std::time::Duration;
44
45use azure_core::credentials::TokenCredential;
46use azure_identity::{
47 ClientSecretCredential, ManagedIdentityCredential, ManagedIdentityCredentialOptions,
48 UserAssignedId,
49};
50
51use crate::AzureAdAuth;
52use crate::error::AuthError;
53use crate::provider::{AuthData, AuthMethod};
54
55const AZURE_SQL_SCOPE: &str = "https://database.windows.net/.default";
57
58#[derive(Clone)]
63pub struct ManagedIdentityAuth {
64 credential: Arc<ManagedIdentityCredential>,
65}
66
67impl ManagedIdentityAuth {
68 pub fn system_assigned() -> Result<Self, AuthError> {
77 let credential = ManagedIdentityCredential::new(None)
78 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
79 Ok(Self { credential })
80 }
81
82 pub fn user_assigned_client_id(client_id: impl Into<String>) -> Result<Self, AuthError> {
94 let options = ManagedIdentityCredentialOptions {
95 user_assigned_id: Some(UserAssignedId::ClientId(client_id.into())),
96 ..Default::default()
97 };
98 let credential = ManagedIdentityCredential::new(Some(options))
99 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
100 Ok(Self { credential })
101 }
102
103 pub fn user_assigned_resource_id(resource_id: impl Into<String>) -> Result<Self, AuthError> {
113 let options = ManagedIdentityCredentialOptions {
114 user_assigned_id: Some(UserAssignedId::ResourceId(resource_id.into())),
115 ..Default::default()
116 };
117 let credential = ManagedIdentityCredential::new(Some(options))
118 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
119 Ok(Self { credential })
120 }
121
122 pub fn user_assigned_object_id(object_id: impl Into<String>) -> Result<Self, AuthError> {
132 let options = ManagedIdentityCredentialOptions {
133 user_assigned_id: Some(UserAssignedId::ObjectId(object_id.into())),
134 ..Default::default()
135 };
136 let credential = ManagedIdentityCredential::new(Some(options))
137 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
138 Ok(Self { credential })
139 }
140
141 pub async fn get_token(&self) -> Result<String, AuthError> {
147 let token = self
148 .credential
149 .get_token(&[AZURE_SQL_SCOPE], None)
150 .await
151 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
152 Ok(token.token.secret().to_string())
153 }
154
155 pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
161 let token = self
162 .credential
163 .get_token(&[AZURE_SQL_SCOPE], None)
164 .await
165 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
166
167 let now = time::OffsetDateTime::now_utc();
169 let expires_in = if token.expires_on > now {
170 let diff = token.expires_on - now;
171 Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
172 } else {
173 None
174 };
175
176 Ok((token.token.secret().to_string(), expires_in))
177 }
178
179 pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
185 let (token, expires_in) = self.get_token_with_expiry().await?;
186 match expires_in {
187 Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
188 None => Ok(AzureAdAuth::with_token(token)),
189 }
190 }
191}
192
193impl std::fmt::Debug for ManagedIdentityAuth {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("ManagedIdentityAuth")
196 .finish_non_exhaustive()
197 }
198}
199
200impl crate::provider::AsyncAuthProvider for ManagedIdentityAuth {
201 fn method(&self) -> AuthMethod {
202 AuthMethod::AzureAd
203 }
204
205 async fn authenticate_async(&self) -> Result<AuthData, AuthError> {
206 let token = self.get_token().await?;
207 Ok(AuthData::FedAuth { token, nonce: None })
208 }
209
210 fn needs_refresh(&self) -> bool {
211 false
213 }
214}
215
216pub struct ServicePrincipalAuth {
221 credential: Arc<ClientSecretCredential>,
222}
223
224impl ServicePrincipalAuth {
225 pub fn new(
237 tenant_id: impl AsRef<str>,
238 client_id: impl Into<String>,
239 client_secret: impl Into<String>,
240 ) -> Result<Self, AuthError> {
241 use azure_core::credentials::Secret;
242
243 let secret = Secret::new(client_secret.into());
244 let credential =
245 ClientSecretCredential::new(tenant_id.as_ref(), client_id.into(), secret, None)
246 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
247 Ok(Self { credential })
248 }
249
250 pub async fn get_token(&self) -> Result<String, AuthError> {
256 let token = self
257 .credential
258 .get_token(&[AZURE_SQL_SCOPE], None)
259 .await
260 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
261 Ok(token.token.secret().to_string())
262 }
263
264 pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
270 let token = self
271 .credential
272 .get_token(&[AZURE_SQL_SCOPE], None)
273 .await
274 .map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
275
276 let now = time::OffsetDateTime::now_utc();
278 let expires_in = if token.expires_on > now {
279 let diff = token.expires_on - now;
280 Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
281 } else {
282 None
283 };
284
285 Ok((token.token.secret().to_string(), expires_in))
286 }
287
288 pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
294 let (token, expires_in) = self.get_token_with_expiry().await?;
295 match expires_in {
296 Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
297 None => Ok(AzureAdAuth::with_token(token)),
298 }
299 }
300}
301
302impl Clone for ServicePrincipalAuth {
303 fn clone(&self) -> Self {
304 Self {
305 credential: Arc::clone(&self.credential),
306 }
307 }
308}
309
310impl std::fmt::Debug for ServicePrincipalAuth {
311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 f.debug_struct("ServicePrincipalAuth")
313 .field("credential", &"[REDACTED]")
314 .finish()
315 }
316}
317
318impl crate::provider::AsyncAuthProvider for ServicePrincipalAuth {
319 fn method(&self) -> AuthMethod {
320 AuthMethod::AzureAd
321 }
322
323 async fn authenticate_async(&self) -> Result<AuthData, AuthError> {
324 let token = self.get_token().await?;
325 Ok(AuthData::FedAuth { token, nonce: None })
326 }
327
328 fn needs_refresh(&self) -> bool {
329 false
331 }
332}
333
334#[cfg(test)]
335#[allow(clippy::unwrap_used, clippy::expect_used)]
336mod tests {
337 use super::*;
338
339 #[tokio::test]
344 #[ignore = "Requires Azure Managed Identity environment"]
345 async fn test_managed_identity_system_assigned() {
346 let auth = ManagedIdentityAuth::system_assigned().expect("Failed to create credential");
347 let token = auth.get_token().await.expect("Failed to get token");
348 assert!(!token.is_empty());
349 }
350
351 #[tokio::test]
352 #[ignore = "Requires Azure Service Principal credentials"]
353 async fn test_service_principal() {
354 let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
355 let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
356 let client_secret =
357 std::env::var("AZURE_CLIENT_SECRET").expect("AZURE_CLIENT_SECRET not set");
358
359 let auth = ServicePrincipalAuth::new(tenant_id, client_id, client_secret)
360 .expect("Failed to create credential");
361 let token = auth.get_token().await.expect("Failed to get token");
362 assert!(!token.is_empty());
363 }
364
365 #[test]
366 fn test_debug_redacts_credentials() {
367 if let Ok(auth) = ManagedIdentityAuth::system_assigned() {
368 let debug = format!("{auth:?}");
369 assert!(debug.contains("ManagedIdentityAuth"));
370 }
371
372 let secret = "client-secret-must-not-appear-in-debug";
374 let auth = ServicePrincipalAuth::new(
375 "00000000-0000-0000-0000-000000000000",
376 "11111111-1111-1111-1111-111111111111",
377 secret,
378 )
379 .expect("constructing a service principal credential should not fail offline");
380 let debug = format!("{auth:?}");
381 assert!(
382 !debug.contains(secret),
383 "Debug output must not expose the client secret"
384 );
385 assert!(debug.contains("[REDACTED]"));
386 }
387}