Skip to main content

hoist_client/
auth.rs

1//! Azure authentication providers
2
3use std::process::Command;
4use thiserror::Error;
5
6/// Authentication errors
7#[derive(Debug, Error)]
8pub enum AuthError {
9    #[error(
10        "Azure CLI not found. Please install it: https://docs.microsoft.com/cli/azure/install-azure-cli"
11    )]
12    AzCliNotFound,
13    #[error("Not logged in to Azure CLI. Run: az login")]
14    NotLoggedIn,
15    #[error("Failed to get access token: {0}")]
16    TokenError(String),
17    #[error("Missing environment variable: {0}")]
18    MissingEnvVar(String),
19    #[error("Authentication failed: {0}")]
20    AuthFailed(String),
21}
22
23/// Authentication provider trait
24pub trait AuthProvider: Send + Sync {
25    /// Get an access token for Azure Search
26    fn get_token(&self) -> Result<String, AuthError>;
27
28    /// Get the authentication method name
29    fn method_name(&self) -> &'static str;
30}
31
32/// Azure CLI authentication provider
33pub struct AzCliAuth {
34    resource_scope: &'static str,
35}
36
37impl AzCliAuth {
38    /// Create an auth provider for Azure Search
39    pub fn for_search() -> Self {
40        Self {
41            resource_scope: "https://search.azure.com",
42        }
43    }
44
45    /// Create an auth provider for Microsoft Foundry
46    pub fn for_foundry() -> Self {
47        Self {
48            resource_scope: "https://ai.azure.com",
49        }
50    }
51
52    /// Create a new auth provider (defaults to Search scope for backward compatibility)
53    pub fn new() -> Self {
54        Self::for_search()
55    }
56
57    /// Check if Azure CLI is available and logged in
58    pub fn check_status() -> Result<AuthStatus, AuthError> {
59        // Check if az CLI is installed
60        let version_output = Command::new("az").arg("--version").output();
61
62        if version_output.is_err() {
63            return Err(AuthError::AzCliNotFound);
64        }
65
66        // Check if logged in
67        let account_output = Command::new("az")
68            .args(["account", "show", "--output", "json"])
69            .output()
70            .map_err(|e| AuthError::TokenError(e.to_string()))?;
71
72        if !account_output.status.success() {
73            return Err(AuthError::NotLoggedIn);
74        }
75
76        // Parse account info
77        let account_json: serde_json::Value = serde_json::from_slice(&account_output.stdout)
78            .map_err(|e| AuthError::TokenError(e.to_string()))?;
79
80        Ok(AuthStatus {
81            logged_in: true,
82            user: account_json
83                .get("user")
84                .and_then(|u| u.get("name"))
85                .and_then(|n| n.as_str())
86                .map(String::from),
87            subscription: account_json
88                .get("name")
89                .and_then(|n| n.as_str())
90                .map(String::from),
91            subscription_id: account_json
92                .get("id")
93                .and_then(|i| i.as_str())
94                .map(String::from),
95        })
96    }
97
98    /// Get an access token for Azure Resource Manager (management.azure.com)
99    pub fn get_arm_token() -> Result<String, AuthError> {
100        let output = Command::new("az")
101            .args([
102                "account",
103                "get-access-token",
104                "--resource",
105                "https://management.azure.com",
106                "--query",
107                "accessToken",
108                "--output",
109                "tsv",
110            ])
111            .output()
112            .map_err(|e| AuthError::TokenError(e.to_string()))?;
113
114        if !output.status.success() {
115            let stderr = String::from_utf8_lossy(&output.stderr);
116            if stderr.contains("not logged in") || stderr.contains("AADSTS") {
117                return Err(AuthError::NotLoggedIn);
118            }
119            return Err(AuthError::TokenError(stderr.to_string()));
120        }
121
122        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
123        if token.is_empty() {
124            return Err(AuthError::TokenError(
125                "Empty ARM token received".to_string(),
126            ));
127        }
128
129        Ok(token)
130    }
131}
132
133impl Default for AzCliAuth {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl AuthProvider for AzCliAuth {
140    fn get_token(&self) -> Result<String, AuthError> {
141        let output = Command::new("az")
142            .args([
143                "account",
144                "get-access-token",
145                "--resource",
146                self.resource_scope,
147                "--query",
148                "accessToken",
149                "--output",
150                "tsv",
151            ])
152            .output()
153            .map_err(|e| AuthError::TokenError(e.to_string()))?;
154
155        if !output.status.success() {
156            let stderr = String::from_utf8_lossy(&output.stderr);
157            if stderr.contains("not logged in") || stderr.contains("AADSTS") {
158                return Err(AuthError::NotLoggedIn);
159            }
160            return Err(AuthError::TokenError(stderr.to_string()));
161        }
162
163        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
164        if token.is_empty() {
165            return Err(AuthError::TokenError("Empty token received".to_string()));
166        }
167
168        Ok(token)
169    }
170
171    fn method_name(&self) -> &'static str {
172        "Azure CLI"
173    }
174}
175
176/// Environment variable authentication provider
177#[derive(Debug)]
178pub struct EnvAuth {
179    client_id: String,
180    client_secret: String,
181    tenant_id: String,
182    resource_scope: &'static str,
183}
184
185impl EnvAuth {
186    /// Create from environment variables (defaults to Search scope)
187    pub fn from_env() -> Result<Self, AuthError> {
188        Self::from_env_for_scope("https://search.azure.com")
189    }
190
191    /// Create from environment variables for a specific resource scope
192    pub fn from_env_for_scope(scope: &'static str) -> Result<Self, AuthError> {
193        let client_id = std::env::var("AZURE_CLIENT_ID")
194            .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_ID".to_string()))?;
195        let client_secret = std::env::var("AZURE_CLIENT_SECRET")
196            .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_SECRET".to_string()))?;
197        let tenant_id = std::env::var("AZURE_TENANT_ID")
198            .map_err(|_| AuthError::MissingEnvVar("AZURE_TENANT_ID".to_string()))?;
199
200        Ok(Self {
201            client_id,
202            client_secret,
203            tenant_id,
204            resource_scope: scope,
205        })
206    }
207
208    /// Check if environment variables are set
209    pub fn is_configured() -> bool {
210        std::env::var("AZURE_CLIENT_ID").is_ok()
211            && std::env::var("AZURE_CLIENT_SECRET").is_ok()
212            && std::env::var("AZURE_TENANT_ID").is_ok()
213    }
214}
215
216impl AuthProvider for EnvAuth {
217    fn get_token(&self) -> Result<String, AuthError> {
218        // Use Azure CLI to get token with service principal
219        let output = Command::new("az")
220            .args([
221                "account",
222                "get-access-token",
223                "--resource",
224                self.resource_scope,
225                "--query",
226                "accessToken",
227                "--output",
228                "tsv",
229                "--tenant",
230                &self.tenant_id,
231                "--username",
232                &self.client_id,
233            ])
234            .env("AZURE_CLIENT_SECRET", &self.client_secret)
235            .output()
236            .map_err(|e| AuthError::TokenError(e.to_string()))?;
237
238        if !output.status.success() {
239            let stderr = String::from_utf8_lossy(&output.stderr);
240            return Err(AuthError::AuthFailed(stderr.to_string()));
241        }
242
243        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
244        Ok(token)
245    }
246
247    fn method_name(&self) -> &'static str {
248        "Environment Variables (Service Principal)"
249    }
250}
251
252/// Authentication status
253#[derive(Debug, Clone)]
254pub struct AuthStatus {
255    pub logged_in: bool,
256    pub user: Option<String>,
257    pub subscription: Option<String>,
258    pub subscription_id: Option<String>,
259}
260
261/// Get the best available authentication provider for Search (backward compat)
262pub fn get_auth_provider() -> Result<Box<dyn AuthProvider>, AuthError> {
263    get_auth_provider_for_scope("https://search.azure.com")
264}
265
266/// Get the best available authentication provider for a specific service domain
267pub fn get_auth_provider_for(
268    domain: hoist_core::ServiceDomain,
269) -> Result<Box<dyn AuthProvider>, AuthError> {
270    let scope = match domain {
271        hoist_core::ServiceDomain::Search => "https://search.azure.com",
272        hoist_core::ServiceDomain::Foundry => "https://ai.azure.com",
273    };
274    get_auth_provider_for_scope(scope)
275}
276
277/// Get the best available authentication provider for a specific resource scope
278fn get_auth_provider_for_scope(scope: &'static str) -> Result<Box<dyn AuthProvider>, AuthError> {
279    // First try environment variables
280    if EnvAuth::is_configured() {
281        return Ok(Box::new(EnvAuth::from_env_for_scope(scope)?));
282    }
283
284    // Fall back to Azure CLI
285    AzCliAuth::check_status()?;
286    Ok(Box::new(AzCliAuth {
287        resource_scope: scope,
288    }))
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use std::sync::Mutex;
295
296    // Env var tests must run serially since they share process-wide state.
297    static ENV_MUTEX: Mutex<()> = Mutex::new(());
298
299    /// # Safety
300    /// Must be called while holding ENV_MUTEX to avoid data races.
301    unsafe fn clear_azure_env_vars() {
302        unsafe {
303            std::env::remove_var("AZURE_CLIENT_ID");
304            std::env::remove_var("AZURE_CLIENT_SECRET");
305            std::env::remove_var("AZURE_TENANT_ID");
306        }
307    }
308
309    /// # Safety
310    /// Must be called while holding ENV_MUTEX to avoid data races.
311    unsafe fn set_azure_env_vars() {
312        unsafe {
313            std::env::set_var("AZURE_CLIENT_ID", "test-client-id");
314            std::env::set_var("AZURE_CLIENT_SECRET", "test-client-secret");
315            std::env::set_var("AZURE_TENANT_ID", "test-tenant-id");
316        }
317    }
318
319    #[test]
320    fn test_env_auth_from_env_success() {
321        let _lock = ENV_MUTEX.lock().unwrap();
322        unsafe { set_azure_env_vars() };
323
324        let result = EnvAuth::from_env();
325        assert!(result.is_ok());
326        let auth = result.unwrap();
327        assert_eq!(auth.client_id, "test-client-id");
328        assert_eq!(auth.client_secret, "test-client-secret");
329        assert_eq!(auth.tenant_id, "test-tenant-id");
330
331        unsafe { clear_azure_env_vars() };
332    }
333
334    #[test]
335    fn test_env_auth_from_env_missing_client_id() {
336        let _lock = ENV_MUTEX.lock().unwrap();
337        unsafe {
338            clear_azure_env_vars();
339            std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
340            std::env::set_var("AZURE_TENANT_ID", "test-tenant");
341        }
342
343        let result = EnvAuth::from_env();
344        assert!(result.is_err());
345        let err = result.unwrap_err();
346        assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_ID"));
347
348        unsafe { clear_azure_env_vars() };
349    }
350
351    #[test]
352    fn test_env_auth_from_env_missing_client_secret() {
353        let _lock = ENV_MUTEX.lock().unwrap();
354        unsafe {
355            clear_azure_env_vars();
356            std::env::set_var("AZURE_CLIENT_ID", "test-id");
357            std::env::set_var("AZURE_TENANT_ID", "test-tenant");
358        }
359
360        let result = EnvAuth::from_env();
361        assert!(result.is_err());
362        let err = result.unwrap_err();
363        assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_SECRET"));
364
365        unsafe { clear_azure_env_vars() };
366    }
367
368    #[test]
369    fn test_env_auth_from_env_missing_tenant_id() {
370        let _lock = ENV_MUTEX.lock().unwrap();
371        unsafe {
372            clear_azure_env_vars();
373            std::env::set_var("AZURE_CLIENT_ID", "test-id");
374            std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
375        }
376
377        let result = EnvAuth::from_env();
378        assert!(result.is_err());
379        let err = result.unwrap_err();
380        assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_TENANT_ID"));
381
382        unsafe { clear_azure_env_vars() };
383    }
384
385    #[test]
386    fn test_env_auth_is_configured_all_set() {
387        let _lock = ENV_MUTEX.lock().unwrap();
388        unsafe { set_azure_env_vars() };
389
390        assert!(EnvAuth::is_configured());
391
392        unsafe { clear_azure_env_vars() };
393    }
394
395    #[test]
396    fn test_env_auth_is_configured_none_set() {
397        let _lock = ENV_MUTEX.lock().unwrap();
398        unsafe { clear_azure_env_vars() };
399
400        assert!(!EnvAuth::is_configured());
401    }
402
403    #[test]
404    fn test_env_auth_is_configured_partial() {
405        let _lock = ENV_MUTEX.lock().unwrap();
406        unsafe {
407            clear_azure_env_vars();
408            std::env::set_var("AZURE_CLIENT_ID", "test-id");
409            std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
410        }
411        // AZURE_TENANT_ID intentionally missing
412
413        assert!(!EnvAuth::is_configured());
414
415        unsafe { clear_azure_env_vars() };
416    }
417
418    #[test]
419    fn test_env_auth_method_name() {
420        let _lock = ENV_MUTEX.lock().unwrap();
421        unsafe { set_azure_env_vars() };
422
423        let auth = EnvAuth::from_env().unwrap();
424        assert_eq!(
425            auth.method_name(),
426            "Environment Variables (Service Principal)"
427        );
428
429        unsafe { clear_azure_env_vars() };
430    }
431
432    #[test]
433    fn test_az_cli_auth_method_name() {
434        let auth = AzCliAuth::new();
435        assert_eq!(auth.method_name(), "Azure CLI");
436    }
437
438    #[test]
439    fn test_az_cli_auth_search_scope() {
440        let auth = AzCliAuth::for_search();
441        assert_eq!(auth.resource_scope, "https://search.azure.com");
442    }
443
444    #[test]
445    fn test_az_cli_auth_foundry_scope() {
446        let auth = AzCliAuth::for_foundry();
447        assert_eq!(auth.resource_scope, "https://ai.azure.com");
448    }
449
450    #[test]
451    fn test_az_cli_auth_new_defaults_to_search() {
452        let auth = AzCliAuth::new();
453        assert_eq!(auth.resource_scope, "https://search.azure.com");
454    }
455
456    #[test]
457    fn test_env_auth_from_env_scope_foundry() {
458        let _lock = ENV_MUTEX.lock().unwrap();
459        unsafe { set_azure_env_vars() };
460
461        let result = EnvAuth::from_env_for_scope("https://ai.azure.com");
462        assert!(result.is_ok());
463        let auth = result.unwrap();
464        assert_eq!(auth.resource_scope, "https://ai.azure.com");
465
466        unsafe { clear_azure_env_vars() };
467    }
468
469    #[test]
470    fn test_env_auth_from_env_default_scope_is_search() {
471        let _lock = ENV_MUTEX.lock().unwrap();
472        unsafe { set_azure_env_vars() };
473
474        let auth = EnvAuth::from_env().unwrap();
475        assert_eq!(auth.resource_scope, "https://search.azure.com");
476
477        unsafe { clear_azure_env_vars() };
478    }
479
480    #[test]
481    fn test_auth_status_fields() {
482        let status = AuthStatus {
483            logged_in: true,
484            user: Some("testuser@example.com".to_string()),
485            subscription: Some("My Subscription".to_string()),
486            subscription_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
487        };
488
489        assert!(status.logged_in);
490        assert_eq!(status.user.as_deref(), Some("testuser@example.com"));
491        assert_eq!(status.subscription.as_deref(), Some("My Subscription"));
492        assert_eq!(
493            status.subscription_id.as_deref(),
494            Some("00000000-0000-0000-0000-000000000000")
495        );
496    }
497}