1use std::fmt;
4
5use chrono::{DateTime, Duration, Utc};
6use secrecy::{ExposeSecret, SecretString};
7use serde::{Deserialize, Serialize};
8
9mod secret_serde {
10 use secrecy::{ExposeSecret, SecretString};
11 use serde::{Deserialize, Deserializer, Serializer};
12
13 pub fn serialize<S: Serializer>(
14 secret: &SecretString,
15 serializer: S,
16 ) -> Result<S::Ok, S::Error> {
17 serializer.serialize_str(secret.expose_secret())
18 }
19
20 pub fn deserialize<'de, D: Deserializer<'de>>(
21 deserializer: D,
22 ) -> Result<SecretString, D::Error> {
23 let s = String::deserialize(deserializer)?;
24 Ok(SecretString::from(s))
25 }
26}
27
28mod option_secret_serde {
29 use secrecy::{ExposeSecret, SecretString};
30 use serde::{Deserialize, Deserializer, Serializer};
31
32 pub fn serialize<S: Serializer>(
33 secret: &Option<SecretString>,
34 serializer: S,
35 ) -> Result<S::Ok, S::Error> {
36 match secret {
37 Some(s) => serializer.serialize_some(s.expose_secret()),
38 None => serializer.serialize_none(),
39 }
40 }
41
42 pub fn deserialize<'de, D: Deserializer<'de>>(
43 deserializer: D,
44 ) -> Result<Option<SecretString>, D::Error> {
45 let opt = Option::<String>::deserialize(deserializer)?;
46 Ok(opt.map(SecretString::from))
47 }
48}
49
50#[derive(Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct OAuthCredential {
53 #[serde(with = "secret_serde")]
54 pub access_token: SecretString,
55 #[serde(
56 with = "option_secret_serde",
57 default,
58 skip_serializing_if = "Option::is_none"
59 )]
60 pub refresh_token: Option<SecretString>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub expires_at: Option<i64>,
63 #[serde(default)]
64 pub scopes: Vec<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub subscription_type: Option<String>,
67}
68
69impl fmt::Debug for OAuthCredential {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("OAuthCredential")
72 .field("access_token", &"[redacted]")
73 .field(
74 "refresh_token",
75 &self.refresh_token.as_ref().map(|_| "[redacted]"),
76 )
77 .field("expires_at", &self.expires_at)
78 .field("scopes", &self.scopes)
79 .field("subscription_type", &self.subscription_type)
80 .finish()
81 }
82}
83
84impl OAuthCredential {
85 pub fn expires_at_datetime(&self) -> Option<DateTime<Utc>> {
86 self.expires_at.and_then(|ts| {
87 DateTime::from_timestamp(ts, 0).or_else(|| {
88 tracing::warn!(
89 timestamp = ts,
90 "Invalid expires_at timestamp, treating as expired"
91 );
92 None
93 })
94 })
95 }
96
97 pub fn is_expired(&self) -> bool {
98 match (self.expires_at, self.expires_at_datetime()) {
99 (Some(_), None) => true,
100 (_, Some(exp)) => Utc::now() >= exp,
101 (None, None) => false,
102 }
103 }
104
105 pub fn needs_refresh(&self) -> bool {
107 match (self.expires_at, self.expires_at_datetime()) {
108 (Some(_), None) => true,
109 (_, Some(exp)) => Utc::now() >= exp - Duration::minutes(5),
110 (None, None) => false,
111 }
112 }
113}
114
115#[derive(Clone)]
116pub enum Credential {
117 ApiKey(SecretString),
118 OAuth(OAuthCredential),
119}
120
121impl fmt::Debug for Credential {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 match self {
124 Self::ApiKey(_) => f.debug_tuple("ApiKey").field(&"[redacted]").finish(),
125 Self::OAuth(oauth) => f.debug_tuple("OAuth").field(oauth).finish(),
126 }
127 }
128}
129
130impl Credential {
131 pub fn placeholder() -> Self {
134 Self::ApiKey(SecretString::from(""))
135 }
136
137 pub fn api_key(key: impl Into<String>) -> Self {
138 Self::ApiKey(SecretString::from(key.into()))
139 }
140
141 pub fn oauth(token: impl Into<String>) -> Self {
142 Self::OAuth(OAuthCredential {
143 access_token: SecretString::from(token.into()),
144 refresh_token: None,
145 expires_at: None,
146 scopes: vec![],
147 subscription_type: None,
148 })
149 }
150
151 pub fn is_placeholder(&self) -> bool {
152 match self {
153 Self::ApiKey(key) => key.expose_secret().is_empty(),
154 Self::OAuth(oauth) => oauth.access_token.expose_secret().is_empty(),
155 }
156 }
157
158 pub fn is_expired(&self) -> bool {
159 match self {
160 Credential::ApiKey(_) => false,
161 Credential::OAuth(oauth) => oauth.is_expired(),
162 }
163 }
164
165 pub fn needs_refresh(&self) -> bool {
166 match self {
167 Credential::ApiKey(_) => false,
168 Credential::OAuth(oauth) => oauth.needs_refresh(),
169 }
170 }
171
172 pub fn credential_type(&self) -> &'static str {
173 match self {
174 Credential::ApiKey(_) => "api_key",
175 Credential::OAuth(_) => "oauth",
176 }
177 }
178
179 pub fn is_oauth(&self) -> bool {
180 matches!(self, Credential::OAuth(_))
181 }
182
183 pub fn is_api_key(&self) -> bool {
184 matches!(self, Credential::ApiKey(_))
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_api_key_credential() {
194 let cred = Credential::api_key("sk-ant-api-test");
195 assert!(!cred.is_expired());
196 assert!(!cred.needs_refresh());
197 assert_eq!(cred.credential_type(), "api_key");
198 }
199
200 #[test]
201 fn test_oauth_credential() {
202 let cred = Credential::oauth("sk-ant-oat01-test");
203 assert_eq!(cred.credential_type(), "oauth");
204 }
205
206 #[test]
207 fn test_oauth_expiry() {
208 let expired = OAuthCredential {
209 access_token: SecretString::from("test"),
210 refresh_token: None,
211 expires_at: Some(0),
212 scopes: vec![],
213 subscription_type: None,
214 };
215 assert!(expired.is_expired());
216
217 let future = OAuthCredential {
218 access_token: SecretString::from("test"),
219 refresh_token: None,
220 expires_at: Some(Utc::now().timestamp() + 3600),
221 scopes: vec![],
222 subscription_type: None,
223 };
224 assert!(!future.is_expired());
225 }
226
227 #[test]
228 fn test_credential_debug_redacts_secrets() {
229 let cred = Credential::api_key("super-secret-key");
230 let debug = format!("{:?}", cred);
231 assert!(!debug.contains("super-secret-key"));
232 assert!(debug.contains("[redacted]"));
233 }
234
235 #[test]
236 fn test_oauth_debug_redacts_tokens() {
237 let oauth = OAuthCredential {
238 access_token: SecretString::from("secret-token"),
239 refresh_token: Some(SecretString::from("secret-refresh")),
240 expires_at: None,
241 scopes: vec![],
242 subscription_type: None,
243 };
244 let debug = format!("{:?}", oauth);
245 assert!(!debug.contains("secret-token"));
246 assert!(!debug.contains("secret-refresh"));
247 assert!(debug.contains("[redacted]"));
248 }
249}