aster/providers/
azureauth.rs1use chrono;
2use serde::Deserialize;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6
7#[derive(Debug, thiserror::Error)]
9pub enum AuthError {
10 #[error("Failed to load credentials: {0}")]
12 Credentials(String),
13
14 #[error("Token exchange failed: {0}")]
16 TokenExchange(String),
17}
18
19#[derive(Debug, Clone)]
21pub struct AuthToken {
22 pub token_type: String,
24 pub token_value: String,
26}
27
28#[derive(Debug, Clone)]
30pub enum AzureCredentials {
31 ApiKey(String),
33 DefaultCredential,
35}
36
37#[derive(Debug, Clone)]
39struct CachedToken {
40 token: AuthToken,
41 expires_at: Instant,
42}
43
44#[derive(Debug, Clone, Deserialize)]
46struct TokenResponse {
47 #[serde(rename = "accessToken")]
48 access_token: String,
49 #[serde(rename = "tokenType")]
50 token_type: String,
51 #[serde(rename = "expires_on")]
52 expires_on: u64,
53}
54
55#[derive(Debug)]
57pub struct AzureAuth {
58 credentials: AzureCredentials,
59 cached_token: Arc<RwLock<Option<CachedToken>>>,
60}
61
62impl AzureAuth {
63 pub fn new(api_key: Option<String>) -> Result<Self, AuthError> {
73 let credentials = match api_key {
74 Some(key) => AzureCredentials::ApiKey(key),
75 None => AzureCredentials::DefaultCredential,
76 };
77
78 Ok(Self {
79 credentials,
80 cached_token: Arc::new(RwLock::new(None)),
81 })
82 }
83
84 pub fn credential_type(&self) -> &AzureCredentials {
86 &self.credentials
87 }
88
89 pub async fn get_token(&self) -> Result<AuthToken, AuthError> {
102 match &self.credentials {
103 AzureCredentials::ApiKey(key) => Ok(AuthToken {
104 token_type: "Bearer".to_string(),
105 token_value: key.clone(),
106 }),
107 AzureCredentials::DefaultCredential => self.get_default_credential_token().await,
108 }
109 }
110
111 async fn get_default_credential_token(&self) -> Result<AuthToken, AuthError> {
112 if let Some(cached) = self.cached_token.read().await.as_ref() {
114 if cached.expires_at > Instant::now() {
115 return Ok(cached.token.clone());
116 }
117 }
118
119 let mut token_guard = self.cached_token.write().await;
121
122 if let Some(cached) = token_guard.as_ref() {
124 if cached.expires_at > Instant::now() {
125 return Ok(cached.token.clone());
126 }
127 }
128
129 let output = tokio::process::Command::new("az")
131 .args([
132 "account",
133 "get-access-token",
134 "--resource",
135 "https://cognitiveservices.azure.com",
136 ])
137 .output()
138 .await
139 .map_err(|e| AuthError::TokenExchange(format!("Failed to execute Azure CLI: {}", e)))?;
140
141 if !output.status.success() {
142 return Err(AuthError::TokenExchange(
143 String::from_utf8_lossy(&output.stderr).to_string(),
144 ));
145 }
146
147 let token_response: TokenResponse = serde_json::from_slice(&output.stdout)
148 .map_err(|e| AuthError::TokenExchange(format!("Invalid token response: {}", e)))?;
149
150 let auth_token = AuthToken {
151 token_type: token_response.token_type,
152 token_value: token_response.access_token,
153 };
154
155 let expires_at = Instant::now()
156 + Duration::from_secs(
157 token_response
158 .expires_on
159 .saturating_sub(chrono::Utc::now().timestamp() as u64)
160 .saturating_sub(30),
161 );
162
163 *token_guard = Some(CachedToken {
164 token: auth_token.clone(),
165 expires_at,
166 });
167
168 Ok(auth_token)
169 }
170}