Skip to main content

hoist_client/
auth.rs

1//! Azure authentication providers
2
3use std::process::Command;
4use thiserror::Error;
5
6/// Authentication errors
7#[derive(Debug, Error)]
8pub enum AuthError {
9    #[error("Azure CLI not found. Please install it: https://docs.microsoft.com/cli/azure/install-azure-cli")]
10    AzCliNotFound,
11    #[error("Not logged in to Azure CLI. Run: az login")]
12    NotLoggedIn,
13    #[error("Failed to get access token: {0}")]
14    TokenError(String),
15    #[error("Missing environment variable: {0}")]
16    MissingEnvVar(String),
17    #[error("Authentication failed: {0}")]
18    AuthFailed(String),
19}
20
21/// Authentication provider trait
22pub trait AuthProvider: Send + Sync {
23    /// Get an access token for Azure Search
24    fn get_token(&self) -> Result<String, AuthError>;
25
26    /// Get the authentication method name
27    fn method_name(&self) -> &'static str;
28}
29
30/// Azure CLI authentication provider
31pub struct AzCliAuth {
32    resource_scope: &'static str,
33}
34
35impl AzCliAuth {
36    /// Create an auth provider for Azure Search
37    pub fn for_search() -> Self {
38        Self {
39            resource_scope: "https://search.azure.com",
40        }
41    }
42
43    /// Create an auth provider for Microsoft Foundry
44    pub fn for_foundry() -> Self {
45        Self {
46            resource_scope: "https://ai.azure.com",
47        }
48    }
49
50    /// Create a new auth provider (defaults to Search scope for backward compatibility)
51    pub fn new() -> Self {
52        Self::for_search()
53    }
54
55    /// Check if Azure CLI is available and logged in
56    pub fn check_status() -> Result<AuthStatus, AuthError> {
57        // Check if az CLI is installed
58        let version_output = Command::new("az").arg("--version").output();
59
60        if version_output.is_err() {
61            return Err(AuthError::AzCliNotFound);
62        }
63
64        // Check if logged in
65        let account_output = Command::new("az")
66            .args(["account", "show", "--output", "json"])
67            .output()
68            .map_err(|e| AuthError::TokenError(e.to_string()))?;
69
70        if !account_output.status.success() {
71            return Err(AuthError::NotLoggedIn);
72        }
73
74        // Parse account info
75        let account_json: serde_json::Value = serde_json::from_slice(&account_output.stdout)
76            .map_err(|e| AuthError::TokenError(e.to_string()))?;
77
78        Ok(AuthStatus {
79            logged_in: true,
80            user: account_json
81                .get("user")
82                .and_then(|u| u.get("name"))
83                .and_then(|n| n.as_str())
84                .map(String::from),
85            subscription: account_json
86                .get("name")
87                .and_then(|n| n.as_str())
88                .map(String::from),
89            subscription_id: account_json
90                .get("id")
91                .and_then(|i| i.as_str())
92                .map(String::from),
93        })
94    }
95
96    /// Get an access token for Azure Resource Manager (management.azure.com)
97    pub fn get_arm_token() -> Result<String, AuthError> {
98        let output = Command::new("az")
99            .args([
100                "account",
101                "get-access-token",
102                "--resource",
103                "https://management.azure.com",
104                "--query",
105                "accessToken",
106                "--output",
107                "tsv",
108            ])
109            .output()
110            .map_err(|e| AuthError::TokenError(e.to_string()))?;
111
112        if !output.status.success() {
113            let stderr = String::from_utf8_lossy(&output.stderr);
114            if stderr.contains("not logged in") || stderr.contains("AADSTS") {
115                return Err(AuthError::NotLoggedIn);
116            }
117            return Err(AuthError::TokenError(stderr.to_string()));
118        }
119
120        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
121        if token.is_empty() {
122            return Err(AuthError::TokenError(
123                "Empty ARM token received".to_string(),
124            ));
125        }
126
127        Ok(token)
128    }
129}
130
131impl Default for AzCliAuth {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl AuthProvider for AzCliAuth {
138    fn get_token(&self) -> Result<String, AuthError> {
139        let output = Command::new("az")
140            .args([
141                "account",
142                "get-access-token",
143                "--resource",
144                self.resource_scope,
145                "--query",
146                "accessToken",
147                "--output",
148                "tsv",
149            ])
150            .output()
151            .map_err(|e| AuthError::TokenError(e.to_string()))?;
152
153        if !output.status.success() {
154            let stderr = String::from_utf8_lossy(&output.stderr);
155            if stderr.contains("not logged in") || stderr.contains("AADSTS") {
156                return Err(AuthError::NotLoggedIn);
157            }
158            return Err(AuthError::TokenError(stderr.to_string()));
159        }
160
161        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
162        if token.is_empty() {
163            return Err(AuthError::TokenError("Empty token received".to_string()));
164        }
165
166        Ok(token)
167    }
168
169    fn method_name(&self) -> &'static str {
170        "Azure CLI"
171    }
172}
173
174/// Environment variable authentication provider
175#[derive(Debug)]
176pub struct EnvAuth {
177    client_id: String,
178    client_secret: String,
179    tenant_id: String,
180    resource_scope: &'static str,
181}
182
183impl EnvAuth {
184    /// Create from environment variables (defaults to Search scope)
185    pub fn from_env() -> Result<Self, AuthError> {
186        Self::from_env_for_scope("https://search.azure.com")
187    }
188
189    /// Create from environment variables for a specific resource scope
190    pub fn from_env_for_scope(scope: &'static str) -> Result<Self, AuthError> {
191        let client_id = std::env::var("AZURE_CLIENT_ID")
192            .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_ID".to_string()))?;
193        let client_secret = std::env::var("AZURE_CLIENT_SECRET")
194            .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_SECRET".to_string()))?;
195        let tenant_id = std::env::var("AZURE_TENANT_ID")
196            .map_err(|_| AuthError::MissingEnvVar("AZURE_TENANT_ID".to_string()))?;
197
198        Ok(Self {
199            client_id,
200            client_secret,
201            tenant_id,
202            resource_scope: scope,
203        })
204    }
205
206    /// Check if environment variables are set
207    pub fn is_configured() -> bool {
208        std::env::var("AZURE_CLIENT_ID").is_ok()
209            && std::env::var("AZURE_CLIENT_SECRET").is_ok()
210            && std::env::var("AZURE_TENANT_ID").is_ok()
211    }
212}
213
214impl AuthProvider for EnvAuth {
215    fn get_token(&self) -> Result<String, AuthError> {
216        // Use Azure CLI to get token with service principal
217        let output = Command::new("az")
218            .args([
219                "account",
220                "get-access-token",
221                "--resource",
222                self.resource_scope,
223                "--query",
224                "accessToken",
225                "--output",
226                "tsv",
227                "--tenant",
228                &self.tenant_id,
229                "--username",
230                &self.client_id,
231            ])
232            .env("AZURE_CLIENT_SECRET", &self.client_secret)
233            .output()
234            .map_err(|e| AuthError::TokenError(e.to_string()))?;
235
236        if !output.status.success() {
237            let stderr = String::from_utf8_lossy(&output.stderr);
238            return Err(AuthError::AuthFailed(stderr.to_string()));
239        }
240
241        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
242        Ok(token)
243    }
244
245    fn method_name(&self) -> &'static str {
246        "Environment Variables (Service Principal)"
247    }
248}
249
250/// Authentication status
251#[derive(Debug, Clone)]
252pub struct AuthStatus {
253    pub logged_in: bool,
254    pub user: Option<String>,
255    pub subscription: Option<String>,
256    pub subscription_id: Option<String>,
257}
258
259/// Get the best available authentication provider for Search (backward compat)
260pub fn get_auth_provider() -> Result<Box<dyn AuthProvider>, AuthError> {
261    get_auth_provider_for_scope("https://search.azure.com")
262}
263
264/// Get the best available authentication provider for a specific service domain
265pub fn get_auth_provider_for(
266    domain: hoist_core::ServiceDomain,
267) -> Result<Box<dyn AuthProvider>, AuthError> {
268    let scope = match domain {
269        hoist_core::ServiceDomain::Search => "https://search.azure.com",
270        hoist_core::ServiceDomain::Foundry => "https://ai.azure.com",
271    };
272    get_auth_provider_for_scope(scope)
273}
274
275/// Get the best available authentication provider for a specific resource scope
276fn get_auth_provider_for_scope(scope: &'static str) -> Result<Box<dyn AuthProvider>, AuthError> {
277    // First try environment variables
278    if EnvAuth::is_configured() {
279        return Ok(Box::new(EnvAuth::from_env_for_scope(scope)?));
280    }
281
282    // Fall back to Azure CLI
283    AzCliAuth::check_status()?;
284    Ok(Box::new(AzCliAuth {
285        resource_scope: scope,
286    }))
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use std::sync::Mutex;
293
294    // Env var tests must run serially since they share process-wide state.
295    static ENV_MUTEX: Mutex<()> = Mutex::new(());
296
297    fn clear_azure_env_vars() {
298        std::env::remove_var("AZURE_CLIENT_ID");
299        std::env::remove_var("AZURE_CLIENT_SECRET");
300        std::env::remove_var("AZURE_TENANT_ID");
301    }
302
303    fn set_azure_env_vars() {
304        std::env::set_var("AZURE_CLIENT_ID", "test-client-id");
305        std::env::set_var("AZURE_CLIENT_SECRET", "test-client-secret");
306        std::env::set_var("AZURE_TENANT_ID", "test-tenant-id");
307    }
308
309    #[test]
310    fn test_env_auth_from_env_success() {
311        let _lock = ENV_MUTEX.lock().unwrap();
312        set_azure_env_vars();
313
314        let result = EnvAuth::from_env();
315        assert!(result.is_ok());
316        let auth = result.unwrap();
317        assert_eq!(auth.client_id, "test-client-id");
318        assert_eq!(auth.client_secret, "test-client-secret");
319        assert_eq!(auth.tenant_id, "test-tenant-id");
320
321        clear_azure_env_vars();
322    }
323
324    #[test]
325    fn test_env_auth_from_env_missing_client_id() {
326        let _lock = ENV_MUTEX.lock().unwrap();
327        clear_azure_env_vars();
328        std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
329        std::env::set_var("AZURE_TENANT_ID", "test-tenant");
330
331        let result = EnvAuth::from_env();
332        assert!(result.is_err());
333        let err = result.unwrap_err();
334        assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_ID"));
335
336        clear_azure_env_vars();
337    }
338
339    #[test]
340    fn test_env_auth_from_env_missing_client_secret() {
341        let _lock = ENV_MUTEX.lock().unwrap();
342        clear_azure_env_vars();
343        std::env::set_var("AZURE_CLIENT_ID", "test-id");
344        std::env::set_var("AZURE_TENANT_ID", "test-tenant");
345
346        let result = EnvAuth::from_env();
347        assert!(result.is_err());
348        let err = result.unwrap_err();
349        assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_CLIENT_SECRET"));
350
351        clear_azure_env_vars();
352    }
353
354    #[test]
355    fn test_env_auth_from_env_missing_tenant_id() {
356        let _lock = ENV_MUTEX.lock().unwrap();
357        clear_azure_env_vars();
358        std::env::set_var("AZURE_CLIENT_ID", "test-id");
359        std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
360
361        let result = EnvAuth::from_env();
362        assert!(result.is_err());
363        let err = result.unwrap_err();
364        assert!(matches!(err, AuthError::MissingEnvVar(ref v) if v == "AZURE_TENANT_ID"));
365
366        clear_azure_env_vars();
367    }
368
369    #[test]
370    fn test_env_auth_is_configured_all_set() {
371        let _lock = ENV_MUTEX.lock().unwrap();
372        set_azure_env_vars();
373
374        assert!(EnvAuth::is_configured());
375
376        clear_azure_env_vars();
377    }
378
379    #[test]
380    fn test_env_auth_is_configured_none_set() {
381        let _lock = ENV_MUTEX.lock().unwrap();
382        clear_azure_env_vars();
383
384        assert!(!EnvAuth::is_configured());
385    }
386
387    #[test]
388    fn test_env_auth_is_configured_partial() {
389        let _lock = ENV_MUTEX.lock().unwrap();
390        clear_azure_env_vars();
391        std::env::set_var("AZURE_CLIENT_ID", "test-id");
392        std::env::set_var("AZURE_CLIENT_SECRET", "test-secret");
393        // AZURE_TENANT_ID intentionally missing
394
395        assert!(!EnvAuth::is_configured());
396
397        clear_azure_env_vars();
398    }
399
400    #[test]
401    fn test_env_auth_method_name() {
402        let _lock = ENV_MUTEX.lock().unwrap();
403        set_azure_env_vars();
404
405        let auth = EnvAuth::from_env().unwrap();
406        assert_eq!(
407            auth.method_name(),
408            "Environment Variables (Service Principal)"
409        );
410
411        clear_azure_env_vars();
412    }
413
414    #[test]
415    fn test_az_cli_auth_method_name() {
416        let auth = AzCliAuth::new();
417        assert_eq!(auth.method_name(), "Azure CLI");
418    }
419
420    #[test]
421    fn test_az_cli_auth_search_scope() {
422        let auth = AzCliAuth::for_search();
423        assert_eq!(auth.resource_scope, "https://search.azure.com");
424    }
425
426    #[test]
427    fn test_az_cli_auth_foundry_scope() {
428        let auth = AzCliAuth::for_foundry();
429        assert_eq!(auth.resource_scope, "https://ai.azure.com");
430    }
431
432    #[test]
433    fn test_az_cli_auth_new_defaults_to_search() {
434        let auth = AzCliAuth::new();
435        assert_eq!(auth.resource_scope, "https://search.azure.com");
436    }
437
438    #[test]
439    fn test_env_auth_from_env_scope_foundry() {
440        let _lock = ENV_MUTEX.lock().unwrap();
441        set_azure_env_vars();
442
443        let result = EnvAuth::from_env_for_scope("https://ai.azure.com");
444        assert!(result.is_ok());
445        let auth = result.unwrap();
446        assert_eq!(auth.resource_scope, "https://ai.azure.com");
447
448        clear_azure_env_vars();
449    }
450
451    #[test]
452    fn test_env_auth_from_env_default_scope_is_search() {
453        let _lock = ENV_MUTEX.lock().unwrap();
454        set_azure_env_vars();
455
456        let auth = EnvAuth::from_env().unwrap();
457        assert_eq!(auth.resource_scope, "https://search.azure.com");
458
459        clear_azure_env_vars();
460    }
461
462    #[test]
463    fn test_auth_status_fields() {
464        let status = AuthStatus {
465            logged_in: true,
466            user: Some("testuser@example.com".to_string()),
467            subscription: Some("My Subscription".to_string()),
468            subscription_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
469        };
470
471        assert!(status.logged_in);
472        assert_eq!(status.user.as_deref(), Some("testuser@example.com"));
473        assert_eq!(status.subscription.as_deref(), Some("My Subscription"));
474        assert_eq!(
475            status.subscription_id.as_deref(),
476            Some("00000000-0000-0000-0000-000000000000")
477        );
478    }
479}