use std::time::Duration;
use serde::Deserialize;
use serde_json::json;
use tracing::{debug, trace};
use url::Url;
use crate::credentials::user_credentials::{
prompt_user, AccessTokenResponse, NewTokenError, DEFAULT_REQUESTED_SCOPES,
};
use super::{PollingInfo, RefreshTokenError, UserToken};
#[derive(Debug)]
pub struct OktaUserCredentials {
idp_base_url: Url,
idp_client_id: String,
}
fn url_push(left: &Url, right: &str) -> Url {
let mut url = left.clone();
url.path_segments_mut()
.expect("url should be well formed")
.push(right);
url
}
impl OktaUserCredentials {
pub fn new(idp_base_url: &Url, idp_client_id: &str) -> Self {
Self {
idp_base_url: idp_base_url.to_owned(),
idp_client_id: idp_client_id.to_string(),
}
}
async fn poll_access_token(
&self,
client: &reqwest::Client,
polling_info: &PollingInfo,
) -> Result<UserToken, NewTokenError> {
debug!(target: "console_credentials", "Logging in - polling for access token");
let mut interval = Duration::from_secs(5);
let url = url_push(&self.idp_base_url, "v1/token");
loop {
let response = client
.post(url.to_owned())
.form(&json!({
"client_id": self.idp_client_id,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": polling_info.device_code,
}))
.send()
.await
.map_err(NewTokenError::PollTokenRequestFailed)?;
match response.status().as_u16() {
200 => {
let token: AccessTokenResponse = response.json().await.map_err(|e| {
trace!("Failed to parse token access response: {e:?}");
NewTokenError::PollTokenBadResponse(e)
})?;
debug!(target: "console_credentials",
"Access Token Acquired - expires_in(s): {}",
&token.expires_in
);
return Ok(token.into());
}
400 => {
#[derive(Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
enum AuthError {
AuthorizationPending,
SlowDown,
InvalidGrant,
AccessDenied,
ExpiredToken,
}
#[derive(Deserialize)]
struct PendingResponse {
error: AuthError,
error_description: Option<String>,
}
let PendingResponse {
error,
error_description,
} = response
.json()
.await
.map_err(NewTokenError::PollTokenBadPendingResponse)?;
match error {
AuthError::AuthorizationPending => {}
AuthError::SlowDown => {
interval += Duration::from_secs(5);
}
_ => {
let reason = error_description.unwrap_or(format!("{error:?}"));
return Err(NewTokenError::PollTokenAuthFailed(reason));
}
}
crate::sleep::sleep(interval).await
}
code => {
let body = response.text().await.ok();
debug!("error response returned body {body:?}");
return Err(NewTokenError::PollTokenUnexpected(format!(
"Unexpected response code: {code}"
)));
}
}
}
}
pub async fn acquire_new_token(&self) -> Result<UserToken, NewTokenError> {
debug!(target: "console_credentials", "Logging in...");
let client = reqwest::Client::new();
let url = url_push(&self.idp_base_url, "v1/device/authorize");
let info_response = client
.post(url)
.form(&json!({
"client_id": &self.idp_client_id,
"scope": DEFAULT_REQUESTED_SCOPES
}))
.send()
.await
.map_err(NewTokenError::DeviceCodeRequestFailed)?;
if let Err(e) = info_response.error_for_status_ref() {
debug!(
"error response returned body: {:?}",
info_response.text().await.ok()
);
return Err(NewTokenError::DeviceCodeRequestFailed(e));
}
let polling_info: PollingInfo = info_response
.json()
.await
.map_err(NewTokenError::DeviceCodeBadResponse)?;
prompt_user(&polling_info);
self.poll_access_token(&client, &polling_info).await
}
pub async fn refresh_access_token(
&self,
cached_token: &UserToken,
) -> Result<Option<UserToken>, RefreshTokenError> {
debug!(target: "console_credentials", "Refreshing Access Token...");
let client = reqwest::Client::new();
let url = url_push(&self.idp_base_url, "v1/token");
let response = client
.post(url)
.form(&json!({
"grant_type": "refresh_token",
"refresh_token": cached_token.refresh_token,
"client_id": self.idp_client_id,
"scope": DEFAULT_REQUESTED_SCOPES
}))
.send()
.await
.map_err(RefreshTokenError::RequestFailed)?;
if let Ok(r) = response.error_for_status() {
let response: AccessTokenResponse =
r.json().await.map_err(RefreshTokenError::BadResponse)?;
debug!(target: "console_credentials",
"Access Token Acquired - expires_in(s): {}",
&response.expires_in
);
Ok(Some(response.into()))
} else {
Ok(None)
}
}
}