claude_agent/auth/providers/
chain.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use crate::auth::{ClaudeCliProvider, Credential, CredentialProvider, EnvironmentProvider};
9use crate::{Error, Result};
10
11pub struct ChainProvider {
12 providers: Vec<Arc<dyn CredentialProvider>>,
13 last_successful: RwLock<Option<Arc<dyn CredentialProvider>>>,
14}
15
16impl ChainProvider {
17 pub fn new(providers: Vec<Box<dyn CredentialProvider>>) -> Self {
18 Self {
19 providers: providers.into_iter().map(Arc::from).collect(),
20 last_successful: RwLock::new(None),
21 }
22 }
23
24 pub fn with<P: CredentialProvider + 'static>(mut self, provider: P) -> Self {
25 self.providers.push(Arc::new(provider));
26 self
27 }
28}
29
30impl Default for ChainProvider {
31 fn default() -> Self {
32 Self {
33 providers: vec![
34 Arc::new(EnvironmentProvider::new()),
35 Arc::new(ClaudeCliProvider::new()),
36 ],
37 last_successful: RwLock::new(None),
38 }
39 }
40}
41
42#[async_trait]
43impl CredentialProvider for ChainProvider {
44 fn name(&self) -> &str {
45 "chain"
46 }
47
48 async fn resolve(&self) -> Result<Credential> {
49 let mut errors = Vec::new();
50
51 for provider in &self.providers {
52 match provider.resolve().await {
53 Ok(cred) => {
54 tracing::debug!("Credential resolved from: {}", provider.name());
55 *self.last_successful.write().await = Some(Arc::clone(provider));
56 return Ok(cred);
57 }
58 Err(e) => {
59 tracing::debug!("Provider {} failed: {}", provider.name(), e);
60 errors.push(format!("{}: {}", provider.name(), e));
61 }
62 }
63 }
64
65 Err(Error::auth(format!(
66 "No credentials found. Tried: {}",
67 errors.join(", ")
68 )))
69 }
70
71 async fn refresh(&self) -> Result<Credential> {
72 let provider = self.last_successful.read().await;
73 match provider.as_ref() {
74 Some(p) if p.supports_refresh() => p.refresh().await,
75 Some(_) => Err(Error::auth(
76 "Last successful provider does not support refresh",
77 )),
78 None => Err(Error::auth("No provider has successfully resolved yet")),
79 }
80 }
81
82 fn supports_refresh(&self) -> bool {
83 false
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::auth::ExplicitProvider;
91
92 #[tokio::test]
93 async fn test_chain_first_success() {
94 let chain = ChainProvider::new(vec![])
95 .with(ExplicitProvider::api_key("first"))
96 .with(ExplicitProvider::api_key("second"));
97
98 let cred = chain.resolve().await.unwrap();
99 assert!(matches!(cred, Credential::ApiKey(k) if k == "first"));
100 }
101
102 #[tokio::test]
103 async fn test_chain_fallback() {
104 let chain = ChainProvider::new(vec![])
105 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR"))
106 .with(ExplicitProvider::api_key("fallback"));
107
108 let cred = chain.resolve().await.unwrap();
109 assert!(matches!(cred, Credential::ApiKey(k) if k == "fallback"));
110 }
111
112 #[tokio::test]
113 async fn test_chain_all_fail() {
114 let chain = ChainProvider::new(vec![])
115 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_1"))
116 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR_2"));
117
118 assert!(chain.resolve().await.is_err());
119 }
120
121 #[tokio::test]
122 async fn test_chain_tracks_last_successful() {
123 let chain = ChainProvider::new(vec![])
124 .with(EnvironmentProvider::with_var("NONEXISTENT_VAR"))
125 .with(ExplicitProvider::api_key("fallback"));
126
127 let _ = chain.resolve().await.unwrap();
128
129 let last = chain.last_successful.read().await;
130 assert!(last.is_some());
131 assert_eq!(last.as_ref().unwrap().name(), "explicit");
132 }
133}