Skip to main content

aster/providers/
azureauth.rs

1use chrono;
2use serde::Deserialize;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6
7/// Represents errors that can occur during Azure authentication.
8#[derive(Debug, thiserror::Error)]
9pub enum AuthError {
10    /// Error when loading credentials from the filesystem or environment
11    #[error("Failed to load credentials: {0}")]
12    Credentials(String),
13
14    /// Error during token exchange
15    #[error("Token exchange failed: {0}")]
16    TokenExchange(String),
17}
18
19/// Represents an authentication token with its type and value.
20#[derive(Debug, Clone)]
21pub struct AuthToken {
22    /// The type of the token (e.g., "Bearer")
23    pub token_type: String,
24    /// The actual token value
25    pub token_value: String,
26}
27
28/// Represents the types of Azure credentials supported.
29#[derive(Debug, Clone)]
30pub enum AzureCredentials {
31    /// API key based authentication
32    ApiKey(String),
33    /// Azure credential chain based authentication
34    DefaultCredential,
35}
36
37/// Holds a cached token and its expiration time.
38#[derive(Debug, Clone)]
39struct CachedToken {
40    token: AuthToken,
41    expires_at: Instant,
42}
43
44/// Response from Azure token endpoint
45#[derive(Debug, Clone, Deserialize)]
46struct TokenResponse {
47    #[serde(rename = "accessToken")]
48    access_token: String,
49    #[serde(rename = "tokenType")]
50    token_type: String,
51    #[serde(rename = "expires_on")]
52    expires_on: u64,
53}
54
55/// Azure authentication handler that manages credentials and token caching.
56#[derive(Debug)]
57pub struct AzureAuth {
58    credentials: AzureCredentials,
59    cached_token: Arc<RwLock<Option<CachedToken>>>,
60}
61
62impl AzureAuth {
63    /// Creates a new Azure authentication handler.
64    ///
65    /// Initializes the authentication handler by:
66    /// 1. Loading credentials from environment
67    /// 2. Setting up an HTTP client for token requests
68    /// 3. Initializing the token cache
69    ///
70    /// # Returns
71    /// * `Result<Self, AuthError>` - A new AzureAuth instance or an error if initialization fails
72    pub fn new(api_key: Option<String>) -> Result<Self, AuthError> {
73        let credentials = match api_key {
74            Some(key) => AzureCredentials::ApiKey(key),
75            None => AzureCredentials::DefaultCredential,
76        };
77
78        Ok(Self {
79            credentials,
80            cached_token: Arc::new(RwLock::new(None)),
81        })
82    }
83
84    /// Returns the type of credentials being used.
85    pub fn credential_type(&self) -> &AzureCredentials {
86        &self.credentials
87    }
88
89    /// Retrieves a valid authentication token.
90    ///
91    /// This method implements an efficient token management strategy:
92    /// 1. For API key auth, returns the API key directly
93    /// 2. For Azure credential chain:
94    ///    a. Checks the cache for a valid token
95    ///    b. Returns the cached token if not expired
96    ///    c. Obtains a new token if needed or expired
97    ///    d. Uses double-checked locking for thread safety
98    ///
99    /// # Returns
100    /// * `Result<AuthToken, AuthError>` - A valid authentication token or an error
101    pub async fn get_token(&self) -> Result<AuthToken, AuthError> {
102        match &self.credentials {
103            AzureCredentials::ApiKey(key) => Ok(AuthToken {
104                token_type: "Bearer".to_string(),
105                token_value: key.clone(),
106            }),
107            AzureCredentials::DefaultCredential => self.get_default_credential_token().await,
108        }
109    }
110
111    async fn get_default_credential_token(&self) -> Result<AuthToken, AuthError> {
112        // Try read lock first for better concurrency
113        if let Some(cached) = self.cached_token.read().await.as_ref() {
114            if cached.expires_at > Instant::now() {
115                return Ok(cached.token.clone());
116            }
117        }
118
119        // Take write lock only if needed
120        let mut token_guard = self.cached_token.write().await;
121
122        // Double-check expiration after acquiring write lock
123        if let Some(cached) = token_guard.as_ref() {
124            if cached.expires_at > Instant::now() {
125                return Ok(cached.token.clone());
126            }
127        }
128
129        // Get new token using Azure CLI credential
130        let output = tokio::process::Command::new("az")
131            .args([
132                "account",
133                "get-access-token",
134                "--resource",
135                "https://cognitiveservices.azure.com",
136            ])
137            .output()
138            .await
139            .map_err(|e| AuthError::TokenExchange(format!("Failed to execute Azure CLI: {}", e)))?;
140
141        if !output.status.success() {
142            return Err(AuthError::TokenExchange(
143                String::from_utf8_lossy(&output.stderr).to_string(),
144            ));
145        }
146
147        let token_response: TokenResponse = serde_json::from_slice(&output.stdout)
148            .map_err(|e| AuthError::TokenExchange(format!("Invalid token response: {}", e)))?;
149
150        let auth_token = AuthToken {
151            token_type: token_response.token_type,
152            token_value: token_response.access_token,
153        };
154
155        let expires_at = Instant::now()
156            + Duration::from_secs(
157                token_response
158                    .expires_on
159                    .saturating_sub(chrono::Utc::now().timestamp() as u64)
160                    .saturating_sub(30),
161            );
162
163        *token_guard = Some(CachedToken {
164            token: auth_token.clone(),
165            expires_at,
166        });
167
168        Ok(auth_token)
169    }
170}