claude_agent/auth/providers/
chain.rs1use async_trait::async_trait;
4
5use crate::auth::{ClaudeCliProvider, Credential, CredentialProvider, EnvironmentProvider};
6use crate::{Error, Result};
7
8pub struct ChainProvider {
10 providers: Vec<Box<dyn CredentialProvider>>,
11}
12
13impl ChainProvider {
14 pub fn new(providers: Vec<Box<dyn CredentialProvider>>) -> Self {
16 Self { providers }
17 }
18
19 pub fn with<P: CredentialProvider + 'static>(mut self, provider: P) -> Self {
21 self.providers.push(Box::new(provider));
22 self
23 }
24}
25
26impl Default for ChainProvider {
27 fn default() -> Self {
28 Self {
29 providers: vec![
30 Box::new(EnvironmentProvider::new()),
31 Box::new(ClaudeCliProvider::new()),
32 ],
33 }
34 }
35}
36
37#[async_trait]
38impl CredentialProvider for ChainProvider {
39 fn name(&self) -> &str {
40 "chain"
41 }
42
43 async fn resolve(&self) -> Result<Credential> {
44 let mut errors = Vec::new();
45
46 for provider in &self.providers {
47 match provider.resolve().await {
48 Ok(cred) => {
49 tracing::debug!("Credential resolved from: {}", provider.name());
50 return Ok(cred);
51 }
52 Err(e) => {
53 tracing::debug!("Provider {} failed: {}", provider.name(), e);
54 errors.push(format!("{}: {}", provider.name(), e));
55 }
56 }
57 }
58
59 Err(Error::auth(format!(
60 "No credentials found. Tried: {}",
61 errors.join(", ")
62 )))
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use crate::auth::ExplicitProvider;
70
71 #[tokio::test]
72 async fn test_chain_first_success() {
73 let chain = ChainProvider::new(vec![])
74 .with(ExplicitProvider::api_key("first"))
75 .with(ExplicitProvider::api_key("second"));
76
77 let cred = chain.resolve().await.unwrap();
78 assert!(matches!(cred, Credential::ApiKey(k) if k == "first"));
79 }
80
81 #[tokio::test]
82 async fn test_chain_fallback() {
83 let chain = ChainProvider::new(vec![])
84 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR"))
85 .with(ExplicitProvider::api_key("fallback"));
86
87 let cred = chain.resolve().await.unwrap();
88 assert!(matches!(cred, Credential::ApiKey(k) if k == "fallback"));
89 }
90
91 #[tokio::test]
92 async fn test_chain_all_fail() {
93 let chain = ChainProvider::new(vec![])
94 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_1"))
95 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_2"));
96
97 assert!(chain.resolve().await.is_err());
98 }
99}