1use std::collections::HashMap;
37
38use serde::Serialize;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
44#[serde(rename_all = "snake_case")]
45pub enum AuthStrategy {
46 Bedrock,
50 Vertex,
53 ApiKey,
55 OauthToken,
58 Subscription,
63}
64
65impl AuthStrategy {
66 pub fn as_str(self) -> &'static str {
69 match self {
70 Self::Bedrock => "bedrock",
71 Self::Vertex => "vertex",
72 Self::ApiKey => "api_key",
73 Self::OauthToken => "oauth_token",
74 Self::Subscription => "subscription",
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize)]
83pub struct AuthSummary {
84 pub strategy: AuthStrategy,
86 pub has_anthropic_api_key: bool,
88 pub has_oauth_token: bool,
90 pub bedrock_enabled: bool,
92 pub vertex_enabled: bool,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
109#[serde(rename_all = "snake_case")]
110pub enum AuthErrorKind {
111 NotAuthenticated,
115 Expired,
118 InvalidCredentials,
122 RateLimit,
126 ProviderError,
130 Other,
135}
136
137impl AuthErrorKind {
138 pub fn as_str(self) -> &'static str {
141 match self {
142 Self::NotAuthenticated => "not_authenticated",
143 Self::Expired => "expired",
144 Self::InvalidCredentials => "invalid_credentials",
145 Self::RateLimit => "rate_limit",
146 Self::ProviderError => "provider_error",
147 Self::Other => "other",
148 }
149 }
150}
151
152pub fn classify_failure(_exit_code: i32, stdout: &str, stderr: &str) -> Option<AuthErrorKind> {
169 let combined = format!("{stdout}\n{stderr}").to_ascii_lowercase();
170
171 let mentions_provider = combined.contains("bedrock") || combined.contains("vertex");
175 let mentions_auth_signal = combined.contains("auth")
176 || combined.contains("credential")
177 || combined.contains("401")
178 || combined.contains("403")
179 || combined.contains("forbidden")
180 || combined.contains("unauthorized");
181 if mentions_provider && mentions_auth_signal {
182 return Some(AuthErrorKind::ProviderError);
183 }
184
185 if combined.contains("rate limit")
186 || combined.contains("too many requests")
187 || combined.contains("429")
188 || combined.contains("quota")
189 {
190 return Some(AuthErrorKind::RateLimit);
191 }
192
193 if combined.contains("expired")
194 || combined.contains("session has expired")
195 || combined.contains("token expired")
196 {
197 return Some(AuthErrorKind::Expired);
198 }
199
200 if combined.contains("invalid api key")
201 || combined.contains("invalid token")
202 || combined.contains("401")
203 || combined.contains("unauthorized")
204 || combined.contains("403")
205 || combined.contains("forbidden")
206 {
207 return Some(AuthErrorKind::InvalidCredentials);
208 }
209
210 if combined.contains("not authenticated")
211 || combined.contains("claude login")
212 || combined.contains("no credentials")
213 || combined.contains("no auth")
214 {
215 return Some(AuthErrorKind::NotAuthenticated);
216 }
217
218 if stderr.to_ascii_lowercase().contains("auth")
223 || stderr.to_ascii_lowercase().contains("credential")
224 {
225 return Some(AuthErrorKind::Other);
226 }
227
228 None
229}
230
231pub fn detect() -> AuthSummary {
234 let env: HashMap<String, String> = std::env::vars().collect();
235 detect_from(&env)
236}
237
238pub fn detect_from(env: &HashMap<String, String>) -> AuthSummary {
242 let bedrock_enabled = is_truthy(env.get("CLAUDE_CODE_USE_BEDROCK").map(String::as_str));
243 let vertex_enabled = is_truthy(env.get("CLAUDE_CODE_USE_VERTEX").map(String::as_str));
244 let has_anthropic_api_key = is_set(env.get("ANTHROPIC_API_KEY").map(String::as_str));
245 let has_oauth_token = is_set(env.get("CLAUDE_CODE_OAUTH_TOKEN").map(String::as_str));
246
247 let strategy = if bedrock_enabled {
248 AuthStrategy::Bedrock
249 } else if vertex_enabled {
250 AuthStrategy::Vertex
251 } else if has_anthropic_api_key {
252 AuthStrategy::ApiKey
253 } else if has_oauth_token {
254 AuthStrategy::OauthToken
255 } else {
256 AuthStrategy::Subscription
257 };
258
259 AuthSummary {
260 strategy,
261 has_anthropic_api_key,
262 has_oauth_token,
263 bedrock_enabled,
264 vertex_enabled,
265 }
266}
267
268fn is_set(value: Option<&str>) -> bool {
270 value.is_some_and(|v| !v.trim().is_empty())
271}
272
273fn is_truthy(value: Option<&str>) -> bool {
277 let Some(v) = value else { return false };
278 let trimmed = v.trim();
279 if trimmed.is_empty() {
280 return false;
281 }
282 !matches!(
283 trimmed.to_ascii_lowercase().as_str(),
284 "0" | "false" | "no" | "off"
285 )
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn env(pairs: &[(&str, &str)]) -> HashMap<String, String> {
293 pairs
294 .iter()
295 .map(|(k, v)| ((*k).to_string(), (*v).to_string()))
296 .collect()
297 }
298
299 #[test]
300 fn empty_env_is_subscription() {
301 let s = detect_from(&env(&[]));
302 assert_eq!(s.strategy, AuthStrategy::Subscription);
303 assert!(!s.has_anthropic_api_key);
304 assert!(!s.has_oauth_token);
305 assert!(!s.bedrock_enabled);
306 assert!(!s.vertex_enabled);
307 }
308
309 #[test]
310 fn api_key_takes_precedence_over_oauth_token() {
311 let s = detect_from(&env(&[
312 ("ANTHROPIC_API_KEY", "sk-abc"),
313 ("CLAUDE_CODE_OAUTH_TOKEN", "tok-xyz"),
314 ]));
315 assert_eq!(s.strategy, AuthStrategy::ApiKey);
316 assert!(s.has_anthropic_api_key);
317 assert!(s.has_oauth_token);
318 }
319
320 #[test]
321 fn oauth_token_alone_picks_oauth() {
322 let s = detect_from(&env(&[("CLAUDE_CODE_OAUTH_TOKEN", "tok-xyz")]));
323 assert_eq!(s.strategy, AuthStrategy::OauthToken);
324 assert!(!s.has_anthropic_api_key);
325 assert!(s.has_oauth_token);
326 }
327
328 #[test]
329 fn bedrock_overrides_api_key() {
330 let s = detect_from(&env(&[
331 ("CLAUDE_CODE_USE_BEDROCK", "1"),
332 ("ANTHROPIC_API_KEY", "sk-abc"),
333 ]));
334 assert_eq!(s.strategy, AuthStrategy::Bedrock);
335 assert!(s.bedrock_enabled);
336 assert!(s.has_anthropic_api_key);
337 }
338
339 #[test]
340 fn vertex_overrides_oauth_token() {
341 let s = detect_from(&env(&[
342 ("CLAUDE_CODE_USE_VERTEX", "true"),
343 ("CLAUDE_CODE_OAUTH_TOKEN", "tok-xyz"),
344 ]));
345 assert_eq!(s.strategy, AuthStrategy::Vertex);
346 assert!(s.vertex_enabled);
347 }
348
349 #[test]
350 fn bedrock_takes_precedence_over_vertex_when_both_set() {
351 let s = detect_from(&env(&[
352 ("CLAUDE_CODE_USE_BEDROCK", "1"),
353 ("CLAUDE_CODE_USE_VERTEX", "1"),
354 ]));
355 assert_eq!(s.strategy, AuthStrategy::Bedrock);
356 assert!(s.bedrock_enabled);
357 assert!(s.vertex_enabled);
358 }
359
360 #[test]
361 fn empty_string_does_not_count_as_set() {
362 let s = detect_from(&env(&[
363 ("ANTHROPIC_API_KEY", ""),
364 ("CLAUDE_CODE_OAUTH_TOKEN", " "),
365 ]));
366 assert_eq!(s.strategy, AuthStrategy::Subscription);
367 }
368
369 #[test]
370 fn explicit_falsy_disables_provider_flag() {
371 let s = detect_from(&env(&[
372 ("CLAUDE_CODE_USE_BEDROCK", "0"),
373 ("CLAUDE_CODE_USE_VERTEX", "false"),
374 ("ANTHROPIC_API_KEY", "sk-abc"),
375 ]));
376 assert_eq!(s.strategy, AuthStrategy::ApiKey);
377 assert!(!s.bedrock_enabled);
378 assert!(!s.vertex_enabled);
379 }
380
381 #[test]
382 fn truthy_values_recognized() {
383 for v in ["1", "true", "TRUE", "yes", "on", "anything"] {
384 let s = detect_from(&env(&[("CLAUDE_CODE_USE_BEDROCK", v)]));
385 assert_eq!(s.strategy, AuthStrategy::Bedrock, "value {v:?}");
386 }
387 }
388
389 #[test]
390 fn falsy_values_recognized() {
391 for v in ["0", "false", "FALSE", "no", "off"] {
392 let s = detect_from(&env(&[("CLAUDE_CODE_USE_BEDROCK", v)]));
393 assert_eq!(s.strategy, AuthStrategy::Subscription, "value {v:?}");
394 assert!(!s.bedrock_enabled, "value {v:?}");
395 }
396 }
397
398 #[test]
401 fn classify_returns_none_for_unrelated_failure() {
402 assert_eq!(classify_failure(1, "no match found", ""), None);
403 assert_eq!(
404 classify_failure(2, "", "syntax error near unexpected token"),
405 None
406 );
407 }
408
409 #[test]
410 fn classify_not_authenticated_from_stderr_hint() {
411 assert_eq!(
412 classify_failure(1, "", "Not authenticated. Run `claude login` to sign in."),
413 Some(AuthErrorKind::NotAuthenticated)
414 );
415 assert_eq!(
416 classify_failure(1, "", "no credentials configured"),
417 Some(AuthErrorKind::NotAuthenticated)
418 );
419 }
420
421 #[test]
422 fn classify_expired_session() {
423 assert_eq!(
424 classify_failure(1, "", "Your session has expired. Please log in again."),
425 Some(AuthErrorKind::Expired)
426 );
427 assert_eq!(
428 classify_failure(1, "", "token expired at 2025-01-01T00:00:00Z"),
429 Some(AuthErrorKind::Expired)
430 );
431 }
432
433 #[test]
434 fn classify_invalid_api_key() {
435 assert_eq!(
436 classify_failure(1, "", "Invalid API key. Check ANTHROPIC_API_KEY."),
437 Some(AuthErrorKind::InvalidCredentials)
438 );
439 assert_eq!(
440 classify_failure(1, "", "HTTP 401 Unauthorized"),
441 Some(AuthErrorKind::InvalidCredentials)
442 );
443 assert_eq!(
444 classify_failure(1, "", "403 Forbidden"),
445 Some(AuthErrorKind::InvalidCredentials)
446 );
447 }
448
449 #[test]
450 fn classify_rate_limit_takes_precedence_over_invalid_creds() {
451 assert_eq!(
454 classify_failure(1, "", "Rate limit exceeded. Please wait."),
455 Some(AuthErrorKind::RateLimit)
456 );
457 assert_eq!(
458 classify_failure(1, "", "HTTP 429 Too Many Requests"),
459 Some(AuthErrorKind::RateLimit)
460 );
461 assert_eq!(
462 classify_failure(1, "", "quota exceeded for this account"),
463 Some(AuthErrorKind::RateLimit)
464 );
465 }
466
467 #[test]
468 fn classify_provider_error_when_bedrock_plus_auth_signal() {
469 assert_eq!(
470 classify_failure(
471 1,
472 "",
473 "Bedrock auth failed: AWS credentials not found in chain"
474 ),
475 Some(AuthErrorKind::ProviderError)
476 );
477 assert_eq!(
478 classify_failure(
479 1,
480 "",
481 "Vertex unauthorized -- check GOOGLE_APPLICATION_CREDENTIALS"
482 ),
483 Some(AuthErrorKind::ProviderError)
484 );
485 }
486
487 #[test]
488 fn classify_falls_back_to_other_for_bare_auth_mention() {
489 assert_eq!(
490 classify_failure(1, "", "auth subsystem returned an unexpected error"),
491 Some(AuthErrorKind::Other)
492 );
493 }
494
495 #[test]
496 fn classify_does_not_match_auth_in_stdout_only() {
497 assert_eq!(
501 classify_failure(0, "auth_helper enabled, all clear", ""),
502 None
503 );
504 }
505
506 #[test]
507 fn classify_examines_stdout_for_specific_patterns() {
508 assert_eq!(
511 classify_failure(1, "Invalid API key", ""),
512 Some(AuthErrorKind::InvalidCredentials)
513 );
514 }
515
516 #[test]
517 fn auth_error_kind_as_str_matches_serde_repr() {
518 for k in [
519 AuthErrorKind::NotAuthenticated,
520 AuthErrorKind::Expired,
521 AuthErrorKind::InvalidCredentials,
522 AuthErrorKind::RateLimit,
523 AuthErrorKind::ProviderError,
524 AuthErrorKind::Other,
525 ] {
526 let json = serde_json::to_string(&k).expect("serialize");
527 assert_eq!(json, format!("\"{}\"", k.as_str()));
528 }
529 }
530
531 #[test]
532 fn as_str_matches_serde_repr() {
533 assert_eq!(AuthStrategy::Bedrock.as_str(), "bedrock");
534 assert_eq!(AuthStrategy::Vertex.as_str(), "vertex");
535 assert_eq!(AuthStrategy::ApiKey.as_str(), "api_key");
536 assert_eq!(AuthStrategy::OauthToken.as_str(), "oauth_token");
537 assert_eq!(AuthStrategy::Subscription.as_str(), "subscription");
538
539 for s in [
542 AuthStrategy::Bedrock,
543 AuthStrategy::Vertex,
544 AuthStrategy::ApiKey,
545 AuthStrategy::OauthToken,
546 AuthStrategy::Subscription,
547 ] {
548 let json = serde_json::to_string(&s).expect("serialize");
549 assert_eq!(json, format!("\"{}\"", s.as_str()));
550 }
551 }
552}