use std::future::Future;
use std::sync::Arc;
use cts_common::WorkspaceId;
use url::Url;
use web_time::{SystemTime, UNIX_EPOCH};
use crate::authorize_dto::AuthoriseResponse;
use crate::refresher::Refresher;
use crate::{http_client, AuthError, SecretToken, Token};
#[cfg(not(target_arch = "wasm32"))]
pub trait OidcProvider: Send + Sync {
fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> + Send;
}
#[cfg(target_arch = "wasm32")]
pub trait OidcProvider {
fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>>;
}
pub struct OidcProviderFn<F> {
fetch: F,
}
impl<F> OidcProviderFn<F> {
pub fn new(fetch: F) -> Self {
Self { fetch }
}
}
#[cfg(not(target_arch = "wasm32"))]
impl<F, Fut> OidcProvider for OidcProviderFn<F>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<SecretToken, AuthError>> + Send,
{
fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> + Send {
(self.fetch)()
}
}
#[cfg(target_arch = "wasm32")]
impl<F, Fut> OidcProvider for OidcProviderFn<F>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<SecretToken, AuthError>>,
{
fn fetch(&self) -> impl Future<Output = Result<SecretToken, AuthError>> {
(self.fetch)()
}
}
pub(crate) struct OidcRefresher<P> {
oidc_provider: P,
workspace_id: WorkspaceId,
base_url: Url,
http_client: Arc<reqwest::Client>,
}
impl<P> OidcRefresher<P> {
pub(crate) fn new(oidc_provider: P, workspace_id: WorkspaceId, base_url: Url) -> Self {
Self {
oidc_provider,
workspace_id,
base_url,
http_client: Arc::new(http_client()),
}
}
}
impl<P: OidcProvider> Refresher for OidcRefresher<P> {
type Credential = ();
fn save(&self, _token: &Token) {
}
fn try_credential(&self, _token: Option<&mut Token>) -> Option<Self::Credential> {
Some(())
}
fn restore(&self, _token: &mut Token, _credential: Self::Credential) {
}
async fn refresh(&self, _credential: &Self::Credential) -> Result<Token, AuthError> {
let oidc_token = self.oidc_provider.fetch().await?;
let url = self.base_url.join("api/authorise")?;
tracing::debug!(url = %url, "federating OIDC token");
let resp = self
.http_client
.post(url)
.json(&OidcAuthoriseRequest {
oidc_token: oidc_token.as_str(),
workspace_id: self.workspace_id.as_str(),
})
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::debug!(%status, %body, "OIDC federation failed");
return Err(AuthError::Server(format!("{status}: {body}")));
}
let auth_resp: AuthoriseResponse = resp.json().await?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Ok(Token {
access_token: auth_resp.access_token,
token_type: "Bearer".to_string(),
expires_at: now + auth_resp.expiry,
refresh_token: None,
region: None,
client_id: None,
device_instance_id: None,
})
}
}
#[derive(serde::Serialize)]
#[serde(rename_all = "camelCase")]
struct OidcAuthoriseRequest<'a> {
oidc_token: &'a str,
workspace_id: &'a str,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use mocktail::prelude::*;
use super::*;
use crate::auto_refresh::{AutoRefresh, AutoRefreshError};
use crate::TokenStore;
const WORKSPACE_ID: &str = "ZVATKW3VHMFG27DY";
fn workspace_id() -> WorkspaceId {
WORKSPACE_ID.parse().unwrap()
}
fn auth_response_json(access: &str, expiry: u64) -> serde_json::Value {
serde_json::json!({ "accessToken": access, "expiry": expiry })
}
async fn start_server(mocks: MockSet) -> MockServer {
let server = MockServer::new_http("oidc-refresher-test").with_mocks(mocks);
server.start().await.unwrap();
server
}
fn counting_provider() -> (Arc<AtomicUsize>, impl OidcProvider) {
let calls = Arc::new(AtomicUsize::new(0));
let calls_clone = Arc::clone(&calls);
let provider = OidcProviderFn::new(move || {
let calls = Arc::clone(&calls_clone);
async move {
let n = calls.fetch_add(1, Ordering::SeqCst);
Ok(SecretToken::new(format!("jwt-{n}")))
}
});
(calls, provider)
}
fn make_strategy<P: OidcProvider>(
server: &MockServer,
provider: P,
) -> AutoRefresh<OidcRefresher<P>> {
let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
AutoRefresh::with_store(refresher, crate::NoStore)
}
fn make_token(access: &str, expires_in_secs: u64) -> Token {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Token {
access_token: SecretToken::new(access),
token_type: "Bearer".to_string(),
expires_at: now + expires_in_secs,
refresh_token: None,
region: None,
client_id: None,
device_instance_id: None,
}
}
#[tokio::test]
async fn test_initial_federation() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.json(auth_response_json("cts-token", 3600));
});
let server = start_server(mocks).await;
let (calls, provider) = counting_provider();
let strategy = make_strategy(&server, provider);
let token = strategy.get_token().await.unwrap();
assert_eq!(token.as_str(), "cts-token");
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"initial federation should invoke the OIDC provider once"
);
}
#[test]
fn test_request_serialization() {
let body = serde_json::to_value(OidcAuthoriseRequest {
oidc_token: "the-jwt",
workspace_id: WORKSPACE_ID,
})
.unwrap();
assert_eq!(
body,
serde_json::json!({ "oidcToken": "the-jwt", "workspaceId": WORKSPACE_ID }),
"request body should carry exactly the OIDC token and workspace ID"
);
}
#[tokio::test]
async fn test_caches_token_after_initial_federation() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.json(auth_response_json("cts-token", 3600));
});
let server = start_server(mocks).await;
let (calls, provider) = counting_provider();
let strategy = make_strategy(&server, provider);
assert_eq!(strategy.get_token().await.unwrap().as_str(), "cts-token");
server.mocks().clear();
server.mocks().mock(|when, then| {
when.post().path("/api/authorise");
then.internal_server_error()
.json(serde_json::json!({"error": "should not be called"}));
});
assert_eq!(strategy.get_token().await.unwrap().as_str(), "cts-token");
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"cached token should be returned without re-federating"
);
}
#[tokio::test]
async fn test_re_federates_on_expiry() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.json(auth_response_json("re-federated-token", 3600));
});
let server = start_server(mocks).await;
let (calls, provider) = counting_provider();
let store = Arc::new(crate::InMemoryTokenStore::new());
store.save(&make_token("stale-cts-token", 0)).await;
let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
let token = strategy.get_token().await.unwrap();
assert_eq!(
token.as_str(),
"re-federated-token",
"expired cached token should trigger re-federation"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"re-federation should invoke the OIDC provider for a current JWT"
);
}
#[tokio::test]
async fn test_oidc_provider_failure_propagates() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.json(auth_response_json("unreachable", 3600));
});
let server = start_server(mocks).await;
let provider = OidcProviderFn::new(|| async {
Err::<SecretToken, _>(AuthError::Server("provider exploded".to_string()))
});
let strategy = make_strategy(&server, provider);
let err = strategy.get_token().await.unwrap_err();
assert!(
matches!(err, AutoRefreshError::Auth(AuthError::Server(_))),
"OIDC provider failure should surface as an auth error, got: {err:?}"
);
}
#[tokio::test]
async fn test_server_rejection_propagates() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.internal_server_error()
.json(serde_json::json!({"error": "workspace mismatch"}));
});
let server = start_server(mocks).await;
let (_calls, provider) = counting_provider();
let strategy = make_strategy(&server, provider);
let err = strategy.get_token().await.unwrap_err();
assert!(
matches!(err, AutoRefreshError::Auth(AuthError::Server(_))),
"a 500 from /api/authorise should surface as a server error, got: {err:?}"
);
}
#[tokio::test]
async fn test_loads_token_from_store_on_cold_start_no_http() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.internal_server_error()
.json(serde_json::json!({"error": "should not be called"}));
});
let server = start_server(mocks).await;
let store = Arc::new(crate::InMemoryTokenStore::new());
store.save(&make_token("from-store", 3600)).await;
let (calls, provider) = counting_provider();
let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
let token = strategy.get_token().await.unwrap();
assert_eq!(token.as_str(), "from-store");
assert_eq!(
calls.load(Ordering::SeqCst),
0,
"a fresh cached token should be used without invoking the OIDC provider"
);
}
#[tokio::test]
async fn test_persists_token_to_store_after_federation() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.json(auth_response_json("freshly-federated", 3600));
});
let server = start_server(mocks).await;
let store = Arc::new(crate::InMemoryTokenStore::new());
let (_calls, provider) = counting_provider();
let refresher = OidcRefresher::new(provider, workspace_id(), server.url(""));
let strategy = AutoRefresh::with_store(refresher, Arc::clone(&store));
let token = strategy.get_token().await.unwrap();
assert_eq!(token.as_str(), "freshly-federated");
let saved = store
.load()
.await
.expect("store should hold a token after federation");
assert_eq!(saved.access_token().as_str(), "freshly-federated");
}
}