claude_agent/auth/providers/
chain.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8#[cfg(feature = "cli-integration")]
9use crate::auth::ClaudeCliProvider;
10use crate::auth::{Credential, CredentialProvider, EnvironmentProvider};
11use crate::{Error, Result};
12
13pub struct ChainProvider {
14 providers: Vec<Arc<dyn CredentialProvider>>,
15 last_successful: RwLock<Option<Arc<dyn CredentialProvider>>>,
16}
17
18impl ChainProvider {
19 pub fn new(providers: Vec<Box<dyn CredentialProvider>>) -> Self {
20 Self {
21 providers: providers.into_iter().map(Arc::from).collect(),
22 last_successful: RwLock::new(None),
23 }
24 }
25
26 pub fn provider<P: CredentialProvider + 'static>(mut self, provider: P) -> Self {
27 self.providers.push(Arc::new(provider));
28 self
29 }
30}
31
32#[cfg(feature = "cli-integration")]
33impl Default for ChainProvider {
34 fn default() -> Self {
35 Self {
36 providers: vec![
37 Arc::new(EnvironmentProvider::new()),
38 Arc::new(ClaudeCliProvider::new()),
39 ],
40 last_successful: RwLock::new(None),
41 }
42 }
43}
44
45#[cfg(not(feature = "cli-integration"))]
46impl Default for ChainProvider {
47 fn default() -> Self {
48 Self {
49 providers: vec![Arc::new(EnvironmentProvider::new())],
50 last_successful: RwLock::new(None),
51 }
52 }
53}
54
55#[async_trait]
56impl CredentialProvider for ChainProvider {
57 fn name(&self) -> &str {
58 "chain"
59 }
60
61 async fn resolve(&self) -> Result<Credential> {
62 let mut errors = Vec::new();
63
64 for provider in &self.providers {
65 match provider.resolve().await {
66 Ok(cred) => {
67 tracing::debug!("Credential resolved from: {}", provider.name());
68 *self.last_successful.write().await = Some(Arc::clone(provider));
69 return Ok(cred);
70 }
71 Err(e) => {
72 tracing::debug!("Provider {} failed: {}", provider.name(), e);
73 errors.push(format!("{}: {}", provider.name(), e));
74 }
75 }
76 }
77
78 Err(Error::auth(format!(
79 "No credentials found. Tried: {}",
80 errors.join(", ")
81 )))
82 }
83
84 async fn refresh(&self) -> Result<Credential> {
85 let provider = self.last_successful.read().await;
86 match provider.as_ref() {
87 Some(p) if p.supports_refresh() => p.refresh().await,
88 Some(_) => Err(Error::auth(
89 "Last successful provider does not support refresh",
90 )),
91 None => Err(Error::auth("No provider has successfully resolved yet")),
92 }
93 }
94
95 fn supports_refresh(&self) -> bool {
96 self.last_successful
97 .try_read()
98 .ok()
99 .and_then(|guard| guard.as_ref().map(|p| p.supports_refresh()))
100 .unwrap_or(false)
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::auth::ExplicitProvider;
108 use secrecy::ExposeSecret;
109
110 #[tokio::test]
111 async fn test_chain_first_success() {
112 let chain = ChainProvider::new(vec![])
113 .provider(ExplicitProvider::api_key("first"))
114 .provider(ExplicitProvider::api_key("second"));
115
116 let cred = chain.resolve().await.unwrap();
117 assert!(matches!(&cred, Credential::ApiKey(k) if k.expose_secret() == "first"));
118 }
119
120 #[tokio::test]
121 async fn test_chain_fallback() {
122 let chain = ChainProvider::new(vec![])
123 .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR"))
124 .provider(ExplicitProvider::api_key("fallback"));
125
126 let cred = chain.resolve().await.unwrap();
127 assert!(matches!(&cred, Credential::ApiKey(k) if k.expose_secret() == "fallback"));
128 }
129
130 #[tokio::test]
131 async fn test_chain_all_fail() {
132 let chain = ChainProvider::new(vec![])
133 .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR_1"))
134 .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR_2"));
135
136 assert!(chain.resolve().await.is_err());
137 }
138
139 struct RefreshableProvider;
140
141 #[async_trait]
142 impl CredentialProvider for RefreshableProvider {
143 fn name(&self) -> &str {
144 "refreshable"
145 }
146
147 async fn resolve(&self) -> Result<Credential> {
148 Ok(Credential::api_key("refreshable-key"))
149 }
150
151 fn supports_refresh(&self) -> bool {
152 true
153 }
154
155 async fn refresh(&self) -> Result<Credential> {
156 Ok(Credential::api_key("refreshed-key"))
157 }
158 }
159
160 #[tokio::test]
161 async fn test_supports_refresh_after_resolve() {
162 let chain = ChainProvider::new(vec![]).provider(RefreshableProvider);
163
164 assert!(!chain.supports_refresh());
165
166 let _ = chain.resolve().await.unwrap();
167
168 assert!(chain.supports_refresh());
169 }
170
171 #[tokio::test]
172 async fn test_supports_refresh_with_non_refreshable() {
173 let chain = ChainProvider::new(vec![]).provider(ExplicitProvider::api_key("key"));
174
175 let _ = chain.resolve().await.unwrap();
176
177 assert!(!chain.supports_refresh());
178 }
179
180 #[tokio::test]
181 async fn test_chain_tracks_last_successful() {
182 let chain = ChainProvider::new(vec![])
183 .provider(EnvironmentProvider::from_var("NONEXISTENT_VAR"))
184 .provider(ExplicitProvider::api_key("fallback"));
185
186 let _ = chain.resolve().await.unwrap();
187
188 let last = chain.last_successful.read().await;
189 assert!(last.is_some());
190 assert_eq!(last.as_ref().unwrap().name(), "explicit");
191 }
192}