mod protocol;
use cts_common::{CtsServiceDiscovery, Region, ServiceDiscovery};
use url::Url;
use std::time::{SystemTime, UNIX_EPOCH};
use std::path::PathBuf;
use stack_profile::ProfileStore;
use crate::{ensure_trailing_slash, http_client, AuthError, DeviceIdentity, Token};
use protocol::{
DeviceCode, DeviceCodeRequest, DeviceCodeResponse, ErrorResponse, TokenRequest, TokenResponse,
};
#[cfg(test)]
mod tests;
pub struct DeviceCodeStrategy {
region: Region,
base_url: Url,
client_id: String,
profile_dir: Option<PathBuf>,
device_identity: Option<DeviceIdentity>,
}
impl DeviceCodeStrategy {
pub fn new(region: Region, client_id: impl Into<String>) -> Result<Self, AuthError> {
Self::builder(region, client_id).build()
}
pub fn builder(region: Region, client_id: impl Into<String>) -> DeviceCodeStrategyBuilder {
DeviceCodeStrategyBuilder {
region,
client_id: client_id.into(),
base_url_override: None,
profile_dir: None,
device_identity: None,
}
}
pub async fn begin(&self) -> Result<PendingDeviceCode, AuthError> {
let client = http_client();
let code_url = self.base_url.join("oauth/device/code")?;
tracing::debug!(url = %code_url, client_id = %self.client_id, "requesting device code");
let device_instance_id = self
.device_identity
.as_ref()
.map(|d| d.device_instance_id.to_string());
let code_resp = client
.post(code_url)
.form(&DeviceCodeRequest {
client_id: &self.client_id,
device_instance_id: device_instance_id.as_deref(),
device_name: self
.device_identity
.as_ref()
.map(|d| d.device_name.as_str()),
})
.send()
.await?;
if !code_resp.status().is_success() {
let err: ErrorResponse = code_resp.json().await?;
tracing::debug!(error = %err.error, "device code request failed");
return Err(match err.error.as_str() {
"invalid_client" => AuthError::InvalidClient,
_ => AuthError::Server(err.error_description),
});
}
let code: DeviceCodeResponse = code_resp.json().await?;
let token_url = self.base_url.join("oauth/device/token")?;
tracing::debug!(
user_code = %code.user_code,
expires_in = code.expires_in,
"device code received"
);
Ok(PendingDeviceCode {
token_url,
region: self.region,
client_id: self.client_id.clone(),
device_code: code.device_code,
user_code: code.user_code,
verification_uri: code.verification_uri,
verification_uri_complete: code.verification_uri_complete,
expires_in: code.expires_in,
profile_dir: self.profile_dir.clone(),
device_identity: self.device_identity.clone(),
})
}
}
pub struct DeviceCodeStrategyBuilder {
region: Region,
client_id: String,
base_url_override: Option<Url>,
profile_dir: Option<PathBuf>,
device_identity: Option<DeviceIdentity>,
}
impl DeviceCodeStrategyBuilder {
#[cfg(any(test, feature = "test-utils"))]
pub fn base_url(mut self, url: Url) -> Self {
self.base_url_override = Some(url);
self
}
#[cfg(any(test, feature = "test-utils"))]
pub fn profile_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.profile_dir = Some(dir.into());
self
}
pub fn device_identity(mut self, identity: DeviceIdentity) -> Self {
self.device_identity = Some(identity);
self
}
pub fn build(self) -> Result<DeviceCodeStrategy, AuthError> {
let base_url = match self.base_url_override {
Some(url) => url,
None => crate::cts_base_url_from_env()?
.unwrap_or(CtsServiceDiscovery::endpoint(self.region)?),
};
Ok(DeviceCodeStrategy {
region: self.region,
base_url: ensure_trailing_slash(base_url),
client_id: self.client_id,
profile_dir: self.profile_dir,
device_identity: self.device_identity,
})
}
}
#[derive(Debug)]
pub struct PendingDeviceCode {
token_url: Url,
region: Region,
client_id: String,
device_code: DeviceCode,
user_code: String,
verification_uri: String,
verification_uri_complete: String,
expires_in: u64,
profile_dir: Option<PathBuf>,
device_identity: Option<DeviceIdentity>,
}
impl PendingDeviceCode {
pub fn user_code(&self) -> &str {
&self.user_code
}
pub fn verification_uri(&self) -> &str {
&self.verification_uri
}
pub fn verification_uri_complete(&self) -> &str {
&self.verification_uri_complete
}
pub fn expires_in(&self) -> u64 {
self.expires_in
}
pub fn open_in_browser(&self) -> bool {
open::that(&self.verification_uri_complete).is_ok()
}
pub async fn poll_for_token(self) -> Result<Token, AuthError> {
let client = http_client();
let mut interval = tokio::time::Duration::from_secs(5);
let deadline =
tokio::time::Instant::now() + tokio::time::Duration::from_secs(self.expires_in);
tracing::debug!(
url = %self.token_url,
expires_in = self.expires_in,
"polling for token"
);
loop {
if tokio::time::Instant::now() >= deadline {
tracing::debug!("device code expired while polling");
return Err(AuthError::TokenExpired);
}
let resp = client
.post(self.token_url.clone())
.form(&TokenRequest {
client_id: &self.client_id,
device_code: &self.device_code,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
})
.send()
.await?;
if resp.status().is_success() {
tracing::debug!("token received");
let token_resp: TokenResponse = resp.json().await?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut token = Token {
access_token: token_resp.access_token,
token_type: token_resp.token_type,
expires_at: now + token_resp.expires_in,
refresh_token: token_resp.refresh_token,
region: None,
client_id: None,
device_instance_id: None,
};
token.set_region(self.region.identifier());
token.set_client_id(&self.client_id);
if let Some(ref identity) = self.device_identity {
token.set_device_instance_id(identity.device_instance_id.to_string());
}
let store = match &self.profile_dir {
Some(dir) => ProfileStore::new(dir),
None => ProfileStore::resolve(None)?,
};
let workspace_id = token.workspace_id()?;
store.init_workspace(workspace_id.as_str())?;
store
.workspace_store(workspace_id.as_str())?
.save_profile(&token)?;
tracing::debug!(
workspace = workspace_id.as_str(),
"token saved to workspace directory"
);
return Ok(token);
}
let err: ErrorResponse = resp.json().await?;
match err.error.as_str() {
"authorization_pending" => {
tracing::debug!("authorization pending, retrying");
}
"slow_down" => {
interval += tokio::time::Duration::from_secs(5);
tracing::debug!(interval_secs = interval.as_secs(), "slowing down");
}
"expired_token" => return Err(AuthError::TokenExpired),
"access_denied" => return Err(AuthError::AccessDenied),
"invalid_grant" => return Err(AuthError::InvalidGrant),
"invalid_client" => return Err(AuthError::InvalidClient),
_ => return Err(AuthError::Server(err.error_description)),
}
tokio::time::sleep(interval).await;
}
}
}