claude_agent/auth/
cache.rs

1//! Credential caching layer.
2
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use tokio::sync::RwLock;
8
9use super::{Credential, CredentialProvider};
10use crate::Result;
11
12const DEFAULT_TTL: Duration = Duration::from_secs(300); // 5 minutes
13
14struct CacheEntry {
15    credential: Credential,
16    fetched_at: Instant,
17}
18
19/// A caching wrapper around any CredentialProvider.
20pub struct CachedProvider<P> {
21    inner: P,
22    cache: Arc<RwLock<Option<CacheEntry>>>,
23    ttl: Duration,
24}
25
26impl<P: CredentialProvider> CachedProvider<P> {
27    pub fn new(provider: P) -> Self {
28        Self {
29            inner: provider,
30            cache: Arc::new(RwLock::new(None)),
31            ttl: DEFAULT_TTL,
32        }
33    }
34
35    pub fn with_ttl(mut self, ttl: Duration) -> Self {
36        self.ttl = ttl;
37        self
38    }
39
40    pub async fn invalidate(&self) {
41        let mut cache = self.cache.write().await;
42        *cache = None;
43    }
44
45    fn is_expired(&self, entry: &CacheEntry) -> bool {
46        entry.fetched_at.elapsed() > self.ttl
47    }
48
49    fn credential_expired(&self, cred: &Credential) -> bool {
50        if let Credential::OAuth(oauth) = cred {
51            oauth.is_expired()
52        } else {
53            false
54        }
55    }
56}
57
58#[async_trait]
59impl<P: CredentialProvider> CredentialProvider for CachedProvider<P> {
60    fn name(&self) -> &str {
61        self.inner.name()
62    }
63
64    async fn resolve(&self) -> Result<Credential> {
65        // Check cache first
66        {
67            let cache = self.cache.read().await;
68            if let Some(ref entry) = *cache
69                && !self.is_expired(entry)
70                && !self.credential_expired(&entry.credential)
71            {
72                return Ok(entry.credential.clone());
73            }
74        }
75
76        // Cache miss or expired - fetch new credential
77        let credential = self.inner.resolve().await?;
78
79        // Update cache
80        let mut cache = self.cache.write().await;
81        *cache = Some(CacheEntry {
82            credential: credential.clone(),
83            fetched_at: Instant::now(),
84        });
85
86        Ok(credential)
87    }
88
89    async fn refresh(&self) -> Result<Credential> {
90        let credential = self.inner.refresh().await?;
91
92        let mut cache = self.cache.write().await;
93        *cache = Some(CacheEntry {
94            credential: credential.clone(),
95            fetched_at: Instant::now(),
96        });
97
98        Ok(credential)
99    }
100
101    fn supports_refresh(&self) -> bool {
102        self.inner.supports_refresh()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use std::sync::atomic::{AtomicUsize, Ordering};
110
111    struct CountingProvider {
112        calls: AtomicUsize,
113    }
114
115    impl CountingProvider {
116        fn new() -> Self {
117            Self {
118                calls: AtomicUsize::new(0),
119            }
120        }
121
122        fn call_count(&self) -> usize {
123            self.calls.load(Ordering::SeqCst)
124        }
125    }
126
127    #[async_trait]
128    impl CredentialProvider for CountingProvider {
129        fn name(&self) -> &str {
130            "counting"
131        }
132
133        async fn resolve(&self) -> Result<Credential> {
134            self.calls.fetch_add(1, Ordering::SeqCst);
135            Ok(Credential::api_key("test-key"))
136        }
137    }
138
139    #[tokio::test]
140    async fn test_caching() {
141        let inner = CountingProvider::new();
142        let cached = CachedProvider::new(inner);
143
144        // First call should hit the provider
145        let _ = cached.resolve().await.unwrap();
146        assert_eq!(1, cached.inner.call_count());
147
148        // Second call should use cache
149        let _ = cached.resolve().await.unwrap();
150        assert_eq!(1, cached.inner.call_count());
151    }
152
153    #[tokio::test]
154    async fn test_invalidate() {
155        let inner = CountingProvider::new();
156        let cached = CachedProvider::new(inner);
157
158        let _ = cached.resolve().await.unwrap();
159        assert_eq!(1, cached.inner.call_count());
160
161        cached.invalidate().await;
162
163        let _ = cached.resolve().await.unwrap();
164        assert_eq!(2, cached.inner.call_count());
165    }
166
167    #[tokio::test]
168    async fn test_ttl_expiry() {
169        let inner = CountingProvider::new();
170        let cached = CachedProvider::new(inner).with_ttl(Duration::from_millis(10));
171
172        let _ = cached.resolve().await.unwrap();
173        assert_eq!(1, cached.inner.call_count());
174
175        // Wait for TTL to expire
176        tokio::time::sleep(Duration::from_millis(20)).await;
177
178        let _ = cached.resolve().await.unwrap();
179        assert_eq!(2, cached.inner.call_count());
180    }
181}