clawspec_core/client/oauth2/
token.rs1#![allow(unused_assignments)]
8
9use std::fmt;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use tokio::sync::RwLock;
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16#[derive(Clone, Zeroize, ZeroizeOnDrop)]
18pub struct OAuth2Token {
19 access_token: String,
21 #[zeroize(skip)]
23 expires_at: Option<Instant>,
24 refresh_token: Option<String>,
26}
27
28impl OAuth2Token {
29 pub fn new(access_token: impl Into<String>) -> Self {
31 Self {
32 access_token: access_token.into(),
33 expires_at: None,
34 refresh_token: None,
35 }
36 }
37
38 pub fn with_expiry(access_token: impl Into<String>, expires_in: Duration) -> Self {
40 Self {
41 access_token: access_token.into(),
42 expires_at: Some(Instant::now() + expires_in),
43 refresh_token: None,
44 }
45 }
46
47 #[must_use]
49 pub fn with_refresh_token(mut self, refresh_token: impl Into<String>) -> Self {
50 self.refresh_token = Some(refresh_token.into());
51 self
52 }
53
54 pub fn access_token(&self) -> &str {
56 &self.access_token
57 }
58
59 pub fn refresh_token(&self) -> Option<&str> {
61 self.refresh_token.as_deref()
62 }
63
64 pub fn is_expired(&self) -> bool {
68 self.expires_at.is_some_and(|exp| Instant::now() >= exp)
69 }
70
71 pub fn should_refresh(&self, threshold: Duration) -> bool {
75 self.expires_at
76 .is_some_and(|exp| Instant::now() + threshold >= exp)
77 }
78
79 pub fn time_until_expiry(&self) -> Option<Duration> {
81 self.expires_at.and_then(|exp| {
82 let now = Instant::now();
83 if now >= exp { None } else { Some(exp - now) }
84 })
85 }
86}
87
88impl fmt::Debug for OAuth2Token {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.debug_struct("OAuth2Token")
91 .field("access_token", &"[REDACTED]")
92 .field("expires_at", &self.expires_at)
93 .field(
94 "refresh_token",
95 &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
96 )
97 .finish()
98 }
99}
100
101#[derive(Debug, Clone, Default)]
106pub struct TokenCache {
107 inner: Arc<RwLock<Option<OAuth2Token>>>,
108}
109
110impl TokenCache {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn with_token(token: OAuth2Token) -> Self {
118 Self {
119 inner: Arc::new(RwLock::new(Some(token))),
120 }
121 }
122
123 pub async fn get(&self) -> Option<OAuth2Token> {
125 let guard = self.inner.read().await;
126 guard.as_ref().filter(|t| !t.is_expired()).cloned()
127 }
128
129 pub async fn should_refresh(&self, threshold: Duration) -> bool {
136 let guard = self.inner.read().await;
137 match guard.as_ref() {
138 None => true,
139 Some(token) => token.should_refresh(threshold),
140 }
141 }
142
143 pub async fn set(&self, token: OAuth2Token) {
145 let mut guard = self.inner.write().await;
146 *guard = Some(token);
147 }
148
149 #[cfg_attr(not(test), allow(dead_code))]
151 pub async fn clear(&self) {
152 let mut guard = self.inner.write().await;
153 *guard = None;
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn should_create_token() {
163 let token = OAuth2Token::new("access-token-123");
164 assert_eq!(token.access_token(), "access-token-123");
165 assert!(token.refresh_token().is_none());
166 assert!(!token.is_expired());
167 }
168
169 #[test]
170 fn should_create_token_with_expiry() {
171 let token = OAuth2Token::with_expiry("token", Duration::from_secs(3600));
172 assert!(!token.is_expired());
173 assert!(token.time_until_expiry().is_some());
174 }
175
176 #[test]
177 fn should_detect_expired_token() {
178 let token = OAuth2Token::with_expiry("token", Duration::ZERO);
179 assert!(token.is_expired());
181 }
182
183 #[test]
184 fn should_detect_refresh_needed() {
185 let token = OAuth2Token::with_expiry("token", Duration::from_secs(30));
187
188 assert!(token.should_refresh(Duration::from_secs(60)));
190
191 assert!(!token.should_refresh(Duration::from_secs(10)));
193 }
194
195 #[test]
196 fn should_add_refresh_token() {
197 let token = OAuth2Token::new("access").with_refresh_token("refresh");
198 assert_eq!(token.refresh_token(), Some("refresh"));
199 }
200
201 #[test]
202 fn should_redact_debug_output() {
203 let token = OAuth2Token::new("secret-token").with_refresh_token("secret-refresh");
204 let debug_str = format!("{token:?}");
205 assert!(debug_str.contains("[REDACTED]"));
206 assert!(!debug_str.contains("secret-token"));
207 assert!(!debug_str.contains("secret-refresh"));
208 }
209
210 #[tokio::test]
211 async fn should_cache_token() {
212 let cache = TokenCache::new();
213 assert!(cache.get().await.is_none());
214
215 let token = OAuth2Token::new("cached-token");
216 cache.set(token).await;
217
218 let cached = cache.get().await.expect("Token should be cached");
219 assert_eq!(cached.access_token(), "cached-token");
220 }
221
222 #[tokio::test]
223 async fn should_not_return_expired_token() {
224 let cache = TokenCache::new();
225 let token = OAuth2Token::with_expiry("expired", Duration::ZERO);
226 cache.set(token).await;
227
228 assert!(cache.get().await.is_none());
230 }
231
232 #[tokio::test]
233 async fn should_clear_cache() {
234 let cache = TokenCache::new();
235 cache.set(OAuth2Token::new("token")).await;
236 assert!(cache.get().await.is_some());
237
238 cache.clear().await;
239 assert!(cache.get().await.is_none());
240 }
241
242 #[tokio::test]
243 async fn should_detect_refresh_needed_in_cache() {
244 let cache = TokenCache::new();
245
246 assert!(cache.should_refresh(Duration::from_secs(60)).await);
248
249 let token = OAuth2Token::with_expiry("token", Duration::from_secs(30));
251 cache.set(token).await;
252 assert!(cache.should_refresh(Duration::from_secs(60)).await);
253
254 let token = OAuth2Token::with_expiry("token", Duration::from_secs(3600));
256 cache.set(token).await;
257 assert!(!cache.should_refresh(Duration::from_secs(60)).await);
258 }
259}