1use std::process::Command;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
8pub enum AuthError {
9 #[error("Azure CLI not found. Please install it: https://docs.microsoft.com/cli/azure/install-azure-cli")]
10 AzCliNotFound,
11 #[error("Not logged in to Azure CLI. Run: az login")]
12 NotLoggedIn,
13 #[error("Failed to get access token: {0}")]
14 TokenError(String),
15 #[error("Missing environment variable: {0}")]
16 MissingEnvVar(String),
17 #[error("Authentication failed: {0}")]
18 AuthFailed(String),
19}
20
21pub trait AuthProvider: Send + Sync {
23 fn get_token(&self) -> Result<String, AuthError>;
25
26 fn method_name(&self) -> &'static str;
28}
29
30pub struct AzCliAuth {
32 resource_scope: &'static str,
33}
34
35impl AzCliAuth {
36 pub fn for_search() -> Self {
38 Self {
39 resource_scope: "https://search.azure.com",
40 }
41 }
42
43 pub fn for_foundry() -> Self {
45 Self {
46 resource_scope: "https://ai.azure.com",
47 }
48 }
49
50 pub fn new() -> Self {
52 Self::for_search()
53 }
54
55 pub fn check_status() -> Result<AuthStatus, AuthError> {
57 let version_output = Command::new("az").arg("--version").output();
59
60 if version_output.is_err() {
61 return Err(AuthError::AzCliNotFound);
62 }
63
64 let account_output = Command::new("az")
66 .args(["account", "show", "--output", "json"])
67 .output()
68 .map_err(|e| AuthError::TokenError(e.to_string()))?;
69
70 if !account_output.status.success() {
71 return Err(AuthError::NotLoggedIn);
72 }
73
74 let account_json: serde_json::Value = serde_json::from_slice(&account_output.stdout)
76 .map_err(|e| AuthError::TokenError(e.to_string()))?;
77
78 Ok(AuthStatus {
79 logged_in: true,
80 user: account_json
81 .get("user")
82 .and_then(|u| u.get("name"))
83 .and_then(|n| n.as_str())
84 .map(String::from),
85 subscription: account_json
86 .get("name")
87 .and_then(|n| n.as_str())
88 .map(String::from),
89 subscription_id: account_json
90 .get("id")
91 .and_then(|i| i.as_str())
92 .map(String::from),
93 })
94 }
95
96 pub fn get_arm_token() -> Result<String, AuthError> {
98 let output = Command::new("az")
99 .args([
100 "account",
101 "get-access-token",
102 "--resource",
103 "https://management.azure.com",
104 "--query",
105 "accessToken",
106 "--output",
107 "tsv",
108 ])
109 .output()
110 .map_err(|e| AuthError::TokenError(e.to_string()))?;
111
112 if !output.status.success() {
113 let stderr = String::from_utf8_lossy(&output.stderr);
114 if stderr.contains("not logged in") || stderr.contains("AADSTS") {
115 return Err(AuthError::NotLoggedIn);
116 }
117 return Err(AuthError::TokenError(stderr.to_string()));
118 }
119
120 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
121 if token.is_empty() {
122 return Err(AuthError::TokenError(
123 "Empty ARM token received".to_string(),
124 ));
125 }
126
127 Ok(token)
128 }
129}
130
131impl Default for AzCliAuth {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl AuthProvider for AzCliAuth {
138 fn get_token(&self) -> Result<String, AuthError> {
139 let output = Command::new("az")
140 .args([
141 "account",
142 "get-access-token",
143 "--resource",
144 self.resource_scope,
145 "--query",
146 "accessToken",
147 "--output",
148 "tsv",
149 ])
150 .output()
151 .map_err(|e| AuthError::TokenError(e.to_string()))?;
152
153 if !output.status.success() {
154 let stderr = String::from_utf8_lossy(&output.stderr);
155 if stderr.contains("not logged in") || stderr.contains("AADSTS") {
156 return Err(AuthError::NotLoggedIn);
157 }
158 return Err(AuthError::TokenError(stderr.to_string()));
159 }
160
161 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
162 if token.is_empty() {
163 return Err(AuthError::TokenError("Empty token received".to_string()));
164 }
165
166 Ok(token)
167 }
168
169 fn method_name(&self) -> &'static str {
170 "Azure CLI"
171 }
172}
173
174#[derive(Debug)]
176pub struct EnvAuth {
177 client_id: String,
178 client_secret: String,
179 tenant_id: String,
180 resource_scope: &'static str,
181}
182
183impl EnvAuth {
184 pub fn from_env() -> Result<Self, AuthError> {
186 Self::from_env_for_scope("https://search.azure.com")
187 }
188
189 pub fn from_env_for_scope(scope: &'static str) -> Result<Self, AuthError> {
191 let client_id = std::env::var("AZURE_CLIENT_ID")
192 .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_ID".to_string()))?;
193 let client_secret = std::env::var("AZURE_CLIENT_SECRET")
194 .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_SECRET".to_string()))?;
195 let tenant_id = std::env::var("AZURE_TENANT_ID")
196 .map_err(|_| AuthError::MissingEnvVar("AZURE_TENANT_ID".to_string()))?;
197
198 Ok(Self {
199 client_id,
200 client_secret,
201 tenant_id,
202 resource_scope: scope,
203 })
204 }
205
206 pub fn is_configured() -> bool {
208 std::env::var("AZURE_CLIENT_ID").is_ok()
209 && std::env::var("AZURE_CLIENT_SECRET").is_ok()
210 && std::env::var("AZURE_TENANT_ID").is_ok()
211 }
212}
213
214impl AuthProvider for EnvAuth {
215 fn get_token(&self) -> Result<String, AuthError> {
216 let output = Command::new("az")
218 .args([
219 "account",
220 "get-access-token",
221 "--resource",
222 self.resource_scope,
223 "--query",
224 "accessToken",
225 "--output",
226 "tsv",
227 "--tenant",
228 &self.tenant_id,
229 "--username",
230 &self.client_id,
231 ])
232 .env("AZURE_CLIENT_SECRET", &self.client_secret)
233 .output()
234 .map_err(|e| AuthError::TokenError(e.to_string()))?;
235
236 if !output.status.success() {
237 let stderr = String::from_utf8_lossy(&output.stderr);
238 return Err(AuthError::AuthFailed(stderr.to_string()));
239 }
240
241 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
242 Ok(token)
243 }
244
245 fn method_name(&self) -> &'static str {
246 "Environment Variables (Service Principal)"
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct AuthStatus {
253 pub logged_in: bool,
254 pub user: Option<String>,
255 pub subscription: Option<String>,
256 pub subscription_id: Option<String>,
257}
258
259pub fn get_auth_provider() -> Result<Box<dyn AuthProvider>, AuthError> {
261 get_auth_provider_for_scope("https://search.azure.com")
262}
263
264pub fn get_auth_provider_for(
266 domain: hoist_core::ServiceDomain,
267) -> Result<Box<dyn AuthProvider>, AuthError> {
268 let scope = match domain {
269 hoist_core::ServiceDomain::Search => "https://search.azure.com",
270 hoist_core::ServiceDomain::Foundry => "https://ai.azure.com",
271 };
272 get_auth_provider_for_scope(scope)
273}
274
275fn get_auth_provider_for_scope(scope: &'static str) -> Result<Box<dyn AuthProvider>, AuthError> {
277 if EnvAuth::is_configured() {
279 return Ok(Box::new(EnvAuth::from_env_for_scope(scope)?));
280 }
281
282 AzCliAuth::check_status()?;
284 Ok(Box::new(AzCliAuth {
285 resource_scope: scope,
286 }))
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use std::sync::Mutex;
293
294 static ENV_MUTEX: Mutex<()> = Mutex::new(());
296
297 fn clear_azure_env_vars() {
298 std::env::remove_var("AZURE_CLIENT_ID");
299 std::env::remove_var("AZURE_CLIENT_SECRET");
300 std::env::remove_var("AZURE_TENANT_ID");
301 }
302
303 fn set_azure_env_vars() {
304 std::env::set_var("AZURE_CLIENT_ID", "test-client-id");
305 std::env::set_var("AZURE_CLIENT_SECRET", "test-client-secret");
306 std::env::set_var("AZURE_TENANT_ID", "test-tenant-id");
307 }
308
309 #[test]
310 fn test_env_auth_from_env_success() {
311 let _lock = ENV_MUTEX.lock().unwrap();
312 set_azure_env_vars();
313
314 let result = EnvAuth::from_env();
315 assert!(result.is_ok());
316 let auth = result.unwrap();
317 assert_eq!(auth.client_id, "test-client-id");
318 assert_eq!(auth.client_secret, "test-client-secret");
319 assert_eq!(auth.tenant_id, "test-tenant-id");
320
321 clear_azure_env_vars();
322 }
323
324 #[test]
325 fn test_env_auth_from_env_missing_client_id() {
326 let _lock = ENV_MUTEX.lock().unwrap();
327 clear_azure_env_vars();
328 std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
329 std::env::set_var("AZURE_TENANT_ID", "test-tenant");
330
331 let result = EnvAuth::from_env();
332 assert!(result.is_err());
333 let err = result.unwrap_err();
334 assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_ID"));
335
336 clear_azure_env_vars();
337 }
338
339 #[test]
340 fn test_env_auth_from_env_missing_client_secret() {
341 let _lock = ENV_MUTEX.lock().unwrap();
342 clear_azure_env_vars();
343 std::env::set_var("AZURE_CLIENT_ID", "test-id");
344 std::env::set_var("AZURE_TENANT_ID", "test-tenant");
345
346 let result = EnvAuth::from_env();
347 assert!(result.is_err());
348 let err = result.unwrap_err();
349 assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_SECRET"));
350
351 clear_azure_env_vars();
352 }
353
354 #[test]
355 fn test_env_auth_from_env_missing_tenant_id() {
356 let _lock = ENV_MUTEX.lock().unwrap();
357 clear_azure_env_vars();
358 std::env::set_var("AZURE_CLIENT_ID", "test-id");
359 std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
360
361 let result = EnvAuth::from_env();
362 assert!(result.is_err());
363 let err = result.unwrap_err();
364 assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_TENANT_ID"));
365
366 clear_azure_env_vars();
367 }
368
369 #[test]
370 fn test_env_auth_is_configured_all_set() {
371 let _lock = ENV_MUTEX.lock().unwrap();
372 set_azure_env_vars();
373
374 assert!(EnvAuth::is_configured());
375
376 clear_azure_env_vars();
377 }
378
379 #[test]
380 fn test_env_auth_is_configured_none_set() {
381 let _lock = ENV_MUTEX.lock().unwrap();
382 clear_azure_env_vars();
383
384 assert!(!EnvAuth::is_configured());
385 }
386
387 #[test]
388 fn test_env_auth_is_configured_partial() {
389 let _lock = ENV_MUTEX.lock().unwrap();
390 clear_azure_env_vars();
391 std::env::set_var("AZURE_CLIENT_ID", "test-id");
392 std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
393 assert!(!EnvAuth::is_configured());
396
397 clear_azure_env_vars();
398 }
399
400 #[test]
401 fn test_env_auth_method_name() {
402 let _lock = ENV_MUTEX.lock().unwrap();
403 set_azure_env_vars();
404
405 let auth = EnvAuth::from_env().unwrap();
406 assert_eq!(
407 auth.method_name(),
408 "Environment Variables (Service Principal)"
409 );
410
411 clear_azure_env_vars();
412 }
413
414 #[test]
415 fn test_az_cli_auth_method_name() {
416 let auth = AzCliAuth::new();
417 assert_eq!(auth.method_name(), "Azure CLI");
418 }
419
420 #[test]
421 fn test_az_cli_auth_search_scope() {
422 let auth = AzCliAuth::for_search();
423 assert_eq!(auth.resource_scope, "https://search.azure.com");
424 }
425
426 #[test]
427 fn test_az_cli_auth_foundry_scope() {
428 let auth = AzCliAuth::for_foundry();
429 assert_eq!(auth.resource_scope, "https://ai.azure.com");
430 }
431
432 #[test]
433 fn test_az_cli_auth_new_defaults_to_search() {
434 let auth = AzCliAuth::new();
435 assert_eq!(auth.resource_scope, "https://search.azure.com");
436 }
437
438 #[test]
439 fn test_env_auth_from_env_scope_foundry() {
440 let _lock = ENV_MUTEX.lock().unwrap();
441 set_azure_env_vars();
442
443 let result = EnvAuth::from_env_for_scope("https://ai.azure.com");
444 assert!(result.is_ok());
445 let auth = result.unwrap();
446 assert_eq!(auth.resource_scope, "https://ai.azure.com");
447
448 clear_azure_env_vars();
449 }
450
451 #[test]
452 fn test_env_auth_from_env_default_scope_is_search() {
453 let _lock = ENV_MUTEX.lock().unwrap();
454 set_azure_env_vars();
455
456 let auth = EnvAuth::from_env().unwrap();
457 assert_eq!(auth.resource_scope, "https://search.azure.com");
458
459 clear_azure_env_vars();
460 }
461
462 #[test]
463 fn test_auth_status_fields() {
464 let status = AuthStatus {
465 logged_in: true,
466 user: Some("testuser@example.com".to_string()),
467 subscription: Some("My Subscription".to_string()),
468 subscription_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
469 };
470
471 assert!(status.logged_in);
472 assert_eq!(status.user.as_deref(), Some("testuser@example.com"));
473 assert_eq!(status.subscription.as_deref(), Some("My Subscription"));
474 assert_eq!(
475 status.subscription_id.as_deref(),
476 Some("00000000-0000-0000-0000-000000000000")
477 );
478 }
479}