use cts_common::{Crn, CtsServiceDiscovery, Region, ServiceDiscovery};
use tracing::warn;
use stack_profile::ProfileStore;
use crate::auto_refresh::AutoRefresh;
use crate::oauth_refresher::OAuthRefresher;
use crate::{ensure_trailing_slash, AuthError, AuthStrategy, ServiceToken, Token};
pub struct OAuthStrategy {
crn: Option<Crn>,
inner: AutoRefresh<OAuthRefresher>,
}
impl OAuthStrategy {
pub fn with_token(
region: Region,
client_id: impl Into<String>,
token: Token,
) -> OAuthStrategyBuilder {
OAuthStrategyBuilder {
source: OAuthTokenSource::Token {
region,
client_id: client_id.into(),
token,
},
base_url_override: None,
}
}
pub fn with_profile(store: ProfileStore) -> OAuthStrategyBuilder {
OAuthStrategyBuilder {
source: OAuthTokenSource::Store(store),
base_url_override: None,
}
}
pub fn workspace_crn(&self) -> Option<&Crn> {
self.crn.as_ref()
}
}
impl AuthStrategy for &OAuthStrategy {
async fn get_token(self) -> Result<ServiceToken, AuthError> {
Ok(self.inner.get_token().await?)
}
}
enum OAuthTokenSource {
Token {
region: Region,
client_id: String,
token: Token,
},
Store(ProfileStore),
}
pub struct OAuthStrategyBuilder {
source: OAuthTokenSource,
base_url_override: Option<url::Url>,
}
impl OAuthStrategyBuilder {
#[cfg(any(test, feature = "test-utils"))]
pub fn base_url(mut self, url: url::Url) -> Self {
self.base_url_override = Some(url);
self
}
pub fn build(self) -> Result<OAuthStrategy, AuthError> {
match self.source {
OAuthTokenSource::Token {
region,
client_id,
mut token,
} => {
let base_url = match self.base_url_override {
Some(url) => url,
None => crate::cts_base_url_from_env()?
.unwrap_or(CtsServiceDiscovery::endpoint(region)?),
};
let crn = token
.workspace_id()
.map(|ws| Crn::new(region, ws))
.map_err(|e| {
warn!("Could not extract workspace CRN from token: {e}");
e
})
.ok();
let region_id = region.identifier();
let device_instance_id = token.device_instance_id().map(String::from);
token.set_region(®ion_id);
token.set_client_id(&client_id);
let refresher = OAuthRefresher::new(
None,
ensure_trailing_slash(base_url),
&client_id,
®ion_id,
device_instance_id,
);
Ok(OAuthStrategy {
crn,
inner: AutoRefresh::with_token(refresher, token),
})
}
OAuthTokenSource::Store(store) => {
let ws_store = store.current_workspace_store()?;
let token: Token = ws_store.load_profile()?;
let region_str = token
.region()
.ok_or(AuthError::NotAuthenticated)?
.to_string();
let client_id = token
.client_id()
.ok_or(AuthError::NotAuthenticated)?
.to_string();
let crn = token
.workspace_crn()
.map_err(|e| {
warn!("Could not extract workspace CRN from token: {e}");
e
})
.ok();
let device_instance_id = token.device_instance_id().map(String::from);
let base_url = match self.base_url_override {
Some(url) => url,
None => crate::cts_base_url_from_env()?.unwrap_or(token.issuer()?),
};
let refresher = OAuthRefresher::new(
Some(ws_store),
ensure_trailing_slash(base_url),
&client_id,
®ion_str,
device_instance_id,
);
Ok(OAuthStrategy {
crn,
inner: AutoRefresh::with_token(refresher, token),
})
}
}
}
}