Skip to main content

claude_agent/auth/
credential.rs

1//! Credential types.
2
3use std::fmt;
4
5use chrono::{DateTime, Duration, Utc};
6use secrecy::{ExposeSecret, SecretString};
7use serde::{Deserialize, Serialize};
8
9mod secret_serde {
10    use secrecy::{ExposeSecret, SecretString};
11    use serde::{Deserialize, Deserializer, Serializer};
12
13    pub fn serialize<S: Serializer>(
14        secret: &SecretString,
15        serializer: S,
16    ) -> Result<S::Ok, S::Error> {
17        serializer.serialize_str(secret.expose_secret())
18    }
19
20    pub fn deserialize<'de, D: Deserializer<'de>>(
21        deserializer: D,
22    ) -> Result<SecretString, D::Error> {
23        let s = String::deserialize(deserializer)?;
24        Ok(SecretString::from(s))
25    }
26}
27
28mod option_secret_serde {
29    use secrecy::{ExposeSecret, SecretString};
30    use serde::{Deserialize, Deserializer, Serializer};
31
32    pub fn serialize<S: Serializer>(
33        secret: &Option<SecretString>,
34        serializer: S,
35    ) -> Result<S::Ok, S::Error> {
36        match secret {
37            Some(s) => serializer.serialize_some(s.expose_secret()),
38            None => serializer.serialize_none(),
39        }
40    }
41
42    pub fn deserialize<'de, D: Deserializer<'de>>(
43        deserializer: D,
44    ) -> Result<Option<SecretString>, D::Error> {
45        let opt = Option::<String>::deserialize(deserializer)?;
46        Ok(opt.map(SecretString::from))
47    }
48}
49
50#[derive(Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct OAuthCredential {
53    #[serde(with = "secret_serde")]
54    pub access_token: SecretString,
55    #[serde(
56        with = "option_secret_serde",
57        default,
58        skip_serializing_if = "Option::is_none"
59    )]
60    pub refresh_token: Option<SecretString>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub expires_at: Option<i64>,
63    #[serde(default)]
64    pub scopes: Vec<String>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub subscription_type: Option<String>,
67}
68
69impl fmt::Debug for OAuthCredential {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("OAuthCredential")
72            .field("access_token", &"[redacted]")
73            .field(
74                "refresh_token",
75                &self.refresh_token.as_ref().map(|_| "[redacted]"),
76            )
77            .field("expires_at", &self.expires_at)
78            .field("scopes", &self.scopes)
79            .field("subscription_type", &self.subscription_type)
80            .finish()
81    }
82}
83
84impl OAuthCredential {
85    pub fn expires_at_datetime(&self) -> Option<DateTime<Utc>> {
86        self.expires_at.and_then(|ts| {
87            DateTime::from_timestamp(ts, 0).or_else(|| {
88                tracing::warn!(
89                    timestamp = ts,
90                    "Invalid expires_at timestamp, treating as expired"
91                );
92                None
93            })
94        })
95    }
96
97    pub fn is_expired(&self) -> bool {
98        match (self.expires_at, self.expires_at_datetime()) {
99            (Some(_), None) => true,
100            (_, Some(exp)) => Utc::now() >= exp,
101            (None, None) => false,
102        }
103    }
104
105    /// Returns true within 5 minutes of expiry.
106    pub fn needs_refresh(&self) -> bool {
107        match (self.expires_at, self.expires_at_datetime()) {
108            (Some(_), None) => true,
109            (_, Some(exp)) => Utc::now() >= exp - Duration::minutes(5),
110            (None, None) => false,
111        }
112    }
113}
114
115#[derive(Clone)]
116pub enum Credential {
117    ApiKey(SecretString),
118    OAuth(OAuthCredential),
119}
120
121impl fmt::Debug for Credential {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        match self {
124            Self::ApiKey(_) => f.debug_tuple("ApiKey").field(&"[redacted]").finish(),
125            Self::OAuth(oauth) => f.debug_tuple("OAuth").field(oauth).finish(),
126        }
127    }
128}
129
130impl Credential {
131    /// Create a placeholder credential for cloud providers that handle
132    /// authentication through their own token mechanisms (Bedrock, Vertex, Foundry).
133    pub fn placeholder() -> Self {
134        Self::ApiKey(SecretString::from(""))
135    }
136
137    pub fn api_key(key: impl Into<String>) -> Self {
138        Self::ApiKey(SecretString::from(key.into()))
139    }
140
141    pub fn oauth(token: impl Into<String>) -> Self {
142        Self::OAuth(OAuthCredential {
143            access_token: SecretString::from(token.into()),
144            refresh_token: None,
145            expires_at: None,
146            scopes: vec![],
147            subscription_type: None,
148        })
149    }
150
151    pub fn is_placeholder(&self) -> bool {
152        match self {
153            Self::ApiKey(key) => key.expose_secret().is_empty(),
154            Self::OAuth(oauth) => oauth.access_token.expose_secret().is_empty(),
155        }
156    }
157
158    pub fn is_expired(&self) -> bool {
159        match self {
160            Credential::ApiKey(_) => false,
161            Credential::OAuth(oauth) => oauth.is_expired(),
162        }
163    }
164
165    pub fn needs_refresh(&self) -> bool {
166        match self {
167            Credential::ApiKey(_) => false,
168            Credential::OAuth(oauth) => oauth.needs_refresh(),
169        }
170    }
171
172    pub fn credential_type(&self) -> &'static str {
173        match self {
174            Credential::ApiKey(_) => "api_key",
175            Credential::OAuth(_) => "oauth",
176        }
177    }
178
179    pub fn is_oauth(&self) -> bool {
180        matches!(self, Credential::OAuth(_))
181    }
182
183    pub fn is_api_key(&self) -> bool {
184        matches!(self, Credential::ApiKey(_))
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_api_key_credential() {
194        let cred = Credential::api_key("sk-ant-api-test");
195        assert!(!cred.is_expired());
196        assert!(!cred.needs_refresh());
197        assert_eq!(cred.credential_type(), "api_key");
198    }
199
200    #[test]
201    fn test_oauth_credential() {
202        let cred = Credential::oauth("sk-ant-oat01-test");
203        assert_eq!(cred.credential_type(), "oauth");
204    }
205
206    #[test]
207    fn test_oauth_expiry() {
208        let expired = OAuthCredential {
209            access_token: SecretString::from("test"),
210            refresh_token: None,
211            expires_at: Some(0),
212            scopes: vec![],
213            subscription_type: None,
214        };
215        assert!(expired.is_expired());
216
217        let future = OAuthCredential {
218            access_token: SecretString::from("test"),
219            refresh_token: None,
220            expires_at: Some(Utc::now().timestamp() + 3600),
221            scopes: vec![],
222            subscription_type: None,
223        };
224        assert!(!future.is_expired());
225    }
226
227    #[test]
228    fn test_credential_debug_redacts_secrets() {
229        let cred = Credential::api_key("super-secret-key");
230        let debug = format!("{:?}", cred);
231        assert!(!debug.contains("super-secret-key"));
232        assert!(debug.contains("[redacted]"));
233    }
234
235    #[test]
236    fn test_oauth_debug_redacts_tokens() {
237        let oauth = OAuthCredential {
238            access_token: SecretString::from("secret-token"),
239            refresh_token: Some(SecretString::from("secret-refresh")),
240            expires_at: None,
241            scopes: vec![],
242            subscription_type: None,
243        };
244        let debug = format!("{:?}", oauth);
245        assert!(!debug.contains("secret-token"));
246        assert!(!debug.contains("secret-refresh"));
247        assert!(debug.contains("[redacted]"));
248    }
249}