use std::sync::Arc;
use std::time::Duration;
use azure_core::credentials::TokenCredential;
use azure_identity::{
ClientSecretCredential, ManagedIdentityCredential, ManagedIdentityCredentialOptions,
UserAssignedId,
};
use crate::AzureAdAuth;
use crate::error::AuthError;
use crate::provider::{AuthData, AuthMethod};
const AZURE_SQL_SCOPE: &str = "https://database.windows.net/.default";
#[derive(Clone)]
pub struct ManagedIdentityAuth {
credential: Arc<ManagedIdentityCredential>,
}
impl ManagedIdentityAuth {
pub fn system_assigned() -> Result<Self, AuthError> {
let credential = ManagedIdentityCredential::new(None)
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(Self { credential })
}
pub fn user_assigned_client_id(client_id: impl Into<String>) -> Result<Self, AuthError> {
let options = ManagedIdentityCredentialOptions {
user_assigned_id: Some(UserAssignedId::ClientId(client_id.into())),
..Default::default()
};
let credential = ManagedIdentityCredential::new(Some(options))
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(Self { credential })
}
pub fn user_assigned_resource_id(resource_id: impl Into<String>) -> Result<Self, AuthError> {
let options = ManagedIdentityCredentialOptions {
user_assigned_id: Some(UserAssignedId::ResourceId(resource_id.into())),
..Default::default()
};
let credential = ManagedIdentityCredential::new(Some(options))
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(Self { credential })
}
pub fn user_assigned_object_id(object_id: impl Into<String>) -> Result<Self, AuthError> {
let options = ManagedIdentityCredentialOptions {
user_assigned_id: Some(UserAssignedId::ObjectId(object_id.into())),
..Default::default()
};
let credential = ManagedIdentityCredential::new(Some(options))
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(Self { credential })
}
pub async fn get_token(&self) -> Result<String, AuthError> {
let token = self
.credential
.get_token(&[AZURE_SQL_SCOPE], None)
.await
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(token.token.secret().to_string())
}
pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
let token = self
.credential
.get_token(&[AZURE_SQL_SCOPE], None)
.await
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
let now = time::OffsetDateTime::now_utc();
let expires_in = if token.expires_on > now {
let diff = token.expires_on - now;
Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
} else {
None
};
Ok((token.token.secret().to_string(), expires_in))
}
pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
let (token, expires_in) = self.get_token_with_expiry().await?;
match expires_in {
Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
None => Ok(AzureAdAuth::with_token(token)),
}
}
}
impl std::fmt::Debug for ManagedIdentityAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ManagedIdentityAuth")
.finish_non_exhaustive()
}
}
impl crate::provider::AsyncAuthProvider for ManagedIdentityAuth {
fn method(&self) -> AuthMethod {
AuthMethod::AzureAd
}
async fn authenticate_async(&self) -> Result<AuthData, AuthError> {
let token = self.get_token().await?;
Ok(AuthData::FedAuth { token, nonce: None })
}
fn needs_refresh(&self) -> bool {
false
}
}
pub struct ServicePrincipalAuth {
credential: Arc<ClientSecretCredential>,
}
impl ServicePrincipalAuth {
pub fn new(
tenant_id: impl AsRef<str>,
client_id: impl Into<String>,
client_secret: impl Into<String>,
) -> Result<Self, AuthError> {
use azure_core::credentials::Secret;
let secret = Secret::new(client_secret.into());
let credential =
ClientSecretCredential::new(tenant_id.as_ref(), client_id.into(), secret, None)
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(Self { credential })
}
pub async fn get_token(&self) -> Result<String, AuthError> {
let token = self
.credential
.get_token(&[AZURE_SQL_SCOPE], None)
.await
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
Ok(token.token.secret().to_string())
}
pub async fn get_token_with_expiry(&self) -> Result<(String, Option<Duration>), AuthError> {
let token = self
.credential
.get_token(&[AZURE_SQL_SCOPE], None)
.await
.map_err(|e| AuthError::AzureIdentity(e.to_string()))?;
let now = time::OffsetDateTime::now_utc();
let expires_in = if token.expires_on > now {
let diff = token.expires_on - now;
Some(Duration::from_secs(diff.whole_seconds().max(0) as u64))
} else {
None
};
Ok((token.token.secret().to_string(), expires_in))
}
pub async fn to_azure_ad_auth(&self) -> Result<AzureAdAuth, AuthError> {
let (token, expires_in) = self.get_token_with_expiry().await?;
match expires_in {
Some(duration) => Ok(AzureAdAuth::with_token_expiring(token, duration)),
None => Ok(AzureAdAuth::with_token(token)),
}
}
}
impl Clone for ServicePrincipalAuth {
fn clone(&self) -> Self {
Self {
credential: Arc::clone(&self.credential),
}
}
}
impl std::fmt::Debug for ServicePrincipalAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServicePrincipalAuth")
.field("credential", &"[REDACTED]")
.finish()
}
}
impl crate::provider::AsyncAuthProvider for ServicePrincipalAuth {
fn method(&self) -> AuthMethod {
AuthMethod::AzureAd
}
async fn authenticate_async(&self) -> Result<AuthData, AuthError> {
let token = self.get_token().await?;
Ok(AuthData::FedAuth { token, nonce: None })
}
fn needs_refresh(&self) -> bool {
false
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[tokio::test]
#[ignore = "Requires Azure Managed Identity environment"]
async fn test_managed_identity_system_assigned() {
let auth = ManagedIdentityAuth::system_assigned().expect("Failed to create credential");
let token = auth.get_token().await.expect("Failed to get token");
assert!(!token.is_empty());
}
#[tokio::test]
#[ignore = "Requires Azure Service Principal credentials"]
async fn test_service_principal() {
let tenant_id = std::env::var("AZURE_TENANT_ID").expect("AZURE_TENANT_ID not set");
let client_id = std::env::var("AZURE_CLIENT_ID").expect("AZURE_CLIENT_ID not set");
let client_secret =
std::env::var("AZURE_CLIENT_SECRET").expect("AZURE_CLIENT_SECRET not set");
let auth = ServicePrincipalAuth::new(tenant_id, client_id, client_secret)
.expect("Failed to create credential");
let token = auth.get_token().await.expect("Failed to get token");
assert!(!token.is_empty());
}
#[test]
fn test_debug_redacts_credentials() {
if let Ok(auth) = ManagedIdentityAuth::system_assigned() {
let debug = format!("{auth:?}");
assert!(debug.contains("ManagedIdentityAuth"));
}
}
}