use crate::auth::{JwtClaims, Token};
use crate::credentials::Credentials;
use crate::{Result, get_token_with_client_and_body};
use arc_swap::ArcSwapOption;
use smpl_jwt::Jwt;
use std::sync::{Arc, Mutex};
use time::{Duration, OffsetDateTime};
pub struct TokenFetcher {
jwt: Arc<Mutex<Jwt<JwtClaims>>>,
credentials: Credentials,
token_state: ArcSwapOption<TokenState>,
refresh_buffer: Duration,
}
struct TokenState {
token: Token,
refresh_at: OffsetDateTime,
}
impl TokenFetcher {
pub fn new(
jwt: Jwt<JwtClaims>,
credentials: Credentials,
refresh_buffer_seconds: i64,
) -> TokenFetcher {
TokenFetcher::with_client(jwt, credentials, Duration::new(refresh_buffer_seconds, 0))
}
pub fn with_client(
jwt: Jwt<JwtClaims>,
credentials: Credentials,
refresh_buffer: Duration,
) -> TokenFetcher {
let token_state = ArcSwapOption::from(None);
TokenFetcher {
jwt: Arc::new(Mutex::new(jwt)),
credentials,
token_state,
refresh_buffer,
}
}
pub fn fetch_token(&self) -> Result<Token> {
let token_state = self.token_state.load();
match &*token_state {
None => self.get_token(),
Some(token_state) => {
let now = OffsetDateTime::now_utc();
if now >= token_state.refresh_at {
self.get_token()
} else {
Ok(token_state.token.clone())
}
}
}
}
fn get_token(&self) -> Result<Token> {
let now = OffsetDateTime::now_utc();
let jwt_body = self.get_jwt_body(now)?;
let token = get_token_with_client_and_body(jwt_body, &self.credentials)?;
let expires_in = Duration::new(token.expires_in().into(), 0);
assert!(
expires_in >= self.refresh_buffer,
"Received a token whose expires_in is less than the configured refresh buffer!"
);
let refresh_at = now + (expires_in - self.refresh_buffer);
let token_state = TokenState {
token: token.clone(),
refresh_at,
};
self.token_state.swap(Some(Arc::new(token_state)));
Ok(token)
}
#[allow(clippy::result_large_err)]
fn get_jwt_body(&self, valid_from: OffsetDateTime) -> Result<String> {
let mut jwt = self.jwt.lock().unwrap();
jwt.body_mut()
.update(Some(valid_from.unix_timestamp()), None);
let jwt_body = jwt.finalize()?;
Ok(jwt_body)
}
}
#[cfg(test)]
mod tests {
use crate::auth::{JwtClaims, Token};
use crate::credentials::Credentials;
use crate::fetcher::TokenFetcher;
use crate::scopes::Scope;
use mockito::{self, mock};
use smpl_jwt::Jwt;
use std::thread;
use std::time::Duration as StdDuration;
fn get_mocks() -> (Jwt<JwtClaims>, Credentials) {
let token_url = mockito::server_url();
let iss = "some_iss";
let mut credentials =
Credentials::from_file("dummy_credentials_file_for_tests.json").unwrap();
credentials.token_uri = token_url.clone();
let claims = JwtClaims::new(
String::from(iss),
&[Scope::DevStorageReadWrite],
String::from(token_url.clone()),
None,
None,
);
let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
(jwt, credentials)
}
fn token_json(access_token: &str, token_type: &str, expires_in: u32) -> (Token, String) {
let json = serde_json::json!({
"access_token": access_token.to_string(),
"token_type": token_type.to_string(),
"expires_in": expires_in
});
let token = serde_json::from_value(json.clone()).unwrap();
(token, json.to_string())
}
#[test]
fn basic_token_fetch() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = 0;
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let (expected_token, json) = token_json("token", "Bearer", 1);
let _mock = mock("POST", "/").with_status(200).with_body(json).create();
let token = fetcher.fetch_token().unwrap();
assert_eq!(expected_token, token);
}
#[test]
fn basic_token_refresh() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = 0;
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let expires_in = 1;
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
let mock = mock("POST", "/")
.with_status(200)
.with_body(json)
.expect(2) .create();
fetcher.fetch_token().unwrap();
thread::sleep(StdDuration::from_secs(expires_in.into()));
fetcher.fetch_token().unwrap();
mock.assert();
}
#[test]
fn token_refresh_with_buffer() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = 4;
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let expires_in = 5;
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
let mock = mock("POST", "/")
.with_status(200)
.with_body(json)
.expect(2) .create();
fetcher.fetch_token().unwrap();
let sleep_for = expires_in - (refresh_buffer as u32);
thread::sleep(StdDuration::from_secs(sleep_for.into()));
fetcher.fetch_token().unwrap();
mock.assert();
}
#[test]
fn doesnt_token_refresh_unnecessarily() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = 0;
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
let expires_in = 1;
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
let mock = mock("POST", "/")
.with_status(200)
.with_body(json)
.expect(1) .create();
fetcher.fetch_token().unwrap();
fetcher.fetch_token().unwrap();
mock.assert();
}
#[test]
fn is_send_and_sync() {
let (jwt, credentials) = get_mocks();
let refresh_buffer = 0;
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
fn check(_: &(dyn Send + Sync)) {}
check(&fetcher);
}
}