graph_oauth/identity/credentials/
device_code_credential.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::fmt::{Debug, Formatter};
4use std::ops::Add;
5use std::str::FromStr;
6use std::time::Duration;
7
8use graph_core::cache::{CacheStore, InMemoryCacheStore, TokenCache};
9use graph_core::identity::ForceTokenRefresh;
10use http::{HeaderMap, HeaderName, HeaderValue};
11use tracing::error;
12use url::Url;
13use uuid::Uuid;
14
15use crate::identity::{
16    AppConfig, Authority, AzureCloudInstance, DeviceAuthorizationResponse, PollDeviceCodeEvent,
17    PublicClientApplication, Token, TokenCredentialExecutor,
18};
19use crate::oauth_serializer::{AuthParameter, AuthSerializer};
20use graph_core::http::{
21    AsyncResponseConverterExt, HttpResponseExt, JsonHttpResponse, ResponseConverterExt,
22};
23use graph_error::{
24    AuthExecutionError, AuthExecutionResult, AuthTaskExecutionResult, AuthorizationFailure,
25    IdentityResult,
26};
27
28#[cfg(feature = "interactive-auth")]
29use {
30    crate::interactive::{HostOptions, UserEvents, WebViewAuth, WebViewOptions},
31    crate::tracing_targets::INTERACTIVE_AUTH,
32    graph_error::WebViewDeviceCodeError,
33    tao::{event_loop::EventLoopProxy, window::Window},
34    wry::{WebView, WebViewBuilder},
35};
36
37const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
38
39credential_builder!(
40    DeviceCodeCredentialBuilder,
41    PublicClientApplication<DeviceCodeCredential>
42);
43
44/// The device authorization grant: allows users to sign in to input-constrained devices
45/// such as a smart TV, IoT device, or a printer. To enable this flow, the device has the
46/// user visit a webpage in a browser on another device to sign in. Once the user signs in,
47/// the device is able to get access tokens and refresh tokens as needed.
48///
49/// For more info on the protocol supported by the Microsoft Identity Platform see the
50/// [Microsoft identity platform and the OAuth 2.0 device authorization grant flow](https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code)
51#[derive(Clone)]
52pub struct DeviceCodeCredential {
53    pub(crate) app_config: AppConfig,
54    /// Required when requesting a new access token using a refresh token
55    /// The refresh token needed to make an access token request using a refresh token.
56    /// Do not include an authorization code when using a refresh token.
57    pub(crate) refresh_token: Option<String>,
58    /// Required.
59    /// The device_code returned in the device authorization request.
60    /// A device_code is a long string used to verify the session between the client and the authorization server.
61    /// The client uses this parameter to request the access token from the authorization server.
62    pub(crate) device_code: Option<String>,
63    token_cache: InMemoryCacheStore<Token>,
64}
65
66impl DeviceCodeCredential {
67    pub fn new<U: ToString, I: IntoIterator<Item = U>>(
68        client_id: impl AsRef<str>,
69        device_code: impl AsRef<str>,
70        scope: I,
71    ) -> DeviceCodeCredential {
72        DeviceCodeCredential {
73            app_config: AppConfig::builder(client_id.as_ref()).scope(scope).build(),
74            refresh_token: None,
75            device_code: Some(device_code.as_ref().to_owned()),
76            token_cache: Default::default(),
77        }
78    }
79
80    pub fn with_refresh_token<T: AsRef<str>>(&mut self, refresh_token: T) -> &mut Self {
81        self.refresh_token = Some(refresh_token.as_ref().to_owned());
82        self
83    }
84
85    pub fn with_device_code<T: AsRef<str>>(&mut self, device_code: T) -> &mut Self {
86        self.device_code = Some(device_code.as_ref().to_owned());
87        self
88    }
89
90    pub fn builder(client_id: impl AsRef<str>) -> DeviceCodeCredentialBuilder {
91        DeviceCodeCredentialBuilder::new(client_id.as_ref())
92    }
93
94    fn execute_cached_token_refresh(&mut self, cache_id: String) -> AuthExecutionResult<Token> {
95        let response = self.execute()?;
96
97        if !response.status().is_success() {
98            return Err(AuthExecutionError::silent_token_auth(
99                response.into_http_response()?,
100            ));
101        }
102
103        let new_token: Token = response.json()?;
104        self.token_cache.store(cache_id, new_token.clone());
105
106        if new_token.refresh_token.is_some() {
107            self.refresh_token = new_token.refresh_token.clone();
108        }
109
110        Ok(new_token)
111    }
112
113    async fn execute_cached_token_refresh_async(
114        &mut self,
115        cache_id: String,
116    ) -> AuthExecutionResult<Token> {
117        let response = self.execute_async().await?;
118
119        if !response.status().is_success() {
120            return Err(AuthExecutionError::silent_token_auth(
121                response.into_http_response_async().await?,
122            ));
123        }
124
125        let new_token: Token = response.json().await?;
126
127        if new_token.refresh_token.is_some() {
128            self.refresh_token = new_token.refresh_token.clone();
129        }
130
131        self.token_cache.store(cache_id, new_token.clone());
132        Ok(new_token)
133    }
134}
135
136impl Debug for DeviceCodeCredential {
137    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("DeviceCodeCredential")
139            .field("app_config", &self.app_config)
140            .finish()
141    }
142}
143
144#[async_trait]
145impl TokenCache for DeviceCodeCredential {
146    type Token = Token;
147
148    fn get_token_silent(&mut self) -> Result<Self::Token, AuthExecutionError> {
149        let cache_id = self.app_config.cache_id.to_string();
150
151        match self.app_config.force_token_refresh {
152            ForceTokenRefresh::Never => {
153                // Attempt to bypass a read on the token store by using previous
154                // refresh token stored outside of RwLock
155                if self.refresh_token.is_some() {
156                    if let Ok(token) = self.execute_cached_token_refresh(cache_id.clone()) {
157                        return Ok(token);
158                    }
159                }
160
161                if let Some(token) = self.token_cache.get(cache_id.as_str()) {
162                    if token.is_expired_sub(time::Duration::minutes(5)) {
163                        if let Some(refresh_token) = token.refresh_token.as_ref() {
164                            self.refresh_token = Some(refresh_token.to_owned());
165                        }
166
167                        self.execute_cached_token_refresh(cache_id)
168                    } else {
169                        Ok(token)
170                    }
171                } else {
172                    self.execute_cached_token_refresh(cache_id)
173                }
174            }
175            ForceTokenRefresh::Once | ForceTokenRefresh::Always => {
176                let token_result = self.execute_cached_token_refresh(cache_id);
177                if self.app_config.force_token_refresh == ForceTokenRefresh::Once {
178                    self.with_force_token_refresh(ForceTokenRefresh::Never);
179                }
180                token_result
181            }
182        }
183    }
184
185    async fn get_token_silent_async(&mut self) -> Result<Self::Token, AuthExecutionError> {
186        let cache_id = self.app_config.cache_id.to_string();
187
188        match self.app_config.force_token_refresh {
189            ForceTokenRefresh::Never => {
190                // Attempt to bypass a read on the token store by using previous
191                // refresh token stored outside of RwLock
192                if self.refresh_token.is_some() {
193                    if let Ok(token) = self
194                        .execute_cached_token_refresh_async(cache_id.clone())
195                        .await
196                    {
197                        return Ok(token);
198                    }
199                }
200
201                if let Some(old_token) = self.token_cache.get(cache_id.as_str()) {
202                    if old_token.is_expired_sub(time::Duration::minutes(5)) {
203                        if let Some(refresh_token) = old_token.refresh_token.as_ref() {
204                            self.refresh_token = Some(refresh_token.to_owned());
205                        }
206
207                        self.execute_cached_token_refresh_async(cache_id).await
208                    } else {
209                        Ok(old_token.clone())
210                    }
211                } else {
212                    self.execute_cached_token_refresh_async(cache_id).await
213                }
214            }
215            ForceTokenRefresh::Once | ForceTokenRefresh::Always => {
216                let token_result = self.execute_cached_token_refresh_async(cache_id).await;
217                if self.app_config.force_token_refresh == ForceTokenRefresh::Once {
218                    self.with_force_token_refresh(ForceTokenRefresh::Never);
219                }
220                token_result
221            }
222        }
223    }
224
225    fn with_force_token_refresh(&mut self, force_token_refresh: ForceTokenRefresh) {
226        self.app_config.force_token_refresh = force_token_refresh;
227    }
228}
229
230impl TokenCredentialExecutor for DeviceCodeCredential {
231    fn uri(&mut self) -> IdentityResult<Url> {
232        if self.device_code.is_none() && self.refresh_token.is_none() {
233            Ok(self
234                .azure_cloud_instance()
235                .device_code_uri(&self.authority())?)
236        } else {
237            Ok(self.azure_cloud_instance().token_uri(&self.authority())?)
238        }
239    }
240
241    fn form_urlencode(&mut self) -> IdentityResult<HashMap<String, String>> {
242        let mut serializer = AuthSerializer::new();
243        let client_id = self.app_config.client_id.to_string();
244        if client_id.is_empty() || self.app_config.client_id.is_nil() {
245            return AuthorizationFailure::result(AuthParameter::ClientId.alias());
246        }
247
248        serializer
249            .client_id(client_id.as_str())
250            .set_scope(self.app_config.scope.clone());
251
252        if let Some(refresh_token) = self.refresh_token.as_ref() {
253            if refresh_token.trim().is_empty() {
254                return AuthorizationFailure::msg_result(
255                    AuthParameter::RefreshToken.alias(),
256                    "Found empty string for refresh token",
257                );
258            }
259
260            serializer
261                .grant_type("refresh_token")
262                .device_code(refresh_token.as_ref());
263
264            return serializer.as_credential_map(
265                vec![],
266                vec![
267                    AuthParameter::ClientId,
268                    AuthParameter::RefreshToken,
269                    AuthParameter::Scope,
270                    AuthParameter::GrantType,
271                ],
272            );
273        } else if let Some(device_code) = self.device_code.as_ref() {
274            if device_code.trim().is_empty() {
275                return AuthorizationFailure::msg_result(
276                    AuthParameter::DeviceCode.alias(),
277                    "Found empty string for device code",
278                );
279            }
280
281            serializer
282                .grant_type(DEVICE_CODE_GRANT_TYPE)
283                .device_code(device_code.as_ref());
284
285            return serializer.as_credential_map(
286                vec![],
287                vec![
288                    AuthParameter::ClientId,
289                    AuthParameter::DeviceCode,
290                    AuthParameter::Scope,
291                    AuthParameter::GrantType,
292                ],
293            );
294        }
295
296        serializer.as_credential_map(vec![], vec![AuthParameter::ClientId, AuthParameter::Scope])
297    }
298
299    fn client_id(&self) -> &Uuid {
300        &self.app_config.client_id
301    }
302
303    fn authority(&self) -> Authority {
304        self.app_config.authority.clone()
305    }
306
307    fn azure_cloud_instance(&self) -> AzureCloudInstance {
308        self.app_config.azure_cloud_instance
309    }
310
311    fn app_config(&self) -> &AppConfig {
312        &self.app_config
313    }
314}
315
316#[derive(Clone)]
317pub struct DeviceCodeCredentialBuilder {
318    credential: DeviceCodeCredential,
319}
320
321impl DeviceCodeCredentialBuilder {
322    fn new(client_id: impl AsRef<str>) -> DeviceCodeCredentialBuilder {
323        DeviceCodeCredentialBuilder {
324            credential: DeviceCodeCredential {
325                app_config: AppConfig::new(client_id.as_ref()),
326                refresh_token: None,
327                device_code: None,
328                token_cache: Default::default(),
329            },
330        }
331    }
332
333    pub(crate) fn new_with_device_code(
334        device_code: impl AsRef<str>,
335        app_config: AppConfig,
336    ) -> DeviceCodeCredentialBuilder {
337        DeviceCodeCredentialBuilder {
338            credential: DeviceCodeCredential {
339                app_config,
340                refresh_token: None,
341                device_code: Some(device_code.as_ref().to_owned()),
342                token_cache: Default::default(),
343            },
344        }
345    }
346
347    pub fn with_device_code<T: AsRef<str>>(&mut self, device_code: T) -> &mut Self {
348        self.credential.device_code = Some(device_code.as_ref().to_owned());
349        self.credential.refresh_token = None;
350        self
351    }
352
353    pub fn with_refresh_token<T: AsRef<str>>(&mut self, refresh_token: T) -> &mut Self {
354        self.credential.device_code = None;
355        self.credential.refresh_token = Some(refresh_token.as_ref().to_owned());
356        self
357    }
358}
359
360#[derive(Debug)]
361pub struct DeviceCodePollingExecutor {
362    credential: DeviceCodeCredential,
363}
364
365impl DeviceCodePollingExecutor {
366    pub(crate) fn new_with_app_config(app_config: AppConfig) -> DeviceCodePollingExecutor {
367        DeviceCodePollingExecutor {
368            credential: DeviceCodeCredential {
369                app_config,
370                refresh_token: None,
371                device_code: None,
372                token_cache: Default::default(),
373            },
374        }
375    }
376
377    pub fn with_scope<T: ToString, I: IntoIterator<Item = T>>(mut self, scope: I) -> Self {
378        self.credential.app_config.scope = scope.into_iter().map(|s| s.to_string()).collect();
379        self
380    }
381
382    pub fn with_tenant(mut self, tenant_id: impl AsRef<str>) -> Self {
383        self.credential.app_config.tenant_id = Some(tenant_id.as_ref().to_owned());
384        self
385    }
386
387    pub fn poll(&mut self) -> AuthExecutionResult<std::sync::mpsc::Receiver<JsonHttpResponse>> {
388        let (sender, receiver) = std::sync::mpsc::channel();
389
390        let mut credential = self.credential.clone();
391        let response = credential.execute()?;
392
393        let http_response = response.into_http_response()?;
394        let json = http_response.json().unwrap();
395        let device_code_response: DeviceAuthorizationResponse = serde_json::from_value(json)?;
396
397        sender.send(http_response).unwrap();
398
399        let device_code = device_code_response.device_code;
400        let mut interval = Duration::from_secs(device_code_response.interval);
401        credential.with_device_code(device_code);
402
403        let _ = std::thread::spawn(move || {
404            loop {
405                // Wait the amount of seconds that interval is.
406                std::thread::sleep(interval);
407
408                let response = credential.execute().unwrap();
409                let http_response = response.into_http_response()?;
410                let status = http_response.status();
411
412                if status.is_success() {
413                    sender.send(http_response)?;
414                    break;
415                } else {
416                    let json = http_response.json().unwrap();
417                    let option_error = json["error"].as_str().map(|value| value.to_owned());
418                    sender.send(http_response)?;
419
420                    if let Some(error) = option_error {
421                        match PollDeviceCodeEvent::from_str(error.as_str()) {
422                            Ok(poll_device_code_type) => match poll_device_code_type {
423                                PollDeviceCodeEvent::AuthorizationPending
424                                | PollDeviceCodeEvent::BadVerificationCode => continue,
425                                PollDeviceCodeEvent::AuthorizationDeclined
426                                | PollDeviceCodeEvent::ExpiredToken
427                                | PollDeviceCodeEvent::AccessDenied => break,
428                                PollDeviceCodeEvent::SlowDown => {
429                                    interval = interval.add(Duration::from_secs(5));
430                                    continue;
431                                }
432                            },
433                            Err(_) => {
434                                error!(
435                                    target = "device_code_polling_executor",
436                                    "invalid PollDeviceCodeEvent"
437                                );
438                                break;
439                            }
440                        }
441                    } else {
442                        // Body should have error or we should bail.
443                        break;
444                    }
445                }
446            }
447            Ok::<(), anyhow::Error>(())
448        });
449
450        Ok(receiver)
451    }
452
453    pub async fn poll_async(
454        &mut self,
455        buffer: Option<usize>,
456    ) -> AuthTaskExecutionResult<tokio::sync::mpsc::Receiver<JsonHttpResponse>, JsonHttpResponse>
457    {
458        let (sender, receiver) = {
459            if let Some(buffer) = buffer {
460                tokio::sync::mpsc::channel(buffer)
461            } else {
462                tokio::sync::mpsc::channel(100)
463            }
464        };
465
466        let mut credential = self.credential.clone();
467        let response = credential.execute_async().await?;
468
469        let http_response = response.into_http_response_async().await?;
470        let json = http_response.json().unwrap();
471        let device_code_response: DeviceAuthorizationResponse =
472            serde_json::from_value(json).map_err(AuthExecutionError::from)?;
473
474        sender
475            .send_timeout(http_response, Duration::from_secs(60))
476            .await?;
477
478        let device_code = device_code_response.device_code;
479        let mut interval = Duration::from_secs(device_code_response.interval);
480        credential.with_device_code(device_code);
481
482        tokio::spawn(async move {
483            loop {
484                // Wait the amount of seconds that interval is.
485                tokio::time::sleep(interval).await;
486
487                let response = credential.execute_async().await?;
488                let http_response = response.into_http_response_async().await?;
489                let status = http_response.status();
490
491                if status.is_success() {
492                    sender
493                        .send_timeout(http_response, Duration::from_secs(60))
494                        .await?;
495                    break;
496                } else {
497                    let json = http_response.json().unwrap();
498                    let option_error = json["error"].as_str().map(|value| value.to_owned());
499                    sender
500                        .send_timeout(http_response, Duration::from_secs(60))
501                        .await?;
502
503                    if let Some(error) = option_error {
504                        match PollDeviceCodeEvent::from_str(error.as_str()) {
505                            Ok(poll_device_code_type) => match poll_device_code_type {
506                                PollDeviceCodeEvent::AuthorizationPending => continue,
507                                PollDeviceCodeEvent::AuthorizationDeclined => break,
508                                PollDeviceCodeEvent::BadVerificationCode => continue,
509                                PollDeviceCodeEvent::ExpiredToken => break,
510                                PollDeviceCodeEvent::AccessDenied => break,
511                                PollDeviceCodeEvent::SlowDown => {
512                                    // Should slow down is part of the openid connect spec and means that
513                                    // that we should wait longer between polling by the amount specified
514                                    // in the interval field of the device code.
515                                    interval = interval.add(Duration::from_secs(5));
516                                    continue;
517                                }
518                            },
519                            Err(_) => break,
520                        }
521                    } else {
522                        // Body should have error or we should bail.
523                        break;
524                    }
525                }
526            }
527            Ok::<(), anyhow::Error>(())
528        });
529
530        Ok(receiver)
531    }
532
533    #[cfg(feature = "interactive-auth")]
534    pub fn with_interactive_auth(
535        &mut self,
536        options: WebViewOptions,
537    ) -> AuthExecutionResult<(DeviceAuthorizationResponse, DeviceCodeInteractiveAuth)> {
538        let response = self.credential.execute()?;
539        let device_authorization_response: DeviceAuthorizationResponse = response.json()?;
540        self.credential
541            .with_device_code(device_authorization_response.device_code.clone());
542
543        Ok((
544            device_authorization_response.clone(),
545            DeviceCodeInteractiveAuth {
546                credential: self.credential.clone(),
547                interval: Duration::from_secs(device_authorization_response.interval),
548                verification_uri: device_authorization_response.verification_uri.clone(),
549                verification_uri_complete: device_authorization_response.verification_uri_complete,
550                options,
551            },
552        ))
553    }
554}
555
556#[cfg(feature = "interactive-auth")]
557pub(crate) mod internal {
558    use super::*;
559
560    impl WebViewAuth for DeviceCodeCredential {
561        fn webview(
562            host_options: HostOptions,
563            window: &Window,
564            _proxy: EventLoopProxy<UserEvents>,
565        ) -> anyhow::Result<WebView> {
566            Ok(WebViewBuilder::new(window)
567                .with_url(host_options.start_uri.as_ref())
568                // Disables file drop
569                .with_file_drop_handler(|_| true)
570                .with_navigation_handler(move |uri| {
571                    tracing::debug!(target: INTERACTIVE_AUTH, url = uri.as_str());
572                    true
573                })
574                .build()?)
575        }
576    }
577}
578
579#[cfg(feature = "interactive-auth")]
580#[derive(Debug)]
581pub struct DeviceCodeInteractiveAuth {
582    credential: DeviceCodeCredential,
583    interval: Duration,
584    verification_uri: String,
585    verification_uri_complete: Option<String>,
586    options: WebViewOptions,
587}
588
589#[allow(dead_code)]
590#[cfg(feature = "interactive-auth")]
591impl DeviceCodeInteractiveAuth {
592    pub(crate) fn new(
593        credential: DeviceCodeCredential,
594        device_authorization_response: DeviceAuthorizationResponse,
595        options: WebViewOptions,
596    ) -> DeviceCodeInteractiveAuth {
597        DeviceCodeInteractiveAuth {
598            credential,
599            interval: Duration::from_secs(device_authorization_response.interval),
600            verification_uri: device_authorization_response.verification_uri.clone(),
601            verification_uri_complete: device_authorization_response.verification_uri_complete,
602            options,
603        }
604    }
605
606    pub fn poll(
607        &mut self,
608    ) -> Result<PublicClientApplication<DeviceCodeCredential>, WebViewDeviceCodeError> {
609        let url = {
610            if let Some(url_complete) = self.verification_uri_complete.as_ref() {
611                Url::parse(url_complete).map_err(AuthorizationFailure::from)?
612            } else {
613                Url::parse(self.verification_uri.as_str()).map_err(AuthorizationFailure::from)?
614            }
615        };
616
617        let (sender, _receiver) = std::sync::mpsc::channel();
618
619        let options = self.options.clone();
620        std::thread::spawn(move || {
621            DeviceCodeCredential::run(url, vec![], options, sender).unwrap();
622        });
623
624        let credential = self.credential.clone();
625        let interval = self.interval;
626        DeviceCodeInteractiveAuth::poll_internal(interval, credential)
627    }
628
629    pub(crate) fn poll_internal(
630        mut interval: Duration,
631        mut credential: DeviceCodeCredential,
632    ) -> Result<PublicClientApplication<DeviceCodeCredential>, WebViewDeviceCodeError> {
633        loop {
634            // Wait the amount of seconds that interval is.
635            std::thread::sleep(interval);
636
637            let response = credential.execute().unwrap();
638            let http_response = response.into_http_response().map_err(Box::new)?;
639            let status = http_response.status();
640
641            if status.is_success() {
642                return if let Some(json) = http_response.json() {
643                    let token: Token = serde_json::from_value(json)
644                        .map_err(|err| Box::new(AuthExecutionError::from(err)))?;
645                    let cache_id = credential.app_config.cache_id.clone();
646                    credential.token_cache.store(cache_id, token);
647                    Ok(PublicClientApplication::from(credential))
648                } else {
649                    Err(WebViewDeviceCodeError::DeviceCodePollingError(
650                        http_response,
651                    ))
652                };
653            } else {
654                let json = http_response.json().unwrap();
655                let option_error = json["error"].as_str().map(|value| value.to_owned());
656
657                if let Some(error) = option_error {
658                    match PollDeviceCodeEvent::from_str(error.as_str()) {
659                        Ok(poll_device_code_type) => match poll_device_code_type {
660                            PollDeviceCodeEvent::AuthorizationPending
661                            | PollDeviceCodeEvent::BadVerificationCode => continue,
662                            PollDeviceCodeEvent::SlowDown => {
663                                interval = interval.add(Duration::from_secs(5));
664                                continue;
665                            }
666                            PollDeviceCodeEvent::AuthorizationDeclined
667                            | PollDeviceCodeEvent::ExpiredToken
668                            | PollDeviceCodeEvent::AccessDenied => {
669                                return Err(WebViewDeviceCodeError::DeviceCodePollingError(
670                                    http_response,
671                                ));
672                            }
673                        },
674                        Err(_) => {
675                            return Err(WebViewDeviceCodeError::DeviceCodePollingError(
676                                http_response,
677                            ));
678                        }
679                    }
680                } else {
681                    // Body should have error or we should bail.
682                    return Err(WebViewDeviceCodeError::DeviceCodePollingError(
683                        http_response,
684                    ));
685                }
686            }
687        }
688    }
689}
690
691#[cfg(test)]
692mod test {
693    use super::*;
694
695    #[test]
696    #[should_panic]
697    fn no_scope() {
698        let mut credential = DeviceCodeCredential::builder("CLIENT_ID").build();
699
700        let _ = credential.form_urlencode().unwrap();
701    }
702}