server/auth/
azure_ad.rs

1use super::provider::{AuthProvider, AuthToken};
2use super::types::{AuthType, AzureAdAuthConfig};
3use crate::service_bus_manager::ServiceBusError;
4use async_trait::async_trait;
5use serde::Deserialize;
6
7/// Information required to complete an Azure AD Device Code Flow authentication.
8///
9/// Contains the device code, user code, and verification URL that the user needs
10/// to complete the authentication process on a separate device or browser.
11#[derive(Clone, Debug)]
12pub struct DeviceCodeFlowInfo {
13    /// Device-specific code used internally by Azure AD
14    pub device_code: String,
15    /// Short user code that the user enters on the verification page
16    pub user_code: String,
17    /// URL where the user should go to enter the user code
18    pub verification_uri: String,
19    /// Time in seconds until the device code expires
20    pub expires_in: u64,
21    /// Recommended polling interval in seconds
22    pub interval: u64,
23    /// Human-readable message with authentication instructions
24    pub message: String,
25}
26
27/// Authentication provider for Azure Active Directory authentication flows.
28///
29/// Supports both Device Code Flow (for interactive scenarios) and Client Credentials Flow
30/// (for service-to-service authentication). This provider handles the complete OAuth 2.0
31/// authentication process with Azure AD.
32///
33/// # Supported Flows
34///
35/// - **Device Code Flow** - Interactive authentication where users enter a code on a separate device
36/// - **Client Credentials Flow** - Service principal authentication using client ID and secret
37///
38/// # Examples
39///
40/// ```no_run
41/// use quetty_server::auth::{AzureAdProvider, AzureAdAuthConfig};
42///
43/// let config = AzureAdAuthConfig {
44///     auth_method: "device_code".to_string(),
45///     tenant_id: Some("your-tenant-id".to_string()),
46///     client_id: Some("your-client-id".to_string()),
47///     ..Default::default()
48/// };
49///
50/// let client = reqwest::Client::new();
51/// let provider = AzureAdProvider::new(config, client)?;
52/// let token = provider.authenticate().await?;
53/// ```
54#[derive(Clone)]
55pub struct AzureAdProvider {
56    config: AzureAdAuthConfig,
57    http_client: reqwest::Client,
58}
59
60#[derive(Deserialize)]
61struct TokenResponse {
62    access_token: String,
63    token_type: String,
64    expires_in: u64,
65}
66
67#[derive(Deserialize)]
68struct DeviceCodeResponse {
69    device_code: String,
70    user_code: String,
71    verification_uri: String,
72    expires_in: u64,
73    interval: u64,
74    message: String,
75}
76
77#[derive(Deserialize)]
78struct ErrorResponse {
79    error: String,
80    error_description: Option<String>,
81}
82
83impl AzureAdProvider {
84    /// Creates a new AzureAdProvider with the specified configuration and HTTP client.
85    ///
86    /// # Arguments
87    ///
88    /// * `config` - Azure AD authentication configuration
89    /// * `http_client` - HTTP client for making authentication requests
90    ///
91    /// # Returns
92    ///
93    /// A configured AzureAdProvider ready for authentication
94    ///
95    /// # Examples
96    ///
97    /// ```no_run
98    /// use quetty_server::auth::{AzureAdProvider, AzureAdAuthConfig};
99    ///
100    /// let config = AzureAdAuthConfig::default();
101    /// let client = reqwest::Client::new();
102    /// let provider = AzureAdProvider::new(config, client)?;
103    /// ```
104    pub fn new(
105        config: AzureAdAuthConfig,
106        http_client: reqwest::Client,
107    ) -> Result<Self, ServiceBusError> {
108        Ok(Self {
109            config,
110            http_client,
111        })
112    }
113
114    /// Gets the configured authentication flow type.
115    ///
116    /// # Returns
117    ///
118    /// The authentication method string ("device_code" or "client_secret")
119    pub fn flow_type(&self) -> &str {
120        &self.config.auth_method
121    }
122
123    fn authority_host(&self) -> &str {
124        self.config
125            .authority_host
126            .as_deref()
127            .unwrap_or("https://login.microsoftonline.com")
128    }
129
130    fn scope(&self) -> &str {
131        self.config
132            .scope
133            .as_deref()
134            .unwrap_or("https://management.azure.com/.default")
135    }
136
137    fn tenant_id(&self) -> Result<&str, ServiceBusError> {
138        self.config.tenant_id.as_deref().ok_or_else(|| {
139            ServiceBusError::ConfigurationError("Azure AD tenant_id is required".to_string())
140        })
141    }
142
143    fn client_id(&self) -> Result<&str, ServiceBusError> {
144        self.config.client_id.as_deref().ok_or_else(|| {
145            ServiceBusError::ConfigurationError("Azure AD client_id is required".to_string())
146        })
147    }
148
149    async fn device_code_flow(&self) -> Result<AuthToken, ServiceBusError> {
150        // For device code flow, we need to start it and poll separately
151        // This method will start the flow and immediately poll
152        let device_info = self.start_device_code_flow().await?;
153
154        // Log the device code info (without sensitive data)
155        log::info!("Device code authentication initiated - awaiting user action");
156
157        // Poll for the token
158        self.poll_device_code_token(&device_info).await
159    }
160
161    async fn client_credentials_flow(&self) -> Result<AuthToken, ServiceBusError> {
162        let client_secret = self.config.client_secret.as_deref().ok_or_else(|| {
163            ServiceBusError::ConfigurationError(
164                "Client secret is required for client credentials flow".to_string(),
165            )
166        })?;
167
168        let token_url = format!(
169            "{}/{}/oauth2/v2.0/token",
170            self.authority_host(),
171            self.tenant_id()?
172        );
173
174        let params = [
175            ("grant_type", "client_credentials"),
176            ("client_id", self.client_id()?),
177            ("client_secret", client_secret),
178            ("scope", self.scope()),
179        ];
180
181        log::info!("Client credentials authentication initiated");
182
183        let response = self
184            .http_client
185            .post(&token_url)
186            .form(&params)
187            .send()
188            .await
189            .map_err(|e| {
190                ServiceBusError::AuthenticationError(format!(
191                    "Failed to authenticate with client credentials: {e}"
192                ))
193            })?;
194
195        if !response.status().is_success() {
196            let error_info = response
197                .json::<ErrorResponse>()
198                .await
199                .unwrap_or(ErrorResponse {
200                    error: "unknown_error".to_string(),
201                    error_description: Some("Failed to parse error response".to_string()),
202                });
203
204            let user_friendly_message = match error_info.error.as_str() {
205                "invalid_client" => {
206                    "Invalid client credentials. Please check your client ID and client secret."
207                }
208                "invalid_request" => {
209                    "Invalid authentication request. Please verify your configuration."
210                }
211                "unauthorized_client" => {
212                    "This application is not authorized for client credentials flow. Please check Azure AD configuration."
213                }
214                "access_denied" => {
215                    "Access denied. Please ensure the application has sufficient permissions."
216                }
217                "invalid_scope" => {
218                    "Invalid scope specified. Please check the requested permissions."
219                }
220                _ => error_info
221                    .error_description
222                    .as_deref()
223                    .unwrap_or(&error_info.error),
224            };
225
226            return Err(ServiceBusError::AuthenticationError(format!(
227                "Client credentials authentication failed: {user_friendly_message}"
228            )));
229        }
230
231        let token_response: TokenResponse = response.json().await.map_err(|e| {
232            ServiceBusError::AuthenticationError(format!("Failed to parse token response: {e}"))
233        })?;
234
235        log::info!("Client credentials authentication successful");
236
237        Ok(AuthToken {
238            token: token_response.access_token,
239            token_type: token_response.token_type,
240            expires_in_secs: Some(token_response.expires_in),
241        })
242    }
243
244    /// Initiates a Device Code Flow authentication process.
245    ///
246    /// This method starts the device code flow by requesting a device code from Azure AD.
247    /// The returned information should be displayed to the user so they can complete
248    /// authentication on a separate device or browser.
249    ///
250    /// # Returns
251    ///
252    /// [`DeviceCodeFlowInfo`] containing the user code, verification URL, and other details
253    ///
254    /// # Errors
255    ///
256    /// Returns [`ServiceBusError::AuthenticationError`] if:
257    /// - The device code request fails
258    /// - Invalid client configuration
259    /// - Network connectivity issues
260    ///
261    /// # Examples
262    ///
263    /// ```no_run
264    /// use quetty_server::auth::AzureAdProvider;
265    ///
266    /// let provider = AzureAdProvider::new(config, client)?;
267    /// let device_info = provider.start_device_code_flow().await?;
268    ///
269    /// println!("Go to: {}", device_info.verification_uri);
270    /// println!("Enter code: {}", device_info.user_code);
271    /// ```
272    pub async fn start_device_code_flow(&self) -> Result<DeviceCodeFlowInfo, ServiceBusError> {
273        let device_code_url = format!(
274            "{}/{}/oauth2/v2.0/devicecode",
275            self.authority_host(),
276            self.tenant_id()?
277        );
278
279        let params = [("client_id", self.client_id()?), ("scope", self.scope())];
280
281        let device_response = self
282            .http_client
283            .post(&device_code_url)
284            .form(&params)
285            .send()
286            .await
287            .map_err(|e| {
288                ServiceBusError::AuthenticationError(format!(
289                    "Failed to initiate device code flow: {e}"
290                ))
291            })?;
292
293        // Check if the response is successful
294        if !device_response.status().is_success() {
295            // Try to parse error response
296            let error_info =
297                device_response
298                    .json::<ErrorResponse>()
299                    .await
300                    .unwrap_or(ErrorResponse {
301                        error: "unknown_error".to_string(),
302                        error_description: Some("Failed to parse error response".to_string()),
303                    });
304
305            let user_friendly_message = match error_info.error.as_str() {
306                "invalid_client" => {
307                    "Invalid client configuration. Please check your Azure AD app registration and ensure 'Allow public client flows' is enabled."
308                }
309                "invalid_request" => {
310                    "Invalid authentication request. Please check your client ID and tenant ID."
311                }
312                "unauthorized_client" => {
313                    "This application is not authorized for device code flow. Please check Azure AD configuration."
314                }
315                "access_denied" => {
316                    "Access denied. Please ensure you have the necessary permissions."
317                }
318                "expired_token" => "Authentication expired. Please try again.",
319                _ => error_info
320                    .error_description
321                    .as_deref()
322                    .unwrap_or(&error_info.error),
323            };
324
325            return Err(ServiceBusError::AuthenticationError(format!(
326                "Authentication failed: {user_friendly_message}"
327            )));
328        }
329
330        let device_code: DeviceCodeResponse = device_response.json().await.map_err(|e| {
331            ServiceBusError::AuthenticationError(format!(
332                "Failed to parse device code response: {e}"
333            ))
334        })?;
335
336        Ok(DeviceCodeFlowInfo {
337            device_code: device_code.device_code,
338            user_code: device_code.user_code,
339            verification_uri: device_code.verification_uri,
340            expires_in: device_code.expires_in,
341            interval: device_code.interval,
342            message: device_code.message,
343        })
344    }
345
346    /// Polls Azure AD for completion of device code authentication.
347    ///
348    /// This method continuously polls Azure AD to check if the user has completed
349    /// the device code authentication process. It handles all the standard OAuth 2.0
350    /// device flow polling logic including backoff and error handling.
351    ///
352    /// # Arguments
353    ///
354    /// * `device_info` - Device code information from [`start_device_code_flow`]
355    ///
356    /// # Returns
357    ///
358    /// An [`AuthToken`] when authentication is successfully completed
359    ///
360    /// # Errors
361    ///
362    /// Returns [`ServiceBusError::AuthenticationError`] if:
363    /// - Authentication times out or expires
364    /// - User denies access
365    /// - Network errors during polling
366    ///
367    /// # Examples
368    ///
369    /// ```no_run
370    /// use quetty_server::auth::AzureAdProvider;
371    ///
372    /// let provider = AzureAdProvider::new(config, client)?;
373    /// let device_info = provider.start_device_code_flow().await?;
374    ///
375    /// // Display info to user...
376    ///
377    /// let token = provider.poll_device_code_token(&device_info).await?;
378    /// ```
379    pub async fn poll_device_code_token(
380        &self,
381        device_info: &DeviceCodeFlowInfo,
382    ) -> Result<AuthToken, ServiceBusError> {
383        let token_url = format!(
384            "{}/{}/oauth2/v2.0/token",
385            self.authority_host(),
386            self.tenant_id()?
387        );
388
389        let mut interval = std::time::Duration::from_secs(device_info.interval);
390        let timeout = std::time::Duration::from_secs(device_info.expires_in);
391        let start = std::time::Instant::now();
392
393        loop {
394            if start.elapsed() > timeout {
395                return Err(ServiceBusError::AuthenticationError(
396                    "Authentication timed out. The device code has expired. Please restart the authentication process.".to_string()
397                ));
398            }
399
400            tokio::time::sleep(interval).await;
401
402            let mut params = vec![
403                ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
404                ("client_id", self.client_id()?),
405                ("device_code", device_info.device_code.as_str()),
406            ];
407
408            // Include client_secret if configured (for confidential clients)
409            if let Some(client_secret) = self.config.client_secret.as_deref() {
410                params.push(("client_secret", client_secret));
411            }
412
413            let response = self
414                .http_client
415                .post(&token_url)
416                .form(&params)
417                .send()
418                .await
419                .map_err(|e| {
420                    ServiceBusError::AuthenticationError(format!("Failed to poll for token: {e}"))
421                })?;
422
423            if response.status().is_success() {
424                let token_response: TokenResponse = response.json().await.map_err(|e| {
425                    ServiceBusError::AuthenticationError(format!(
426                        "Failed to parse token response: {e}"
427                    ))
428                })?;
429
430                return Ok(AuthToken {
431                    token: token_response.access_token,
432                    token_type: token_response.token_type,
433                    expires_in_secs: Some(token_response.expires_in),
434                });
435            }
436
437            let error_response: serde_json::Value = response.json().await.unwrap_or_default();
438
439            if let Some(error) = error_response["error"].as_str() {
440                match error {
441                    "authorization_pending" => {
442                        log::debug!("Waiting for user to complete authentication");
443                        continue;
444                    }
445                    "slow_down" => {
446                        log::debug!("Polling too frequently, increasing interval");
447                        interval += std::time::Duration::from_secs(5);
448                        continue;
449                    }
450                    "expired_token" => {
451                        return Err(ServiceBusError::AuthenticationError(
452                            "The device code has expired. Please restart the authentication process.".to_string()
453                        ));
454                    }
455                    "access_denied" => {
456                        return Err(ServiceBusError::AuthenticationError(
457                            "Access was denied. Please ensure you have the necessary permissions."
458                                .to_string(),
459                        ));
460                    }
461                    _ => {
462                        let error_desc = error_response["error_description"]
463                            .as_str()
464                            .unwrap_or("Unknown error occurred");
465                        return Err(ServiceBusError::AuthenticationError(format!(
466                            "Authentication failed: {error} - {error_desc}"
467                        )));
468                    }
469                }
470            }
471        }
472    }
473}
474
475#[async_trait]
476impl AuthProvider for AzureAdProvider {
477    /// Authenticates using the configured Azure AD authentication flow.
478    ///
479    /// Automatically selects the appropriate authentication method based on the
480    /// configuration (device_code or client_secret) and handles the complete
481    /// OAuth 2.0 flow including error handling and token retrieval.
482    ///
483    /// # Returns
484    ///
485    /// An [`AuthToken`] containing the Azure AD access token and metadata
486    ///
487    /// # Errors
488    ///
489    /// Returns [`ServiceBusError`] if:
490    /// - Authentication method is not supported
491    /// - Authentication flow fails
492    /// - Network connectivity issues
493    /// - Invalid credentials or configuration
494    async fn authenticate(&self) -> Result<AuthToken, ServiceBusError> {
495        match self.config.auth_method.as_str() {
496            "device_code" => self.device_code_flow().await,
497            "client_secret" => self.client_credentials_flow().await,
498            _ => Err(ServiceBusError::ConfigurationError(format!(
499                "Unsupported auth method: {}",
500                self.config.auth_method
501            ))),
502        }
503    }
504
505    /// Returns the authentication type for this provider.
506    ///
507    /// # Returns
508    ///
509    /// [`AuthType::AzureAd`] indicating Azure Active Directory authentication
510    fn auth_type(&self) -> AuthType {
511        AuthType::AzureAd
512    }
513}