use crate::auth::{JwtClaims, Token};
use crate::credentials::Credentials;
use crate::{get_token_with_client, Result};
use arc_swap::ArcSwapOption;
use reqwest::Client;
use smpl_jwt::Jwt;
use std::sync::Arc;
use time::{Duration, OffsetDateTime};
pub struct TokenFetcher {
client: Client,
jwt: Jwt<JwtClaims>,
credentials: Credentials,
token_state: ArcSwapOption<TokenState>,
refresh_buffer: Duration,
}
struct TokenState {
token: Token,
refresh_at: OffsetDateTime,
}
impl TokenFetcher {
pub fn new(
jwt: Jwt<JwtClaims>,
credentials: Credentials,
refresh_buffer: Duration,
) -> TokenFetcher {
TokenFetcher::with_client(Client::new(), jwt, credentials, refresh_buffer)
}
pub fn with_client(
client: Client,
jwt: Jwt<JwtClaims>,
credentials: Credentials,
refresh_buffer: Duration,
) -> TokenFetcher {
let token_state = ArcSwapOption::from(None);
TokenFetcher {
client,
jwt,
credentials,
token_state,
refresh_buffer,
}
}
pub async fn fetch_token(&self) -> Result<Token> {
let token_state = self.token_state.load();
match &*token_state {
None => self.get_token().await,
Some(token_state) => {
let now = OffsetDateTime::now_utc();
if now >= token_state.refresh_at {
self.get_token().await
} else {
Ok(token_state.token.clone())
}
}
}
}
async fn get_token(&self) -> Result<Token> {
let now = OffsetDateTime::now_utc();
let token = get_token_with_client(&self.client, &self.jwt, &self.credentials).await?;
let expires_in = Duration::new(token.expires_in().into(), 0);
assert!(
expires_in >= self.refresh_buffer,
"Received a token whose expires_in is less than the configured refresh buffer!"
);
let refresh_at = now + (expires_in - self.refresh_buffer);
let token_state = TokenState {
token: token.clone(),
refresh_at,
};
self.token_state.swap(Some(Arc::new(token_state)));
Ok(token)
}
}
#[cfg(test)]
mod tests {
use crate::auth::{JwtClaims, Token};
use crate::credentials::Credentials;
use crate::fetcher::TokenFetcher;
use crate::scopes::Scope;
use mockito::{self, mock};
use smpl_jwt::Jwt;
use std::thread;
use std::time::Duration as StdDuration;
use time::Duration;
fn get_mocks() -> (Jwt<JwtClaims>, Credentials) {
let token_url = mockito::server_url();
let iss = "some_iss";
let mut credentials =
Credentials::from_file("dummy_credentials_file_for_tests.json").unwrap();
credentials.token_uri = token_url.clone();
let claims = JwtClaims::new(
String::from(iss),
&Scope::DevStorageReadWrite,
String::from(token_url.clone()),
None,
None,
);
let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
(jwt, credentials)
}
fn token_json(access_token: &str, token_type: &str, expires_in: u32) -> (Token, String) {
let json = serde_json::json!({
"access_token": access_token.to_string(),
"token_type": token_type.to_string(),
"expires_in": expires_in
});
let token = serde_json::from_value(json.clone()).unwrap();
(token, json.to_string())
}
#[tokio::test]
async fn basic_token_fetch() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = Duration::new(0, 0);
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let (expected_token, json) = token_json("token", "Bearer", 1);
let _mock = mock("POST", "/").with_status(200).with_body(json).create();
let token = fetcher.fetch_token().await.unwrap();
assert_eq!(expected_token, token);
}
#[tokio::test]
async fn basic_token_refresh() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = Duration::new(0, 0);
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let expires_in = 1;
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
let mock = mock("POST", "/")
.with_status(200)
.with_body(json)
.expect(2) .create();
fetcher.fetch_token().await.unwrap();
thread::sleep(StdDuration::from_secs(expires_in.into()));
fetcher.fetch_token().await.unwrap();
mock.assert();
}
#[tokio::test]
async fn token_refresh_with_buffer() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = 4;
let fetcher = TokenFetcher::new(jwt, credentials, Duration::new(refresh_buffer, 0));
let expires_in = 5;
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
let mock = mock("POST", "/")
.with_status(200)
.with_body(json)
.expect(2) .create();
fetcher.fetch_token().await.unwrap();
let sleep_for = expires_in - (refresh_buffer as u32);
thread::sleep(StdDuration::from_secs(sleep_for.into()));
fetcher.fetch_token().await.unwrap();
mock.assert();
}
#[tokio::test]
async fn doesnt_token_refresh_unnecessarily() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = Duration::new(0, 0);
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let expires_in = 1;
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
let mock = mock("POST", "/")
.with_status(200)
.with_body(json)
.expect(1) .create();
fetcher.fetch_token().await.unwrap();
fetcher.fetch_token().await.unwrap();
mock.assert();
}
}