Skip to main content

claude_agent/auth/
helper.rs

1//! API Key Helper for dynamic credential generation.
2
3use std::process::Stdio;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use secrecy::{ExposeSecret, SecretString};
8use tokio::process::Command;
9use tokio::sync::Mutex;
10
11use crate::{Error, Result};
12
13use std::fmt;
14
15async fn run_shell_command(cmd: &str, context: &str) -> Result<String> {
16    let output = Command::new("sh")
17        .arg("-c")
18        .arg(cmd)
19        .stdout(Stdio::piped())
20        .stderr(Stdio::piped())
21        .output()
22        .await
23        .map_err(|e| Error::auth(format!("{} failed: {}", context, e)))?;
24
25    if !output.status.success() {
26        let stderr = String::from_utf8_lossy(&output.stderr);
27        return Err(Error::auth(format!(
28            "{} failed: {}",
29            context,
30            stderr.trim()
31        )));
32    }
33
34    Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
35}
36
37#[derive(Debug)]
38pub struct ApiKeyHelper {
39    command: String,
40    ttl: Duration,
41    cache: Mutex<Option<CachedKey>>,
42}
43
44#[derive(Clone)]
45struct CachedKey {
46    key: SecretString,
47    expires_at: Instant,
48}
49
50impl fmt::Debug for CachedKey {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("CachedKey")
53            .field("key", &"[redacted]")
54            .field("expires_at", &self.expires_at)
55            .finish()
56    }
57}
58
59impl ApiKeyHelper {
60    pub fn new(command: impl Into<String>) -> Self {
61        Self {
62            command: command.into(),
63            ttl: Duration::from_secs(3600),
64            cache: Mutex::new(None),
65        }
66    }
67
68    pub fn ttl(mut self, ttl: Duration) -> Self {
69        self.ttl = ttl;
70        self
71    }
72
73    pub fn ttl_ms(mut self, ttl_ms: u64) -> Self {
74        self.ttl = Duration::from_millis(ttl_ms);
75        self
76    }
77
78    pub fn from_env() -> Option<Self> {
79        let command = std::env::var("ANTHROPIC_API_KEY_HELPER").ok()?;
80        let ttl_ms = std::env::var("CLAUDE_CODE_API_KEY_HELPER_TTL_MS")
81            .ok()
82            .and_then(|v| v.parse().ok())
83            .unwrap_or(3_600_000);
84
85        Some(Self::new(command).ttl_ms(ttl_ms))
86    }
87
88    pub async fn get_key(&self) -> Result<SecretString> {
89        let mut cache = self.cache.lock().await;
90
91        if let Some(ref cached) = *cache
92            && Instant::now() < cached.expires_at
93        {
94            return Ok(cached.key.clone());
95        }
96
97        let key = run_shell_command(&self.command, "API key helper").await?;
98
99        if key.is_empty() {
100            return Err(Error::auth("API key helper returned empty key"));
101        }
102
103        let secret_key = SecretString::from(key);
104
105        *cache = Some(CachedKey {
106            key: secret_key.clone(),
107            expires_at: Instant::now() + self.ttl,
108        });
109
110        Ok(secret_key)
111    }
112
113    pub async fn invalidate(&self) {
114        *self.cache.lock().await = None;
115    }
116}
117
118#[derive(Debug)]
119pub struct AwsCredentialRefresh {
120    auth_refresh_cmd: Option<String>,
121    credential_export_cmd: Option<String>,
122}
123
124impl AwsCredentialRefresh {
125    pub fn new() -> Self {
126        Self {
127            auth_refresh_cmd: None,
128            credential_export_cmd: None,
129        }
130    }
131
132    pub fn from_settings(
133        auth_refresh: Option<String>,
134        credential_export: Option<String>,
135    ) -> Option<Self> {
136        if auth_refresh.is_none() && credential_export.is_none() {
137            return None;
138        }
139
140        Some(Self {
141            auth_refresh_cmd: auth_refresh,
142            credential_export_cmd: credential_export,
143        })
144    }
145
146    pub async fn refresh(&self) -> Result<Option<AwsCredentials>> {
147        if let Some(ref cmd) = self.credential_export_cmd {
148            return self.export_credentials(cmd).await.map(Some);
149        }
150
151        if let Some(ref cmd) = self.auth_refresh_cmd {
152            run_shell_command(cmd, "AWS auth refresh").await?;
153        }
154
155        Ok(None)
156    }
157
158    async fn export_credentials(&self, cmd: &str) -> Result<AwsCredentials> {
159        let stdout = run_shell_command(cmd, "AWS credential export").await?;
160
161        let json: serde_json::Value = serde_json::from_str(&stdout)
162            .map_err(|e| Error::auth(format!("Invalid credential JSON: {}", e)))?;
163
164        let creds = json
165            .get("Credentials")
166            .ok_or_else(|| Error::auth("Missing Credentials in response"))?;
167
168        Ok(AwsCredentials {
169            access_key_id: creds
170                .get("AccessKeyId")
171                .and_then(|v| v.as_str())
172                .ok_or_else(|| Error::auth("Missing AccessKeyId"))?
173                .to_string(),
174            secret_access_key: SecretString::from(
175                creds
176                    .get("SecretAccessKey")
177                    .and_then(|v| v.as_str())
178                    .ok_or_else(|| Error::auth("Missing SecretAccessKey"))?
179                    .to_string(),
180            ),
181            session_token: creds
182                .get("SessionToken")
183                .and_then(|v| v.as_str())
184                .map(|s| SecretString::from(s.to_string())),
185        })
186    }
187}
188
189impl Default for AwsCredentialRefresh {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195#[derive(Clone)]
196pub struct AwsCredentials {
197    pub access_key_id: String,
198    secret_access_key: SecretString,
199    session_token: Option<SecretString>,
200}
201
202impl AwsCredentials {
203    pub fn secret_access_key(&self) -> &str {
204        self.secret_access_key.expose_secret()
205    }
206
207    pub fn session_token(&self) -> Option<&str> {
208        self.session_token.as_ref().map(|s| s.expose_secret())
209    }
210}
211
212impl fmt::Debug for AwsCredentials {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        f.debug_struct("AwsCredentials")
215            .field("access_key_id", &self.access_key_id)
216            .field("secret_access_key", &"[redacted]")
217            .field(
218                "session_token",
219                &self.session_token.as_ref().map(|_| "[redacted]"),
220            )
221            .finish()
222    }
223}
224
225#[derive(Debug)]
226pub struct CredentialManager {
227    api_key_helper: Option<Arc<ApiKeyHelper>>,
228    aws_refresh: Option<Arc<AwsCredentialRefresh>>,
229}
230
231impl CredentialManager {
232    pub fn new() -> Self {
233        Self {
234            api_key_helper: None,
235            aws_refresh: None,
236        }
237    }
238
239    pub fn api_key_helper(mut self, helper: ApiKeyHelper) -> Self {
240        self.api_key_helper = Some(Arc::new(helper));
241        self
242    }
243
244    pub fn aws_refresh(mut self, refresh: AwsCredentialRefresh) -> Self {
245        self.aws_refresh = Some(Arc::new(refresh));
246        self
247    }
248
249    pub async fn get_api_key(&self) -> Result<Option<SecretString>> {
250        match &self.api_key_helper {
251            Some(helper) => helper.get_key().await.map(Some),
252            None => Ok(None),
253        }
254    }
255
256    pub async fn refresh_aws(&self) -> Result<Option<AwsCredentials>> {
257        match &self.aws_refresh {
258            Some(refresh) => refresh.refresh().await,
259            None => Ok(None),
260        }
261    }
262}
263
264impl Default for CredentialManager {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[tokio::test]
275    async fn test_api_key_helper_echo() {
276        let helper = ApiKeyHelper::new("echo test-key");
277        let key = helper.get_key().await.unwrap();
278        assert_eq!(key.expose_secret(), "test-key");
279    }
280
281    #[tokio::test]
282    async fn test_api_key_helper_caching() {
283        let helper = ApiKeyHelper::new("echo test-key").ttl(Duration::from_secs(60));
284
285        let key1 = helper.get_key().await.unwrap();
286        let key2 = helper.get_key().await.unwrap();
287        assert_eq!(key1.expose_secret(), key2.expose_secret());
288    }
289
290    #[tokio::test]
291    async fn test_api_key_helper_failure() {
292        let helper = ApiKeyHelper::new("exit 1");
293        assert!(helper.get_key().await.is_err());
294    }
295
296    #[test]
297    fn test_credential_manager_default() {
298        let manager = CredentialManager::default();
299        assert!(manager.api_key_helper.is_none());
300        assert!(manager.aws_refresh.is_none());
301    }
302}