azure_identity_helpers/
devicecode_credentials.rs1use crate::{cache::TokenCache, device_code::start, refresh_token::exchange};
2use async_lock::Mutex;
3use azure_core::{
4 credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
5 error::{Error, ErrorKind},
6};
7use futures::stream::StreamExt;
8use std::{collections::BTreeMap, str, sync::Arc, time::Duration};
9use time::OffsetDateTime;
10use tracing::debug;
11
12#[derive(Debug)]
13pub struct DeviceCodeCredential {
15 tenant_id: String,
16 client_id: String,
17 cache: TokenCache,
18 refresh_tokens: Mutex<BTreeMap<Vec<String>, Secret>>,
19}
20
21impl DeviceCodeCredential {
22 pub fn new<T, C>(tenant_id: T, client_id: C) -> azure_core::Result<Arc<Self>>
24 where
25 T: Into<String>,
26 C: Into<String>,
27 {
28 Ok(Arc::new(Self {
29 tenant_id: tenant_id.into(),
30 client_id: client_id.into(),
31 cache: TokenCache::new(),
32 refresh_tokens: Mutex::new(BTreeMap::new()),
33 }))
34 }
35
36 async fn get_access_token(
37 &self,
38 scopes: &[&str],
39 _options: Option<TokenRequestOptions<'_>>,
40 ) -> azure_core::Result<AccessToken> {
41 let scopes_owned = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
42 let mut refresh_tokens = self.refresh_tokens.lock().await;
43 if let Some(refresh_token) = refresh_tokens.remove(&scopes_owned) {
44 let response = exchange(
45 self.tenant_id.as_str(),
46 &self.client_id,
47 None,
48 &refresh_token,
49 )
50 .await?;
51 let token = AccessToken {
52 token: response.access_token().to_owned(),
53 expires_on: convert_expires_in(response.expires_in()),
54 };
55 refresh_tokens.insert(scopes_owned, response.refresh_token().to_owned());
56 return Ok(token);
57 }
58
59 let flow = start(self.tenant_id.clone(), self.client_id.as_str(), scopes).await?;
60
61 eprintln!("{}", flow.message());
62
63 let mut stream = flow.stream();
64 let mut last_error: Option<Error> = None;
65 let auth = loop {
66 let Some(response) = stream.next().await else {
67 return Err(last_error.unwrap_or_else(|| {
73 Error::with_message(
74 ErrorKind::Credential,
75 "device code did not return a response",
76 )
77 }));
78 };
79 match response {
80 Ok(auth) => break auth,
81 Err(err) => {
82 debug!("device code poll returned error, will continue if recoverable: {err}");
83 last_error = Some(err);
84 }
85 }
86 };
87
88 let token = AccessToken {
89 token: auth.access_token().to_owned(),
90 expires_on: convert_expires_in(auth.expires_in),
91 };
92
93 if let Some(refresh_token) = auth.refresh_token() {
94 refresh_tokens.insert(scopes_owned, refresh_token.to_owned());
95 }
96 Ok(token)
97 }
98}
99
100#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
101#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
102impl TokenCredential for DeviceCodeCredential {
103 async fn get_token(
104 &self,
105 scopes: &[&str],
106 options: Option<TokenRequestOptions<'_>>,
107 ) -> azure_core::Result<AccessToken> {
108 self.cache
109 .get_token(scopes, options, |s, o| self.get_access_token(s, o))
110 .await
111 }
112}
113
114fn convert_expires_in(seconds: u64) -> OffsetDateTime {
115 OffsetDateTime::now_utc() + Duration::new(seconds, 0)
116}