claude_agent/auth/
helper.rs1use std::process::Stdio;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use tokio::process::Command;
8use tokio::sync::RwLock;
9
10use crate::{Error, Result};
11
12#[derive(Debug)]
13pub struct ApiKeyHelper {
14 command: String,
15 ttl: Duration,
16 cache: RwLock<Option<CachedKey>>,
17}
18
19#[derive(Debug, Clone)]
20struct CachedKey {
21 key: String,
22 expires_at: Instant,
23}
24
25impl ApiKeyHelper {
26 pub fn new(command: impl Into<String>) -> Self {
27 Self {
28 command: command.into(),
29 ttl: Duration::from_secs(3600),
30 cache: RwLock::new(None),
31 }
32 }
33
34 pub fn with_ttl(mut self, ttl: Duration) -> Self {
35 self.ttl = ttl;
36 self
37 }
38
39 pub fn with_ttl_ms(mut self, ttl_ms: u64) -> Self {
40 self.ttl = Duration::from_millis(ttl_ms);
41 self
42 }
43
44 pub fn from_env() -> Option<Self> {
45 let command = std::env::var("ANTHROPIC_API_KEY_HELPER").ok()?;
46 let ttl_ms = std::env::var("CLAUDE_CODE_API_KEY_HELPER_TTL_MS")
47 .ok()
48 .and_then(|v| v.parse().ok())
49 .unwrap_or(3_600_000);
50
51 Some(Self::new(command).with_ttl_ms(ttl_ms))
52 }
53
54 pub async fn get_key(&self) -> Result<String> {
55 {
56 let cache = self.cache.read().await;
57 if let Some(ref cached) = *cache
58 && Instant::now() < cached.expires_at
59 {
60 return Ok(cached.key.clone());
61 }
62 }
63
64 let key = self.execute_helper().await?;
65
66 let cached = CachedKey {
67 key: key.clone(),
68 expires_at: Instant::now() + self.ttl,
69 };
70
71 *self.cache.write().await = Some(cached);
72
73 Ok(key)
74 }
75
76 async fn execute_helper(&self) -> Result<String> {
77 let output = Command::new("sh")
78 .arg("-c")
79 .arg(&self.command)
80 .stdout(Stdio::piped())
81 .stderr(Stdio::piped())
82 .output()
83 .await
84 .map_err(|e| Error::auth(format!("Failed to execute API key helper: {}", e)))?;
85
86 if !output.status.success() {
87 let stderr = String::from_utf8_lossy(&output.stderr);
88 return Err(Error::auth(format!(
89 "API key helper failed: {}",
90 stderr.trim()
91 )));
92 }
93
94 let key = String::from_utf8_lossy(&output.stdout).trim().to_string();
95
96 if key.is_empty() {
97 return Err(Error::auth("API key helper returned empty key"));
98 }
99
100 Ok(key)
101 }
102
103 pub async fn invalidate(&self) {
104 *self.cache.write().await = None;
105 }
106}
107
108#[derive(Debug)]
109pub struct AwsCredentialRefresh {
110 auth_refresh_cmd: Option<String>,
111 credential_export_cmd: Option<String>,
112}
113
114impl AwsCredentialRefresh {
115 pub fn new() -> Self {
116 Self {
117 auth_refresh_cmd: None,
118 credential_export_cmd: None,
119 }
120 }
121
122 pub fn from_settings(
123 auth_refresh: Option<String>,
124 credential_export: Option<String>,
125 ) -> Option<Self> {
126 if auth_refresh.is_none() && credential_export.is_none() {
127 return None;
128 }
129
130 Some(Self {
131 auth_refresh_cmd: auth_refresh,
132 credential_export_cmd: credential_export,
133 })
134 }
135
136 pub async fn refresh(&self) -> Result<Option<AwsCredentials>> {
137 if let Some(ref cmd) = self.credential_export_cmd {
138 return self.export_credentials(cmd).await.map(Some);
139 }
140
141 if let Some(ref cmd) = self.auth_refresh_cmd {
142 self.run_auth_refresh(cmd).await?;
143 }
144
145 Ok(None)
146 }
147
148 async fn run_auth_refresh(&self, cmd: &str) -> Result<()> {
149 let output = Command::new("sh")
150 .arg("-c")
151 .arg(cmd)
152 .stdout(Stdio::piped())
153 .stderr(Stdio::piped())
154 .output()
155 .await
156 .map_err(|e| Error::auth(format!("AWS auth refresh failed: {}", e)))?;
157
158 if !output.status.success() {
159 let stderr = String::from_utf8_lossy(&output.stderr);
160 return Err(Error::auth(format!(
161 "AWS auth refresh failed: {}",
162 stderr.trim()
163 )));
164 }
165
166 Ok(())
167 }
168
169 async fn export_credentials(&self, cmd: &str) -> Result<AwsCredentials> {
170 let output = Command::new("sh")
171 .arg("-c")
172 .arg(cmd)
173 .stdout(Stdio::piped())
174 .stderr(Stdio::piped())
175 .output()
176 .await
177 .map_err(|e| Error::auth(format!("AWS credential export failed: {}", e)))?;
178
179 if !output.status.success() {
180 let stderr = String::from_utf8_lossy(&output.stderr);
181 return Err(Error::auth(format!(
182 "AWS credential export failed: {}",
183 stderr.trim()
184 )));
185 }
186
187 let stdout = String::from_utf8_lossy(&output.stdout);
188 let json: serde_json::Value = serde_json::from_str(&stdout)
189 .map_err(|e| Error::auth(format!("Invalid credential JSON: {}", e)))?;
190
191 let creds = json
192 .get("Credentials")
193 .ok_or_else(|| Error::auth("Missing Credentials in response"))?;
194
195 Ok(AwsCredentials {
196 access_key_id: creds
197 .get("AccessKeyId")
198 .and_then(|v| v.as_str())
199 .ok_or_else(|| Error::auth("Missing AccessKeyId"))?
200 .to_string(),
201 secret_access_key: creds
202 .get("SecretAccessKey")
203 .and_then(|v| v.as_str())
204 .ok_or_else(|| Error::auth("Missing SecretAccessKey"))?
205 .to_string(),
206 session_token: creds
207 .get("SessionToken")
208 .and_then(|v| v.as_str())
209 .map(String::from),
210 })
211 }
212}
213
214impl Default for AwsCredentialRefresh {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[derive(Debug, Clone)]
221pub struct AwsCredentials {
222 pub access_key_id: String,
223 pub secret_access_key: String,
224 pub session_token: Option<String>,
225}
226
227#[derive(Debug)]
228pub struct CredentialManager {
229 api_key_helper: Option<Arc<ApiKeyHelper>>,
230 aws_refresh: Option<Arc<AwsCredentialRefresh>>,
231}
232
233impl CredentialManager {
234 pub fn new() -> Self {
235 Self {
236 api_key_helper: None,
237 aws_refresh: None,
238 }
239 }
240
241 pub fn with_api_key_helper(mut self, helper: ApiKeyHelper) -> Self {
242 self.api_key_helper = Some(Arc::new(helper));
243 self
244 }
245
246 pub fn with_aws_refresh(mut self, refresh: AwsCredentialRefresh) -> Self {
247 self.aws_refresh = Some(Arc::new(refresh));
248 self
249 }
250
251 pub async fn get_api_key(&self) -> Result<Option<String>> {
252 match &self.api_key_helper {
253 Some(helper) => helper.get_key().await.map(Some),
254 None => Ok(None),
255 }
256 }
257
258 pub async fn refresh_aws(&self) -> Result<Option<AwsCredentials>> {
259 match &self.aws_refresh {
260 Some(refresh) => refresh.refresh().await,
261 None => Ok(None),
262 }
263 }
264}
265
266impl Default for CredentialManager {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[tokio::test]
277 async fn test_api_key_helper_echo() {
278 let helper = ApiKeyHelper::new("echo test-key");
279 let key = helper.get_key().await.unwrap();
280 assert_eq!(key, "test-key");
281 }
282
283 #[tokio::test]
284 async fn test_api_key_helper_caching() {
285 let helper = ApiKeyHelper::new("echo test-key").with_ttl(Duration::from_secs(60));
286
287 let key1 = helper.get_key().await.unwrap();
288 let key2 = helper.get_key().await.unwrap();
289 assert_eq!(key1, key2);
290 }
291
292 #[tokio::test]
293 async fn test_api_key_helper_failure() {
294 let helper = ApiKeyHelper::new("exit 1");
295 assert!(helper.get_key().await.is_err());
296 }
297
298 #[test]
299 fn test_credential_manager_default() {
300 let manager = CredentialManager::default();
301 assert!(manager.api_key_helper.is_none());
302 assert!(manager.aws_refresh.is_none());
303 }
304}