use cts_common::{CtsServiceDiscovery, Region, ServiceDiscovery, WorkspaceId};
use crate::auto_refresh::AutoRefresh;
use crate::oidc_refresher::{OidcProvider, OidcRefresher};
use crate::token_store::{NoStore, TokenStore};
use crate::{ensure_trailing_slash, AuthError, AuthStrategy, ServiceToken};
pub struct OidcFederationStrategy<P, S = NoStore> {
inner: AutoRefresh<OidcRefresher<P>, S>,
expected_workspace: WorkspaceId,
}
impl<P: OidcProvider> OidcFederationStrategy<P> {
pub fn new(
region: Region,
workspace_id: WorkspaceId,
oidc_provider: P,
) -> Result<Self, AuthError> {
Self::builder(region, workspace_id, oidc_provider).build()
}
pub fn builder(
region: Region,
workspace_id: WorkspaceId,
oidc_provider: P,
) -> OidcFederationStrategyBuilder<P> {
OidcFederationStrategyBuilder {
region,
workspace_id,
oidc_provider,
base_url_override: None,
token_store: NoStore,
}
}
}
impl<P: OidcProvider, S: TokenStore> AuthStrategy for &OidcFederationStrategy<P, S> {
async fn get_token(self) -> Result<ServiceToken, AuthError> {
let token: ServiceToken = self.inner.get_token().await?;
let token_workspace = *token.workspace_id()?;
if token_workspace != self.expected_workspace {
return Err(AuthError::WorkspaceMismatch {
expected_workspace: self.expected_workspace,
token_workspace,
});
}
Ok(token)
}
}
pub struct OidcFederationStrategyBuilder<P, S = NoStore> {
region: Region,
workspace_id: WorkspaceId,
oidc_provider: P,
base_url_override: Option<url::Url>,
token_store: S,
}
impl<P, S> OidcFederationStrategyBuilder<P, S> {
#[cfg(any(test, feature = "test-utils"))]
pub fn base_url(mut self, url: url::Url) -> Self {
self.base_url_override = Some(url);
self
}
pub fn with_token_store<T: TokenStore>(self, store: T) -> OidcFederationStrategyBuilder<P, T> {
OidcFederationStrategyBuilder {
region: self.region,
workspace_id: self.workspace_id,
oidc_provider: self.oidc_provider,
base_url_override: self.base_url_override,
token_store: store,
}
}
}
impl<P: OidcProvider, S: TokenStore> OidcFederationStrategyBuilder<P, S> {
pub fn build(self) -> Result<OidcFederationStrategy<P, S>, AuthError> {
let base_url = match self.base_url_override {
Some(url) => url,
None => crate::cts_base_url_from_env()?
.unwrap_or(CtsServiceDiscovery::endpoint(self.region)?),
};
let expected_workspace = self.workspace_id;
let refresher = OidcRefresher::new(
self.oidc_provider,
self.workspace_id,
ensure_trailing_slash(base_url),
);
Ok(OidcFederationStrategy {
inner: AutoRefresh::with_store(refresher, self.token_store),
expected_workspace,
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use cts_common::Region;
use mocktail::prelude::*;
use super::*;
use crate::oidc_refresher::OidcProviderFn;
use crate::{InMemoryTokenStore, SecretToken, Token, TokenStore};
fn jwt_with_workspace(workspace: &str) -> String {
use jsonwebtoken::{encode, EncodingKey, Header};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock")
.as_secs();
let claims = serde_json::json!({
"iss": "https://cts.example.com/",
"sub": "CS|test-user",
"aud": "test-audience",
"iat": now,
"exp": now + 3600,
"workspace": workspace,
"scope": "",
});
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"test-secret"),
)
.expect("JWT encode")
}
async fn start_mock_server_returning_jwt(workspace: &str) -> MockServer {
let mut mocks = MockSet::new();
let jwt = jwt_with_workspace(workspace);
mocks.mock(move |when, then| {
when.post().path("/api/authorise");
then.json(serde_json::json!({ "accessToken": jwt, "expiry": 3600 }));
});
let server =
MockServer::new_http("oidc-federation-strategy-workspace-test").with_mocks(mocks);
server.start().await.expect("mock server start");
server
}
fn test_region() -> Region {
Region::aws("ap-southeast-2").expect("region parses")
}
fn provider() -> OidcProviderFn<impl Fn() -> std::future::Ready<Result<SecretToken, AuthError>>>
{
OidcProviderFn::new(|| {
std::future::ready(Ok(SecretToken::new("header.payload.signature".to_string())))
})
}
#[tokio::test]
async fn returns_token_when_workspace_matches() {
const WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(WS).await;
let strategy =
OidcFederationStrategy::builder(test_region(), WS.parse().unwrap(), provider())
.base_url(server.url(""))
.build()
.expect("builder");
let token = (&strategy).get_token().await.expect("get_token");
assert_eq!(
token.workspace_id().expect("workspace_id").as_str(),
WS,
"happy-path token should carry the expected workspace",
);
}
#[tokio::test]
async fn errors_when_token_workspace_differs() {
const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(TOKEN_WS).await;
let strategy = OidcFederationStrategy::builder(
test_region(),
EXPECTED_WS.parse().unwrap(),
provider(),
)
.base_url(server.url(""))
.build()
.expect("builder");
let err = (&strategy)
.get_token()
.await
.expect_err("expected mismatch");
match err {
AuthError::WorkspaceMismatch {
expected_workspace,
token_workspace,
} => {
assert_eq!(expected_workspace.as_str(), EXPECTED_WS);
assert_eq!(token_workspace.as_str(), TOKEN_WS);
}
other => panic!("expected WorkspaceMismatch, got {other:?}"),
}
}
#[tokio::test]
async fn errors_with_invalid_token_when_jwt_malformed() {
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.json(serde_json::json!({ "accessToken": "not-a-jwt", "expiry": 3600 }));
});
let server =
MockServer::new_http("oidc-federation-strategy-malformed-test").with_mocks(mocks);
server.start().await.expect("mock server start");
let strategy = OidcFederationStrategy::builder(
test_region(),
"ZVATKW3VHMFG27DY".parse().unwrap(),
provider(),
)
.base_url(server.url(""))
.build()
.expect("builder");
let err = (&strategy)
.get_token()
.await
.expect_err("expected invalid-token error");
assert!(
matches!(err, AuthError::InvalidToken(_)),
"expected InvalidToken, got {err:?}",
);
}
#[tokio::test]
async fn rejects_stored_token_for_different_workspace() {
const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.internal_server_error()
.json(serde_json::json!({"error": "store must satisfy the request"}));
});
let server =
MockServer::new_http("oidc-federation-strategy-store-mismatch-test").with_mocks(mocks);
server.start().await.expect("mock server start");
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock")
.as_secs();
let stored = Token {
access_token: SecretToken::new(jwt_with_workspace(TOKEN_WS)),
token_type: "Bearer".to_string(),
expires_at: now + 3600,
refresh_token: None,
region: None,
client_id: None,
device_instance_id: None,
};
let store = Arc::new(InMemoryTokenStore::new());
store.save(&stored).await;
let strategy = OidcFederationStrategy::builder(
test_region(),
EXPECTED_WS.parse().unwrap(),
provider(),
)
.base_url(server.url(""))
.with_token_store(Arc::clone(&store))
.build()
.expect("builder");
let err = (&strategy)
.get_token()
.await
.expect_err("expected mismatch from stored token");
assert!(
matches!(err, AuthError::WorkspaceMismatch { .. }),
"expected WorkspaceMismatch, got {err:?}",
);
}
#[tokio::test]
async fn errors_on_each_subsequent_get_token_call() {
const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
const EXPECTED_WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(TOKEN_WS).await;
let strategy = OidcFederationStrategy::builder(
test_region(),
EXPECTED_WS.parse().unwrap(),
provider(),
)
.base_url(server.url(""))
.build()
.expect("builder");
for call in 1..=2 {
let err = match (&strategy).get_token().await {
Ok(_) => panic!("call {call}: expected Err, got Ok"),
Err(e) => e,
};
assert!(
matches!(err, AuthError::WorkspaceMismatch { .. }),
"call {call}: expected WorkspaceMismatch, got {err:?}",
);
}
}
}