use std::{path::Path, time::Duration};
use async_mutex::Mutex as AsyncMutex;
use async_trait::async_trait;
use cts_common::WorkspaceId;
use miette::Diagnostic;
use serde_json::json;
use thiserror::Error;
use tracing::debug;
use url::Url;
use super::service_token::ServiceToken;
use crate::credentials::{
token_store::TokenStore, user_credentials::UserToken, AutoRefreshable, ClearTokenError,
Credentials, GetTokenError, TokenExpiry,
};
use crate::reqwest_client::create_client;
use crate::user_agent::get_user_agent;
pub struct ServiceUserCredentials<C: Credentials<Token = UserToken>> {
user_credentials: C,
cts_base_url: Url,
workspace_id: WorkspaceId,
token_store: AsyncMutex<TokenStore<ServiceToken>>,
client: reqwest_middleware::ClientWithMiddleware,
}
#[derive(Diagnostic, Error, Debug)]
pub enum AcquireTokenError {
#[error("Failed to acquire console token: {0}")]
GetTokenError(#[from] GetTokenError),
#[error("Failed to acquire token: {0}")]
RequestFailed(Box<dyn std::error::Error + Send + Sync>),
#[error("Failed to parse json response: {0}")]
BadResponse(Box<dyn std::error::Error + Sync + Send>),
}
impl<C: Credentials<Token = UserToken>> ServiceUserCredentials<C> {
pub fn new(
token_path: &Path,
user_credentials: C,
cts_base_url: &Url,
workspace_id: WorkspaceId,
) -> Self {
debug!(target: "service_user_credentials", token_path = %token_path.display(), "Creating credentials");
Self {
user_credentials,
cts_base_url: cts_base_url.to_owned(),
workspace_id,
token_store: AsyncMutex::new(TokenStore::new(token_path)),
client: create_client(),
}
}
async fn federate_token(&self) -> Result<ServiceToken, AcquireTokenError> {
let url = self.cts_base_url.join("/api/federate").unwrap();
let user_token = self.user_credentials.get_token().await?;
tracing::debug!(target: "service_user_credentials", cts_url = %url, "Federating OIDC token for service token");
let token: ServiceToken = self
.client
.post(url)
.json(&json!({
"accessToken": user_token.access_token(),
"workspaceId": self.workspace_id.clone(),
}))
.header("authorization", user_token.as_header())
.header("user-agent", get_user_agent())
.send()
.await
.map_err(|e| AcquireTokenError::RequestFailed(Box::new(e)))?
.error_for_status()
.map_err(|e| AcquireTokenError::RequestFailed(Box::new(e)))?
.json()
.await
.map_err(|e| AcquireTokenError::BadResponse(Box::new(e)))?;
debug!(target: "service_user_credentials",
expiry = %token.expiry,
"Service Token Acquired",
);
Ok(token)
}
}
#[async_trait]
impl<C: Credentials<Token = UserToken>> Credentials for ServiceUserCredentials<C> {
type Token = ServiceToken;
async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
debug!(target: "service_user_credentials", "getting OIDC token (waiting for lock)");
let cached_token = {
let mut token_store = self.token_store.lock().await;
token_store.get()
};
if let Some(cached_token) = &cached_token {
if !cached_token.is_expired() {
debug!(target: "service_user_credentials", "using cached OIDC token");
return Ok(cached_token.clone());
}
debug!(target: "service_user_credentials", "cached OIDC token is expired");
}
let mut token_store = self.token_store.lock().await;
if let Some(cached_token) = token_store.get() {
if !cached_token.is_expired() {
debug!(target: "service_user_credentials", "token already refreshed by another caller");
return Ok(cached_token);
}
}
debug!(target: "service_user_credentials", "fetching new OIDC token");
let new_token = self
.federate_token()
.await
.map_err(|e| GetTokenError::AcquireNewTokenFailed(Box::new(e)))?;
token_store
.set(&new_token)
.map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;
Ok(new_token)
}
async fn clear_token(&self) -> Result<(), ClearTokenError> {
let mut token_store = self.token_store.lock().await;
token_store
.clear()
.map_err(|e| ClearTokenError(Box::new(e)))
}
}
#[async_trait]
impl<C: Credentials<Token = UserToken>> AutoRefreshable for ServiceUserCredentials<C> {
async fn refresh(&self) -> Duration {
let token = {
let mut token_store = self.token_store.lock().await;
token_store.get()
};
if let Some(cached_token) = &token {
if !cached_token.should_refresh() {
debug!(target: "service_user_credentials", "Access token is still new");
return cached_token.refresh_interval();
}
}
let mut token_store = self.token_store.lock().await;
if let Some(cached_token) = token_store.get() {
if !cached_token.should_refresh() {
debug!(target: "service_user_credentials", "Access token already refreshed by another caller");
return cached_token.refresh_interval();
}
}
debug!(target: "service_user_credentials", "Access token is missing or close to expiry, refreshing");
match self.federate_token().await {
Ok(new_token) => {
if let Err(err) = token_store.set(&new_token) {
tracing::warn!(
target: "service_user_credentials",
error = %err,
"Failed to persist refreshed token"
);
} else {
debug!(target: "service_user_credentials", "Access token refreshed and saved to cache");
return new_token.refresh_interval();
}
}
Err(err) => {
tracing::warn!(
target: "service_user_credentials",
error = %err,
"Failed to refresh service token"
);
}
}
Self::Token::min_refresh_interval()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credentials::test_utils::CountingState;
use crate::credentials::StaticCredentials;
use std::sync::Arc;
async fn slow_federate(
axum::extract::State(state): axum::extract::State<CountingState>,
) -> axum::Json<serde_json::Value> {
state.enter();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
state.exit();
axum::Json(serde_json::json!({
"accessToken": "test-token",
"expiry": 9999999999u64
}))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_service_user_get_token_serializes_federate() {
let state = CountingState::new();
let stats = state.clone();
let app = axum::Router::new()
.route("/api/federate", axum::routing::post(slow_federate))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let tmp = tempfile::TempDir::new().unwrap();
let token_path = tmp.path().join("token.json");
let base_url = Url::parse(&format!("http://{addr}")).unwrap();
let workspace_id = cts_common::WorkspaceId::generate().unwrap();
let user_token = UserToken::new_from_raw("test-refresh", "test-access", u64::MAX);
let user_creds = StaticCredentials::new(user_token);
let creds = Arc::new(ServiceUserCredentials::new(
&token_path,
user_creds,
&base_url,
workspace_id,
));
let mut handles = vec![];
for _ in 0..5 {
let creds = creds.clone();
handles.push(tokio::spawn(
async move { creds.get_token().await.unwrap() },
));
}
for h in handles {
h.await.unwrap();
}
let peak = stats.peak();
let total = stats.total();
assert_eq!(
peak, 1,
"Expected serialized federate but peak concurrency was {peak}. \
Concurrent federate calls waste resources and race with clear_token().",
);
assert_eq!(
total, 1,
"Expected exactly 1 federate request but got {total}. \
Double-check pattern should let waiters use the refreshed token.",
);
}
}