use cts_common::{Crn, CtsServiceDiscovery, ServiceDiscovery, WorkspaceId};
use crate::access_key::AccessKey;
use crate::access_key_refresher::AccessKeyRefresher;
use crate::auto_refresh::AutoRefresh;
use crate::token_store::{NoStore, TokenStore};
use crate::{ensure_trailing_slash, AuthError, AuthStrategy, SecretToken, ServiceToken};
pub struct AccessKeyStrategy<S = NoStore> {
inner: AutoRefresh<AccessKeyRefresher, S>,
expected_workspace: WorkspaceId,
}
impl AccessKeyStrategy {
pub fn new(workspace_crn: Crn, access_key: AccessKey) -> Result<Self, AuthError> {
Self::builder(workspace_crn, access_key).build()
}
pub fn builder(workspace_crn: Crn, access_key: AccessKey) -> AccessKeyStrategyBuilder {
AccessKeyStrategyBuilder {
workspace_crn,
access_key: access_key.into_secret_token(),
audience: None,
base_url_override: None,
token_store: NoStore,
}
}
}
impl<S: TokenStore> AuthStrategy for &AccessKeyStrategy<S> {
async fn get_token(self) -> Result<ServiceToken, AuthError> {
let token: ServiceToken = self.inner.get_token().await?;
let token_workspace = *token.workspace_id()?;
if token_workspace != self.expected_workspace {
return Err(AuthError::WorkspaceMismatch {
expected_workspace: self.expected_workspace,
token_workspace,
});
}
Ok(token)
}
}
pub struct AccessKeyStrategyBuilder<S = NoStore> {
workspace_crn: Crn,
access_key: SecretToken,
audience: Option<String>,
base_url_override: Option<url::Url>,
token_store: S,
}
impl<S> AccessKeyStrategyBuilder<S> {
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
#[cfg(any(test, feature = "test-utils"))]
pub fn base_url(mut self, url: url::Url) -> Self {
self.base_url_override = Some(url);
self
}
pub fn with_token_store<T: TokenStore>(self, store: T) -> AccessKeyStrategyBuilder<T> {
AccessKeyStrategyBuilder {
workspace_crn: self.workspace_crn,
access_key: self.access_key,
audience: self.audience,
base_url_override: self.base_url_override,
token_store: store,
}
}
}
impl<S: TokenStore> AccessKeyStrategyBuilder<S> {
pub fn build(self) -> Result<AccessKeyStrategy<S>, AuthError> {
let expected_workspace = self.workspace_crn.workspace_id;
let region = self.workspace_crn.region;
let base_url = match self.base_url_override {
Some(url) => url,
None => {
crate::cts_base_url_from_env()?.unwrap_or(CtsServiceDiscovery::endpoint(region)?)
}
};
let refresher = AccessKeyRefresher::new(
self.access_key,
ensure_trailing_slash(base_url),
self.audience,
);
Ok(AccessKeyStrategy {
inner: AutoRefresh::with_store(refresher, self.token_store),
expected_workspace,
})
}
}
#[cfg(test)]
mod workspace_verification_tests {
use super::*;
use mocktail::prelude::*;
use std::time::{SystemTime, UNIX_EPOCH};
fn jwt_with_workspace(workspace: &str) -> String {
use jsonwebtoken::{encode, EncodingKey, Header};
#[allow(clippy::expect_used)]
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock")
.as_secs();
let claims = serde_json::json!({
"iss": "https://cts.example.com/",
"sub": "CS|test-access-key",
"aud": "test-audience",
"iat": now,
"exp": now + 3600,
"workspace": workspace,
"scope": "",
});
#[allow(clippy::expect_used)]
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"test-secret"),
)
.expect("JWT encode")
}
async fn start_mock_server_returning_jwt(workspace: &str) -> MockServer {
let mut mocks = MockSet::new();
let jwt = jwt_with_workspace(workspace);
mocks.mock(move |when, then| {
when.post().path("/api/authorise");
then.json(serde_json::json!({
"accessToken": jwt,
"expiry": 3600,
}));
});
let server = MockServer::new_http("access-key-strategy-workspace-test").with_mocks(mocks);
#[allow(clippy::expect_used)]
server.start().await.expect("mock server start");
server
}
fn crn_with_workspace(workspace: &str) -> Crn {
let s = format!("crn:ap-southeast-2.aws:{workspace}");
s.parse().expect("test CRN parses")
}
fn test_access_key() -> AccessKey {
"CSAKtestKeyId.testKeySecret"
.parse()
.expect("test access key parses")
}
#[tokio::test]
async fn returns_token_when_workspace_matches() {
const WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(WS).await;
let crn = crn_with_workspace(WS);
let strategy = AccessKeyStrategy::builder(crn, test_access_key())
.base_url(server.url(""))
.build()
.expect("builder");
let token = (&strategy).get_token().await.expect("get_token");
assert_eq!(
token.workspace_id().expect("workspace_id").as_str(),
WS,
"happy-path token should carry the expected workspace",
);
}
#[tokio::test]
async fn errors_when_token_workspace_differs_from_crn() {
const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
const CRN_WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(TOKEN_WS).await;
let crn = crn_with_workspace(CRN_WS);
let strategy = AccessKeyStrategy::builder(crn, test_access_key())
.base_url(server.url(""))
.build()
.expect("builder");
let err = (&strategy)
.get_token()
.await
.expect_err("expected mismatch");
match err {
AuthError::WorkspaceMismatch {
expected_workspace,
token_workspace,
} => {
assert_eq!(expected_workspace.as_str(), CRN_WS);
assert_eq!(token_workspace.as_str(), TOKEN_WS);
}
other => panic!("expected WorkspaceMismatch, got {other:?}"),
}
assert_eq!(
AuthError::WorkspaceMismatch {
expected_workspace: CRN_WS.parse().unwrap(),
token_workspace: TOKEN_WS.parse().unwrap(),
}
.error_code(),
"WORKSPACE_MISMATCH",
);
}
#[tokio::test]
async fn accepts_crn_with_service_name() {
const WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(WS).await;
let crn: Crn = format!("crn:ap-southeast-2.aws:{WS}:zerokms")
.parse()
.expect("CRN with service_name parses");
let strategy = AccessKeyStrategy::builder(crn, test_access_key())
.base_url(server.url(""))
.build()
.expect("CRN with service_name should construct a strategy");
let token = (&strategy).get_token().await.expect("get_token");
assert_eq!(
token.workspace_id().expect("workspace_id").as_str(),
WS,
"service_name is ignored — verification still uses the workspace ID",
);
}
#[tokio::test]
async fn rejects_stored_token_for_different_workspace() {
const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
const CRN_WS: &str = "ZVATKW3VHMFG27DY";
let mut mocks = MockSet::new();
mocks.mock(|when, then| {
when.post().path("/api/authorise");
then.internal_server_error()
.json(serde_json::json!({"error": "store must satisfy the request"}));
});
let server =
MockServer::new_http("access-key-strategy-store-mismatch-test").with_mocks(mocks);
#[allow(clippy::expect_used)]
server.start().await.expect("mock server start");
let now = std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock")
.as_secs();
let stored = crate::Token {
access_token: crate::SecretToken::new(jwt_with_workspace(TOKEN_WS)),
token_type: "Bearer".to_string(),
expires_at: now + 3600,
refresh_token: None,
region: None,
client_id: None,
device_instance_id: None,
};
let store = std::sync::Arc::new(crate::InMemoryTokenStore::new());
store.save(&stored).await;
let strategy = AccessKeyStrategy::builder(crn_with_workspace(CRN_WS), test_access_key())
.base_url(server.url(""))
.with_token_store(std::sync::Arc::clone(&store))
.build()
.expect("builder");
let err = (&strategy)
.get_token()
.await
.expect_err("expected mismatch from stored token");
assert!(
matches!(err, AuthError::WorkspaceMismatch { .. }),
"expected WorkspaceMismatch, got {err:?}",
);
}
#[tokio::test]
async fn errors_on_each_subsequent_get_token_call() {
const TOKEN_WS: &str = "AAAAAAAAAAAAAAAA";
const CRN_WS: &str = "ZVATKW3VHMFG27DY";
let server = start_mock_server_returning_jwt(TOKEN_WS).await;
let crn = crn_with_workspace(CRN_WS);
let strategy = AccessKeyStrategy::builder(crn, test_access_key())
.base_url(server.url(""))
.build()
.expect("builder");
for call in 1..=2 {
let result = (&strategy).get_token().await;
let err = match result {
Ok(_) => panic!("call {call}: expected Err, got Ok"),
Err(e) => e,
};
assert!(
matches!(err, AuthError::WorkspaceMismatch { .. }),
"call {call}: expected WorkspaceMismatch, got {err:?}",
);
}
}
}