use std::{
collections::HashMap,
io::IsTerminal,
net::{SocketAddr, TcpListener},
sync::Arc,
time::Duration,
};
use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{SecondsFormat, Utc};
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use sha2::{Digest, Sha256};
use tokio::sync::RwLock;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::{
Credential, Result,
auth::AuthProvider,
auth::CredentialRequest,
auth::storage::{CredentialKey, CredentialStorage, default_storage, storage_for},
config::CredentialStore,
error::CliCoreError,
};
const REDIRECT_PORT_DEFAULT: u16 = 7443;
const TOKEN_EXPIRY_BUFFER_SECS: i64 = 30;
#[derive(Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
struct StoredToken {
access_token: String,
expires_at: i64,
refresh_token: Option<String>,
#[serde(default)]
#[zeroize(skip)]
scopes: Vec<String>,
}
impl std::fmt::Debug for StoredToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StoredToken")
.field("access_token", &"[redacted]")
.field("expires_at", &self.expires_at)
.field(
"refresh_token",
if self.refresh_token.is_some() {
&"Some([redacted])"
} else {
&"None"
},
)
.field("scopes", &self.scopes)
.finish()
}
}
impl StoredToken {
fn is_valid(&self) -> bool {
let now = Utc::now().timestamp();
self.expires_at - TOKEN_EXPIRY_BUFFER_SECS > now
}
}
#[derive(Debug)]
pub struct PkceAuthProvider {
name: String,
auth_url: String,
token_url: String,
client_id: String,
scopes: Vec<String>,
redirect_port: u16,
redirect_uri: Option<String>,
app_id: String,
env_prefix: String,
storage_override: Option<Arc<dyn CredentialStorage>>,
store_mode: Option<CredentialStore>,
storage: tokio::sync::OnceCell<Arc<dyn CredentialStorage>>,
identity_claims: Vec<String>,
cache: Arc<RwLock<HashMap<String, StoredToken>>>,
}
const DEFAULT_IDENTITY_CLAIMS: &[&str] =
&["email", "preferred_username", "username", "name", "sub"];
impl PkceAuthProvider {
#[must_use]
pub fn new(
name: impl Into<String>,
auth_url: impl Into<String>,
token_url: impl Into<String>,
client_id: impl Into<String>,
scopes: &[impl AsRef<str>],
) -> Self {
let name = name.into();
let env_prefix = name.to_uppercase().replace('-', "_");
Self {
name,
auth_url: auth_url.into(),
token_url: token_url.into(),
client_id: client_id.into(),
scopes: scopes.iter().map(|s| s.as_ref().to_owned()).collect(),
redirect_port: REDIRECT_PORT_DEFAULT,
redirect_uri: None,
app_id: String::new(),
env_prefix,
storage_override: None,
store_mode: None,
storage: tokio::sync::OnceCell::new(),
identity_claims: DEFAULT_IDENTITY_CLAIMS
.iter()
.map(|claim| (*claim).to_owned())
.collect(),
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
#[must_use]
pub fn with_redirect_port(mut self, port: u16) -> Self {
self.redirect_port = port;
self
}
#[must_use]
pub fn with_redirect_uri(mut self, uri: impl Into<String>) -> Self {
self.redirect_uri = Some(uri.into());
self
}
#[must_use]
pub fn with_app_id(mut self, app_id: impl Into<String>) -> Self {
self.app_id = app_id.into();
self
}
#[must_use]
pub fn with_extra_scopes(mut self, scopes: &[impl AsRef<str>]) -> Self {
self.scopes
.extend(scopes.iter().map(|s| s.as_ref().to_owned()));
self
}
#[must_use]
pub fn with_storage(mut self, storage: Arc<dyn CredentialStorage>) -> Self {
self.storage_override = Some(storage);
self
}
#[must_use]
pub fn with_credential_store(mut self, mode: CredentialStore) -> Self {
self.store_mode = Some(mode);
self
}
#[must_use]
#[deprecated(
since = "0.3.0",
note = "use with_credential_store(CredentialStore::Auto) or (CredentialStore::Keyring)"
)]
pub fn with_file_fallback(self, enabled: bool) -> Self {
self.with_credential_store(if enabled {
CredentialStore::Auto
} else {
CredentialStore::Keyring
})
}
#[must_use]
pub fn with_identity_claims(mut self, claims: &[impl AsRef<str>]) -> Self {
self.identity_claims = claims.iter().map(|c| c.as_ref().to_owned()).collect();
self
}
fn build_credential(&self, env: &str, token: &StoredToken) -> Credential {
let claims = decode_jwt_claims(&token.access_token);
let identity = claims
.as_ref()
.map(|claims| extract_identity(claims, &self.identity_claims))
.unwrap_or_default();
let sub = claims
.as_ref()
.and_then(|claims| claims.get("sub"))
.and_then(Value::as_str)
.unwrap_or_default()
.to_owned();
Credential {
token: token.access_token.clone(),
env: env.to_owned(),
provider: self.name.clone(),
expires_at: chrono::DateTime::from_timestamp(token.expires_at, 0)
.map(|dt| dt.to_rfc3339_opts(SecondsFormat::Secs, true))
.unwrap_or_default(),
identity,
sub,
..Credential::default()
}
}
fn effective_client_id(&self) -> String {
let key = format!("{}_OAUTH_CLIENT_ID", self.env_prefix);
std::env::var(&key).unwrap_or_else(|_| self.client_id.clone())
}
fn effective_auth_url(&self) -> String {
let key = format!("{}_OAUTH_AUTH_URL", self.env_prefix);
std::env::var(&key).unwrap_or_else(|_| self.auth_url.clone())
}
fn effective_token_url(&self) -> String {
let key = format!("{}_OAUTH_TOKEN_URL", self.env_prefix);
std::env::var(&key).unwrap_or_else(|_| self.token_url.clone())
}
fn effective_redirect_uri(&self) -> String {
self.redirect_uri
.clone()
.unwrap_or_else(|| format!("http://127.0.0.1:{}/callback", self.redirect_port))
}
fn parse_redirect_uri(&self) -> Result<(u16, String)> {
let uri_str = self.effective_redirect_uri();
let parsed = url::Url::parse(&uri_str)
.map_err(|e| CliCoreError::message(format!("invalid redirect URI '{uri_str}': {e}")))?;
let port = parsed
.port()
.or_else(|| parsed.port_or_known_default())
.ok_or_else(|| {
CliCoreError::message(format!("redirect URI '{uri_str}' has no port"))
})?;
let path = parsed.path().to_owned();
Ok((port, path))
}
fn credential_key<'key>(&'key self, env: &'key str) -> CredentialKey<'key> {
CredentialKey::new(&self.app_id, &self.name, env)
}
async fn storage(&self) -> &Arc<dyn CredentialStorage> {
self.storage
.get_or_init(async || {
if let Some(storage) = &self.storage_override {
storage.clone()
} else if let Some(mode) = self.store_mode {
storage_for(mode)
} else {
default_storage(&self.app_id)
}
})
.await
}
async fn load_stored(&self, env: &str) -> Option<StoredToken> {
let key = self.credential_key(env);
let raw = self.storage().await.load(&key).await?;
match serde_json::from_str::<StoredToken>(&raw) {
Ok(token) => Some(token),
Err(e) => {
tracing::warn!(env, error = %e, "stored token JSON invalid; clearing");
self.storage().await.delete(&key).await;
None
}
}
}
async fn save_stored(&self, env: &str, token: &StoredToken) -> Result<()> {
let json = serde_json::to_string(token).map_err(CliCoreError::from)?;
let key = self.credential_key(env);
self.storage().await.save(&key, &json).await
}
async fn delete_stored(&self, env: &str) {
let key = self.credential_key(env);
self.storage().await.delete(&key).await;
}
async fn cached_token(&self, env: &str) -> Option<StoredToken> {
let cache = self.cache.read().await;
cache.get(env).filter(|t| t.is_valid()).cloned()
}
async fn store_cached_token(&self, env: &str, token: StoredToken) {
let mut cache = self.cache.write().await;
cache.insert(env.to_owned(), token);
}
async fn resolve_token(&self, env: &str) -> Result<StoredToken> {
if let Some(token) = self.existing_token(env).await? {
return Ok(token);
}
self.reauthenticate(env, &self.scopes).await
}
async fn existing_token(&self, env: &str) -> Result<Option<StoredToken>> {
if let Some(token) = self.cached_token(env).await {
return Ok(Some(token));
}
if let Some(token) = self.load_stored(env).await {
if token.is_valid() {
self.store_cached_token(env, token.clone()).await;
return Ok(Some(token));
}
if let Some(refresh_token) = token.refresh_token.as_deref()
&& let Ok(mut refreshed) = self
.refresh_access_token(refresh_token, &token.scopes)
.await
{
if refreshed.refresh_token.is_none() {
refreshed.refresh_token = Some(refresh_token.to_owned());
}
self.save_stored(env, &refreshed).await?;
self.store_cached_token(env, refreshed.clone()).await;
return Ok(Some(refreshed));
}
}
Ok(None)
}
async fn reauthenticate(&self, env: &str, scopes: &[String]) -> Result<StoredToken> {
let token = self.run_pkce_flow_with(scopes).await?;
self.save_stored(env, &token).await?;
self.store_cached_token(env, token.clone()).await;
Ok(token)
}
async fn run_pkce_flow_with(&self, scopes: &[String]) -> Result<StoredToken> {
let (code_verifier, code_challenge) = pkce_challenge();
let state = random_state();
let client_id = self.effective_client_id();
let auth_url = self.effective_auth_url();
let redirect_uri = self.effective_redirect_uri();
let scope = scopes.join(" ");
let auth_params = [
("response_type", "code"),
("client_id", &client_id),
("redirect_uri", &redirect_uri),
("scope", &scope),
("state", &state),
("code_challenge", &code_challenge),
("code_challenge_method", "S256"),
];
let url = url::Url::parse_with_params(&auth_url, &auth_params)
.map_err(|err| CliCoreError::message(format!("invalid auth URL: {err}")))?;
let (bind_port, callback_path) = self.parse_redirect_uri()?;
let listener =
TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], bind_port))).map_err(|err| {
CliCoreError::message(format!(
"failed to bind callback server on port {bind_port}: {err}"
))
})?;
tracing::info!("Opening browser for authentication…");
tracing::info!("If the browser does not open, visit:\n {url}");
drop(open::that(url.as_str()));
let code =
wait_for_callback(listener, &state, &callback_path, Duration::from_secs(120)).await?;
self.exchange_code_for_token(&code, &code_verifier, scopes)
.await
}
async fn exchange_code_for_token(
&self,
code: &str,
code_verifier: &str,
requested_scopes: &[String],
) -> Result<StoredToken> {
let redirect_uri = self.effective_redirect_uri();
let client_id = self.effective_client_id();
let token_url = self.effective_token_url();
let params = [
("grant_type", "authorization_code"),
("client_id", &client_id),
("redirect_uri", &redirect_uri),
("code", code),
("code_verifier", code_verifier),
];
let response = reqwest::Client::new()
.post(&token_url)
.form(¶ms)
.send()
.await
.map_err(|err| CliCoreError::message(format!("token request failed: {err}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(CliCoreError::message(format!(
"token endpoint returned {status}: {body}"
)));
}
parse_token_response(response, requested_scopes).await
}
async fn refresh_access_token(
&self,
refresh_token: &str,
prior_scopes: &[String],
) -> Result<StoredToken> {
let client_id = self.effective_client_id();
let token_url = self.effective_token_url();
let params = [
("grant_type", "refresh_token"),
("client_id", &client_id),
("refresh_token", refresh_token),
];
let response = reqwest::Client::new()
.post(&token_url)
.form(¶ms)
.send()
.await
.map_err(|err| CliCoreError::message(format!("token refresh failed: {err}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(CliCoreError::message(format!(
"refresh endpoint returned {status}: {body}"
)));
}
parse_token_response(response, prior_scopes).await
}
}
#[async_trait]
impl AuthProvider for PkceAuthProvider {
fn name(&self) -> &str {
&self.name
}
async fn get_credential(&self, env: &str, _command: &str, _tier: &str) -> Result<Credential> {
let token = self.resolve_token(env).await?;
Ok(self.build_credential(env, &token))
}
async fn get_credential_for(&self, req: &CredentialRequest<'_>) -> Result<Credential> {
let env = req.env;
let required = &req.meta.scopes;
if let Some(token) = self.existing_token(env).await? {
let granted = granted_scopes(&token);
match plan_step_up(&self.scopes, &granted, required, session_is_interactive()) {
StepUp::Covered => return Ok(self.build_credential(env, &token)),
StepUp::MissingNonInteractive(missing) => {
return Err(missing_scope_error(env, &missing));
}
StepUp::Reauthenticate(union) => {
let token = self.reauthenticate(env, &union).await?;
ensure_granted(env, &token, required)?;
return Ok(self.build_credential(env, &token));
}
}
}
let union = union_scopes(&self.scopes, &[], required);
let token = self.reauthenticate(env, &union).await?;
ensure_granted(env, &token, required)?;
Ok(self.build_credential(env, &token))
}
async fn status(&self, env: &str) -> Result<Credential> {
let Some(token) = self.load_stored(env).await else {
return Err(CliCoreError::message(format!(
"not logged in for environment {env:?}"
)));
};
Ok(self.build_credential(env, &token))
}
async fn logout(&self, env: &str) -> Result<()> {
self.delete_stored(env).await;
let mut cache = self.cache.write().await;
cache.remove(env);
Ok(())
}
async fn list_environments(&self) -> Result<Vec<String>> {
let cache = self.cache.read().await;
Ok(cache.keys().cloned().collect())
}
}
fn pkce_challenge() -> (String, String) {
let bytes: [u8; 32] = rand::rng().random();
let verifier = URL_SAFE_NO_PAD.encode(bytes);
let hash = Sha256::digest(verifier.as_bytes());
let challenge = URL_SAFE_NO_PAD.encode(hash);
(verifier, challenge)
}
fn random_state() -> String {
let bytes: [u8; 16] = rand::rng().random();
URL_SAFE_NO_PAD.encode(bytes)
}
async fn wait_for_callback(
listener: TcpListener,
expected_state: &str,
expected_path: &str,
timeout: Duration,
) -> Result<String> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
listener
.set_nonblocking(true)
.map_err(|err| CliCoreError::message(format!("callback server setup failed: {err}")))?;
let listener = tokio::net::TcpListener::from_std(listener)
.map_err(|err| CliCoreError::message(format!("callback server setup failed: {err}")))?;
let expected_state = expected_state.to_owned();
let expected_path = expected_path.to_owned();
let result = tokio::time::timeout(timeout, async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => {
tokio::time::sleep(Duration::from_millis(50)).await;
continue;
}
};
let mut buf = vec![0_u8; 4096];
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => continue,
Ok(n) => n,
};
let request = String::from_utf8_lossy(&buf[..n]);
if extract_request_path(&request).as_deref() != Some(expected_path.as_str()) {
drop(
stream
.write_all(b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n")
.await,
);
continue;
}
let code = extract_query_param(&request, "code");
let state = extract_query_param(&request, "state");
let html_response = if state.as_deref() == Some(&expected_state) && code.is_some() {
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
<html><body>Authentication successful. You may close this window.</body></html>"
} else {
"HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
<html><body>Authentication failed. Please try again.</body></html>"
};
drop(stream.write_all(html_response.as_bytes()).await);
if state.as_deref() == Some(expected_state.as_str()) {
return code
.ok_or_else(|| CliCoreError::message("no authorization code in callback"));
}
}
})
.await;
match result {
Ok(inner) => inner,
Err(_) => Err(CliCoreError::message(
"timed out waiting for OAuth callback",
)),
}
}
fn extract_request_path(request: &str) -> Option<String> {
let line = request.lines().next()?;
let path_with_query = line.split_whitespace().nth(1)?;
Some(
path_with_query
.split_once('?')
.map_or(path_with_query, |(p, _)| p)
.to_owned(),
)
}
fn extract_query_param(request: &str, name: &str) -> Option<String> {
let line = request.lines().next()?;
let path = line.split_whitespace().nth(1)?;
let query = path.split_once('?')?.1;
url::form_urlencoded::parse(query.as_bytes())
.find(|(key, _)| key == name)
.map(|(_, value)| value.into_owned())
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: Option<i64>,
refresh_token: Option<String>,
scope: Option<String>,
}
fn decode_jwt_claims(token: &str) -> Option<Map<String, Value>> {
let payload = token.split('.').nth(1)?;
let bytes = URL_SAFE_NO_PAD.decode(payload).ok()?;
serde_json::from_slice(&bytes).ok()
}
fn union_scopes(defaults: &[String], granted: &[String], required: &[String]) -> Vec<String> {
let mut union = defaults.to_vec();
for scope in granted.iter().chain(required.iter()) {
if !union.contains(scope) {
union.push(scope.clone());
}
}
union
}
fn scopes_from_jwt(token: &str) -> Vec<String> {
let Some(claims) = decode_jwt_claims(token) else {
return Vec::new();
};
for key in ["scope", "scp"] {
if let Some(value) = claims.get(key) {
let scopes = scopes_from_claim(value);
if !scopes.is_empty() {
return scopes;
}
}
}
Vec::new()
}
fn scopes_from_claim(value: &Value) -> Vec<String> {
match value {
Value::String(scope) => scope.split_whitespace().map(str::to_owned).collect(),
Value::Array(items) => items
.iter()
.filter_map(Value::as_str)
.flat_map(str::split_whitespace)
.map(str::to_owned)
.collect(),
_ => Vec::new(),
}
}
fn granted_scopes(token: &StoredToken) -> Vec<String> {
let mut scopes = scopes_from_jwt(&token.access_token);
for scope in &token.scopes {
if !scopes.contains(scope) {
scopes.push(scope.clone());
}
}
scopes
}
#[derive(Debug, PartialEq, Eq)]
enum StepUp {
Covered,
Reauthenticate(Vec<String>),
MissingNonInteractive(Vec<String>),
}
fn plan_step_up(
defaults: &[String],
granted: &[String],
required: &[String],
interactive: bool,
) -> StepUp {
let missing: Vec<String> = required
.iter()
.filter(|scope| !granted.iter().any(|have| have == *scope))
.cloned()
.collect();
if missing.is_empty() {
StepUp::Covered
} else if interactive {
StepUp::Reauthenticate(union_scopes(defaults, granted, required))
} else {
StepUp::MissingNonInteractive(missing)
}
}
fn session_is_interactive() -> bool {
std::io::stdin().is_terminal()
|| std::io::stdout().is_terminal()
|| std::io::stderr().is_terminal()
}
fn ensure_granted(env: &str, token: &StoredToken, required: &[String]) -> Result<()> {
let granted = granted_scopes(token);
let missing: Vec<String> = required
.iter()
.filter(|scope| !granted.iter().any(|have| have == *scope))
.cloned()
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(CliCoreError::message(format!(
"authorization server did not grant required scope(s) for {env:?}: {}",
missing.join(", ")
)))
}
}
fn missing_scope_error(env: &str, missing: &[String]) -> CliCoreError {
let display = missing.join(", ");
let hint = missing
.iter()
.map(|scope| format!("--scope {scope}"))
.collect::<Vec<_>>()
.join(" ");
CliCoreError::message(format!(
"access token for {env:?} is missing required scope(s): {display}; \
run `auth login --env {env} {hint}` in an interactive terminal"
))
}
fn extract_identity(claims: &Map<String, Value>, priority: &[String]) -> String {
priority
.iter()
.filter_map(|name| claims.get(name).and_then(Value::as_str))
.find(|value| !value.is_empty())
.unwrap_or_default()
.to_owned()
}
async fn parse_token_response(
response: reqwest::Response,
requested_scopes: &[String],
) -> Result<StoredToken> {
let body: TokenResponse = response
.json()
.await
.map_err(|err| CliCoreError::message(format!("failed to parse token response: {err}")))?;
let expires_in = body.expires_in.unwrap_or(3600);
let expires_at = Utc::now().timestamp() + expires_in;
let scopes = body
.scope
.as_deref()
.map(|scope| {
scope
.split_whitespace()
.map(str::to_owned)
.collect::<Vec<_>>()
})
.filter(|scopes| !scopes.is_empty())
.unwrap_or_else(|| requested_scopes.to_vec());
Ok(StoredToken {
access_token: body.access_token,
expires_at,
refresh_token: body.refresh_token,
scopes,
})
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
fn test_provider() -> PkceAuthProvider {
PkceAuthProvider::new(
"test",
"https://example.com/auth",
"https://example.com/token",
"client-id",
&["openid"],
)
}
fn valid_token(access_token: &str) -> StoredToken {
StoredToken {
access_token: access_token.to_owned(),
expires_at: Utc::now().timestamp() + 3600,
refresh_token: None,
scopes: Vec::new(),
}
}
fn token_with_scopes(access_token: &str, scopes: &[&str]) -> StoredToken {
StoredToken {
access_token: access_token.to_owned(),
expires_at: Utc::now().timestamp() + 3600,
refresh_token: None,
scopes: scopes.iter().map(|s| (*s).to_owned()).collect(),
}
}
fn expired_token() -> StoredToken {
StoredToken {
access_token: "old-token".to_owned(),
expires_at: Utc::now().timestamp() - TOKEN_EXPIRY_BUFFER_SECS - 1,
refresh_token: None,
scopes: Vec::new(),
}
}
#[tokio::test]
async fn cache_stores_and_retrieves_valid_token() {
let provider = test_provider();
let token = valid_token("access-abc");
provider.store_cached_token("dev", token.clone()).await;
let cached = provider.cached_token("dev").await;
assert!(cached.is_some(), "expected cached token to be present");
assert_eq!(
cached.expect("token must be present").access_token,
"access-abc"
);
}
#[tokio::test]
async fn cached_token_ignores_expired_tokens() {
let provider = test_provider();
provider.store_cached_token("dev", expired_token()).await;
assert!(
provider.cached_token("dev").await.is_none(),
"expired token should not be returned from cache"
);
}
#[test]
fn scopes_from_jwt_parses_scope_claim() {
let token = make_jwt(&json!({ "scope": "a b c" }));
assert_eq!(scopes_from_jwt(&token), vec!["a", "b", "c"]);
}
#[test]
fn scopes_from_jwt_parses_scp_and_array_claims() {
let scp = make_jwt(&json!({ "scp": ["a", "b"] }));
assert_eq!(scopes_from_jwt(&scp), vec!["a", "b"]);
let array = make_jwt(&json!({ "scope": ["a", "b c"] }));
assert_eq!(scopes_from_jwt(&array), vec!["a", "b", "c"]);
let mixed = make_jwt(&json!({ "scope": "", "scp": ["x"] }));
assert_eq!(scopes_from_jwt(&mixed), vec!["x"]);
}
#[test]
fn granted_scopes_uses_recorded_scopes_for_opaque_token() {
let token = token_with_scopes("opaque-token", &["a", "b"]);
assert_eq!(granted_scopes(&token), vec!["a", "b"]);
}
#[test]
fn ensure_granted_rejects_a_token_missing_required_scopes() {
let required = vec!["a".to_owned(), "b".to_owned()];
let jwt = valid_token(&make_jwt(&json!({ "scope": "a" })));
let err = ensure_granted("dev", &jwt, &required).expect_err("b is not granted");
assert!(
err.to_string().contains("did not grant required scope(s)"),
"{err}"
);
assert!(err.to_string().contains('b'), "{err}");
let ok = valid_token(&make_jwt(&json!({ "scope": "a b" })));
ensure_granted("dev", &ok, &required).expect("both granted");
let opaque = token_with_scopes("opaque", &["a", "b"]);
ensure_granted("dev", &opaque, &required).expect("recorded scopes granted");
}
#[test]
fn plan_step_up_covers_reauths_and_fails_non_interactive() {
let defaults = vec!["base".to_owned()];
let granted = vec!["base".to_owned(), "read".to_owned()];
let read = vec!["read".to_owned()];
let write = vec!["write".to_owned()];
assert_eq!(
plan_step_up(&defaults, &granted, &read, true),
StepUp::Covered
);
assert_eq!(
plan_step_up(&defaults, &granted, &write, true),
StepUp::Reauthenticate(vec![
"base".to_owned(),
"read".to_owned(),
"write".to_owned()
])
);
assert_eq!(
plan_step_up(&defaults, &granted, &write, false),
StepUp::MissingNonInteractive(vec!["write".to_owned()])
);
}
#[tokio::test]
async fn get_credential_for_uses_recorded_scopes_for_opaque_token() {
let provider = test_provider();
provider
.store_cached_token("dev", token_with_scopes("opaque-token", &["read", "write"]))
.await;
let meta = crate::middleware::CommandMeta {
scopes: vec!["read".to_owned()],
..crate::middleware::CommandMeta::default()
};
let req = CredentialRequest::new("dev", "app:list", "read", &meta);
let credential = provider
.get_credential_for(&req)
.await
.expect("recorded scopes cover the requirement");
assert_eq!(credential.token, "opaque-token");
}
#[test]
fn union_scopes_dedupes_and_preserves_order() {
let defaults = vec!["a".to_owned(), "b".to_owned()];
let granted = vec!["b".to_owned(), "c".to_owned()];
let required = vec!["c".to_owned(), "d".to_owned()];
assert_eq!(
union_scopes(&defaults, &granted, &required),
vec!["a", "b", "c", "d"]
);
}
#[test]
fn scopes_from_jwt_empty_for_opaque_or_missing() {
assert!(scopes_from_jwt("opaque-token").is_empty());
let no_scope = make_jwt(&json!({ "sub": "user" }));
assert!(scopes_from_jwt(&no_scope).is_empty());
}
#[tokio::test]
async fn get_credential_for_uses_cached_token_when_scopes_covered() {
let provider = test_provider();
let token = valid_token(&make_jwt(&json!({
"scope": "apps.app-registry:read apps.app-registry:write",
"sub": "user-1",
})));
provider.store_cached_token("dev", token).await;
let mut meta = crate::middleware::CommandMeta::default();
meta.set_scopes(vec!["apps.app-registry:read".to_owned()]);
let req = CredentialRequest {
env: "dev",
command: "app:list",
tier: "read",
meta: &meta,
};
let credential = provider
.get_credential_for(&req)
.await
.expect("cached token covers required scopes");
assert_eq!(credential.sub, "user-1");
}
#[tokio::test]
async fn get_credential_for_no_scopes_returns_cached() {
let provider = test_provider();
provider
.store_cached_token("dev", valid_token("opaque"))
.await;
let meta = crate::middleware::CommandMeta::default();
let req = CredentialRequest {
env: "dev",
command: "app:list",
tier: "read",
meta: &meta,
};
let credential = provider
.get_credential_for(&req)
.await
.expect("no scopes required");
assert_eq!(credential.token, "opaque");
}
#[test]
fn redirect_uri_default_uses_127_0_0_1_and_redirect_port() {
let provider = test_provider().with_redirect_port(9000);
assert_eq!(
provider.effective_redirect_uri(),
"http://127.0.0.1:9000/callback"
);
}
#[test]
fn with_redirect_uri_overrides_default() {
let provider = test_provider().with_redirect_uri("http://localhost:8080/auth/callback");
assert_eq!(
provider.effective_redirect_uri(),
"http://localhost:8080/auth/callback"
);
}
#[test]
fn parse_redirect_uri_extracts_port_and_path_from_default() {
let provider = test_provider().with_redirect_port(9000);
let (port, path) = provider.parse_redirect_uri().expect("valid URI");
assert_eq!(port, 9000);
assert_eq!(path, "/callback");
}
#[test]
fn parse_redirect_uri_extracts_port_and_path_from_custom_uri() {
let provider = test_provider().with_redirect_uri("http://localhost:8080/auth/callback");
let (port, path) = provider.parse_redirect_uri().expect("valid URI");
assert_eq!(port, 8080);
assert_eq!(path, "/auth/callback");
}
#[test]
fn with_redirect_uri_does_not_affect_listener_host() {
let provider = test_provider().with_redirect_uri("http://localhost:7777/callback");
let (port, _) = provider.parse_redirect_uri().expect("valid URI");
assert_eq!(port, 7777);
}
#[test]
fn extract_request_path_strips_query_string() {
assert_eq!(
extract_request_path("GET /auth/callback?code=abc&state=xyz HTTP/1.1\r\n"),
Some("/auth/callback".to_owned()),
);
}
#[test]
fn extract_request_path_handles_no_query_string() {
assert_eq!(
extract_request_path("GET /callback HTTP/1.1\r\n"),
Some("/callback".to_owned()),
);
}
#[test]
fn extract_query_param_skips_malformed_pairs() {
let request = "GET /callback?foo&code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n";
assert_eq!(
extract_query_param(request, "code"),
Some("abc123".to_owned()),
);
assert_eq!(
extract_query_param(request, "state"),
Some("xyz".to_owned()),
);
}
#[test]
fn extract_query_param_decodes_percent_encoding() {
let request = "GET /callback?code=a%20b%2Bc&state=ok HTTP/1.1\r\n";
assert_eq!(
extract_query_param(request, "code"),
Some("a b+c".to_owned()),
);
}
#[tokio::test]
async fn resolve_token_returns_cached_token_without_pkce_flow() {
let provider = test_provider();
provider
.store_cached_token("dev", valid_token("cached-token"))
.await;
let resolved = provider
.resolve_token("dev")
.await
.expect("resolve from cache");
assert_eq!(resolved.access_token, "cached-token");
}
#[tokio::test]
async fn list_environments_returns_only_cached_keys() {
let provider = test_provider();
provider.store_cached_token("dev", valid_token("t1")).await;
provider.store_cached_token("prod", valid_token("t2")).await;
let mut envs = provider.list_environments().await.expect("list");
envs.sort();
assert_eq!(envs, ["dev", "prod"]);
}
#[tokio::test]
async fn list_environments_returns_empty_without_cache() {
let provider = test_provider();
let envs = provider.list_environments().await.expect("list");
assert!(envs.is_empty(), "expected empty list for a fresh provider");
}
fn make_jwt(claims: &Value) -> String {
let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none","typ":"JWT"}"#);
let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(claims).expect("serialize claims"));
format!("{header}.{payload}.signature")
}
#[test]
fn decode_jwt_claims_extracts_payload() {
let token = make_jwt(&json!({"email": "user@example.com", "sub": "abc123"}));
let claims = decode_jwt_claims(&token).expect("claims decode");
assert_eq!(
claims.get("email").and_then(Value::as_str),
Some("user@example.com")
);
assert_eq!(claims.get("sub").and_then(Value::as_str), Some("abc123"));
}
#[test]
fn decode_jwt_claims_returns_none_for_non_jwt() {
assert!(decode_jwt_claims("opaque-access-token").is_none());
assert!(decode_jwt_claims("only.two").is_none());
assert!(decode_jwt_claims("aaa.!!!.bbb").is_none());
}
#[test]
fn extract_identity_honors_priority_and_skips_empty() {
let priority: Vec<String> = DEFAULT_IDENTITY_CLAIMS
.iter()
.map(|c| (*c).to_owned())
.collect();
let claims = serde_json::from_value(json!({
"email": "",
"preferred_username": "jdoe",
"name": "Jane Doe",
}))
.expect("claims map");
assert_eq!(extract_identity(&claims, &priority), "jdoe");
let empty = serde_json::from_value(json!({"unrelated": "x"})).expect("claims map");
assert_eq!(extract_identity(&empty, &priority), "");
}
#[test]
fn build_credential_populates_identity_and_sub() {
let provider = test_provider();
let token = valid_token(&make_jwt(&json!({
"email": "user@example.com",
"sub": "subject-1",
})));
let credential = provider.build_credential("prod", &token);
assert_eq!(credential.identity, "user@example.com");
assert_eq!(credential.sub, "subject-1");
assert_eq!(credential.env, "prod");
assert_eq!(credential.provider, "test");
}
#[test]
fn build_credential_leaves_identity_blank_for_opaque_token() {
let provider = test_provider();
let token = valid_token("opaque-token");
let credential = provider.build_credential("prod", &token);
assert_eq!(credential.identity, "");
assert_eq!(credential.sub, "");
}
#[test]
fn with_identity_claims_overrides_selection() {
let provider = test_provider().with_identity_claims(&["custom_user"]);
let token = valid_token(&make_jwt(&json!({
"email": "ignored@example.com",
"custom_user": "picked",
})));
let credential = provider.build_credential("prod", &token);
assert_eq!(credential.identity, "picked");
}
#[derive(Debug, Default)]
struct MemoryStorage {
entries: std::sync::Mutex<HashMap<String, String>>,
}
impl MemoryStorage {
fn entry_key(key: &CredentialKey<'_>) -> String {
format!("{}/{}/{}", key.app_id, key.provider, key.env)
}
}
#[async_trait]
impl CredentialStorage for MemoryStorage {
async fn load(&self, key: &CredentialKey<'_>) -> Option<String> {
self.entries
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(&Self::entry_key(key))
.cloned()
}
async fn save(&self, key: &CredentialKey<'_>, value: &str) -> Result<()> {
self.entries
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(Self::entry_key(key), value.to_owned());
Ok(())
}
async fn delete(&self, key: &CredentialKey<'_>) {
self.entries
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(&Self::entry_key(key));
}
}
#[test]
#[allow(deprecated)]
fn with_file_fallback_maps_to_store_modes() {
assert_eq!(
test_provider().with_file_fallback(true).store_mode,
Some(CredentialStore::Auto)
);
assert_eq!(
test_provider().with_file_fallback(false).store_mode,
Some(CredentialStore::Keyring)
);
}
#[test]
fn builders_record_storage_selection() {
assert_eq!(
test_provider()
.with_credential_store(CredentialStore::File)
.store_mode,
Some(CredentialStore::File)
);
let provider = test_provider().with_storage(Arc::new(MemoryStorage::default()));
assert!(provider.storage_override.is_some());
}
#[tokio::test]
async fn provider_delegates_to_injected_storage() {
let mem = Arc::new(MemoryStorage::default());
let provider = test_provider().with_app_id("app").with_storage(mem.clone());
assert!(provider.status("dev").await.is_err());
provider
.save_stored("dev", &valid_token("tok"))
.await
.expect("save");
let key = CredentialKey::new("app", "test", "dev");
assert!(mem.load(&key).await.is_some(), "token reached the store");
let cred = provider.status("dev").await.expect("status");
assert_eq!(cred.token, "tok");
provider.logout("dev").await.expect("logout");
assert!(mem.load(&key).await.is_none(), "token removed on logout");
}
#[tokio::test]
async fn corrupt_stored_blob_self_heals() {
let mem = Arc::new(MemoryStorage::default());
let key = CredentialKey::new("app", "test", "dev");
mem.save(&key, "not-valid-json").await.expect("seed");
let provider = test_provider().with_app_id("app").with_storage(mem.clone());
assert!(provider.load_stored("dev").await.is_none());
assert!(
mem.load(&key).await.is_none(),
"corrupt blob should be deleted (self-heal)"
);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn file_store_round_trips_without_keyring() {
let dir = tempfile::tempdir().expect("tempdir");
let _lock = crate::config::test_env::lock();
let _env = crate::config::test_env::EnvVarGuard::set("XDG_CONFIG_HOME", Some(dir.path()));
let provider = test_provider()
.with_app_id("app")
.with_credential_store(CredentialStore::File);
assert!(provider.status("dev").await.is_err());
provider
.save_stored("dev", &valid_token("filetok"))
.await
.expect("save");
let cred = provider.status("dev").await.expect("status");
assert_eq!(cred.token, "filetok");
}
}