1use std::process::Command;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
8pub enum AuthError {
9 #[error(
10 "Azure CLI not found. Please install it: https://docs.microsoft.com/cli/azure/install-azure-cli"
11 )]
12 AzCliNotFound,
13 #[error("Not logged in to Azure CLI. Run: az login")]
14 NotLoggedIn,
15 #[error("Failed to get access token: {0}")]
16 TokenError(String),
17 #[error("Missing environment variable: {0}")]
18 MissingEnvVar(String),
19 #[error("Authentication failed: {0}")]
20 AuthFailed(String),
21}
22
23pub trait AuthProvider: Send + Sync {
25 fn get_token(&self) -> Result<String, AuthError>;
27
28 fn method_name(&self) -> &'static str;
30}
31
32pub struct AzCliAuth {
34 resource_scope: &'static str,
35}
36
37impl AzCliAuth {
38 pub fn for_search() -> Self {
40 Self {
41 resource_scope: "https://search.azure.com",
42 }
43 }
44
45 pub fn for_foundry() -> Self {
47 Self {
48 resource_scope: "https://ai.azure.com",
49 }
50 }
51
52 pub fn new() -> Self {
54 Self::for_search()
55 }
56
57 pub fn check_status() -> Result<AuthStatus, AuthError> {
59 let version_output = Command::new("az").arg("--version").output();
61
62 if version_output.is_err() {
63 return Err(AuthError::AzCliNotFound);
64 }
65
66 let account_output = Command::new("az")
68 .args(["account", "show", "--output", "json"])
69 .output()
70 .map_err(|e| AuthError::TokenError(e.to_string()))?;
71
72 if !account_output.status.success() {
73 return Err(AuthError::NotLoggedIn);
74 }
75
76 let account_json: serde_json::Value = serde_json::from_slice(&account_output.stdout)
78 .map_err(|e| AuthError::TokenError(e.to_string()))?;
79
80 Ok(AuthStatus {
81 logged_in: true,
82 user: account_json
83 .get("user")
84 .and_then(|u| u.get("name"))
85 .and_then(|n| n.as_str())
86 .map(String::from),
87 subscription: account_json
88 .get("name")
89 .and_then(|n| n.as_str())
90 .map(String::from),
91 subscription_id: account_json
92 .get("id")
93 .and_then(|i| i.as_str())
94 .map(String::from),
95 })
96 }
97
98 pub fn get_arm_token() -> Result<String, AuthError> {
100 let output = Command::new("az")
101 .args([
102 "account",
103 "get-access-token",
104 "--resource",
105 "https://management.azure.com",
106 "--query",
107 "accessToken",
108 "--output",
109 "tsv",
110 ])
111 .output()
112 .map_err(|e| AuthError::TokenError(e.to_string()))?;
113
114 if !output.status.success() {
115 let stderr = String::from_utf8_lossy(&output.stderr);
116 if stderr.contains("not logged in") || stderr.contains("AADSTS") {
117 return Err(AuthError::NotLoggedIn);
118 }
119 return Err(AuthError::TokenError(stderr.to_string()));
120 }
121
122 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
123 if token.is_empty() {
124 return Err(AuthError::TokenError(
125 "Empty ARM token received".to_string(),
126 ));
127 }
128
129 Ok(token)
130 }
131}
132
133impl Default for AzCliAuth {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl AuthProvider for AzCliAuth {
140 fn get_token(&self) -> Result<String, AuthError> {
141 let output = Command::new("az")
142 .args([
143 "account",
144 "get-access-token",
145 "--resource",
146 self.resource_scope,
147 "--query",
148 "accessToken",
149 "--output",
150 "tsv",
151 ])
152 .output()
153 .map_err(|e| AuthError::TokenError(e.to_string()))?;
154
155 if !output.status.success() {
156 let stderr = String::from_utf8_lossy(&output.stderr);
157 if stderr.contains("not logged in") || stderr.contains("AADSTS") {
158 return Err(AuthError::NotLoggedIn);
159 }
160 return Err(AuthError::TokenError(stderr.to_string()));
161 }
162
163 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
164 if token.is_empty() {
165 return Err(AuthError::TokenError("Empty token received".to_string()));
166 }
167
168 Ok(token)
169 }
170
171 fn method_name(&self) -> &'static str {
172 "Azure CLI"
173 }
174}
175
176#[derive(Debug)]
178pub struct EnvAuth {
179 client_id: String,
180 client_secret: String,
181 tenant_id: String,
182 resource_scope: &'static str,
183}
184
185impl EnvAuth {
186 pub fn from_env() -> Result<Self, AuthError> {
188 Self::from_env_for_scope("https://search.azure.com")
189 }
190
191 pub fn from_env_for_scope(scope: &'static str) -> Result<Self, AuthError> {
193 let client_id = std::env::var("AZURE_CLIENT_ID")
194 .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_ID".to_string()))?;
195 let client_secret = std::env::var("AZURE_CLIENT_SECRET")
196 .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_SECRET".to_string()))?;
197 let tenant_id = std::env::var("AZURE_TENANT_ID")
198 .map_err(|_| AuthError::MissingEnvVar("AZURE_TENANT_ID".to_string()))?;
199
200 Ok(Self {
201 client_id,
202 client_secret,
203 tenant_id,
204 resource_scope: scope,
205 })
206 }
207
208 pub fn is_configured() -> bool {
210 std::env::var("AZURE_CLIENT_ID").is_ok()
211 && std::env::var("AZURE_CLIENT_SECRET").is_ok()
212 && std::env::var("AZURE_TENANT_ID").is_ok()
213 }
214}
215
216impl AuthProvider for EnvAuth {
217 fn get_token(&self) -> Result<String, AuthError> {
218 let output = Command::new("az")
220 .args([
221 "account",
222 "get-access-token",
223 "--resource",
224 self.resource_scope,
225 "--query",
226 "accessToken",
227 "--output",
228 "tsv",
229 "--tenant",
230 &self.tenant_id,
231 "--username",
232 &self.client_id,
233 ])
234 .env("AZURE_CLIENT_SECRET", &self.client_secret)
235 .output()
236 .map_err(|e| AuthError::TokenError(e.to_string()))?;
237
238 if !output.status.success() {
239 let stderr = String::from_utf8_lossy(&output.stderr);
240 return Err(AuthError::AuthFailed(stderr.to_string()));
241 }
242
243 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
244 Ok(token)
245 }
246
247 fn method_name(&self) -> &'static str {
248 "Environment Variables (Service Principal)"
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct AuthStatus {
255 pub logged_in: bool,
256 pub user: Option<String>,
257 pub subscription: Option<String>,
258 pub subscription_id: Option<String>,
259}
260
261pub fn get_auth_provider() -> Result<Box<dyn AuthProvider>, AuthError> {
263 get_auth_provider_for_scope("https://search.azure.com")
264}
265
266pub fn get_auth_provider_for(
268 domain: hoist_core::ServiceDomain,
269) -> Result<Box<dyn AuthProvider>, AuthError> {
270 let scope = match domain {
271 hoist_core::ServiceDomain::Search => "https://search.azure.com",
272 hoist_core::ServiceDomain::Foundry => "https://ai.azure.com",
273 };
274 get_auth_provider_for_scope(scope)
275}
276
277fn get_auth_provider_for_scope(scope: &'static str) -> Result<Box<dyn AuthProvider>, AuthError> {
279 if EnvAuth::is_configured() {
281 return Ok(Box::new(EnvAuth::from_env_for_scope(scope)?));
282 }
283
284 AzCliAuth::check_status()?;
286 Ok(Box::new(AzCliAuth {
287 resource_scope: scope,
288 }))
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use std::sync::Mutex;
295
296 static ENV_MUTEX: Mutex<()> = Mutex::new(());
298
299 unsafe fn clear_azure_env_vars() {
302 unsafe {
303 std::env::remove_var("AZURE_CLIENT_ID");
304 std::env::remove_var("AZURE_CLIENT_SECRET");
305 std::env::remove_var("AZURE_TENANT_ID");
306 }
307 }
308
309 unsafe fn set_azure_env_vars() {
312 unsafe {
313 std::env::set_var("AZURE_CLIENT_ID", "test-client-id");
314 std::env::set_var("AZURE_CLIENT_SECRET", "test-client-secret");
315 std::env::set_var("AZURE_TENANT_ID", "test-tenant-id");
316 }
317 }
318
319 #[test]
320 fn test_env_auth_from_env_success() {
321 let _lock = ENV_MUTEX.lock().unwrap();
322 unsafe { set_azure_env_vars() };
323
324 let result = EnvAuth::from_env();
325 assert!(result.is_ok());
326 let auth = result.unwrap();
327 assert_eq!(auth.client_id, "test-client-id");
328 assert_eq!(auth.client_secret, "test-client-secret");
329 assert_eq!(auth.tenant_id, "test-tenant-id");
330
331 unsafe { clear_azure_env_vars() };
332 }
333
334 #[test]
335 fn test_env_auth_from_env_missing_client_id() {
336 let _lock = ENV_MUTEX.lock().unwrap();
337 unsafe {
338 clear_azure_env_vars();
339 std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
340 std::env::set_var("AZURE_TENANT_ID", "test-tenant");
341 }
342
343 let result = EnvAuth::from_env();
344 assert!(result.is_err());
345 let err = result.unwrap_err();
346 assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_ID"));
347
348 unsafe { clear_azure_env_vars() };
349 }
350
351 #[test]
352 fn test_env_auth_from_env_missing_client_secret() {
353 let _lock = ENV_MUTEX.lock().unwrap();
354 unsafe {
355 clear_azure_env_vars();
356 std::env::set_var("AZURE_CLIENT_ID", "test-id");
357 std::env::set_var("AZURE_TENANT_ID", "test-tenant");
358 }
359
360 let result = EnvAuth::from_env();
361 assert!(result.is_err());
362 let err = result.unwrap_err();
363 assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_SECRET"));
364
365 unsafe { clear_azure_env_vars() };
366 }
367
368 #[test]
369 fn test_env_auth_from_env_missing_tenant_id() {
370 let _lock = ENV_MUTEX.lock().unwrap();
371 unsafe {
372 clear_azure_env_vars();
373 std::env::set_var("AZURE_CLIENT_ID", "test-id");
374 std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
375 }
376
377 let result = EnvAuth::from_env();
378 assert!(result.is_err());
379 let err = result.unwrap_err();
380 assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_TENANT_ID"));
381
382 unsafe { clear_azure_env_vars() };
383 }
384
385 #[test]
386 fn test_env_auth_is_configured_all_set() {
387 let _lock = ENV_MUTEX.lock().unwrap();
388 unsafe { set_azure_env_vars() };
389
390 assert!(EnvAuth::is_configured());
391
392 unsafe { clear_azure_env_vars() };
393 }
394
395 #[test]
396 fn test_env_auth_is_configured_none_set() {
397 let _lock = ENV_MUTEX.lock().unwrap();
398 unsafe { clear_azure_env_vars() };
399
400 assert!(!EnvAuth::is_configured());
401 }
402
403 #[test]
404 fn test_env_auth_is_configured_partial() {
405 let _lock = ENV_MUTEX.lock().unwrap();
406 unsafe {
407 clear_azure_env_vars();
408 std::env::set_var("AZURE_CLIENT_ID", "test-id");
409 std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
410 }
411 assert!(!EnvAuth::is_configured());
414
415 unsafe { clear_azure_env_vars() };
416 }
417
418 #[test]
419 fn test_env_auth_method_name() {
420 let _lock = ENV_MUTEX.lock().unwrap();
421 unsafe { set_azure_env_vars() };
422
423 let auth = EnvAuth::from_env().unwrap();
424 assert_eq!(
425 auth.method_name(),
426 "Environment Variables (Service Principal)"
427 );
428
429 unsafe { clear_azure_env_vars() };
430 }
431
432 #[test]
433 fn test_az_cli_auth_method_name() {
434 let auth = AzCliAuth::new();
435 assert_eq!(auth.method_name(), "Azure CLI");
436 }
437
438 #[test]
439 fn test_az_cli_auth_search_scope() {
440 let auth = AzCliAuth::for_search();
441 assert_eq!(auth.resource_scope, "https://search.azure.com");
442 }
443
444 #[test]
445 fn test_az_cli_auth_foundry_scope() {
446 let auth = AzCliAuth::for_foundry();
447 assert_eq!(auth.resource_scope, "https://ai.azure.com");
448 }
449
450 #[test]
451 fn test_az_cli_auth_new_defaults_to_search() {
452 let auth = AzCliAuth::new();
453 assert_eq!(auth.resource_scope, "https://search.azure.com");
454 }
455
456 #[test]
457 fn test_env_auth_from_env_scope_foundry() {
458 let _lock = ENV_MUTEX.lock().unwrap();
459 unsafe { set_azure_env_vars() };
460
461 let result = EnvAuth::from_env_for_scope("https://ai.azure.com");
462 assert!(result.is_ok());
463 let auth = result.unwrap();
464 assert_eq!(auth.resource_scope, "https://ai.azure.com");
465
466 unsafe { clear_azure_env_vars() };
467 }
468
469 #[test]
470 fn test_env_auth_from_env_default_scope_is_search() {
471 let _lock = ENV_MUTEX.lock().unwrap();
472 unsafe { set_azure_env_vars() };
473
474 let auth = EnvAuth::from_env().unwrap();
475 assert_eq!(auth.resource_scope, "https://search.azure.com");
476
477 unsafe { clear_azure_env_vars() };
478 }
479
480 #[test]
481 fn test_auth_status_fields() {
482 let status = AuthStatus {
483 logged_in: true,
484 user: Some("testuser@example.com".to_string()),
485 subscription: Some("My Subscription".to_string()),
486 subscription_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
487 };
488
489 assert!(status.logged_in);
490 assert_eq!(status.user.as_deref(), Some("testuser@example.com"));
491 assert_eq!(status.subscription.as_deref(), Some("My Subscription"));
492 assert_eq!(
493 status.subscription_id.as_deref(),
494 Some("00000000-0000-0000-0000-000000000000")
495 );
496 }
497}