1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::fmt::{Debug, Formatter};
4use std::ops::Add;
5use std::str::FromStr;
6use std::time::Duration;
7
8use graph_core::cache::{CacheStore, InMemoryCacheStore, TokenCache};
9use graph_core::identity::ForceTokenRefresh;
10use http::{HeaderMap, HeaderName, HeaderValue};
11use tracing::error;
12use url::Url;
13use uuid::Uuid;
14
15use crate::identity::{
16 AppConfig, Authority, AzureCloudInstance, DeviceAuthorizationResponse, PollDeviceCodeEvent,
17 PublicClientApplication, Token, TokenCredentialExecutor,
18};
19use crate::oauth_serializer::{AuthParameter, AuthSerializer};
20use graph_core::http::{
21 AsyncResponseConverterExt, HttpResponseExt, JsonHttpResponse, ResponseConverterExt,
22};
23use graph_error::{
24 AuthExecutionError, AuthExecutionResult, AuthTaskExecutionResult, AuthorizationFailure,
25 IdentityResult,
26};
27
28#[cfg(feature = "interactive-auth")]
29use {
30 crate::interactive::{HostOptions, UserEvents, WebViewAuth, WebViewOptions},
31 crate::tracing_targets::INTERACTIVE_AUTH,
32 graph_error::WebViewDeviceCodeError,
33 tao::{event_loop::EventLoopProxy, window::Window},
34 wry::{WebView, WebViewBuilder},
35};
36
37const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
38
39credential_builder!(
40 DeviceCodeCredentialBuilder,
41 PublicClientApplication<DeviceCodeCredential>
42);
43
44#[derive(Clone)]
52pub struct DeviceCodeCredential {
53 pub(crate) app_config: AppConfig,
54 pub(crate) refresh_token: Option<String>,
58 pub(crate) device_code: Option<String>,
63 token_cache: InMemoryCacheStore<Token>,
64}
65
66impl DeviceCodeCredential {
67 pub fn new<U: ToString, I: IntoIterator<Item = U>>(
68 client_id: impl AsRef<str>,
69 device_code: impl AsRef<str>,
70 scope: I,
71 ) -> DeviceCodeCredential {
72 DeviceCodeCredential {
73 app_config: AppConfig::builder(client_id.as_ref()).scope(scope).build(),
74 refresh_token: None,
75 device_code: Some(device_code.as_ref().to_owned()),
76 token_cache: Default::default(),
77 }
78 }
79
80 pub fn with_refresh_token<T: AsRef<str>>(&mut self, refresh_token: T) -> &mut Self {
81 self.refresh_token = Some(refresh_token.as_ref().to_owned());
82 self
83 }
84
85 pub fn with_device_code<T: AsRef<str>>(&mut self, device_code: T) -> &mut Self {
86 self.device_code = Some(device_code.as_ref().to_owned());
87 self
88 }
89
90 pub fn builder(client_id: impl AsRef<str>) -> DeviceCodeCredentialBuilder {
91 DeviceCodeCredentialBuilder::new(client_id.as_ref())
92 }
93
94 fn execute_cached_token_refresh(&mut self, cache_id: String) -> AuthExecutionResult<Token> {
95 let response = self.execute()?;
96
97 if !response.status().is_success() {
98 return Err(AuthExecutionError::silent_token_auth(
99 response.into_http_response()?,
100 ));
101 }
102
103 let new_token: Token = response.json()?;
104 self.token_cache.store(cache_id, new_token.clone());
105
106 if new_token.refresh_token.is_some() {
107 self.refresh_token = new_token.refresh_token.clone();
108 }
109
110 Ok(new_token)
111 }
112
113 async fn execute_cached_token_refresh_async(
114 &mut self,
115 cache_id: String,
116 ) -> AuthExecutionResult<Token> {
117 let response = self.execute_async().await?;
118
119 if !response.status().is_success() {
120 return Err(AuthExecutionError::silent_token_auth(
121 response.into_http_response_async().await?,
122 ));
123 }
124
125 let new_token: Token = response.json().await?;
126
127 if new_token.refresh_token.is_some() {
128 self.refresh_token = new_token.refresh_token.clone();
129 }
130
131 self.token_cache.store(cache_id, new_token.clone());
132 Ok(new_token)
133 }
134}
135
136impl Debug for DeviceCodeCredential {
137 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("DeviceCodeCredential")
139 .field("app_config", &self.app_config)
140 .finish()
141 }
142}
143
144#[async_trait]
145impl TokenCache for DeviceCodeCredential {
146 type Token = Token;
147
148 fn get_token_silent(&mut self) -> Result<Self::Token, AuthExecutionError> {
149 let cache_id = self.app_config.cache_id.to_string();
150
151 match self.app_config.force_token_refresh {
152 ForceTokenRefresh::Never => {
153 if self.refresh_token.is_some() {
156 if let Ok(token) = self.execute_cached_token_refresh(cache_id.clone()) {
157 return Ok(token);
158 }
159 }
160
161 if let Some(token) = self.token_cache.get(cache_id.as_str()) {
162 if token.is_expired_sub(time::Duration::minutes(5)) {
163 if let Some(refresh_token) = token.refresh_token.as_ref() {
164 self.refresh_token = Some(refresh_token.to_owned());
165 }
166
167 self.execute_cached_token_refresh(cache_id)
168 } else {
169 Ok(token)
170 }
171 } else {
172 self.execute_cached_token_refresh(cache_id)
173 }
174 }
175 ForceTokenRefresh::Once | ForceTokenRefresh::Always => {
176 let token_result = self.execute_cached_token_refresh(cache_id);
177 if self.app_config.force_token_refresh == ForceTokenRefresh::Once {
178 self.with_force_token_refresh(ForceTokenRefresh::Never);
179 }
180 token_result
181 }
182 }
183 }
184
185 async fn get_token_silent_async(&mut self) -> Result<Self::Token, AuthExecutionError> {
186 let cache_id = self.app_config.cache_id.to_string();
187
188 match self.app_config.force_token_refresh {
189 ForceTokenRefresh::Never => {
190 if self.refresh_token.is_some() {
193 if let Ok(token) = self
194 .execute_cached_token_refresh_async(cache_id.clone())
195 .await
196 {
197 return Ok(token);
198 }
199 }
200
201 if let Some(old_token) = self.token_cache.get(cache_id.as_str()) {
202 if old_token.is_expired_sub(time::Duration::minutes(5)) {
203 if let Some(refresh_token) = old_token.refresh_token.as_ref() {
204 self.refresh_token = Some(refresh_token.to_owned());
205 }
206
207 self.execute_cached_token_refresh_async(cache_id).await
208 } else {
209 Ok(old_token.clone())
210 }
211 } else {
212 self.execute_cached_token_refresh_async(cache_id).await
213 }
214 }
215 ForceTokenRefresh::Once | ForceTokenRefresh::Always => {
216 let token_result = self.execute_cached_token_refresh_async(cache_id).await;
217 if self.app_config.force_token_refresh == ForceTokenRefresh::Once {
218 self.with_force_token_refresh(ForceTokenRefresh::Never);
219 }
220 token_result
221 }
222 }
223 }
224
225 fn with_force_token_refresh(&mut self, force_token_refresh: ForceTokenRefresh) {
226 self.app_config.force_token_refresh = force_token_refresh;
227 }
228}
229
230impl TokenCredentialExecutor for DeviceCodeCredential {
231 fn uri(&mut self) -> IdentityResult<Url> {
232 if self.device_code.is_none() && self.refresh_token.is_none() {
233 Ok(self
234 .azure_cloud_instance()
235 .device_code_uri(&self.authority())?)
236 } else {
237 Ok(self.azure_cloud_instance().token_uri(&self.authority())?)
238 }
239 }
240
241 fn form_urlencode(&mut self) -> IdentityResult<HashMap<String, String>> {
242 let mut serializer = AuthSerializer::new();
243 let client_id = self.app_config.client_id.to_string();
244 if client_id.is_empty() || self.app_config.client_id.is_nil() {
245 return AuthorizationFailure::result(AuthParameter::ClientId.alias());
246 }
247
248 serializer
249 .client_id(client_id.as_str())
250 .set_scope(self.app_config.scope.clone());
251
252 if let Some(refresh_token) = self.refresh_token.as_ref() {
253 if refresh_token.trim().is_empty() {
254 return AuthorizationFailure::msg_result(
255 AuthParameter::RefreshToken.alias(),
256 "Found empty string for refresh token",
257 );
258 }
259
260 serializer
261 .grant_type("refresh_token")
262 .device_code(refresh_token.as_ref());
263
264 return serializer.as_credential_map(
265 vec![],
266 vec![
267 AuthParameter::ClientId,
268 AuthParameter::RefreshToken,
269 AuthParameter::Scope,
270 AuthParameter::GrantType,
271 ],
272 );
273 } else if let Some(device_code) = self.device_code.as_ref() {
274 if device_code.trim().is_empty() {
275 return AuthorizationFailure::msg_result(
276 AuthParameter::DeviceCode.alias(),
277 "Found empty string for device code",
278 );
279 }
280
281 serializer
282 .grant_type(DEVICE_CODE_GRANT_TYPE)
283 .device_code(device_code.as_ref());
284
285 return serializer.as_credential_map(
286 vec![],
287 vec![
288 AuthParameter::ClientId,
289 AuthParameter::DeviceCode,
290 AuthParameter::Scope,
291 AuthParameter::GrantType,
292 ],
293 );
294 }
295
296 serializer.as_credential_map(vec![], vec![AuthParameter::ClientId, AuthParameter::Scope])
297 }
298
299 fn client_id(&self) -> &Uuid {
300 &self.app_config.client_id
301 }
302
303 fn authority(&self) -> Authority {
304 self.app_config.authority.clone()
305 }
306
307 fn azure_cloud_instance(&self) -> AzureCloudInstance {
308 self.app_config.azure_cloud_instance
309 }
310
311 fn app_config(&self) -> &AppConfig {
312 &self.app_config
313 }
314}
315
316#[derive(Clone)]
317pub struct DeviceCodeCredentialBuilder {
318 credential: DeviceCodeCredential,
319}
320
321impl DeviceCodeCredentialBuilder {
322 fn new(client_id: impl AsRef<str>) -> DeviceCodeCredentialBuilder {
323 DeviceCodeCredentialBuilder {
324 credential: DeviceCodeCredential {
325 app_config: AppConfig::new(client_id.as_ref()),
326 refresh_token: None,
327 device_code: None,
328 token_cache: Default::default(),
329 },
330 }
331 }
332
333 pub(crate) fn new_with_device_code(
334 device_code: impl AsRef<str>,
335 app_config: AppConfig,
336 ) -> DeviceCodeCredentialBuilder {
337 DeviceCodeCredentialBuilder {
338 credential: DeviceCodeCredential {
339 app_config,
340 refresh_token: None,
341 device_code: Some(device_code.as_ref().to_owned()),
342 token_cache: Default::default(),
343 },
344 }
345 }
346
347 pub fn with_device_code<T: AsRef<str>>(&mut self, device_code: T) -> &mut Self {
348 self.credential.device_code = Some(device_code.as_ref().to_owned());
349 self.credential.refresh_token = None;
350 self
351 }
352
353 pub fn with_refresh_token<T: AsRef<str>>(&mut self, refresh_token: T) -> &mut Self {
354 self.credential.device_code = None;
355 self.credential.refresh_token = Some(refresh_token.as_ref().to_owned());
356 self
357 }
358}
359
360#[derive(Debug)]
361pub struct DeviceCodePollingExecutor {
362 credential: DeviceCodeCredential,
363}
364
365impl DeviceCodePollingExecutor {
366 pub(crate) fn new_with_app_config(app_config: AppConfig) -> DeviceCodePollingExecutor {
367 DeviceCodePollingExecutor {
368 credential: DeviceCodeCredential {
369 app_config,
370 refresh_token: None,
371 device_code: None,
372 token_cache: Default::default(),
373 },
374 }
375 }
376
377 pub fn with_scope<T: ToString, I: IntoIterator<Item = T>>(mut self, scope: I) -> Self {
378 self.credential.app_config.scope = scope.into_iter().map(|s| s.to_string()).collect();
379 self
380 }
381
382 pub fn with_tenant(mut self, tenant_id: impl AsRef<str>) -> Self {
383 self.credential.app_config.tenant_id = Some(tenant_id.as_ref().to_owned());
384 self
385 }
386
387 pub fn poll(&mut self) -> AuthExecutionResult<std::sync::mpsc::Receiver<JsonHttpResponse>> {
388 let (sender, receiver) = std::sync::mpsc::channel();
389
390 let mut credential = self.credential.clone();
391 let response = credential.execute()?;
392
393 let http_response = response.into_http_response()?;
394 let json = http_response.json().unwrap();
395 let device_code_response: DeviceAuthorizationResponse = serde_json::from_value(json)?;
396
397 sender.send(http_response).unwrap();
398
399 let device_code = device_code_response.device_code;
400 let mut interval = Duration::from_secs(device_code_response.interval);
401 credential.with_device_code(device_code);
402
403 let _ = std::thread::spawn(move || {
404 loop {
405 std::thread::sleep(interval);
407
408 let response = credential.execute().unwrap();
409 let http_response = response.into_http_response()?;
410 let status = http_response.status();
411
412 if status.is_success() {
413 sender.send(http_response)?;
414 break;
415 } else {
416 let json = http_response.json().unwrap();
417 let option_error = json["error"].as_str().map(|value| value.to_owned());
418 sender.send(http_response)?;
419
420 if let Some(error) = option_error {
421 match PollDeviceCodeEvent::from_str(error.as_str()) {
422 Ok(poll_device_code_type) => match poll_device_code_type {
423 PollDeviceCodeEvent::AuthorizationPending
424 | PollDeviceCodeEvent::BadVerificationCode => continue,
425 PollDeviceCodeEvent::AuthorizationDeclined
426 | PollDeviceCodeEvent::ExpiredToken
427 | PollDeviceCodeEvent::AccessDenied => break,
428 PollDeviceCodeEvent::SlowDown => {
429 interval = interval.add(Duration::from_secs(5));
430 continue;
431 }
432 },
433 Err(_) => {
434 error!(
435 target = "device_code_polling_executor",
436 "invalid PollDeviceCodeEvent"
437 );
438 break;
439 }
440 }
441 } else {
442 break;
444 }
445 }
446 }
447 Ok::<(), anyhow::Error>(())
448 });
449
450 Ok(receiver)
451 }
452
453 pub async fn poll_async(
454 &mut self,
455 buffer: Option<usize>,
456 ) -> AuthTaskExecutionResult<tokio::sync::mpsc::Receiver<JsonHttpResponse>, JsonHttpResponse>
457 {
458 let (sender, receiver) = {
459 if let Some(buffer) = buffer {
460 tokio::sync::mpsc::channel(buffer)
461 } else {
462 tokio::sync::mpsc::channel(100)
463 }
464 };
465
466 let mut credential = self.credential.clone();
467 let response = credential.execute_async().await?;
468
469 let http_response = response.into_http_response_async().await?;
470 let json = http_response.json().unwrap();
471 let device_code_response: DeviceAuthorizationResponse =
472 serde_json::from_value(json).map_err(AuthExecutionError::from)?;
473
474 sender
475 .send_timeout(http_response, Duration::from_secs(60))
476 .await?;
477
478 let device_code = device_code_response.device_code;
479 let mut interval = Duration::from_secs(device_code_response.interval);
480 credential.with_device_code(device_code);
481
482 tokio::spawn(async move {
483 loop {
484 tokio::time::sleep(interval).await;
486
487 let response = credential.execute_async().await?;
488 let http_response = response.into_http_response_async().await?;
489 let status = http_response.status();
490
491 if status.is_success() {
492 sender
493 .send_timeout(http_response, Duration::from_secs(60))
494 .await?;
495 break;
496 } else {
497 let json = http_response.json().unwrap();
498 let option_error = json["error"].as_str().map(|value| value.to_owned());
499 sender
500 .send_timeout(http_response, Duration::from_secs(60))
501 .await?;
502
503 if let Some(error) = option_error {
504 match PollDeviceCodeEvent::from_str(error.as_str()) {
505 Ok(poll_device_code_type) => match poll_device_code_type {
506 PollDeviceCodeEvent::AuthorizationPending => continue,
507 PollDeviceCodeEvent::AuthorizationDeclined => break,
508 PollDeviceCodeEvent::BadVerificationCode => continue,
509 PollDeviceCodeEvent::ExpiredToken => break,
510 PollDeviceCodeEvent::AccessDenied => break,
511 PollDeviceCodeEvent::SlowDown => {
512 interval = interval.add(Duration::from_secs(5));
516 continue;
517 }
518 },
519 Err(_) => break,
520 }
521 } else {
522 break;
524 }
525 }
526 }
527 Ok::<(), anyhow::Error>(())
528 });
529
530 Ok(receiver)
531 }
532
533 #[cfg(feature = "interactive-auth")]
534 pub fn with_interactive_auth(
535 &mut self,
536 options: WebViewOptions,
537 ) -> AuthExecutionResult<(DeviceAuthorizationResponse, DeviceCodeInteractiveAuth)> {
538 let response = self.credential.execute()?;
539 let device_authorization_response: DeviceAuthorizationResponse = response.json()?;
540 self.credential
541 .with_device_code(device_authorization_response.device_code.clone());
542
543 Ok((
544 device_authorization_response.clone(),
545 DeviceCodeInteractiveAuth {
546 credential: self.credential.clone(),
547 interval: Duration::from_secs(device_authorization_response.interval),
548 verification_uri: device_authorization_response.verification_uri.clone(),
549 verification_uri_complete: device_authorization_response.verification_uri_complete,
550 options,
551 },
552 ))
553 }
554}
555
556#[cfg(feature = "interactive-auth")]
557pub(crate) mod internal {
558 use super::*;
559
560 impl WebViewAuth for DeviceCodeCredential {
561 fn webview(
562 host_options: HostOptions,
563 window: &Window,
564 _proxy: EventLoopProxy<UserEvents>,
565 ) -> anyhow::Result<WebView> {
566 Ok(WebViewBuilder::new(window)
567 .with_url(host_options.start_uri.as_ref())
568 .with_file_drop_handler(|_| true)
570 .with_navigation_handler(move |uri| {
571 tracing::debug!(target: INTERACTIVE_AUTH, url = uri.as_str());
572 true
573 })
574 .build()?)
575 }
576 }
577}
578
579#[cfg(feature = "interactive-auth")]
580#[derive(Debug)]
581pub struct DeviceCodeInteractiveAuth {
582 credential: DeviceCodeCredential,
583 interval: Duration,
584 verification_uri: String,
585 verification_uri_complete: Option<String>,
586 options: WebViewOptions,
587}
588
589#[allow(dead_code)]
590#[cfg(feature = "interactive-auth")]
591impl DeviceCodeInteractiveAuth {
592 pub(crate) fn new(
593 credential: DeviceCodeCredential,
594 device_authorization_response: DeviceAuthorizationResponse,
595 options: WebViewOptions,
596 ) -> DeviceCodeInteractiveAuth {
597 DeviceCodeInteractiveAuth {
598 credential,
599 interval: Duration::from_secs(device_authorization_response.interval),
600 verification_uri: device_authorization_response.verification_uri.clone(),
601 verification_uri_complete: device_authorization_response.verification_uri_complete,
602 options,
603 }
604 }
605
606 pub fn poll(
607 &mut self,
608 ) -> Result<PublicClientApplication<DeviceCodeCredential>, WebViewDeviceCodeError> {
609 let url = {
610 if let Some(url_complete) = self.verification_uri_complete.as_ref() {
611 Url::parse(url_complete).map_err(AuthorizationFailure::from)?
612 } else {
613 Url::parse(self.verification_uri.as_str()).map_err(AuthorizationFailure::from)?
614 }
615 };
616
617 let (sender, _receiver) = std::sync::mpsc::channel();
618
619 let options = self.options.clone();
620 std::thread::spawn(move || {
621 DeviceCodeCredential::run(url, vec![], options, sender).unwrap();
622 });
623
624 let credential = self.credential.clone();
625 let interval = self.interval;
626 DeviceCodeInteractiveAuth::poll_internal(interval, credential)
627 }
628
629 pub(crate) fn poll_internal(
630 mut interval: Duration,
631 mut credential: DeviceCodeCredential,
632 ) -> Result<PublicClientApplication<DeviceCodeCredential>, WebViewDeviceCodeError> {
633 loop {
634 std::thread::sleep(interval);
636
637 let response = credential.execute().unwrap();
638 let http_response = response.into_http_response().map_err(Box::new)?;
639 let status = http_response.status();
640
641 if status.is_success() {
642 return if let Some(json) = http_response.json() {
643 let token: Token = serde_json::from_value(json)
644 .map_err(|err| Box::new(AuthExecutionError::from(err)))?;
645 let cache_id = credential.app_config.cache_id.clone();
646 credential.token_cache.store(cache_id, token);
647 Ok(PublicClientApplication::from(credential))
648 } else {
649 Err(WebViewDeviceCodeError::DeviceCodePollingError(
650 http_response,
651 ))
652 };
653 } else {
654 let json = http_response.json().unwrap();
655 let option_error = json["error"].as_str().map(|value| value.to_owned());
656
657 if let Some(error) = option_error {
658 match PollDeviceCodeEvent::from_str(error.as_str()) {
659 Ok(poll_device_code_type) => match poll_device_code_type {
660 PollDeviceCodeEvent::AuthorizationPending
661 | PollDeviceCodeEvent::BadVerificationCode => continue,
662 PollDeviceCodeEvent::SlowDown => {
663 interval = interval.add(Duration::from_secs(5));
664 continue;
665 }
666 PollDeviceCodeEvent::AuthorizationDeclined
667 | PollDeviceCodeEvent::ExpiredToken
668 | PollDeviceCodeEvent::AccessDenied => {
669 return Err(WebViewDeviceCodeError::DeviceCodePollingError(
670 http_response,
671 ));
672 }
673 },
674 Err(_) => {
675 return Err(WebViewDeviceCodeError::DeviceCodePollingError(
676 http_response,
677 ));
678 }
679 }
680 } else {
681 return Err(WebViewDeviceCodeError::DeviceCodePollingError(
683 http_response,
684 ));
685 }
686 }
687 }
688 }
689}
690
691#[cfg(test)]
692mod test {
693 use super::*;
694
695 #[test]
696 #[should_panic]
697 fn no_scope() {
698 let mut credential = DeviceCodeCredential::builder("CLIENT_ID").build();
699
700 let _ = credential.form_urlencode().unwrap();
701 }
702}