pub mod user_token;
mod auth0;
mod okta;
use std::{
path::Path,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_mutex::Mutex as AsyncMutex;
use async_trait::async_trait;
use auth0::Auth0UserCredentials;
use miette::Diagnostic;
use okta::OktaUserCredentials;
use serde::Deserialize;
use thiserror::Error;
use tracing::{debug, error};
use url::Url;
use crate::{
config::idp_provider::IdpProvider,
credentials::{
token_store::TokenStore, AutoRefreshable, ClearTokenError, Credentials, GetTokenError,
TokenExpiry,
},
};
pub use user_token::UserToken;
pub const DEFAULT_REQUESTED_SCOPES: &str = "offline_access cipherstash:admin";
#[derive(Deserialize)]
pub(crate) struct PollingInfo {
pub user_code: String,
pub device_code: String,
pub verification_uri_complete: String,
}
#[derive(Deserialize)]
pub(crate) struct AccessTokenResponse {
pub refresh_token: String,
pub access_token: String,
pub expires_in: u64,
}
impl From<AccessTokenResponse> for UserToken {
fn from(value: AccessTokenResponse) -> Self {
Self {
access_token: value.access_token,
refresh_token: value.refresh_token,
expiry: value.expires_in + now_secs(),
}
}
}
pub fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Expected system time to be greater than UNIX_EPOCH")
.as_secs()
}
pub(crate) fn prompt_user(polling_info: &PollingInfo) {
if open::that(&polling_info.verification_uri_complete).is_err() {
println!(
"Failed to open web browser. Please manually click the link in the following message."
)
}
let user_code = &polling_info.user_code;
let code_len = user_code.len();
println!();
println!("### ACTION REQUIRED ###");
println!();
println!(
"Visit {} to complete authentication by following the below steps:",
polling_info.verification_uri_complete
);
println!();
println!("1. Verify that this code matches the code in your browser");
println!();
println!(" +------{}------+", "-".repeat(code_len));
println!(" | {} |", " ".repeat(code_len));
println!(" | {user_code} |");
println!(" | {} |", " ".repeat(code_len));
println!(" +------{}------+", "-".repeat(code_len));
println!();
println!("2. If the codes match, click on the confirm button in the browser");
println!();
println!("Waiting for authentication...");
}
#[derive(Diagnostic, Error, Debug)]
pub enum RefreshTokenError {
#[error("Failed to redeem refresh token: {0}")]
RequestFailed(reqwest::Error),
#[error("Failed to parse json response: {0}")]
BadResponse(reqwest::Error),
}
#[derive(Diagnostic, Error, Debug)]
pub enum NewTokenError {
#[error("Failed to parse Url: {0}")]
UrlParse(#[from] url::ParseError),
#[error("Failed to get device code: {0}")]
DeviceCodeRequestFailed(reqwest::Error),
#[error("Failed to parse polling info json response: {0}")]
DeviceCodeBadResponse(reqwest::Error),
#[error("Failed to poll for new token: {0}")]
PollTokenRequestFailed(reqwest::Error),
#[error("Failed to parse access token response: {0}")]
PollTokenBadResponse(reqwest::Error),
#[error("Failed to parse pending auth response: {0}")]
PollTokenBadPendingResponse(reqwest::Error),
#[error("Device code authentication failed: {0}")]
PollTokenAuthFailed(String),
#[error("Unexpected error code in response body: {0}")]
PollTokenUnexpected(String),
}
pub struct UserCredentials {
token_store: AsyncMutex<TokenStore<UserToken>>,
provider: UserCredentialsProvider,
}
enum UserCredentialsProvider {
Auth0(Auth0UserCredentials),
Okta(OktaUserCredentials),
}
impl UserCredentials {
pub fn new(
idp_token_path: &Path,
idp_base_url: &Url,
idp_audience: &str,
idp_client_id: &str,
idp_provider: IdpProvider,
) -> Self {
let provider = match idp_provider {
IdpProvider::Auth0 => UserCredentialsProvider::Auth0(Auth0UserCredentials::new(
idp_base_url,
idp_audience,
idp_client_id,
)),
IdpProvider::Okta => {
UserCredentialsProvider::Okta(OktaUserCredentials::new(idp_base_url, idp_client_id))
}
};
Self {
token_store: AsyncMutex::new(TokenStore::new(idp_token_path)),
provider,
}
}
pub async fn authenticate_interactively(&self) -> Result<UserToken, GetTokenError> {
let new_token = self
.provider
.acquire_new_token()
.await
.map_err(|err| GetTokenError::AcquireNewTokenFailed(Box::new(err)))?;
let mut token_store = self.token_store.lock().await;
token_store
.set(&new_token)
.map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;
Ok(new_token)
}
}
impl UserCredentialsProvider {
async fn refresh_access_token(
&self,
cached_token: &UserToken,
) -> Result<Option<UserToken>, RefreshTokenError> {
match self {
Self::Auth0(creds) => creds.refresh_access_token(cached_token).await,
Self::Okta(creds) => creds.refresh_access_token(cached_token).await,
}
}
async fn acquire_new_token(&self) -> Result<UserToken, NewTokenError> {
match self {
Self::Auth0(creds) => creds.acquire_new_token().await,
Self::Okta(creds) => creds.acquire_new_token().await,
}
}
}
#[async_trait]
impl Credentials for UserCredentials {
type Token = UserToken;
async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
let cached_token = {
let mut token_store = self.token_store.lock().await;
token_store.get()
};
if let Some(cached_token) = &cached_token {
if !cached_token.is_expired() {
return Ok(cached_token.clone());
}
}
let mut token_store = self.token_store.lock().await;
if let Some(cached_token) = token_store.get() {
if !cached_token.is_expired() {
return Ok(cached_token);
}
if let Some(new_token) = self
.provider
.refresh_access_token(&cached_token)
.await
.map_err(|e| {
error!("Failed to refresh token: {}", e);
GetTokenError::RefreshTokenFailed(Box::new(e))
})?
{
token_store
.set(&new_token)
.map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;
return Ok(new_token);
}
}
Err(GetTokenError::MissingOrExpired)
}
async fn clear_token(&self) -> Result<(), ClearTokenError> {
let mut token_store = self.token_store.lock().await;
token_store
.clear()
.map_err(|e| ClearTokenError(Box::new(e)))
}
}
#[async_trait]
impl AutoRefreshable for UserCredentials {
async fn refresh(&self) -> Duration {
let token = {
let mut token_store = self.token_store.lock().await;
token_store.get()
};
if let Some(cached_token) = &token {
if !cached_token.should_refresh() {
debug!(target: "console_credentials", "Access token is still new");
return cached_token.refresh_interval();
}
}
let mut token_store = self.token_store.lock().await;
if let Some(cached_token) = token_store.get() {
if !cached_token.should_refresh() {
debug!(target: "console_credentials", "Access token already refreshed by another caller");
return cached_token.refresh_interval();
}
debug!(target: "console_credentials", "Access token close to expiry, refreshing");
match self.provider.refresh_access_token(&cached_token).await {
Ok(Some(new_token)) => {
if let Err(err) = token_store.set(&new_token) {
tracing::warn!(
target: "console_credentials",
error = %err,
"Failed to persist refreshed token"
);
} else {
debug!(target: "console_credentials", "Access token refreshed and saved to disk");
return new_token.refresh_interval();
}
}
Ok(None) => {
tracing::warn!(
target: "console_credentials",
"Token refresh returned no new token"
);
}
Err(err) => {
tracing::warn!(
target: "console_credentials",
error = %err,
"Failed to refresh user token"
);
}
}
}
Self::Token::min_refresh_interval()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::idp_provider::IdpProvider;
use crate::credentials::test_utils::CountingState;
use std::sync::Arc;
async fn slow_refresh(
axum::extract::State(state): axum::extract::State<CountingState>,
) -> axum::Json<serde_json::Value> {
state.enter();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
state.exit();
axum::Json(serde_json::json!({
"refresh_token": "new-refresh",
"access_token": "new-access",
"expires_in": 3600u64
}))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_user_credentials_serializes_refresh_token_exchange() {
let state = CountingState::new();
let stats = state.clone();
let app = axum::Router::new()
.route("/oauth/token", axum::routing::post(slow_refresh))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let tmp = tempfile::TempDir::new().unwrap();
let token_path = tmp.path().join("idp_token.json");
let base_url = Url::parse(&format!("http://{addr}")).unwrap();
let expired_token = UserToken::new_from_raw("test-refresh", "expired-access", 0);
std::fs::write(&token_path, serde_json::to_string(&expired_token).unwrap()).unwrap();
let creds = Arc::new(UserCredentials::new(
&token_path,
&base_url,
"test-audience",
"test-client-id",
IdpProvider::Auth0,
));
let mut handles = vec![];
for _ in 0..5 {
let creds = creds.clone();
handles.push(tokio::spawn(
async move { creds.get_token().await.unwrap() },
));
}
for h in handles {
h.await.unwrap();
}
let peak = stats.peak();
let total = stats.total();
assert_eq!(
peak, 1,
"Expected serialized refresh but peak concurrency was {peak}. \
Concurrent refresh-token exchange can cause auth failures \
with IdPs that rotate refresh tokens.",
);
assert_eq!(
total, 1,
"Expected exactly 1 refresh request but got {total}. \
Double-check pattern should let waiters use the refreshed token.",
);
}
}