Skip to main content

claude_agent/auth/
cache.rs

1//! Credential caching layer.
2
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use tokio::sync::Mutex;
8
9use super::{Credential, CredentialProvider};
10use crate::Result;
11
12const DEFAULT_TTL: Duration = Duration::from_secs(300); // 5 minutes
13
14struct CacheEntry {
15    credential: Credential,
16    fetched_at: Instant,
17}
18
19/// A caching wrapper around any CredentialProvider.
20///
21/// Uses a `Mutex` to prevent thundering herd: when the cache is empty or expired,
22/// only one caller fetches a new credential while others wait.
23pub struct CachedProvider<P> {
24    inner: P,
25    cache: Arc<Mutex<Option<CacheEntry>>>,
26    ttl: Duration,
27}
28
29impl<P: CredentialProvider> CachedProvider<P> {
30    pub fn new(provider: P) -> Self {
31        Self {
32            inner: provider,
33            cache: Arc::new(Mutex::new(None)),
34            ttl: DEFAULT_TTL,
35        }
36    }
37
38    pub fn ttl(mut self, ttl: Duration) -> Self {
39        self.ttl = ttl;
40        self
41    }
42
43    pub async fn invalidate(&self) {
44        let mut cache = self.cache.lock().await;
45        *cache = None;
46    }
47
48    fn is_expired(&self, entry: &CacheEntry) -> bool {
49        entry.fetched_at.elapsed() > self.ttl
50    }
51
52    fn credential_expired(&self, cred: &Credential) -> bool {
53        if let Credential::OAuth(oauth) = cred {
54            oauth.is_expired()
55        } else {
56            false
57        }
58    }
59}
60
61#[async_trait]
62impl<P: CredentialProvider> CredentialProvider for CachedProvider<P> {
63    fn name(&self) -> &str {
64        self.inner.name()
65    }
66
67    async fn resolve(&self) -> Result<Credential> {
68        // Hold mutex through entire check-fetch-store to prevent thundering herd
69        let mut cache = self.cache.lock().await;
70
71        if let Some(ref entry) = *cache
72            && !self.is_expired(entry)
73            && !self.credential_expired(&entry.credential)
74        {
75            return Ok(entry.credential.clone());
76        }
77
78        let credential = self.inner.resolve().await?;
79
80        *cache = Some(CacheEntry {
81            credential: credential.clone(),
82            fetched_at: Instant::now(),
83        });
84
85        Ok(credential)
86    }
87
88    async fn refresh(&self) -> Result<Credential> {
89        let credential = self.inner.refresh().await?;
90
91        let mut cache = self.cache.lock().await;
92        *cache = Some(CacheEntry {
93            credential: credential.clone(),
94            fetched_at: Instant::now(),
95        });
96
97        Ok(credential)
98    }
99
100    fn supports_refresh(&self) -> bool {
101        self.inner.supports_refresh()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use std::sync::atomic::{AtomicUsize, Ordering};
109
110    struct CountingProvider {
111        calls: AtomicUsize,
112    }
113
114    impl CountingProvider {
115        fn new() -> Self {
116            Self {
117                calls: AtomicUsize::new(0),
118            }
119        }
120
121        fn call_count(&self) -> usize {
122            self.calls.load(Ordering::SeqCst)
123        }
124    }
125
126    #[async_trait]
127    impl CredentialProvider for CountingProvider {
128        fn name(&self) -> &str {
129            "counting"
130        }
131
132        async fn resolve(&self) -> Result<Credential> {
133            self.calls.fetch_add(1, Ordering::SeqCst);
134            Ok(Credential::api_key("test-key"))
135        }
136    }
137
138    #[tokio::test]
139    async fn test_caching() {
140        let inner = CountingProvider::new();
141        let cached = CachedProvider::new(inner);
142
143        // First call should hit the provider
144        let _ = cached.resolve().await.unwrap();
145        assert_eq!(1, cached.inner.call_count());
146
147        // Second call should use cache
148        let _ = cached.resolve().await.unwrap();
149        assert_eq!(1, cached.inner.call_count());
150    }
151
152    #[tokio::test]
153    async fn test_invalidate() {
154        let inner = CountingProvider::new();
155        let cached = CachedProvider::new(inner);
156
157        let _ = cached.resolve().await.unwrap();
158        assert_eq!(1, cached.inner.call_count());
159
160        cached.invalidate().await;
161
162        let _ = cached.resolve().await.unwrap();
163        assert_eq!(2, cached.inner.call_count());
164    }
165
166    #[tokio::test]
167    async fn test_ttl_expiry() {
168        let inner = CountingProvider::new();
169        let cached = CachedProvider::new(inner).ttl(Duration::from_millis(10));
170
171        let _ = cached.resolve().await.unwrap();
172        assert_eq!(1, cached.inner.call_count());
173
174        // Wait for TTL to expire
175        tokio::time::sleep(Duration::from_millis(20)).await;
176
177        let _ = cached.resolve().await.unwrap();
178        assert_eq!(2, cached.inner.call_count());
179    }
180}