claude_agent/auth/
helper.rs1use std::process::Stdio;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use secrecy::{ExposeSecret, SecretString};
8use tokio::process::Command;
9use tokio::sync::Mutex;
10
11use crate::{Error, Result};
12
13use std::fmt;
14
15async fn run_shell_command(cmd: &str, context: &str) -> Result<String> {
16 let output = Command::new("sh")
17 .arg("-c")
18 .arg(cmd)
19 .stdout(Stdio::piped())
20 .stderr(Stdio::piped())
21 .output()
22 .await
23 .map_err(|e| Error::auth(format!("{} failed: {}", context, e)))?;
24
25 if !output.status.success() {
26 let stderr = String::from_utf8_lossy(&output.stderr);
27 return Err(Error::auth(format!(
28 "{} failed: {}",
29 context,
30 stderr.trim()
31 )));
32 }
33
34 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
35}
36
37#[derive(Debug)]
38pub struct ApiKeyHelper {
39 command: String,
40 ttl: Duration,
41 cache: Mutex<Option<CachedKey>>,
42}
43
44#[derive(Clone)]
45struct CachedKey {
46 key: SecretString,
47 expires_at: Instant,
48}
49
50impl fmt::Debug for CachedKey {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("CachedKey")
53 .field("key", &"[redacted]")
54 .field("expires_at", &self.expires_at)
55 .finish()
56 }
57}
58
59impl ApiKeyHelper {
60 pub fn new(command: impl Into<String>) -> Self {
61 Self {
62 command: command.into(),
63 ttl: Duration::from_secs(3600),
64 cache: Mutex::new(None),
65 }
66 }
67
68 pub fn ttl(mut self, ttl: Duration) -> Self {
69 self.ttl = ttl;
70 self
71 }
72
73 pub fn ttl_ms(mut self, ttl_ms: u64) -> Self {
74 self.ttl = Duration::from_millis(ttl_ms);
75 self
76 }
77
78 pub fn from_env() -> Option<Self> {
79 let command = std::env::var("ANTHROPIC_API_KEY_HELPER").ok()?;
80 let ttl_ms = std::env::var("CLAUDE_CODE_API_KEY_HELPER_TTL_MS")
81 .ok()
82 .and_then(|v| v.parse().ok())
83 .unwrap_or(3_600_000);
84
85 Some(Self::new(command).ttl_ms(ttl_ms))
86 }
87
88 pub async fn get_key(&self) -> Result<SecretString> {
89 let mut cache = self.cache.lock().await;
90
91 if let Some(ref cached) = *cache
92 && Instant::now() < cached.expires_at
93 {
94 return Ok(cached.key.clone());
95 }
96
97 let key = run_shell_command(&self.command, "API key helper").await?;
98
99 if key.is_empty() {
100 return Err(Error::auth("API key helper returned empty key"));
101 }
102
103 let secret_key = SecretString::from(key);
104
105 *cache = Some(CachedKey {
106 key: secret_key.clone(),
107 expires_at: Instant::now() + self.ttl,
108 });
109
110 Ok(secret_key)
111 }
112
113 pub async fn invalidate(&self) {
114 *self.cache.lock().await = None;
115 }
116}
117
118#[derive(Debug)]
119pub struct AwsCredentialRefresh {
120 auth_refresh_cmd: Option<String>,
121 credential_export_cmd: Option<String>,
122}
123
124impl AwsCredentialRefresh {
125 pub fn new() -> Self {
126 Self {
127 auth_refresh_cmd: None,
128 credential_export_cmd: None,
129 }
130 }
131
132 pub fn from_settings(
133 auth_refresh: Option<String>,
134 credential_export: Option<String>,
135 ) -> Option<Self> {
136 if auth_refresh.is_none() && credential_export.is_none() {
137 return None;
138 }
139
140 Some(Self {
141 auth_refresh_cmd: auth_refresh,
142 credential_export_cmd: credential_export,
143 })
144 }
145
146 pub async fn refresh(&self) -> Result<Option<AwsCredentials>> {
147 if let Some(ref cmd) = self.credential_export_cmd {
148 return self.export_credentials(cmd).await.map(Some);
149 }
150
151 if let Some(ref cmd) = self.auth_refresh_cmd {
152 run_shell_command(cmd, "AWS auth refresh").await?;
153 }
154
155 Ok(None)
156 }
157
158 async fn export_credentials(&self, cmd: &str) -> Result<AwsCredentials> {
159 let stdout = run_shell_command(cmd, "AWS credential export").await?;
160
161 let json: serde_json::Value = serde_json::from_str(&stdout)
162 .map_err(|e| Error::auth(format!("Invalid credential JSON: {}", e)))?;
163
164 let creds = json
165 .get("Credentials")
166 .ok_or_else(|| Error::auth("Missing Credentials in response"))?;
167
168 Ok(AwsCredentials {
169 access_key_id: creds
170 .get("AccessKeyId")
171 .and_then(|v| v.as_str())
172 .ok_or_else(|| Error::auth("Missing AccessKeyId"))?
173 .to_string(),
174 secret_access_key: SecretString::from(
175 creds
176 .get("SecretAccessKey")
177 .and_then(|v| v.as_str())
178 .ok_or_else(|| Error::auth("Missing SecretAccessKey"))?
179 .to_string(),
180 ),
181 session_token: creds
182 .get("SessionToken")
183 .and_then(|v| v.as_str())
184 .map(|s| SecretString::from(s.to_string())),
185 })
186 }
187}
188
189impl Default for AwsCredentialRefresh {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[derive(Clone)]
196pub struct AwsCredentials {
197 pub access_key_id: String,
198 secret_access_key: SecretString,
199 session_token: Option<SecretString>,
200}
201
202impl AwsCredentials {
203 pub fn secret_access_key(&self) -> &str {
204 self.secret_access_key.expose_secret()
205 }
206
207 pub fn session_token(&self) -> Option<&str> {
208 self.session_token.as_ref().map(|s| s.expose_secret())
209 }
210}
211
212impl fmt::Debug for AwsCredentials {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 f.debug_struct("AwsCredentials")
215 .field("access_key_id", &self.access_key_id)
216 .field("secret_access_key", &"[redacted]")
217 .field(
218 "session_token",
219 &self.session_token.as_ref().map(|_| "[redacted]"),
220 )
221 .finish()
222 }
223}
224
225#[derive(Debug)]
226pub struct CredentialManager {
227 api_key_helper: Option<Arc<ApiKeyHelper>>,
228 aws_refresh: Option<Arc<AwsCredentialRefresh>>,
229}
230
231impl CredentialManager {
232 pub fn new() -> Self {
233 Self {
234 api_key_helper: None,
235 aws_refresh: None,
236 }
237 }
238
239 pub fn api_key_helper(mut self, helper: ApiKeyHelper) -> Self {
240 self.api_key_helper = Some(Arc::new(helper));
241 self
242 }
243
244 pub fn aws_refresh(mut self, refresh: AwsCredentialRefresh) -> Self {
245 self.aws_refresh = Some(Arc::new(refresh));
246 self
247 }
248
249 pub async fn get_api_key(&self) -> Result<Option<SecretString>> {
250 match &self.api_key_helper {
251 Some(helper) => helper.get_key().await.map(Some),
252 None => Ok(None),
253 }
254 }
255
256 pub async fn refresh_aws(&self) -> Result<Option<AwsCredentials>> {
257 match &self.aws_refresh {
258 Some(refresh) => refresh.refresh().await,
259 None => Ok(None),
260 }
261 }
262}
263
264impl Default for CredentialManager {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[tokio::test]
275 async fn test_api_key_helper_echo() {
276 let helper = ApiKeyHelper::new("echo test-key");
277 let key = helper.get_key().await.unwrap();
278 assert_eq!(key.expose_secret(), "test-key");
279 }
280
281 #[tokio::test]
282 async fn test_api_key_helper_caching() {
283 let helper = ApiKeyHelper::new("echo test-key").ttl(Duration::from_secs(60));
284
285 let key1 = helper.get_key().await.unwrap();
286 let key2 = helper.get_key().await.unwrap();
287 assert_eq!(key1.expose_secret(), key2.expose_secret());
288 }
289
290 #[tokio::test]
291 async fn test_api_key_helper_failure() {
292 let helper = ApiKeyHelper::new("exit 1");
293 assert!(helper.get_key().await.is_err());
294 }
295
296 #[test]
297 fn test_credential_manager_default() {
298 let manager = CredentialManager::default();
299 assert!(manager.api_key_helper.is_none());
300 assert!(manager.aws_refresh.is_none());
301 }
302}