claude_agent/auth/
credential.rs

1//! Credential types.
2
3use std::fmt;
4
5use chrono::{DateTime, Duration, Utc};
6use serde::{Deserialize, Serialize};
7
8/// OAuth credential from Claude Code CLI.
9#[derive(Clone, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct OAuthCredential {
12    pub access_token: String,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub refresh_token: Option<String>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub expires_at: Option<i64>,
17    #[serde(default)]
18    pub scopes: Vec<String>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub subscription_type: Option<String>,
21}
22
23impl fmt::Debug for OAuthCredential {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_struct("OAuthCredential")
26            .field("access_token", &"[redacted]")
27            .field(
28                "refresh_token",
29                &self.refresh_token.as_ref().map(|_| "[redacted]"),
30            )
31            .field("expires_at", &self.expires_at)
32            .field("scopes", &self.scopes)
33            .field("subscription_type", &self.subscription_type)
34            .finish()
35    }
36}
37
38impl OAuthCredential {
39    /// Get expiration as DateTime.
40    pub fn expires_at_datetime(&self) -> Option<DateTime<Utc>> {
41        self.expires_at
42            .map(|ts| DateTime::from_timestamp(ts, 0).unwrap_or_else(Utc::now))
43    }
44
45    /// Check if token is expired.
46    pub fn is_expired(&self) -> bool {
47        self.expires_at_datetime()
48            .map(|exp| Utc::now() >= exp)
49            .unwrap_or(false)
50    }
51
52    /// Check if token needs refresh (within 5 minutes of expiry).
53    pub fn needs_refresh(&self) -> bool {
54        self.expires_at_datetime()
55            .map(|exp| Utc::now() >= exp - Duration::minutes(5))
56            .unwrap_or(false)
57    }
58}
59
60/// Authentication credential.
61#[derive(Clone)]
62pub enum Credential {
63    ApiKey(String),
64    OAuth(OAuthCredential),
65}
66
67impl fmt::Debug for Credential {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        match self {
70            Self::ApiKey(_) => f.debug_tuple("ApiKey").field(&"[redacted]").finish(),
71            Self::OAuth(oauth) => f.debug_tuple("OAuth").field(oauth).finish(),
72        }
73    }
74}
75
76impl Default for Credential {
77    /// Default credential is an empty API key (placeholder for cloud providers).
78    fn default() -> Self {
79        Self::ApiKey(String::new())
80    }
81}
82
83impl Credential {
84    /// Create API Key credential.
85    pub fn api_key(key: impl Into<String>) -> Self {
86        Self::ApiKey(key.into())
87    }
88
89    /// Create OAuth credential.
90    pub fn oauth(token: impl Into<String>) -> Self {
91        Self::OAuth(OAuthCredential {
92            access_token: token.into(),
93            refresh_token: None,
94            expires_at: None,
95            scopes: vec![],
96            subscription_type: None,
97        })
98    }
99
100    /// Check if this is a default (empty) credential.
101    /// Used for cloud providers that handle auth differently.
102    pub fn is_default(&self) -> bool {
103        match self {
104            Self::ApiKey(key) => key.is_empty(),
105            Self::OAuth(oauth) => oauth.access_token.is_empty(),
106        }
107    }
108
109    /// Check if credential is expired.
110    pub fn is_expired(&self) -> bool {
111        match self {
112            Credential::ApiKey(_) => false,
113            Credential::OAuth(oauth) => oauth.is_expired(),
114        }
115    }
116
117    /// Check if credential needs refresh.
118    pub fn needs_refresh(&self) -> bool {
119        match self {
120            Credential::ApiKey(_) => false,
121            Credential::OAuth(oauth) => oauth.needs_refresh(),
122        }
123    }
124
125    /// Get credential type name.
126    pub fn credential_type(&self) -> &'static str {
127        match self {
128            Credential::ApiKey(_) => "api_key",
129            Credential::OAuth(_) => "oauth",
130        }
131    }
132
133    /// Check if this is an OAuth credential.
134    pub fn is_oauth(&self) -> bool {
135        matches!(self, Credential::OAuth(_))
136    }
137
138    /// Check if this is an API key credential.
139    pub fn is_api_key(&self) -> bool {
140        matches!(self, Credential::ApiKey(_))
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_api_key_credential() {
150        let cred = Credential::api_key("sk-ant-api-test");
151        assert!(!cred.is_expired());
152        assert!(!cred.needs_refresh());
153        assert_eq!(cred.credential_type(), "api_key");
154    }
155
156    #[test]
157    fn test_oauth_credential() {
158        let cred = Credential::oauth("sk-ant-oat01-test");
159        assert_eq!(cred.credential_type(), "oauth");
160    }
161
162    #[test]
163    fn test_oauth_expiry() {
164        let expired = OAuthCredential {
165            access_token: "test".into(),
166            refresh_token: None,
167            expires_at: Some(0),
168            scopes: vec![],
169            subscription_type: None,
170        };
171        assert!(expired.is_expired());
172
173        let future = OAuthCredential {
174            access_token: "test".into(),
175            refresh_token: None,
176            expires_at: Some(Utc::now().timestamp() + 3600),
177            scopes: vec![],
178            subscription_type: None,
179        };
180        assert!(!future.is_expired());
181    }
182}