jwk_simple/jwks/
inmemory_cache.rs

1//! In-memory key cache implementation using Tokio.
2//!
3//! This module provides [`InMemoryKeyCache`], a thread-safe in-memory cache
4//! with TTL-based expiration, and [`InMemoryCachedKeySet`], a convenience type
5//! for caching any key source with in-memory storage.
6
7use std::collections::HashMap;
8use std::time::Duration;
9
10use tokio::sync::RwLock;
11
12use crate::jwk::Key;
13
14use super::cache::{CachedKeySet, KeyCache};
15
16/// Default cache TTL (5 minutes).
17pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(300);
18
19/// Internal cache entry with timestamp.
20struct CacheEntry {
21    key: Key,
22    inserted_at: std::time::Instant,
23}
24
25/// An in-memory key cache with TTL-based expiration.
26///
27/// Keys are stored in memory and automatically expire after the configured TTL.
28/// This implementation is thread-safe and can be shared across tasks.
29///
30/// # Examples
31///
32/// ```ignore
33/// use jwk_simple::InMemoryKeyCache;
34/// use std::time::Duration;
35///
36/// // Create a cache with 5-minute TTL
37/// let cache = InMemoryKeyCache::new(Duration::from_secs(300));
38///
39/// // Or use the default TTL
40/// let cache = InMemoryKeyCache::default();
41/// ```
42pub struct InMemoryKeyCache {
43    entries: RwLock<HashMap<String, CacheEntry>>,
44    ttl: Duration,
45}
46
47impl InMemoryKeyCache {
48    /// Creates a new in-memory cache with the specified TTL.
49    pub fn new(ttl: Duration) -> Self {
50        Self {
51            entries: RwLock::new(HashMap::new()),
52            ttl,
53        }
54    }
55
56    /// Creates a new in-memory cache with the default TTL (5 minutes).
57    pub fn with_default_ttl() -> Self {
58        Self::new(DEFAULT_CACHE_TTL)
59    }
60
61    /// Returns the configured TTL.
62    pub fn ttl(&self) -> Duration {
63        self.ttl
64    }
65
66    /// Returns the number of entries currently in the cache.
67    ///
68    /// Note: This count may include expired entries that haven't been cleaned up yet.
69    pub async fn len(&self) -> usize {
70        self.entries.read().await.len()
71    }
72
73    /// Returns `true` if the cache is empty.
74    pub async fn is_empty(&self) -> bool {
75        self.entries.read().await.is_empty()
76    }
77}
78
79impl Default for InMemoryKeyCache {
80    fn default() -> Self {
81        Self::with_default_ttl()
82    }
83}
84
85impl std::fmt::Debug for InMemoryKeyCache {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("InMemoryKeyCache")
88            .field("ttl", &self.ttl)
89            .finish_non_exhaustive()
90    }
91}
92
93#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
94#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
95impl KeyCache for InMemoryKeyCache {
96    async fn get(&self, kid: &str) -> Option<Key> {
97        let entries = self.entries.read().await;
98        entries.get(kid).and_then(|entry| {
99            if entry.inserted_at.elapsed() < self.ttl {
100                Some(entry.key.clone())
101            } else {
102                None
103            }
104        })
105    }
106
107    async fn set(&self, kid: &str, key: Key) {
108        let mut entries = self.entries.write().await;
109        entries.insert(
110            kid.to_string(),
111            CacheEntry {
112                key,
113                inserted_at: std::time::Instant::now(),
114            },
115        );
116    }
117
118    async fn remove(&self, kid: &str) {
119        let mut entries = self.entries.write().await;
120        entries.remove(kid);
121    }
122
123    async fn clear(&self) {
124        let mut entries = self.entries.write().await;
125        entries.clear();
126    }
127}
128
129/// Convenience type alias for a cached key set using in-memory caching.
130pub type InMemoryCachedKeySet<S> = CachedKeySet<InMemoryKeyCache, S>;
131
132impl<S> InMemoryCachedKeySet<S> {
133    /// Creates a new cached key source with in-memory caching and the specified TTL.
134    pub fn with_ttl(source: S, ttl: Duration) -> Self {
135        Self::new(InMemoryKeyCache::new(ttl), source)
136    }
137
138    /// Creates a new cached key source with in-memory caching and the default TTL.
139    pub fn with_default_ttl(source: S) -> Self {
140        Self::new(InMemoryKeyCache::with_default_ttl(), source)
141    }
142
143    /// Invalidates the cache, forcing fresh fetches on subsequent requests.
144    pub async fn invalidate(&self) {
145        self.cache().clear().await;
146    }
147
148    /// Removes a specific key from the cache.
149    pub async fn invalidate_key(&self, kid: &str) {
150        self.cache().remove(kid).await;
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::jwks::KeySource;
158
159    #[tokio::test]
160    async fn test_in_memory_cache_basic() {
161        let json = r#"{"kty": "oct", "kid": "test-key", "k": "AQAB"}"#;
162        let key: Key = serde_json::from_str(json).unwrap();
163
164        let cache = InMemoryKeyCache::new(Duration::from_secs(300));
165
166        // Initially empty
167        assert!(cache.get("test-key").await.is_none());
168
169        // Set and get
170        cache.set("test-key", key.clone()).await;
171        let cached = cache.get("test-key").await;
172        assert!(cached.is_some());
173        assert_eq!(cached.unwrap().kid, Some("test-key".to_string()));
174
175        // Remove
176        cache.remove("test-key").await;
177        assert!(cache.get("test-key").await.is_none());
178    }
179
180    #[tokio::test]
181    async fn test_in_memory_cache_expiration() {
182        let json = r#"{"kty": "oct", "kid": "test-key", "k": "AQAB"}"#;
183        let key: Key = serde_json::from_str(json).unwrap();
184
185        // Very short TTL
186        let cache = InMemoryKeyCache::new(Duration::from_millis(50));
187
188        cache.set("test-key", key).await;
189        assert!(cache.get("test-key").await.is_some());
190
191        // Wait for expiration
192        tokio::time::sleep(Duration::from_millis(100)).await;
193
194        assert!(cache.get("test-key").await.is_none());
195    }
196
197    #[tokio::test]
198    async fn test_in_memory_cache_clear() {
199        let json = r#"{"kty": "oct", "kid": "test-key", "k": "AQAB"}"#;
200        let key: Key = serde_json::from_str(json).unwrap();
201
202        let cache = InMemoryKeyCache::new(Duration::from_secs(300));
203
204        cache.set("key1", key.clone()).await;
205        cache.set("key2", key).await;
206
207        assert_eq!(cache.len().await, 2);
208
209        cache.clear().await;
210
211        assert!(cache.is_empty().await);
212    }
213
214    #[tokio::test]
215    async fn test_cached_key_set() {
216        use crate::jwks::KeySet;
217
218        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
219        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
220
221        let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
222
223        // First call should fetch from source and cache
224        let key = cached.get_key("test-key").await.unwrap();
225        assert!(key.is_some());
226
227        // Verify it's cached
228        let cached_key = cached.cache().get("test-key").await;
229        assert!(cached_key.is_some());
230
231        // Second call should use cache
232        let key2 = cached.get_key("test-key").await.unwrap();
233        assert!(key2.is_some());
234    }
235
236    #[tokio::test]
237    async fn test_cached_key_set_miss() {
238        use crate::jwks::KeySet;
239
240        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
241        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
242
243        let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
244
245        // Request a non-existent key
246        let key = cached.get_key("nonexistent").await.unwrap();
247        assert!(key.is_none());
248    }
249
250    #[tokio::test]
251    async fn test_cached_key_set_get_keyset() {
252        use crate::jwks::KeySet;
253
254        let json = r#"{"keys": [
255            {"kty": "oct", "kid": "key1", "k": "AQAB"},
256            {"kty": "oct", "kid": "key2", "k": "AQAB"}
257        ]}"#;
258        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
259
260        let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
261
262        // Fetch keyset should cache all keys
263        let keyset = cached.get_keyset().await.unwrap();
264        assert_eq!(keyset.len(), 2);
265
266        // Both keys should be cached
267        assert!(cached.cache().get("key1").await.is_some());
268        assert!(cached.cache().get("key2").await.is_some());
269    }
270
271    #[tokio::test]
272    async fn test_cached_key_set_invalidate() {
273        use crate::jwks::KeySet;
274
275        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
276        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
277
278        let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
279
280        // Populate cache
281        let _ = cached.get_key("test-key").await.unwrap();
282        assert!(cached.cache().get("test-key").await.is_some());
283
284        // Invalidate
285        cached.invalidate().await;
286        assert!(cached.cache().get("test-key").await.is_none());
287    }
288
289    #[tokio::test]
290    async fn test_cached_key_set_invalidate_key() {
291        use crate::jwks::KeySet;
292
293        let json = r#"{"keys": [
294            {"kty": "oct", "kid": "key1", "k": "AQAB"},
295            {"kty": "oct", "kid": "key2", "k": "AQAB"}
296        ]}"#;
297        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
298
299        let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
300
301        // Populate cache
302        let _ = cached.get_keyset().await.unwrap();
303
304        // Invalidate only one key
305        cached.invalidate_key("key1").await;
306
307        assert!(cached.cache().get("key1").await.is_none());
308        assert!(cached.cache().get("key2").await.is_some());
309    }
310}