jwk_simple/jwks/
inmemory_cache.rs1use std::collections::HashMap;
8use std::time::Duration;
9
10use tokio::sync::RwLock;
11
12use crate::jwk::Key;
13
14use super::cache::{CachedKeySet, KeyCache};
15
16pub const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(300);
18
19struct CacheEntry {
21 key: Key,
22 inserted_at: std::time::Instant,
23}
24
25pub struct InMemoryKeyCache {
43 entries: RwLock<HashMap<String, CacheEntry>>,
44 ttl: Duration,
45}
46
47impl InMemoryKeyCache {
48 pub fn new(ttl: Duration) -> Self {
50 Self {
51 entries: RwLock::new(HashMap::new()),
52 ttl,
53 }
54 }
55
56 pub fn with_default_ttl() -> Self {
58 Self::new(DEFAULT_CACHE_TTL)
59 }
60
61 pub fn ttl(&self) -> Duration {
63 self.ttl
64 }
65
66 pub async fn len(&self) -> usize {
70 self.entries.read().await.len()
71 }
72
73 pub async fn is_empty(&self) -> bool {
75 self.entries.read().await.is_empty()
76 }
77}
78
79impl Default for InMemoryKeyCache {
80 fn default() -> Self {
81 Self::with_default_ttl()
82 }
83}
84
85impl std::fmt::Debug for InMemoryKeyCache {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("InMemoryKeyCache")
88 .field("ttl", &self.ttl)
89 .finish_non_exhaustive()
90 }
91}
92
93#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
94#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
95impl KeyCache for InMemoryKeyCache {
96 async fn get(&self, kid: &str) -> Option<Key> {
97 let entries = self.entries.read().await;
98 entries.get(kid).and_then(|entry| {
99 if entry.inserted_at.elapsed() < self.ttl {
100 Some(entry.key.clone())
101 } else {
102 None
103 }
104 })
105 }
106
107 async fn set(&self, kid: &str, key: Key) {
108 let mut entries = self.entries.write().await;
109 entries.insert(
110 kid.to_string(),
111 CacheEntry {
112 key,
113 inserted_at: std::time::Instant::now(),
114 },
115 );
116 }
117
118 async fn remove(&self, kid: &str) {
119 let mut entries = self.entries.write().await;
120 entries.remove(kid);
121 }
122
123 async fn clear(&self) {
124 let mut entries = self.entries.write().await;
125 entries.clear();
126 }
127}
128
129pub type InMemoryCachedKeySet<S> = CachedKeySet<InMemoryKeyCache, S>;
131
132impl<S> InMemoryCachedKeySet<S> {
133 pub fn with_ttl(source: S, ttl: Duration) -> Self {
135 Self::new(InMemoryKeyCache::new(ttl), source)
136 }
137
138 pub fn with_default_ttl(source: S) -> Self {
140 Self::new(InMemoryKeyCache::with_default_ttl(), source)
141 }
142
143 pub async fn invalidate(&self) {
145 self.cache().clear().await;
146 }
147
148 pub async fn invalidate_key(&self, kid: &str) {
150 self.cache().remove(kid).await;
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::jwks::KeySource;
158
159 #[tokio::test]
160 async fn test_in_memory_cache_basic() {
161 let json = r#"{"kty": "oct", "kid": "test-key", "k": "AQAB"}"#;
162 let key: Key = serde_json::from_str(json).unwrap();
163
164 let cache = InMemoryKeyCache::new(Duration::from_secs(300));
165
166 assert!(cache.get("test-key").await.is_none());
168
169 cache.set("test-key", key.clone()).await;
171 let cached = cache.get("test-key").await;
172 assert!(cached.is_some());
173 assert_eq!(cached.unwrap().kid, Some("test-key".to_string()));
174
175 cache.remove("test-key").await;
177 assert!(cache.get("test-key").await.is_none());
178 }
179
180 #[tokio::test]
181 async fn test_in_memory_cache_expiration() {
182 let json = r#"{"kty": "oct", "kid": "test-key", "k": "AQAB"}"#;
183 let key: Key = serde_json::from_str(json).unwrap();
184
185 let cache = InMemoryKeyCache::new(Duration::from_millis(50));
187
188 cache.set("test-key", key).await;
189 assert!(cache.get("test-key").await.is_some());
190
191 tokio::time::sleep(Duration::from_millis(100)).await;
193
194 assert!(cache.get("test-key").await.is_none());
195 }
196
197 #[tokio::test]
198 async fn test_in_memory_cache_clear() {
199 let json = r#"{"kty": "oct", "kid": "test-key", "k": "AQAB"}"#;
200 let key: Key = serde_json::from_str(json).unwrap();
201
202 let cache = InMemoryKeyCache::new(Duration::from_secs(300));
203
204 cache.set("key1", key.clone()).await;
205 cache.set("key2", key).await;
206
207 assert_eq!(cache.len().await, 2);
208
209 cache.clear().await;
210
211 assert!(cache.is_empty().await);
212 }
213
214 #[tokio::test]
215 async fn test_cached_key_set() {
216 use crate::jwks::KeySet;
217
218 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
219 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
220
221 let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
222
223 let key = cached.get_key("test-key").await.unwrap();
225 assert!(key.is_some());
226
227 let cached_key = cached.cache().get("test-key").await;
229 assert!(cached_key.is_some());
230
231 let key2 = cached.get_key("test-key").await.unwrap();
233 assert!(key2.is_some());
234 }
235
236 #[tokio::test]
237 async fn test_cached_key_set_miss() {
238 use crate::jwks::KeySet;
239
240 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
241 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
242
243 let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
244
245 let key = cached.get_key("nonexistent").await.unwrap();
247 assert!(key.is_none());
248 }
249
250 #[tokio::test]
251 async fn test_cached_key_set_get_keyset() {
252 use crate::jwks::KeySet;
253
254 let json = r#"{"keys": [
255 {"kty": "oct", "kid": "key1", "k": "AQAB"},
256 {"kty": "oct", "kid": "key2", "k": "AQAB"}
257 ]}"#;
258 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
259
260 let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
261
262 let keyset = cached.get_keyset().await.unwrap();
264 assert_eq!(keyset.len(), 2);
265
266 assert!(cached.cache().get("key1").await.is_some());
268 assert!(cached.cache().get("key2").await.is_some());
269 }
270
271 #[tokio::test]
272 async fn test_cached_key_set_invalidate() {
273 use crate::jwks::KeySet;
274
275 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
276 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
277
278 let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
279
280 let _ = cached.get_key("test-key").await.unwrap();
282 assert!(cached.cache().get("test-key").await.is_some());
283
284 cached.invalidate().await;
286 assert!(cached.cache().get("test-key").await.is_none());
287 }
288
289 #[tokio::test]
290 async fn test_cached_key_set_invalidate_key() {
291 use crate::jwks::KeySet;
292
293 let json = r#"{"keys": [
294 {"kty": "oct", "kid": "key1", "k": "AQAB"},
295 {"kty": "oct", "kid": "key2", "k": "AQAB"}
296 ]}"#;
297 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
298
299 let cached = InMemoryCachedKeySet::with_default_ttl(static_source);
300
301 let _ = cached.get_keyset().await.unwrap();
303
304 cached.invalidate_key("key1").await;
306
307 assert!(cached.cache().get("key1").await.is_none());
308 assert!(cached.cache().get("key2").await.is_some());
309 }
310}