Skip to main content

azure_identity_helpers/
devicecode_credentials.rs

1use crate::{cache::TokenCache, device_code::start, refresh_token::exchange};
2use async_lock::Mutex;
3use azure_core::{
4    credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
5    error::{Error, ErrorKind},
6};
7use futures::stream::StreamExt;
8use std::{collections::BTreeMap, str, sync::Arc, time::Duration};
9use time::OffsetDateTime;
10use tracing::debug;
11
12#[derive(Debug)]
13/// Enables authentication to an Azure Client using a Device Code workflow.
14pub struct DeviceCodeCredential {
15    tenant_id: String,
16    client_id: String,
17    cache: TokenCache,
18    refresh_tokens: Mutex<BTreeMap<Vec<String>, Secret>>,
19}
20
21impl DeviceCodeCredential {
22    /// Create a new `DeviceCodeCredential` with the specified tenant ID, client ID, and options.
23    pub fn new<T, C>(tenant_id: T, client_id: C) -> azure_core::Result<Arc<Self>>
24    where
25        T: Into<String>,
26        C: Into<String>,
27    {
28        Ok(Arc::new(Self {
29            tenant_id: tenant_id.into(),
30            client_id: client_id.into(),
31            cache: TokenCache::new(),
32            refresh_tokens: Mutex::new(BTreeMap::new()),
33        }))
34    }
35
36    async fn get_access_token(
37        &self,
38        scopes: &[&str],
39        _options: Option<TokenRequestOptions<'_>>,
40    ) -> azure_core::Result<AccessToken> {
41        let scopes_owned = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
42        let mut refresh_tokens = self.refresh_tokens.lock().await;
43        if let Some(refresh_token) = refresh_tokens.remove(&scopes_owned) {
44            let response = exchange(
45                self.tenant_id.as_str(),
46                &self.client_id,
47                None,
48                &refresh_token,
49            )
50            .await?;
51            let token = AccessToken {
52                token: response.access_token().to_owned(),
53                expires_on: convert_expires_in(response.expires_in()),
54            };
55            refresh_tokens.insert(scopes_owned, response.refresh_token().to_owned());
56            return Ok(token);
57        }
58
59        let flow = start(self.tenant_id.clone(), self.client_id.as_str(), scopes).await?;
60
61        eprintln!("{}", flow.message());
62
63        let mut stream = flow.stream();
64        let mut last_error: Option<Error> = None;
65        let auth = loop {
66            let Some(response) = stream.next().await else {
67                // The polling stream ended without yielding a successful
68                // authorization. Surface the most recent error from the
69                // server (e.g. `expired_token`, `access_denied`) instead
70                // of a generic message — that's almost always what the
71                // caller actually needs to see.
72                return Err(last_error.unwrap_or_else(|| {
73                    Error::with_message(
74                        ErrorKind::Credential,
75                        "device code did not return a response",
76                    )
77                }));
78            };
79            match response {
80                Ok(auth) => break auth,
81                Err(err) => {
82                    debug!("device code poll returned error, will continue if recoverable: {err}");
83                    last_error = Some(err);
84                }
85            }
86        };
87
88        let token = AccessToken {
89            token: auth.access_token().to_owned(),
90            expires_on: convert_expires_in(auth.expires_in),
91        };
92
93        if let Some(refresh_token) = auth.refresh_token() {
94            refresh_tokens.insert(scopes_owned, refresh_token.to_owned());
95        }
96        Ok(token)
97    }
98}
99
100#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
101#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
102impl TokenCredential for DeviceCodeCredential {
103    async fn get_token(
104        &self,
105        scopes: &[&str],
106        options: Option<TokenRequestOptions<'_>>,
107    ) -> azure_core::Result<AccessToken> {
108        self.cache
109            .get_token(scopes, options, |s, o| self.get_access_token(s, o))
110            .await
111    }
112}
113
114fn convert_expires_in(seconds: u64) -> OffsetDateTime {
115    OffsetDateTime::now_utc() + Duration::new(seconds, 0)
116}