use std::sync::Arc;
use anyhow::Context;
use chrono::{Duration, Utc};
use tokio::sync::Mutex;
use tracing::debug;
use super::{Token, TokenProvider};
#[derive(Clone)]
pub struct RenewingTokenProvider<T> {
provider: T,
token: Arc<Mutex<Arc<Token>>>,
}
impl<T: TokenProvider> RenewingTokenProvider<T> {
pub fn new(provider: T) -> Self {
Self {
provider,
token: Default::default(),
}
}
}
#[async_trait::async_trait]
impl<T: TokenProvider> TokenProvider for RenewingTokenProvider<T> {
async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
let mut token = self.token.lock().await;
if Utc::now() >= token.expiry - Duration::minutes(10) {
debug!(expiry = %token.expiry, now = %Utc::now(), "renewing token");
let new_token = self
.provider
.get_token()
.await
.context("failed to renew token")?;
*token = new_token;
debug!("successfully renewed token");
Ok(Arc::clone(&token))
} else {
Ok(token.clone())
}
}
async fn invalidate_token(&self) {
let mut token = self.token.lock().await;
*token = Default::default();
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::token_provider::MockTokenProvider;
#[tokio::test]
async fn renewing_provider_renews_expired() {
let new_token = Arc::new(Token {
token: "dummy token".into(),
expiry: Utc::now() + Duration::minutes(30),
});
let new_token_return = new_token.clone();
let mut mock_provider = MockTokenProvider::new();
mock_provider
.expect_get_token()
.times(1)
.returning(move || Ok(new_token_return.clone()));
let renewing = RenewingTokenProvider::new(mock_provider);
assert_eq!(renewing.get_token().await.unwrap(), new_token);
assert_eq!(renewing.get_token().await.unwrap(), new_token);
}
#[tokio::test]
async fn renewing_provider_renews_almost_expired() {
let shortly_expiring_token = Arc::new(Token {
token: "dummy token".into(),
expiry: Utc::now() + Duration::minutes(1),
});
let new_token = Arc::new(Token {
token: "dummy token".into(),
expiry: Utc::now() + Duration::minutes(30),
});
let mut mock_provider = MockTokenProvider::new();
let mut seq = mockall::Sequence::new();
mock_provider
.expect_get_token()
.times(1)
.in_sequence(&mut seq)
.returning(move || Ok(shortly_expiring_token.clone()));
mock_provider
.expect_get_token()
.times(1)
.in_sequence(&mut seq)
.returning(move || Ok(new_token.clone()));
let renewing = RenewingTokenProvider::new(mock_provider);
renewing.get_token().await.unwrap();
renewing.get_token().await.unwrap();
renewing.get_token().await.unwrap();
}
}