use std::collections::HashMap;
use serde::Serialize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthStrategy {
Bedrock,
Vertex,
ApiKey,
OauthToken,
Subscription,
}
impl AuthStrategy {
pub fn as_str(self) -> &'static str {
match self {
Self::Bedrock => "bedrock",
Self::Vertex => "vertex",
Self::ApiKey => "api_key",
Self::OauthToken => "oauth_token",
Self::Subscription => "subscription",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct AuthSummary {
pub strategy: AuthStrategy,
pub has_anthropic_api_key: bool,
pub has_oauth_token: bool,
pub bedrock_enabled: bool,
pub vertex_enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthErrorKind {
NotAuthenticated,
Expired,
InvalidCredentials,
RateLimit,
ProviderError,
Other,
}
impl AuthErrorKind {
pub fn as_str(self) -> &'static str {
match self {
Self::NotAuthenticated => "not_authenticated",
Self::Expired => "expired",
Self::InvalidCredentials => "invalid_credentials",
Self::RateLimit => "rate_limit",
Self::ProviderError => "provider_error",
Self::Other => "other",
}
}
}
pub fn classify_failure(_exit_code: i32, stdout: &str, stderr: &str) -> Option<AuthErrorKind> {
let combined = format!("{stdout}\n{stderr}").to_ascii_lowercase();
let mentions_provider = combined.contains("bedrock") || combined.contains("vertex");
let mentions_auth_signal = combined.contains("auth")
|| combined.contains("credential")
|| combined.contains("401")
|| combined.contains("403")
|| combined.contains("forbidden")
|| combined.contains("unauthorized");
if mentions_provider && mentions_auth_signal {
return Some(AuthErrorKind::ProviderError);
}
if combined.contains("rate limit")
|| combined.contains("too many requests")
|| combined.contains("429")
|| combined.contains("quota")
{
return Some(AuthErrorKind::RateLimit);
}
if combined.contains("expired")
|| combined.contains("session has expired")
|| combined.contains("token expired")
{
return Some(AuthErrorKind::Expired);
}
if combined.contains("invalid api key")
|| combined.contains("invalid token")
|| combined.contains("401")
|| combined.contains("unauthorized")
|| combined.contains("403")
|| combined.contains("forbidden")
{
return Some(AuthErrorKind::InvalidCredentials);
}
if combined.contains("not authenticated")
|| combined.contains("claude login")
|| combined.contains("no credentials")
|| combined.contains("no auth")
{
return Some(AuthErrorKind::NotAuthenticated);
}
if stderr.to_ascii_lowercase().contains("auth")
|| stderr.to_ascii_lowercase().contains("credential")
{
return Some(AuthErrorKind::Other);
}
None
}
pub fn detect() -> AuthSummary {
let env: HashMap<String, String> = std::env::vars().collect();
detect_from(&env)
}
pub fn detect_from(env: &HashMap<String, String>) -> AuthSummary {
let bedrock_enabled = is_truthy(env.get("CLAUDE_CODE_USE_BEDROCK").map(String::as_str));
let vertex_enabled = is_truthy(env.get("CLAUDE_CODE_USE_VERTEX").map(String::as_str));
let has_anthropic_api_key = is_set(env.get("ANTHROPIC_API_KEY").map(String::as_str));
let has_oauth_token = is_set(env.get("CLAUDE_CODE_OAUTH_TOKEN").map(String::as_str));
let strategy = if bedrock_enabled {
AuthStrategy::Bedrock
} else if vertex_enabled {
AuthStrategy::Vertex
} else if has_anthropic_api_key {
AuthStrategy::ApiKey
} else if has_oauth_token {
AuthStrategy::OauthToken
} else {
AuthStrategy::Subscription
};
AuthSummary {
strategy,
has_anthropic_api_key,
has_oauth_token,
bedrock_enabled,
vertex_enabled,
}
}
fn is_set(value: Option<&str>) -> bool {
value.is_some_and(|v| !v.trim().is_empty())
}
fn is_truthy(value: Option<&str>) -> bool {
let Some(v) = value else { return false };
let trimmed = v.trim();
if trimmed.is_empty() {
return false;
}
!matches!(
trimmed.to_ascii_lowercase().as_str(),
"0" | "false" | "no" | "off"
)
}
#[cfg(test)]
mod tests {
use super::*;
fn env(pairs: &[(&str, &str)]) -> HashMap<String, String> {
pairs
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect()
}
#[test]
fn empty_env_is_subscription() {
let s = detect_from(&env(&[]));
assert_eq!(s.strategy, AuthStrategy::Subscription);
assert!(!s.has_anthropic_api_key);
assert!(!s.has_oauth_token);
assert!(!s.bedrock_enabled);
assert!(!s.vertex_enabled);
}
#[test]
fn api_key_takes_precedence_over_oauth_token() {
let s = detect_from(&env(&[
("ANTHROPIC_API_KEY", "sk-abc"),
("CLAUDE_CODE_OAUTH_TOKEN", "tok-xyz"),
]));
assert_eq!(s.strategy, AuthStrategy::ApiKey);
assert!(s.has_anthropic_api_key);
assert!(s.has_oauth_token);
}
#[test]
fn oauth_token_alone_picks_oauth() {
let s = detect_from(&env(&[("CLAUDE_CODE_OAUTH_TOKEN", "tok-xyz")]));
assert_eq!(s.strategy, AuthStrategy::OauthToken);
assert!(!s.has_anthropic_api_key);
assert!(s.has_oauth_token);
}
#[test]
fn bedrock_overrides_api_key() {
let s = detect_from(&env(&[
("CLAUDE_CODE_USE_BEDROCK", "1"),
("ANTHROPIC_API_KEY", "sk-abc"),
]));
assert_eq!(s.strategy, AuthStrategy::Bedrock);
assert!(s.bedrock_enabled);
assert!(s.has_anthropic_api_key);
}
#[test]
fn vertex_overrides_oauth_token() {
let s = detect_from(&env(&[
("CLAUDE_CODE_USE_VERTEX", "true"),
("CLAUDE_CODE_OAUTH_TOKEN", "tok-xyz"),
]));
assert_eq!(s.strategy, AuthStrategy::Vertex);
assert!(s.vertex_enabled);
}
#[test]
fn bedrock_takes_precedence_over_vertex_when_both_set() {
let s = detect_from(&env(&[
("CLAUDE_CODE_USE_BEDROCK", "1"),
("CLAUDE_CODE_USE_VERTEX", "1"),
]));
assert_eq!(s.strategy, AuthStrategy::Bedrock);
assert!(s.bedrock_enabled);
assert!(s.vertex_enabled);
}
#[test]
fn empty_string_does_not_count_as_set() {
let s = detect_from(&env(&[
("ANTHROPIC_API_KEY", ""),
("CLAUDE_CODE_OAUTH_TOKEN", " "),
]));
assert_eq!(s.strategy, AuthStrategy::Subscription);
}
#[test]
fn explicit_falsy_disables_provider_flag() {
let s = detect_from(&env(&[
("CLAUDE_CODE_USE_BEDROCK", "0"),
("CLAUDE_CODE_USE_VERTEX", "false"),
("ANTHROPIC_API_KEY", "sk-abc"),
]));
assert_eq!(s.strategy, AuthStrategy::ApiKey);
assert!(!s.bedrock_enabled);
assert!(!s.vertex_enabled);
}
#[test]
fn truthy_values_recognized() {
for v in ["1", "true", "TRUE", "yes", "on", "anything"] {
let s = detect_from(&env(&[("CLAUDE_CODE_USE_BEDROCK", v)]));
assert_eq!(s.strategy, AuthStrategy::Bedrock, "value {v:?}");
}
}
#[test]
fn falsy_values_recognized() {
for v in ["0", "false", "FALSE", "no", "off"] {
let s = detect_from(&env(&[("CLAUDE_CODE_USE_BEDROCK", v)]));
assert_eq!(s.strategy, AuthStrategy::Subscription, "value {v:?}");
assert!(!s.bedrock_enabled, "value {v:?}");
}
}
#[test]
fn classify_returns_none_for_unrelated_failure() {
assert_eq!(classify_failure(1, "no match found", ""), None);
assert_eq!(
classify_failure(2, "", "syntax error near unexpected token"),
None
);
}
#[test]
fn classify_not_authenticated_from_stderr_hint() {
assert_eq!(
classify_failure(1, "", "Not authenticated. Run `claude login` to sign in."),
Some(AuthErrorKind::NotAuthenticated)
);
assert_eq!(
classify_failure(1, "", "no credentials configured"),
Some(AuthErrorKind::NotAuthenticated)
);
}
#[test]
fn classify_expired_session() {
assert_eq!(
classify_failure(1, "", "Your session has expired. Please log in again."),
Some(AuthErrorKind::Expired)
);
assert_eq!(
classify_failure(1, "", "token expired at 2025-01-01T00:00:00Z"),
Some(AuthErrorKind::Expired)
);
}
#[test]
fn classify_invalid_api_key() {
assert_eq!(
classify_failure(1, "", "Invalid API key. Check ANTHROPIC_API_KEY."),
Some(AuthErrorKind::InvalidCredentials)
);
assert_eq!(
classify_failure(1, "", "HTTP 401 Unauthorized"),
Some(AuthErrorKind::InvalidCredentials)
);
assert_eq!(
classify_failure(1, "", "403 Forbidden"),
Some(AuthErrorKind::InvalidCredentials)
);
}
#[test]
fn classify_rate_limit_takes_precedence_over_invalid_creds() {
assert_eq!(
classify_failure(1, "", "Rate limit exceeded. Please wait."),
Some(AuthErrorKind::RateLimit)
);
assert_eq!(
classify_failure(1, "", "HTTP 429 Too Many Requests"),
Some(AuthErrorKind::RateLimit)
);
assert_eq!(
classify_failure(1, "", "quota exceeded for this account"),
Some(AuthErrorKind::RateLimit)
);
}
#[test]
fn classify_provider_error_when_bedrock_plus_auth_signal() {
assert_eq!(
classify_failure(
1,
"",
"Bedrock auth failed: AWS credentials not found in chain"
),
Some(AuthErrorKind::ProviderError)
);
assert_eq!(
classify_failure(
1,
"",
"Vertex unauthorized -- check GOOGLE_APPLICATION_CREDENTIALS"
),
Some(AuthErrorKind::ProviderError)
);
}
#[test]
fn classify_falls_back_to_other_for_bare_auth_mention() {
assert_eq!(
classify_failure(1, "", "auth subsystem returned an unexpected error"),
Some(AuthErrorKind::Other)
);
}
#[test]
fn classify_does_not_match_auth_in_stdout_only() {
assert_eq!(
classify_failure(0, "auth_helper enabled, all clear", ""),
None
);
}
#[test]
fn classify_examines_stdout_for_specific_patterns() {
assert_eq!(
classify_failure(1, "Invalid API key", ""),
Some(AuthErrorKind::InvalidCredentials)
);
}
#[test]
fn auth_error_kind_as_str_matches_serde_repr() {
for k in [
AuthErrorKind::NotAuthenticated,
AuthErrorKind::Expired,
AuthErrorKind::InvalidCredentials,
AuthErrorKind::RateLimit,
AuthErrorKind::ProviderError,
AuthErrorKind::Other,
] {
let json = serde_json::to_string(&k).expect("serialize");
assert_eq!(json, format!("\"{}\"", k.as_str()));
}
}
#[test]
fn as_str_matches_serde_repr() {
assert_eq!(AuthStrategy::Bedrock.as_str(), "bedrock");
assert_eq!(AuthStrategy::Vertex.as_str(), "vertex");
assert_eq!(AuthStrategy::ApiKey.as_str(), "api_key");
assert_eq!(AuthStrategy::OauthToken.as_str(), "oauth_token");
assert_eq!(AuthStrategy::Subscription.as_str(), "subscription");
for s in [
AuthStrategy::Bedrock,
AuthStrategy::Vertex,
AuthStrategy::ApiKey,
AuthStrategy::OauthToken,
AuthStrategy::Subscription,
] {
let json = serde_json::to_string(&s).expect("serialize");
assert_eq!(json, format!("\"{}\"", s.as_str()));
}
}
}