Skip to main content

hoist_client/
auth.rs

1//! Azure authentication providers
2
3use std::process::Command;
4use thiserror::Error;
5
6/// Authentication errors
7#[derive(Debug, Error)]
8pub enum AuthError {
9    #[error("Azure CLI not found. Please install it: https://docs.microsoft.com/cli/azure/install-azure-cli")]
10    AzCliNotFound,
11    #[error("Not logged in to Azure CLI. Run: az login")]
12    NotLoggedIn,
13    #[error("Failed to get access token: {0}")]
14    TokenError(String),
15    #[error("Missing environment variable: {0}")]
16    MissingEnvVar(String),
17    #[error("Authentication failed: {0}")]
18    AuthFailed(String),
19}
20
21/// Authentication provider trait
22pub trait AuthProvider: Send + Sync {
23    /// Get an access token for Azure Search
24    fn get_token(&self) -> Result<String, AuthError>;
25
26    /// Get the authentication method name
27    fn method_name(&self) -> &'static str;
28}
29
30/// Azure CLI authentication provider
31pub struct AzCliAuth;
32
33impl AzCliAuth {
34    pub fn new() -> Self {
35        Self
36    }
37
38    /// Check if Azure CLI is available and logged in
39    pub fn check_status() -> Result<AuthStatus, AuthError> {
40        // Check if az CLI is installed
41        let version_output = Command::new("az").arg("--version").output();
42
43        if version_output.is_err() {
44            return Err(AuthError::AzCliNotFound);
45        }
46
47        // Check if logged in
48        let account_output = Command::new("az")
49            .args(["account", "show", "--output", "json"])
50            .output()
51            .map_err(|e| AuthError::TokenError(e.to_string()))?;
52
53        if !account_output.status.success() {
54            return Err(AuthError::NotLoggedIn);
55        }
56
57        // Parse account info
58        let account_json: serde_json::Value = serde_json::from_slice(&account_output.stdout)
59            .map_err(|e| AuthError::TokenError(e.to_string()))?;
60
61        Ok(AuthStatus {
62            logged_in: true,
63            user: account_json
64                .get("user")
65                .and_then(|u| u.get("name"))
66                .and_then(|n| n.as_str())
67                .map(String::from),
68            subscription: account_json
69                .get("name")
70                .and_then(|n| n.as_str())
71                .map(String::from),
72            subscription_id: account_json
73                .get("id")
74                .and_then(|i| i.as_str())
75                .map(String::from),
76        })
77    }
78
79    /// Get an access token for Azure Resource Manager (management.azure.com)
80    pub fn get_arm_token() -> Result<String, AuthError> {
81        let output = Command::new("az")
82            .args([
83                "account",
84                "get-access-token",
85                "--resource",
86                "https://management.azure.com",
87                "--query",
88                "accessToken",
89                "--output",
90                "tsv",
91            ])
92            .output()
93            .map_err(|e| AuthError::TokenError(e.to_string()))?;
94
95        if !output.status.success() {
96            let stderr = String::from_utf8_lossy(&output.stderr);
97            if stderr.contains("not logged in") || stderr.contains("AADSTS") {
98                return Err(AuthError::NotLoggedIn);
99            }
100            return Err(AuthError::TokenError(stderr.to_string()));
101        }
102
103        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
104        if token.is_empty() {
105            return Err(AuthError::TokenError(
106                "Empty ARM token received".to_string(),
107            ));
108        }
109
110        Ok(token)
111    }
112}
113
114impl Default for AzCliAuth {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl AuthProvider for AzCliAuth {
121    fn get_token(&self) -> Result<String, AuthError> {
122        let output = Command::new("az")
123            .args([
124                "account",
125                "get-access-token",
126                "--resource",
127                "https://search.azure.com",
128                "--query",
129                "accessToken",
130                "--output",
131                "tsv",
132            ])
133            .output()
134            .map_err(|e| AuthError::TokenError(e.to_string()))?;
135
136        if !output.status.success() {
137            let stderr = String::from_utf8_lossy(&output.stderr);
138            if stderr.contains("not logged in") || stderr.contains("AADSTS") {
139                return Err(AuthError::NotLoggedIn);
140            }
141            return Err(AuthError::TokenError(stderr.to_string()));
142        }
143
144        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
145        if token.is_empty() {
146            return Err(AuthError::TokenError("Empty token received".to_string()));
147        }
148
149        Ok(token)
150    }
151
152    fn method_name(&self) -> &'static str {
153        "Azure CLI"
154    }
155}
156
157/// Environment variable authentication provider
158pub struct EnvAuth {
159    client_id: String,
160    client_secret: String,
161    tenant_id: String,
162}
163
164impl EnvAuth {
165    /// Create from environment variables
166    pub fn from_env() -> Result<Self, AuthError> {
167        let client_id = std::env::var("AZURE_CLIENT_ID")
168            .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_ID".to_string()))?;
169        let client_secret = std::env::var("AZURE_CLIENT_SECRET")
170            .map_err(|_| AuthError::MissingEnvVar("AZURE_CLIENT_SECRET".to_string()))?;
171        let tenant_id = std::env::var("AZURE_TENANT_ID")
172            .map_err(|_| AuthError::MissingEnvVar("AZURE_TENANT_ID".to_string()))?;
173
174        Ok(Self {
175            client_id,
176            client_secret,
177            tenant_id,
178        })
179    }
180
181    /// Check if environment variables are set
182    pub fn is_configured() -> bool {
183        std::env::var("AZURE_CLIENT_ID").is_ok()
184            && std::env::var("AZURE_CLIENT_SECRET").is_ok()
185            && std::env::var("AZURE_TENANT_ID").is_ok()
186    }
187}
188
189impl AuthProvider for EnvAuth {
190    fn get_token(&self) -> Result<String, AuthError> {
191        // Use Azure CLI to get token with service principal
192        let output = Command::new("az")
193            .args([
194                "account",
195                "get-access-token",
196                "--resource",
197                "https://search.azure.com",
198                "--query",
199                "accessToken",
200                "--output",
201                "tsv",
202                "--tenant",
203                &self.tenant_id,
204                "--username",
205                &self.client_id,
206            ])
207            .env("AZURE_CLIENT_SECRET", &self.client_secret)
208            .output()
209            .map_err(|e| AuthError::TokenError(e.to_string()))?;
210
211        if !output.status.success() {
212            let stderr = String::from_utf8_lossy(&output.stderr);
213            return Err(AuthError::AuthFailed(stderr.to_string()));
214        }
215
216        let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
217        Ok(token)
218    }
219
220    fn method_name(&self) -> &'static str {
221        "Environment Variables (Service Principal)"
222    }
223}
224
225/// Authentication status
226#[derive(Debug, Clone)]
227pub struct AuthStatus {
228    pub logged_in: bool,
229    pub user: Option<String>,
230    pub subscription: Option<String>,
231    pub subscription_id: Option<String>,
232}
233
234/// Get the best available authentication provider
235pub fn get_auth_provider() -> Result<Box<dyn AuthProvider>, AuthError> {
236    // First try environment variables
237    if EnvAuth::is_configured() {
238        return Ok(Box::new(EnvAuth::from_env()?));
239    }
240
241    // Fall back to Azure CLI
242    AzCliAuth::check_status()?;
243    Ok(Box::new(AzCliAuth::new()))
244}