Skip to main content

claude_agent/auth/providers/
chain.rs

1//! Chain credential provider.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8#[cfg(feature = "cli-integration")]
9use crate::auth::ClaudeCliProvider;
10use crate::auth::{Credential, CredentialProvider, EnvironmentProvider};
11use crate::{Error, Result};
12
13pub struct ChainProvider {
14    providers: Vec<Arc<dyn CredentialProvider>>,
15    last_successful: RwLock<Option<Arc<dyn CredentialProvider>>>,
16}
17
18impl ChainProvider {
19    pub fn new(providers: Vec<Box<dyn CredentialProvider>>) -> Self {
20        Self {
21            providers: providers.into_iter().map(Arc::from).collect(),
22            last_successful: RwLock::new(None),
23        }
24    }
25
26    pub fn provider<P: CredentialProvider + 'static>(mut self, provider: P) -> Self {
27        self.providers.push(Arc::new(provider));
28        self
29    }
30}
31
32#[cfg(feature = "cli-integration")]
33impl Default for ChainProvider {
34    fn default() -> Self {
35        Self {
36            providers: vec![
37                Arc::new(EnvironmentProvider::new()),
38                Arc::new(ClaudeCliProvider::new()),
39            ],
40            last_successful: RwLock::new(None),
41        }
42    }
43}
44
45#[cfg(not(feature = "cli-integration"))]
46impl Default for ChainProvider {
47    fn default() -> Self {
48        Self {
49            providers: vec![Arc::new(EnvironmentProvider::new())],
50            last_successful: RwLock::new(None),
51        }
52    }
53}
54
55#[async_trait]
56impl CredentialProvider for ChainProvider {
57    fn name(&self) -> &str {
58        "chain"
59    }
60
61    async fn resolve(&self) -> Result<Credential> {
62        let mut errors = Vec::new();
63
64        for provider in &self.providers {
65            match provider.resolve().await {
66                Ok(cred) => {
67                    tracing::debug!("Credential resolved from: {}", provider.name());
68                    *self.last_successful.write().await = Some(Arc::clone(provider));
69                    return Ok(cred);
70                }
71                Err(e) => {
72                    tracing::debug!("Provider {} failed: {}", provider.name(), e);
73                    errors.push(format!("{}: {}", provider.name(), e));
74                }
75            }
76        }
77
78        Err(Error::auth(format!(
79            "No credentials found. Tried: {}",
80            errors.join(", ")
81        )))
82    }
83
84    async fn refresh(&self) -> Result<Credential> {
85        let provider = self.last_successful.read().await;
86        match provider.as_ref() {
87            Some(p) if p.supports_refresh() => p.refresh().await,
88            Some(_) => Err(Error::auth(
89                "Last successful provider does not support refresh",
90            )),
91            None => Err(Error::auth("No provider has successfully resolved yet")),
92        }
93    }
94
95    fn supports_refresh(&self) -> bool {
96        self.last_successful
97            .try_read()
98            .ok()
99            .and_then(|guard| guard.as_ref().map(|p| p.supports_refresh()))
100            .unwrap_or(false)
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::auth::ExplicitProvider;
108    use secrecy::ExposeSecret;
109
110    #[tokio::test]
111    async fn test_chain_first_success() {
112        let chain = ChainProvider::new(vec![])
113            .provider(ExplicitProvider::api_key("first"))
114            .provider(ExplicitProvider::api_key("second"));
115
116        let cred = chain.resolve().await.unwrap();
117        assert!(matches!(&cred, Credential::ApiKey(k) if k.expose_secret() == "first"));
118    }
119
120    #[tokio::test]
121    async fn test_chain_fallback() {
122        let chain = ChainProvider::new(vec![])
123            .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR"))
124            .provider(ExplicitProvider::api_key("fallback"));
125
126        let cred = chain.resolve().await.unwrap();
127        assert!(matches!(&cred, Credential::ApiKey(k) if k.expose_secret() == "fallback"));
128    }
129
130    #[tokio::test]
131    async fn test_chain_all_fail() {
132        let chain = ChainProvider::new(vec![])
133            .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR_1"))
134            .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR_2"));
135
136        assert!(chain.resolve().await.is_err());
137    }
138
139    struct RefreshableProvider;
140
141    #[async_trait]
142    impl CredentialProvider for RefreshableProvider {
143        fn name(&self) -> &str {
144            "refreshable"
145        }
146
147        async fn resolve(&self) -> Result<Credential> {
148            Ok(Credential::api_key("refreshable-key"))
149        }
150
151        fn supports_refresh(&self) -> bool {
152            true
153        }
154
155        async fn refresh(&self) -> Result<Credential> {
156            Ok(Credential::api_key("refreshed-key"))
157        }
158    }
159
160    #[tokio::test]
161    async fn test_supports_refresh_after_resolve() {
162        let chain = ChainProvider::new(vec![]).provider(RefreshableProvider);
163
164        assert!(!chain.supports_refresh());
165
166        let _ = chain.resolve().await.unwrap();
167
168        assert!(chain.supports_refresh());
169    }
170
171    #[tokio::test]
172    async fn test_supports_refresh_with_non_refreshable() {
173        let chain = ChainProvider::new(vec![]).provider(ExplicitProvider::api_key("key"));
174
175        let _ = chain.resolve().await.unwrap();
176
177        assert!(!chain.supports_refresh());
178    }
179
180    #[tokio::test]
181    async fn test_chain_tracks_last_successful() {
182        let chain = ChainProvider::new(vec![])
183            .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR"))
184            .provider(ExplicitProvider::api_key("fallback"));
185
186        let _ = chain.resolve().await.unwrap();
187
188        let last = chain.last_successful.read().await;
189        assert!(last.is_some());
190        assert_eq!(last.as_ref().unwrap().name(), "explicit");
191    }
192}