use crate::Result;
use crate::credentials::{CacheableResource, EntityTag};
use crate::token::{CachedTokenProvider, Token, TokenProvider};
use http::Extensions;
use std::sync::Arc;
use tokio::sync::watch;
use tokio::time::{Duration, Instant, sleep};
const NORMAL_REFRESH_SLACK: Duration = Duration::from_secs(240);
const SHORT_REFRESH_SLACK: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub(crate) struct TokenCache {
rx_token: watch::Receiver<Option<Result<(Token, EntityTag)>>>,
}
impl TokenCache {
pub(crate) fn new<T>(inner: T) -> Self
where
T: TokenProvider + Send + Sync + 'static,
{
let (tx_token, rx_token) = watch::channel::<Option<Result<(Token, EntityTag)>>>(None);
let token_provider = Arc::new(inner);
tokio::spawn(refresh_task(token_provider, tx_token));
Self { rx_token }
}
async fn latest_token_and_entity_tag(&self) -> Result<(Token, EntityTag)> {
let mut rx = self.rx_token.clone();
let token_result = rx.borrow_and_update().clone();
if let Some(token_result) = token_result {
match token_result {
Ok((token, tag)) => match token.expires_at {
None => Ok((token, tag)),
Some(e) => {
if e < Instant::now() {
wait_for_next_token(rx).await
} else {
Ok((token, tag))
}
}
},
Err(e) => Err(e),
}
} else {
wait_for_next_token(rx).await
}
}
}
#[async_trait::async_trait]
impl CachedTokenProvider for TokenCache {
async fn token(&self, extensions: Extensions) -> Result<CacheableResource<Token>> {
let (data, entity_tag) = self.latest_token_and_entity_tag().await?;
match extensions.get::<EntityTag>() {
Some(tag) if entity_tag.eq(tag) => Ok(CacheableResource::NotModified),
_ => Ok(CacheableResource::New { entity_tag, data }),
}
}
}
async fn wait_for_next_token(
mut rx_token: watch::Receiver<Option<Result<(Token, EntityTag)>>>,
) -> Result<(Token, EntityTag)> {
rx_token.changed().await.unwrap();
let token_result = rx_token.borrow().clone();
token_result.expect("There should always be a token or error in the channel after changed()")
}
async fn refresh_task<T>(
token_provider: Arc<T>,
tx_token: watch::Sender<Option<Result<(Token, EntityTag)>>>,
) where
T: TokenProvider + Send + Sync + 'static,
{
loop {
let token_result = token_provider.token().await;
let expiry = token_result.as_ref().map(|t| t.expires_at);
let tagged = token_result.clone().map(|token| {
let entity_tag = EntityTag::new();
(token, entity_tag)
});
let _ = tx_token.send(Some(tagged));
match expiry {
Ok(Some(expiry)) => {
let time_until_expiry = expiry.checked_duration_since(Instant::now());
match time_until_expiry {
None => {
}
Some(time_until_expiry) => {
if time_until_expiry > NORMAL_REFRESH_SLACK {
sleep(time_until_expiry - NORMAL_REFRESH_SLACK).await;
} else if time_until_expiry > SHORT_REFRESH_SLACK {
sleep(SHORT_REFRESH_SLACK).await;
}
}
}
}
Ok(None) => {
break;
}
Err(err) => {
if !err.is_transient() {
break;
}
sleep(SHORT_REFRESH_SLACK).await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors;
use crate::token::tests::MockTokenProvider;
use google_cloud_gax::error::CredentialsError;
use std::ops::{Add, Sub};
use std::sync::{Arc, Mutex};
use tokio::time::{Duration, Instant};
static TOKEN_VALID_DURATION: Duration = Duration::from_secs(3600);
type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
fn get_cached_token(cache: CacheableResource<Token>) -> Result<Token> {
match cache {
CacheableResource::New { data, .. } => Ok(data),
CacheableResource::NotModified => Err(CredentialsError::from_msg(
false,
"Expecting token to be present.",
)),
}
}
#[tokio::test]
async fn initial_token_success() -> TestResult {
let expected = Token {
token: "test-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: None,
metadata: None,
};
let expected_clone = expected.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(expected_clone));
let cache = TokenCache::new(mock);
let mut extensions = Extensions::new();
let cached_token = cache.token(extensions.clone()).await.unwrap();
let (actual, entity_tag) = match cached_token {
CacheableResource::New { entity_tag, data } => (data, entity_tag),
CacheableResource::NotModified => unreachable!("expecting new headers"),
};
assert_eq!(actual, expected);
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual, expected);
extensions.insert(entity_tag);
let cached_token = cache.token(extensions).await?;
match cached_token {
CacheableResource::New { .. } => unreachable!("expecting new headers"),
CacheableResource::NotModified => CacheableResource::<Token>::NotModified,
};
Ok(())
}
#[tokio::test]
async fn initial_token_failure() {
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.returning(|| Err(errors::non_retryable_from_str("fail")));
let cache = TokenCache::new(mock);
let result = cache.token(Extensions::new()).await;
assert!(result.is_err(), "{result:?}");
let result = cache.token(Extensions::new()).await;
assert!(result.is_err(), "{result:?}");
}
#[tokio::test(start_paused = true)]
async fn expired_token_success() -> TestResult {
let now = Instant::now();
let initial = Token {
token: "initial-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};
let initial_clone = initial.clone();
let refresh = Token {
token: "refreshed-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 2 * TOKEN_VALID_DURATION),
metadata: None,
};
let refresh_clone = refresh.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(initial_clone));
mock.expect_token()
.times(1)
.return_once(|| Ok(refresh_clone));
let cache = TokenCache::new(mock);
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual, initial);
let sleep = TOKEN_VALID_DURATION.add(Duration::from_secs(100));
tokio::time::advance(sleep).await;
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual, refresh);
Ok(())
}
#[tokio::test(start_paused = true)]
async fn expired_token_failure() -> TestResult {
let now = Instant::now();
let initial = Token {
token: "initial-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};
let initial_clone = initial.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(initial_clone));
mock.expect_token()
.times(1)
.return_once(|| Err(errors::non_retryable_from_str("fail")));
let cache = TokenCache::new(mock);
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual, initial);
let sleep = TOKEN_VALID_DURATION.add(Duration::from_secs(100));
tokio::time::advance(sleep).await;
let result = cache.token(Extensions::new()).await;
assert!(result.is_err(), "{result:?}");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn token_cache_multiple_requests_existing_valid_token() -> TestResult {
let now = Instant::now();
let token = Token {
token: "initial-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};
let token_clone = token.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token().times(1).return_once(|| Ok(token_clone));
let cache = TokenCache::new(mock);
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual, token);
let tasks = (0..1000)
.map(|_| {
let cache_clone = cache.clone();
tokio::spawn(async move { cache_clone.token(Extensions::new()).await })
})
.collect::<Vec<_>>();
for task in tasks {
let actual = task.await??;
assert_eq!(get_cached_token(actual)?, token);
}
Ok(())
}
#[tokio::test]
async fn refresh_task_expired_token_loop() {
let now = Instant::now();
let token1 = Token {
token: "token1".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now),
metadata: None,
};
let token1_clone = token1.clone();
let token2 = Token {
token: "token2".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};
let token2_clone = token2.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(token1_clone));
mock.expect_token()
.times(1)
.return_once(|| Ok(token2_clone));
let (tx, mut rx) = watch::channel::<Option<Result<(Token, EntityTag)>>>(None);
tokio::spawn(async move {
refresh_task(Arc::new(mock), tx).await;
});
sleep(Duration::from_millis(100)).await;
rx.changed().await.unwrap();
assert!(Instant::now() <= now + Duration::from_millis(500));
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token2.clone());
}
#[tokio::test(start_paused = true)]
async fn refresh_task_loop() {
let now = Instant::now();
let token1 = Token {
token: "token1".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};
let token1_clone = token1.clone();
let token2 = Token {
token: "token2".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 2 * TOKEN_VALID_DURATION),
metadata: None,
};
let token2_clone = token2.clone();
let token3 = Token {
token: "token3".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 3 * TOKEN_VALID_DURATION),
metadata: None,
};
let token3_clone = token3.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(token1_clone));
mock.expect_token()
.times(1)
.return_once(|| Ok(token2_clone));
mock.expect_token()
.times(1)
.return_once(|| Ok(token3_clone));
let (tx, mut rx) = watch::channel::<Option<Result<(Token, EntityTag)>>>(None);
let actual = rx.borrow().clone();
assert!(actual.is_none(), "{actual:?}");
tokio::spawn(async move {
refresh_task(Arc::new(mock), tx).await;
});
rx.changed().await.unwrap();
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token1.clone());
let sleep = Duration::from_secs(120);
tokio::time::advance(sleep).await;
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token1.clone());
tokio::time::advance(TOKEN_VALID_DURATION.sub(Duration::from_secs(300))).await;
rx.changed().await.unwrap();
assert!(Instant::now() < now + TOKEN_VALID_DURATION);
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token2);
let sleep = TOKEN_VALID_DURATION.add(Duration::from_secs(500));
tokio::time::advance(sleep).await;
rx.changed().await.unwrap();
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token3);
}
#[tokio::test(start_paused = true)]
async fn refresh_task_loop_less_expiry() {
let now = Instant::now();
let token1 = Token {
token: "token1".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + Duration::from_secs(120)),
metadata: None,
};
let token1_clone = token1.clone();
let token1_clone2 = token1.clone();
let token2 = Token {
token: "token2".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 2 * TOKEN_VALID_DURATION),
metadata: None,
};
let token2_clone = token2.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(token1_clone));
mock.expect_token()
.times(1)
.return_once(|| Ok(token1_clone2));
mock.expect_token()
.times(1)
.return_once(|| Ok(token2_clone));
let (tx, mut rx) = watch::channel::<Option<Result<(Token, EntityTag)>>>(None);
let actual = rx.borrow().clone();
assert!(actual.is_none(), "{actual:?}");
tokio::spawn(async move {
refresh_task(Arc::new(mock), tx).await;
});
rx.changed().await.unwrap();
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token1);
tokio::time::advance(Duration::from_secs(10)).await;
assert!(Instant::now() < now + Duration::from_secs(11));
rx.changed().await.unwrap();
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token1);
tokio::time::advance(Duration::from_secs(100)).await;
rx.changed().await.unwrap();
assert!(Instant::now() < now + Duration::from_secs(111));
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token2);
}
#[tokio::test(start_paused = true)]
async fn refresh_task_loop_long_expiry_waits_long_time_before_refresh() {
let now = Instant::now();
let token1 = Token {
token: "token1".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 3 * NORMAL_REFRESH_SLACK),
metadata: None,
};
let token1_clone = token1.clone();
let token2 = Token {
token: "token2".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 2 * TOKEN_VALID_DURATION),
metadata: None,
};
let token2_clone = token2.clone();
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(|| Ok(token1_clone));
mock.expect_token()
.times(1)
.return_once(|| Ok(token2_clone));
let (tx, mut rx) = watch::channel::<Option<Result<(Token, EntityTag)>>>(None);
let actual = rx.borrow().clone();
assert!(actual.is_none(), "{actual:?}");
tokio::spawn(async move {
refresh_task(Arc::new(mock), tx).await;
});
rx.changed().await.unwrap();
tokio::time::advance(NORMAL_REFRESH_SLACK).await;
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token1);
tokio::time::advance(NORMAL_REFRESH_SLACK).await;
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token1);
tokio::time::advance(2 * NORMAL_REFRESH_SLACK).await;
let (actual, ..) = rx.borrow().clone().unwrap().unwrap();
assert_eq!(actual, token2);
}
#[tokio::test(start_paused = true)]
async fn refresh_task_sleeps_on_transient_error_and_recovers_on_next_loop() -> TestResult {
let now = Instant::now();
let token = Token {
token: "token-1".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + TOKEN_VALID_DURATION),
metadata: None,
};
let mut mock = MockTokenProvider::new();
mock.expect_token()
.times(1)
.return_once(move || Ok(token.clone()));
mock.expect_token()
.times(1)
.return_once(|| Err(CredentialsError::from_msg(true, "transient error")));
let token = Token {
token: "token-2".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(now + 2 * TOKEN_VALID_DURATION),
metadata: None,
};
mock.expect_token()
.times(1)
.return_once(move || Ok(token.clone()));
let cache = TokenCache::new(mock);
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual.token, "token-1");
let sleep = TOKEN_VALID_DURATION.add(Duration::from_secs(10));
tokio::time::advance(sleep).await;
let result = cache.token(Extensions::new()).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("transient error"));
tokio::time::advance(SHORT_REFRESH_SLACK.add(Duration::from_secs(10))).await;
tokio::task::yield_now().await;
let actual = get_cached_token(cache.token(Extensions::new()).await.unwrap())?;
assert_eq!(actual.token, "token-2");
Ok(())
}
#[derive(Clone, Debug)]
struct FakeTokenProvider {
result: Result<Token>,
calls: Arc<Mutex<i32>>,
}
impl FakeTokenProvider {
pub fn new(result: Result<Token>) -> Self {
FakeTokenProvider {
result,
calls: Arc::new(Mutex::new(0)),
}
}
pub fn calls(&self) -> i32 {
*self.calls.lock().unwrap()
}
}
#[async_trait::async_trait]
impl TokenProvider for FakeTokenProvider {
async fn token(&self) -> Result<Token> {
sleep(Duration::from_millis(50)).await;
*self.calls.lock().unwrap() += 1;
self.result.clone()
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn no_initial_token_thundering_herd_success() -> TestResult {
let token = Token {
token: "delayed-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Some(Instant::now()),
metadata: None,
};
let tp = FakeTokenProvider::new(Ok(token.clone()));
let cache = TokenCache::new(tp.clone());
let tasks = (0..100)
.map(|_| {
let cache_clone = cache.clone();
tokio::spawn(async move { cache_clone.token(Extensions::new()).await })
})
.collect::<Vec<_>>();
for task in tasks {
let actual = task.await??;
assert_eq!(get_cached_token(actual)?, token);
}
let calls = tp.calls();
assert!(calls < 10, "calls to inner token provider: {calls}");
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn no_initial_token_thundering_herd_failure_shares_error() -> TestResult {
let err = Err(errors::non_retryable_from_str("epic fail"));
let tp = FakeTokenProvider::new(err);
let cache = TokenCache::new(tp.clone());
let tasks = (0..100)
.map(|_| {
let cache_clone = cache.clone();
tokio::spawn(async move { cache_clone.token(Extensions::new()).await })
})
.collect::<Vec<_>>();
for task in tasks {
let actual = task.await?;
assert!(actual.is_err(), "{actual:?}");
let e = format!("{}", actual.unwrap_err());
assert!(e.contains("epic fail"), "{e}");
}
let calls = tp.calls();
assert!(calls < 10, "calls to inner token provider: {calls}");
Ok(())
}
#[tokio::test]
async fn debug_token_cache() {
let mut mock_provider = MockTokenProvider::new();
mock_provider.expect_token().return_const(Ok(Token {
token: "test-token".to_string(),
token_type: "Bearer".to_string(),
expires_at: None,
metadata: None,
}));
let cache = TokenCache::new(mock_provider);
let debug_output = format!("{cache:?}");
assert!(debug_output.contains("TokenCache"));
assert!(debug_output.contains("rx_token"));
}
}