use std::{path::Path, time::Duration};
use async_mutex::Mutex as AsyncMutex;
use async_trait::async_trait;
use miette::Diagnostic;
use serde_json::json;
use thiserror::Error;
use tracing::debug;
use url::Url;
use super::service_token::ServiceToken;
use crate::credentials::{
token_store::TokenStore, AutoRefreshable, ClearTokenError, Credentials, GetTokenError,
TokenExpiry,
};
use crate::reqwest_client::create_client;
use crate::user_agent::get_user_agent;
pub struct ServiceAccessKeyCredentials {
access_key: String,
audience: Option<String>,
cts_base_url: Url,
token_store: AsyncMutex<TokenStore<ServiceToken>>,
client: reqwest_middleware::ClientWithMiddleware,
}
#[derive(Diagnostic, Error, Debug)]
pub enum AcquireTokenError {
#[error("Failed to acquire token: {0}")]
RequestFailed(Box<dyn std::error::Error + Sync + Send>),
#[error("Failed to parse json response: {0}")]
BadResponse(Box<dyn std::error::Error + Sync + Send>),
}
impl ServiceAccessKeyCredentials {
pub fn new(
token_path: &Path,
access_key: &str,
cts_base_url: &Url,
audience: Option<&str>,
) -> Self {
Self {
access_key: access_key.to_string(),
audience: audience.map(|s| s.to_string()),
cts_base_url: cts_base_url.to_owned(),
token_store: AsyncMutex::new(TokenStore::new(token_path)),
client: create_client(),
}
}
async fn authorise(&self) -> Result<ServiceToken, AcquireTokenError> {
debug!(target: "service_access_key_credentials", "Authorising Access Token with CTS");
let url = self.cts_base_url.join("/api/authorise").unwrap();
let token: ServiceToken = self
.client
.post(url)
.json(&json!({ "accessKey": self.access_key, "audience": self.audience }))
.header("user-agent", get_user_agent())
.send()
.await
.map_err(|e| AcquireTokenError::RequestFailed(Box::new(e)))?
.error_for_status()
.map_err(|e| AcquireTokenError::RequestFailed(Box::new(e)))?
.json()
.await
.map_err(|e| AcquireTokenError::BadResponse(Box::new(e)))?;
debug!(target: "service_access_key_credentials",
"Access Token Acquired - expiry(epoch seconds): {}",
&token.expiry
);
Ok(token)
}
}
#[async_trait]
impl Credentials for ServiceAccessKeyCredentials {
type Token = ServiceToken;
async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
debug!(target: "service_access_key_credentials", "getting token (waiting for lock)");
let cached_token = {
let mut token_store = self.token_store.lock().await;
debug!(target: "service_access_key_credentials", "getting token (got lock)");
token_store.get()
};
if let Some(cached_token) = &cached_token {
if !cached_token.is_expired() {
debug!(target: "service_access_key_credentials", "using cached token");
return Ok(cached_token.clone());
}
debug!(target: "service_access_key_credentials", "cached token is expired");
}
let mut token_store = self.token_store.lock().await;
if let Some(cached_token) = token_store.get() {
if !cached_token.is_expired() {
debug!(target: "service_access_key_credentials", "token already refreshed by another caller");
return Ok(cached_token);
}
}
debug!(target: "service_access_key_credentials", "fetching new token from CTS");
let new_token = self
.authorise()
.await
.map_err(|e| GetTokenError::AcquireNewTokenFailed(Box::new(e)))?;
token_store
.set(&new_token)
.map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;
Ok(new_token)
}
async fn clear_token(&self) -> Result<(), ClearTokenError> {
debug!(target: "service_access_key_credentials", "clearing token");
let mut token_store = self.token_store.lock().await;
token_store
.clear()
.map_err(|e| ClearTokenError(Box::new(e)))
}
}
#[async_trait]
impl AutoRefreshable for ServiceAccessKeyCredentials {
async fn refresh(&self) -> Duration {
let token = {
let mut token_store = self.token_store.lock().await;
token_store.get()
};
if let Some(cached_token) = &token {
if !cached_token.should_refresh() {
debug!(target: "service_access_key_credentials", "Access token is still new");
return cached_token.refresh_interval();
}
}
let mut token_store = self.token_store.lock().await;
if let Some(cached_token) = token_store.get() {
if !cached_token.should_refresh() {
debug!(target: "service_access_key_credentials", "Access token already refreshed by another caller");
return cached_token.refresh_interval();
}
}
debug!(target: "service_access_key_credentials", "Access token is missing or close to expiry, refreshing");
match self.authorise().await {
Ok(new_token) => {
if let Err(err) = token_store.set(&new_token) {
tracing::warn!(
target: "service_access_key_credentials",
error = %err,
"Failed to persist refreshed token"
);
} else {
debug!(target: "service_access_key_credentials", "Access token refreshed and saved to disk");
return new_token.refresh_interval();
}
}
Err(err) => {
tracing::warn!(
target: "service_access_key_credentials",
error = %err,
"Failed to refresh access key token"
);
}
}
Self::Token::min_refresh_interval()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credentials::test_utils::CountingState;
use std::sync::Arc;
async fn slow_authorise(
axum::extract::State(state): axum::extract::State<CountingState>,
) -> axum::Json<serde_json::Value> {
state.enter();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
state.exit();
axum::Json(serde_json::json!({
"accessToken": "test-token",
"expiry": 9999999999u64
}))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_service_access_key_get_token_serializes_authorise() {
let state = CountingState::new();
let stats = state.clone();
let app = axum::Router::new()
.route("/api/authorise", axum::routing::post(slow_authorise))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let tmp = tempfile::TempDir::new().unwrap();
let token_path = tmp.path().join("token.json");
let base_url = Url::parse(&format!("http://{addr}")).unwrap();
let creds = Arc::new(ServiceAccessKeyCredentials::new(
&token_path,
"test-access-key",
&base_url,
None,
));
let mut handles = vec![];
for _ in 0..5 {
let creds = creds.clone();
handles.push(tokio::spawn(
async move { creds.get_token().await.unwrap() },
));
}
for h in handles {
h.await.unwrap();
}
let peak = stats.peak();
let total = stats.total();
assert_eq!(
peak, 1,
"Expected serialized authorise but peak concurrency was {peak}. \
Concurrent authorise calls waste resources and race with clear_token().",
);
assert_eq!(
total, 1,
"Expected exactly 1 authorise request but got {total}. \
Double-check pattern should let waiters use the refreshed token.",
);
}
}