1use super::provider::{AuthProvider, AuthToken};
2use super::types::{AuthType, AzureAdAuthConfig};
3use crate::service_bus_manager::ServiceBusError;
4use async_trait::async_trait;
5use serde::Deserialize;
6
7#[derive(Clone, Debug)]
12pub struct DeviceCodeFlowInfo {
13 pub device_code: String,
15 pub user_code: String,
17 pub verification_uri: String,
19 pub expires_in: u64,
21 pub interval: u64,
23 pub message: String,
25}
26
27#[derive(Clone)]
55pub struct AzureAdProvider {
56 config: AzureAdAuthConfig,
57 http_client: reqwest::Client,
58}
59
60#[derive(Deserialize)]
61struct TokenResponse {
62 access_token: String,
63 token_type: String,
64 expires_in: u64,
65}
66
67#[derive(Deserialize)]
68struct DeviceCodeResponse {
69 device_code: String,
70 user_code: String,
71 verification_uri: String,
72 expires_in: u64,
73 interval: u64,
74 message: String,
75}
76
77#[derive(Deserialize)]
78struct ErrorResponse {
79 error: String,
80 error_description: Option<String>,
81}
82
83impl AzureAdProvider {
84 pub fn new(
105 config: AzureAdAuthConfig,
106 http_client: reqwest::Client,
107 ) -> Result<Self, ServiceBusError> {
108 Ok(Self {
109 config,
110 http_client,
111 })
112 }
113
114 pub fn flow_type(&self) -> &str {
120 &self.config.auth_method
121 }
122
123 fn authority_host(&self) -> &str {
124 self.config
125 .authority_host
126 .as_deref()
127 .unwrap_or("https://login.microsoftonline.com")
128 }
129
130 fn scope(&self) -> &str {
131 self.config
132 .scope
133 .as_deref()
134 .unwrap_or("https://management.azure.com/.default")
135 }
136
137 fn tenant_id(&self) -> Result<&str, ServiceBusError> {
138 self.config.tenant_id.as_deref().ok_or_else(|| {
139 ServiceBusError::ConfigurationError("Azure AD tenant_id is required".to_string())
140 })
141 }
142
143 fn client_id(&self) -> Result<&str, ServiceBusError> {
144 self.config.client_id.as_deref().ok_or_else(|| {
145 ServiceBusError::ConfigurationError("Azure AD client_id is required".to_string())
146 })
147 }
148
149 async fn device_code_flow(&self) -> Result<AuthToken, ServiceBusError> {
150 let device_info = self.start_device_code_flow().await?;
153
154 log::info!("Device code authentication initiated - awaiting user action");
156
157 self.poll_device_code_token(&device_info).await
159 }
160
161 async fn client_credentials_flow(&self) -> Result<AuthToken, ServiceBusError> {
162 let client_secret = self.config.client_secret.as_deref().ok_or_else(|| {
163 ServiceBusError::ConfigurationError(
164 "Client secret is required for client credentials flow".to_string(),
165 )
166 })?;
167
168 let token_url = format!(
169 "{}/{}/oauth2/v2.0/token",
170 self.authority_host(),
171 self.tenant_id()?
172 );
173
174 let params = [
175 ("grant_type", "client_credentials"),
176 ("client_id", self.client_id()?),
177 ("client_secret", client_secret),
178 ("scope", self.scope()),
179 ];
180
181 log::info!("Client credentials authentication initiated");
182
183 let response = self
184 .http_client
185 .post(&token_url)
186 .form(¶ms)
187 .send()
188 .await
189 .map_err(|e| {
190 ServiceBusError::AuthenticationError(format!(
191 "Failed to authenticate with client credentials: {e}"
192 ))
193 })?;
194
195 if !response.status().is_success() {
196 let error_info = response
197 .json::<ErrorResponse>()
198 .await
199 .unwrap_or(ErrorResponse {
200 error: "unknown_error".to_string(),
201 error_description: Some("Failed to parse error response".to_string()),
202 });
203
204 let user_friendly_message = match error_info.error.as_str() {
205 "invalid_client" => {
206 "Invalid client credentials. Please check your client ID and client secret."
207 }
208 "invalid_request" => {
209 "Invalid authentication request. Please verify your configuration."
210 }
211 "unauthorized_client" => {
212 "This application is not authorized for client credentials flow. Please check Azure AD configuration."
213 }
214 "access_denied" => {
215 "Access denied. Please ensure the application has sufficient permissions."
216 }
217 "invalid_scope" => {
218 "Invalid scope specified. Please check the requested permissions."
219 }
220 _ => error_info
221 .error_description
222 .as_deref()
223 .unwrap_or(&error_info.error),
224 };
225
226 return Err(ServiceBusError::AuthenticationError(format!(
227 "Client credentials authentication failed: {user_friendly_message}"
228 )));
229 }
230
231 let token_response: TokenResponse = response.json().await.map_err(|e| {
232 ServiceBusError::AuthenticationError(format!("Failed to parse token response: {e}"))
233 })?;
234
235 log::info!("Client credentials authentication successful");
236
237 Ok(AuthToken {
238 token: token_response.access_token,
239 token_type: token_response.token_type,
240 expires_in_secs: Some(token_response.expires_in),
241 })
242 }
243
244 pub async fn start_device_code_flow(&self) -> Result<DeviceCodeFlowInfo, ServiceBusError> {
273 let device_code_url = format!(
274 "{}/{}/oauth2/v2.0/devicecode",
275 self.authority_host(),
276 self.tenant_id()?
277 );
278
279 let params = [("client_id", self.client_id()?), ("scope", self.scope())];
280
281 let device_response = self
282 .http_client
283 .post(&device_code_url)
284 .form(¶ms)
285 .send()
286 .await
287 .map_err(|e| {
288 ServiceBusError::AuthenticationError(format!(
289 "Failed to initiate device code flow: {e}"
290 ))
291 })?;
292
293 if !device_response.status().is_success() {
295 let error_info =
297 device_response
298 .json::<ErrorResponse>()
299 .await
300 .unwrap_or(ErrorResponse {
301 error: "unknown_error".to_string(),
302 error_description: Some("Failed to parse error response".to_string()),
303 });
304
305 let user_friendly_message = match error_info.error.as_str() {
306 "invalid_client" => {
307 "Invalid client configuration. Please check your Azure AD app registration and ensure 'Allow public client flows' is enabled."
308 }
309 "invalid_request" => {
310 "Invalid authentication request. Please check your client ID and tenant ID."
311 }
312 "unauthorized_client" => {
313 "This application is not authorized for device code flow. Please check Azure AD configuration."
314 }
315 "access_denied" => {
316 "Access denied. Please ensure you have the necessary permissions."
317 }
318 "expired_token" => "Authentication expired. Please try again.",
319 _ => error_info
320 .error_description
321 .as_deref()
322 .unwrap_or(&error_info.error),
323 };
324
325 return Err(ServiceBusError::AuthenticationError(format!(
326 "Authentication failed: {user_friendly_message}"
327 )));
328 }
329
330 let device_code: DeviceCodeResponse = device_response.json().await.map_err(|e| {
331 ServiceBusError::AuthenticationError(format!(
332 "Failed to parse device code response: {e}"
333 ))
334 })?;
335
336 Ok(DeviceCodeFlowInfo {
337 device_code: device_code.device_code,
338 user_code: device_code.user_code,
339 verification_uri: device_code.verification_uri,
340 expires_in: device_code.expires_in,
341 interval: device_code.interval,
342 message: device_code.message,
343 })
344 }
345
346 pub async fn poll_device_code_token(
380 &self,
381 device_info: &DeviceCodeFlowInfo,
382 ) -> Result<AuthToken, ServiceBusError> {
383 let token_url = format!(
384 "{}/{}/oauth2/v2.0/token",
385 self.authority_host(),
386 self.tenant_id()?
387 );
388
389 let mut interval = std::time::Duration::from_secs(device_info.interval);
390 let timeout = std::time::Duration::from_secs(device_info.expires_in);
391 let start = std::time::Instant::now();
392
393 loop {
394 if start.elapsed() > timeout {
395 return Err(ServiceBusError::AuthenticationError(
396 "Authentication timed out. The device code has expired. Please restart the authentication process.".to_string()
397 ));
398 }
399
400 tokio::time::sleep(interval).await;
401
402 let mut params = vec![
403 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
404 ("client_id", self.client_id()?),
405 ("device_code", device_info.device_code.as_str()),
406 ];
407
408 if let Some(client_secret) = self.config.client_secret.as_deref() {
410 params.push(("client_secret", client_secret));
411 }
412
413 let response = self
414 .http_client
415 .post(&token_url)
416 .form(¶ms)
417 .send()
418 .await
419 .map_err(|e| {
420 ServiceBusError::AuthenticationError(format!("Failed to poll for token: {e}"))
421 })?;
422
423 if response.status().is_success() {
424 let token_response: TokenResponse = response.json().await.map_err(|e| {
425 ServiceBusError::AuthenticationError(format!(
426 "Failed to parse token response: {e}"
427 ))
428 })?;
429
430 return Ok(AuthToken {
431 token: token_response.access_token,
432 token_type: token_response.token_type,
433 expires_in_secs: Some(token_response.expires_in),
434 });
435 }
436
437 let error_response: serde_json::Value = response.json().await.unwrap_or_default();
438
439 if let Some(error) = error_response["error"].as_str() {
440 match error {
441 "authorization_pending" => {
442 log::debug!("Waiting for user to complete authentication");
443 continue;
444 }
445 "slow_down" => {
446 log::debug!("Polling too frequently, increasing interval");
447 interval += std::time::Duration::from_secs(5);
448 continue;
449 }
450 "expired_token" => {
451 return Err(ServiceBusError::AuthenticationError(
452 "The device code has expired. Please restart the authentication process.".to_string()
453 ));
454 }
455 "access_denied" => {
456 return Err(ServiceBusError::AuthenticationError(
457 "Access was denied. Please ensure you have the necessary permissions."
458 .to_string(),
459 ));
460 }
461 _ => {
462 let error_desc = error_response["error_description"]
463 .as_str()
464 .unwrap_or("Unknown error occurred");
465 return Err(ServiceBusError::AuthenticationError(format!(
466 "Authentication failed: {error} - {error_desc}"
467 )));
468 }
469 }
470 }
471 }
472 }
473}
474
475#[async_trait]
476impl AuthProvider for AzureAdProvider {
477 async fn authenticate(&self) -> Result<AuthToken, ServiceBusError> {
495 match self.config.auth_method.as_str() {
496 "device_code" => self.device_code_flow().await,
497 "client_secret" => self.client_credentials_flow().await,
498 _ => Err(ServiceBusError::ConfigurationError(format!(
499 "Unsupported auth method: {}",
500 self.config.auth_method
501 ))),
502 }
503 }
504
505 fn auth_type(&self) -> AuthType {
511 AuthType::AzureAd
512 }
513}