claude_agent/auth/providers/
chain.rs

1//! Chain credential provider.
2
3use async_trait::async_trait;
4
5use crate::auth::{ClaudeCliProvider, Credential, CredentialProvider, EnvironmentProvider};
6use crate::{Error, Result};
7
8/// Chain provider that tries multiple providers in order.
9pub struct ChainProvider {
10    providers: Vec<Box<dyn CredentialProvider>>,
11}
12
13impl ChainProvider {
14    /// Create with specified providers.
15    pub fn new(providers: Vec<Box<dyn CredentialProvider>>) -> Self {
16        Self { providers }
17    }
18
19    /// Add a provider to the chain.
20    pub fn with<P: CredentialProvider + 'static>(mut self, provider: P) -> Self {
21        self.providers.push(Box::new(provider));
22        self
23    }
24}
25
26impl Default for ChainProvider {
27    fn default() -> Self {
28        Self {
29            providers: vec![
30                Box::new(EnvironmentProvider::new()),
31                Box::new(ClaudeCliProvider::new()),
32            ],
33        }
34    }
35}
36
37#[async_trait]
38impl CredentialProvider for ChainProvider {
39    fn name(&self) -> &str {
40        "chain"
41    }
42
43    async fn resolve(&self) -> Result<Credential> {
44        let mut errors = Vec::new();
45
46        for provider in &self.providers {
47            match provider.resolve().await {
48                Ok(cred) => {
49                    tracing::debug!("Credential resolved from: {}", provider.name());
50                    return Ok(cred);
51                }
52                Err(e) => {
53                    tracing::debug!("Provider {} failed: {}", provider.name(), e);
54                    errors.push(format!("{}: {}", provider.name(), e));
55                }
56            }
57        }
58
59        Err(Error::auth(format!(
60            "No credentials found. Tried: {}",
61            errors.join(", ")
62        )))
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use crate::auth::ExplicitProvider;
70
71    #[tokio::test]
72    async fn test_chain_first_success() {
73        let chain = ChainProvider::new(vec![])
74            .with(ExplicitProvider::api_key("first"))
75            .with(ExplicitProvider::api_key("second"));
76
77        let cred = chain.resolve().await.unwrap();
78        assert!(matches!(cred, Credential::ApiKey(k) if k == "first"));
79    }
80
81    #[tokio::test]
82    async fn test_chain_fallback() {
83        let chain = ChainProvider::new(vec![])
84            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR"))
85            .with(ExplicitProvider::api_key("fallback"));
86
87        let cred = chain.resolve().await.unwrap();
88        assert!(matches!(cred, Credential::ApiKey(k) if k == "fallback"));
89    }
90
91    #[tokio::test]
92    async fn test_chain_all_fail() {
93        let chain = ChainProvider::new(vec![])
94            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_1"))
95            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_2"));
96
97        assert!(chain.resolve().await.is_err());
98    }
99}