claude_agent/auth/providers/
chain.rs

1//! Chain credential provider.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use crate::auth::{ClaudeCliProvider, Credential, CredentialProvider, EnvironmentProvider};
9use crate::{Error, Result};
10
11pub struct ChainProvider {
12    providers: Vec<Arc<dyn CredentialProvider>>,
13    last_successful: RwLock<Option<Arc<dyn CredentialProvider>>>,
14}
15
16impl ChainProvider {
17    pub fn new(providers: Vec<Box<dyn CredentialProvider>>) -> Self {
18        Self {
19            providers: providers.into_iter().map(Arc::from).collect(),
20            last_successful: RwLock::new(None),
21        }
22    }
23
24    pub fn with<P: CredentialProvider + 'static>(mut self, provider: P) -> Self {
25        self.providers.push(Arc::new(provider));
26        self
27    }
28}
29
30impl Default for ChainProvider {
31    fn default() -> Self {
32        Self {
33            providers: vec![
34                Arc::new(EnvironmentProvider::new()),
35                Arc::new(ClaudeCliProvider::new()),
36            ],
37            last_successful: RwLock::new(None),
38        }
39    }
40}
41
42#[async_trait]
43impl CredentialProvider for ChainProvider {
44    fn name(&self) -> &str {
45        "chain"
46    }
47
48    async fn resolve(&self) -> Result<Credential> {
49        let mut errors = Vec::new();
50
51        for provider in &self.providers {
52            match provider.resolve().await {
53                Ok(cred) => {
54                    tracing::debug!("Credential resolved from: {}", provider.name());
55                    *self.last_successful.write().await = Some(Arc::clone(provider));
56                    return Ok(cred);
57                }
58                Err(e) => {
59                    tracing::debug!("Provider {} failed: {}", provider.name(), e);
60                    errors.push(format!("{}: {}", provider.name(), e));
61                }
62            }
63        }
64
65        Err(Error::auth(format!(
66            "No credentials found. Tried: {}",
67            errors.join(", ")
68        )))
69    }
70
71    async fn refresh(&self) -> Result<Credential> {
72        let provider = self.last_successful.read().await;
73        match provider.as_ref() {
74            Some(p) if p.supports_refresh() => p.refresh().await,
75            Some(_) => Err(Error::auth(
76                "Last successful provider does not support refresh",
77            )),
78            None => Err(Error::auth("No provider has successfully resolved yet")),
79        }
80    }
81
82    fn supports_refresh(&self) -> bool {
83        false
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use crate::auth::ExplicitProvider;
91
92    #[tokio::test]
93    async fn test_chain_first_success() {
94        let chain = ChainProvider::new(vec![])
95            .with(ExplicitProvider::api_key("first"))
96            .with(ExplicitProvider::api_key("second"));
97
98        let cred = chain.resolve().await.unwrap();
99        assert!(matches!(cred, Credential::ApiKey(k) if k == "first"));
100    }
101
102    #[tokio::test]
103    async fn test_chain_fallback() {
104        let chain = ChainProvider::new(vec![])
105            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR"))
106            .with(ExplicitProvider::api_key("fallback"));
107
108        let cred = chain.resolve().await.unwrap();
109        assert!(matches!(cred, Credential::ApiKey(k) if k == "fallback"));
110    }
111
112    #[tokio::test]
113    async fn test_chain_all_fail() {
114        let chain = ChainProvider::new(vec![])
115            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_1"))
116            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_2"));
117
118        assert!(chain.resolve().await.is_err());
119    }
120
121    #[tokio::test]
122    async fn test_chain_tracks_last_successful() {
123        let chain = ChainProvider::new(vec![])
124            .with(EnvironmentProvider::with_var("NONEXISTENT_VAR"))
125            .with(ExplicitProvider::api_key("fallback"));
126
127        let _ = chain.resolve().await.unwrap();
128
129        let last = chain.last_successful.read().await;
130        assert!(last.is_some());
131        assert_eq!(last.as_ref().unwrap().name(), "explicit");
132    }
133}