claude_agent/auth/
cache.rs1use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use tokio::sync::Mutex;
8
9use super::{Credential, CredentialProvider};
10use crate::Result;
11
12const DEFAULT_TTL: Duration = Duration::from_secs(300); struct CacheEntry {
15 credential: Credential,
16 fetched_at: Instant,
17}
18
19pub struct CachedProvider<P> {
24 inner: P,
25 cache: Arc<Mutex<Option<CacheEntry>>>,
26 ttl: Duration,
27}
28
29impl<P: CredentialProvider> CachedProvider<P> {
30 pub fn new(provider: P) -> Self {
31 Self {
32 inner: provider,
33 cache: Arc::new(Mutex::new(None)),
34 ttl: DEFAULT_TTL,
35 }
36 }
37
38 pub fn ttl(mut self, ttl: Duration) -> Self {
39 self.ttl = ttl;
40 self
41 }
42
43 pub async fn invalidate(&self) {
44 let mut cache = self.cache.lock().await;
45 *cache = None;
46 }
47
48 fn is_expired(&self, entry: &CacheEntry) -> bool {
49 entry.fetched_at.elapsed() > self.ttl
50 }
51
52 fn credential_expired(&self, cred: &Credential) -> bool {
53 if let Credential::OAuth(oauth) = cred {
54 oauth.is_expired()
55 } else {
56 false
57 }
58 }
59}
60
61#[async_trait]
62impl<P: CredentialProvider> CredentialProvider for CachedProvider<P> {
63 fn name(&self) -> &str {
64 self.inner.name()
65 }
66
67 async fn resolve(&self) -> Result<Credential> {
68 let mut cache = self.cache.lock().await;
70
71 if let Some(ref entry) = *cache
72 && !self.is_expired(entry)
73 && !self.credential_expired(&entry.credential)
74 {
75 return Ok(entry.credential.clone());
76 }
77
78 let credential = self.inner.resolve().await?;
79
80 *cache = Some(CacheEntry {
81 credential: credential.clone(),
82 fetched_at: Instant::now(),
83 });
84
85 Ok(credential)
86 }
87
88 async fn refresh(&self) -> Result<Credential> {
89 let credential = self.inner.refresh().await?;
90
91 let mut cache = self.cache.lock().await;
92 *cache = Some(CacheEntry {
93 credential: credential.clone(),
94 fetched_at: Instant::now(),
95 });
96
97 Ok(credential)
98 }
99
100 fn supports_refresh(&self) -> bool {
101 self.inner.supports_refresh()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use std::sync::atomic::{AtomicUsize, Ordering};
109
110 struct CountingProvider {
111 calls: AtomicUsize,
112 }
113
114 impl CountingProvider {
115 fn new() -> Self {
116 Self {
117 calls: AtomicUsize::new(0),
118 }
119 }
120
121 fn call_count(&self) -> usize {
122 self.calls.load(Ordering::SeqCst)
123 }
124 }
125
126 #[async_trait]
127 impl CredentialProvider for CountingProvider {
128 fn name(&self) -> &str {
129 "counting"
130 }
131
132 async fn resolve(&self) -> Result<Credential> {
133 self.calls.fetch_add(1, Ordering::SeqCst);
134 Ok(Credential::api_key("test-key"))
135 }
136 }
137
138 #[tokio::test]
139 async fn test_caching() {
140 let inner = CountingProvider::new();
141 let cached = CachedProvider::new(inner);
142
143 let _ = cached.resolve().await.unwrap();
145 assert_eq!(1, cached.inner.call_count());
146
147 let _ = cached.resolve().await.unwrap();
149 assert_eq!(1, cached.inner.call_count());
150 }
151
152 #[tokio::test]
153 async fn test_invalidate() {
154 let inner = CountingProvider::new();
155 let cached = CachedProvider::new(inner);
156
157 let _ = cached.resolve().await.unwrap();
158 assert_eq!(1, cached.inner.call_count());
159
160 cached.invalidate().await;
161
162 let _ = cached.resolve().await.unwrap();
163 assert_eq!(2, cached.inner.call_count());
164 }
165
166 #[tokio::test]
167 async fn test_ttl_expiry() {
168 let inner = CountingProvider::new();
169 let cached = CachedProvider::new(inner).ttl(Duration::from_millis(10));
170
171 let _ = cached.resolve().await.unwrap();
172 assert_eq!(1, cached.inner.call_count());
173
174 tokio::time::sleep(Duration::from_millis(20)).await;
176
177 let _ = cached.resolve().await.unwrap();
178 assert_eq!(2, cached.inner.call_count());
179 }
180}