use async_lock::RwLock;
use azure_core::credentials::{AccessToken, TokenRequestOptions};
use azure_core::time::{Duration, OffsetDateTime};
use std::collections::HashMap;
use std::future::Future;
use tracing::trace;
#[derive(Debug)]
pub(crate) struct TokenCache(RwLock<HashMap<Vec<String>, AccessToken>>);
impl TokenCache {
pub(crate) fn new() -> Self {
Self(RwLock::new(HashMap::new()))
}
pub(crate) async fn get_token<'a, C, F>(
&self,
scopes: &'a [&'a str],
options: Option<TokenRequestOptions<'a>>,
callback: C,
) -> azure_core::Result<AccessToken>
where
C: FnOnce(&'a [&'a str], Option<TokenRequestOptions<'a>>) -> F + Send,
F: Future<Output = azure_core::Result<AccessToken>> + Send,
{
let token_cache = self.0.read().await;
let scopes_owned = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
if let Some(token) = token_cache.get(&scopes_owned) {
if !should_refresh(token) {
trace!("returning cached token");
return Ok(token.clone());
}
}
drop(token_cache);
let mut token_cache = self.0.write().await;
if let Some(token) = token_cache.get(&scopes_owned) {
if !should_refresh(token) {
trace!("returning token that was updated while waiting on write lock");
return Ok(token.clone());
}
}
trace!("token cache miss");
let token = callback(scopes, options).await?;
token_cache.insert(scopes_owned, token.clone());
Ok(token)
}
}
impl Default for TokenCache {
fn default() -> Self {
TokenCache::new()
}
}
fn should_refresh(token: &AccessToken) -> bool {
token.expires_on <= OffsetDateTime::now_utc() + Duration::seconds(300)
}
#[cfg(test)]
mod tests {
use super::*;
use async_lock::Mutex;
use azure_core::{
credentials::Secret,
time::{Duration, OffsetDateTime},
};
#[derive(Debug)]
struct MockCredential {
token: AccessToken,
get_token_call_count: Mutex<usize>,
}
impl MockCredential {
fn new(token: AccessToken) -> Self {
Self {
token,
get_token_call_count: Mutex::new(0),
}
}
async fn get_token(
&self,
scopes: &[&str],
_: Option<TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
let mut call_count = self.get_token_call_count.lock().await;
*call_count += 1;
Ok(AccessToken {
token: Secret::new(format!(
"{}-{}:{}",
scopes.join(" "),
self.token.token.secret(),
*call_count
)),
expires_on: self.token.expires_on,
})
}
}
const STORAGE_TOKEN_SCOPE: &str = "https://storage.azure.com/";
const IOTHUB_TOKEN_SCOPE: &str = "https://iothubs.azure.net";
#[tokio::test]
async fn test_get_token_different_resources() -> azure_core::Result<()> {
let resource1 = &[STORAGE_TOKEN_SCOPE];
let resource2 = &[IOTHUB_TOKEN_SCOPE];
let secret_string = "test-token";
let expires_on = OffsetDateTime::now_utc() + Duration::seconds(3600);
let access_token = AccessToken::new(Secret::new(secret_string), expires_on);
let mock_credential = MockCredential::new(access_token);
let cache = TokenCache::new();
let token1 = cache
.get_token(resource1, None, |s, o| mock_credential.get_token(s, o))
.await?;
let token2 = cache
.get_token(resource1, None, |s, o| mock_credential.get_token(s, o))
.await?;
let expected_token = format!("{}-{}:1", resource1.join(" "), secret_string);
assert_eq!(token1.token.secret(), expected_token);
assert_eq!(token2.token.secret(), expected_token);
let token3 = cache
.get_token(resource2, None, |s, o| mock_credential.get_token(s, o))
.await?;
let token4 = cache
.get_token(resource2, None, |s, o| mock_credential.get_token(s, o))
.await?;
let expected_token = format!("{}-{}:2", resource2.join(" "), secret_string);
assert_eq!(token3.token.secret(), expected_token);
assert_eq!(token4.token.secret(), expected_token);
Ok(())
}
#[tokio::test]
async fn test_refresh_expired_token() -> azure_core::Result<()> {
let resource = &[STORAGE_TOKEN_SCOPE];
let access_token = "test-token";
let expires_on = OffsetDateTime::now_utc();
let token_response = AccessToken::new(Secret::new(access_token), expires_on);
let mock_credential = MockCredential::new(token_response);
let cache = TokenCache::new();
for i in 1..5 {
let token = cache
.get_token(resource, None, |s, o| mock_credential.get_token(s, o))
.await?;
assert_eq!(
token.token.secret(),
format!("{}-{}:{}", resource.join(" "), access_token, i)
);
}
Ok(())
}
}