claude_agent/auth/
cache.rs1use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use tokio::sync::RwLock;
8
9use super::{Credential, CredentialProvider};
10use crate::Result;
11
12const DEFAULT_TTL: Duration = Duration::from_secs(300); struct CacheEntry {
15 credential: Credential,
16 fetched_at: Instant,
17}
18
19pub struct CachedProvider<P> {
21 inner: P,
22 cache: Arc<RwLock<Option<CacheEntry>>>,
23 ttl: Duration,
24}
25
26impl<P: CredentialProvider> CachedProvider<P> {
27 pub fn new(provider: P) -> Self {
28 Self {
29 inner: provider,
30 cache: Arc::new(RwLock::new(None)),
31 ttl: DEFAULT_TTL,
32 }
33 }
34
35 pub fn with_ttl(mut self, ttl: Duration) -> Self {
36 self.ttl = ttl;
37 self
38 }
39
40 pub async fn invalidate(&self) {
41 let mut cache = self.cache.write().await;
42 *cache = None;
43 }
44
45 fn is_expired(&self, entry: &CacheEntry) -> bool {
46 entry.fetched_at.elapsed() > self.ttl
47 }
48
49 fn credential_expired(&self, cred: &Credential) -> bool {
50 if let Credential::OAuth(oauth) = cred {
51 oauth.is_expired()
52 } else {
53 false
54 }
55 }
56}
57
58#[async_trait]
59impl<P: CredentialProvider> CredentialProvider for CachedProvider<P> {
60 fn name(&self) -> &str {
61 self.inner.name()
62 }
63
64 async fn resolve(&self) -> Result<Credential> {
65 {
67 let cache = self.cache.read().await;
68 if let Some(ref entry) = *cache
69 && !self.is_expired(entry)
70 && !self.credential_expired(&entry.credential)
71 {
72 return Ok(entry.credential.clone());
73 }
74 }
75
76 let credential = self.inner.resolve().await?;
78
79 let mut cache = self.cache.write().await;
81 *cache = Some(CacheEntry {
82 credential: credential.clone(),
83 fetched_at: Instant::now(),
84 });
85
86 Ok(credential)
87 }
88
89 async fn refresh(&self) -> Result<Credential> {
90 let credential = self.inner.refresh().await?;
91
92 let mut cache = self.cache.write().await;
93 *cache = Some(CacheEntry {
94 credential: credential.clone(),
95 fetched_at: Instant::now(),
96 });
97
98 Ok(credential)
99 }
100
101 fn supports_refresh(&self) -> bool {
102 self.inner.supports_refresh()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use std::sync::atomic::{AtomicUsize, Ordering};
110
111 struct CountingProvider {
112 calls: AtomicUsize,
113 }
114
115 impl CountingProvider {
116 fn new() -> Self {
117 Self {
118 calls: AtomicUsize::new(0),
119 }
120 }
121
122 fn call_count(&self) -> usize {
123 self.calls.load(Ordering::SeqCst)
124 }
125 }
126
127 #[async_trait]
128 impl CredentialProvider for CountingProvider {
129 fn name(&self) -> &str {
130 "counting"
131 }
132
133 async fn resolve(&self) -> Result<Credential> {
134 self.calls.fetch_add(1, Ordering::SeqCst);
135 Ok(Credential::api_key("test-key"))
136 }
137 }
138
139 #[tokio::test]
140 async fn test_caching() {
141 let inner = CountingProvider::new();
142 let cached = CachedProvider::new(inner);
143
144 let _ = cached.resolve().await.unwrap();
146 assert_eq!(1, cached.inner.call_count());
147
148 let _ = cached.resolve().await.unwrap();
150 assert_eq!(1, cached.inner.call_count());
151 }
152
153 #[tokio::test]
154 async fn test_invalidate() {
155 let inner = CountingProvider::new();
156 let cached = CachedProvider::new(inner);
157
158 let _ = cached.resolve().await.unwrap();
159 assert_eq!(1, cached.inner.call_count());
160
161 cached.invalidate().await;
162
163 let _ = cached.resolve().await.unwrap();
164 assert_eq!(2, cached.inner.call_count());
165 }
166
167 #[tokio::test]
168 async fn test_ttl_expiry() {
169 let inner = CountingProvider::new();
170 let cached = CachedProvider::new(inner).with_ttl(Duration::from_millis(10));
171
172 let _ = cached.resolve().await.unwrap();
173 assert_eq!(1, cached.inner.call_count());
174
175 tokio::time::sleep(Duration::from_millis(20)).await;
177
178 let _ = cached.resolve().await.unwrap();
179 assert_eq!(2, cached.inner.call_count());
180 }
181}