jwk_simple/jwks/cache/
moka.rs1use std::time::Duration;
2
3use moka::future::Cache;
4
5use crate::error::Result;
6use crate::jwks::KeySet;
7
8use super::KeyCache;
9
10const KEYSET_CACHE_KEY: &str = "jwks";
11
12pub const DEFAULT_MOKA_CACHE_TTL: Duration = Duration::from_secs(300);
14
15#[derive(Debug)]
17pub struct MokaKeyCache {
18 cache: Cache<&'static str, KeySet>,
19 ttl: Duration,
20}
21
22impl MokaKeyCache {
23 pub fn new(ttl: Duration) -> Self {
25 let cache = Cache::builder().max_capacity(1).time_to_live(ttl).build();
26
27 Self { cache, ttl }
28 }
29
30 pub fn with_default_ttl() -> Self {
32 Self::new(DEFAULT_MOKA_CACHE_TTL)
33 }
34
35 pub fn ttl(&self) -> Duration {
37 self.ttl
38 }
39}
40
41impl Default for MokaKeyCache {
42 fn default() -> Self {
43 Self::with_default_ttl()
44 }
45}
46
47#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
48#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
49impl KeyCache for MokaKeyCache {
50 async fn get(&self) -> Result<Option<KeySet>> {
51 Ok(self.cache.get(&KEYSET_CACHE_KEY).await)
52 }
53
54 async fn set(&self, keyset: KeySet) -> Result<()> {
55 self.cache.insert(KEYSET_CACHE_KEY, keyset).await;
56
57 Ok(())
58 }
59
60 async fn clear(&self) -> Result<()> {
61 self.cache.invalidate_all();
62
63 Ok(())
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use crate::Key;
71 use crate::error::Error;
72 use crate::jwks::{CachedKeyStore, KeyStore};
73 use std::sync::atomic::{AtomicUsize, Ordering};
74
75 #[tokio::test]
76 async fn moka_cache_basic() {
77 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
78 let keyset: KeySet = serde_json::from_str(json).unwrap();
79
80 let cache = MokaKeyCache::new(Duration::from_secs(300));
81
82 assert!(cache.get().await.unwrap().is_none());
83
84 cache.set(keyset.clone()).await.unwrap();
85 let cached = cache.get().await.unwrap();
86 assert!(cached.is_some());
87 assert_eq!(cached.unwrap().len(), 1);
88
89 cache.clear().await.unwrap();
90 assert!(cache.get().await.unwrap().is_none());
91 }
92
93 #[tokio::test]
94 async fn moka_cache_expiration() {
95 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
96 let keyset: KeySet = serde_json::from_str(json).unwrap();
97
98 let cache = MokaKeyCache::new(Duration::from_millis(20));
99
100 cache.set(keyset).await.unwrap();
101 assert!(cache.get().await.unwrap().is_some());
102
103 tokio::time::sleep(Duration::from_millis(40)).await;
104 cache.cache.run_pending_tasks().await;
105
106 assert!(cache.get().await.unwrap().is_none());
107 }
108
109 struct RotatingKeyStore {
110 keysets: Vec<KeySet>,
111 call_count: AtomicUsize,
112 }
113
114 impl RotatingKeyStore {
115 fn new(keysets: Vec<KeySet>) -> Self {
116 Self {
117 keysets,
118 call_count: AtomicUsize::new(0),
119 }
120 }
121
122 fn fetch_count(&self) -> usize {
123 self.call_count.load(Ordering::SeqCst)
124 }
125 }
126
127 #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
128 #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
129 impl KeyStore for RotatingKeyStore {
130 async fn get_keyset(&self) -> crate::error::Result<KeySet> {
131 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
132 let keyset = self
133 .keysets
134 .get(idx)
135 .unwrap_or_else(|| self.keysets.last().unwrap());
136 Ok(keyset.clone())
137 }
138 }
139
140 struct FailingKeyStore;
141
142 #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
143 #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
144 impl KeyStore for FailingKeyStore {
145 async fn get_keyset(&self) -> crate::error::Result<KeySet> {
146 Err(Error::Other("mock source failure".to_string()))
147 }
148 }
149
150 #[tokio::test]
151 async fn cached_key_store_refetches_on_unknown_kid() {
152 let initial: KeySet =
153 serde_json::from_str(r#"{"keys": [{"kty": "oct", "kid": "old-key", "k": "AQAB"}]}"#)
154 .unwrap();
155 let rotated: KeySet = serde_json::from_str(
156 r#"{"keys": [
157 {"kty": "oct", "kid": "old-key", "k": "AQAB"},
158 {"kty": "oct", "kid": "new-key", "k": "AQAB"}
159 ]}"#,
160 )
161 .unwrap();
162
163 let source = RotatingKeyStore::new(vec![initial, rotated]);
164 let cached = CachedKeyStore::new(MokaKeyCache::new(Duration::from_secs(300)), source);
165
166 let key = cached.get_key("old-key").await.unwrap();
167 assert!(key.is_some());
168 assert_eq!(cached.store().fetch_count(), 1);
169
170 let key = cached.get_key("old-key").await.unwrap();
171 assert!(key.is_some());
172 assert_eq!(cached.store().fetch_count(), 1);
173
174 let key = cached.get_key("new-key").await.unwrap();
175 assert!(key.is_some());
176 assert_eq!(cached.store().fetch_count(), 2);
177
178 let key = cached.get_key("new-key").await.unwrap();
179 assert!(key.is_some());
180 assert_eq!(cached.store().fetch_count(), 2);
181 }
182
183 #[tokio::test]
184 async fn cached_key_store_source_error_propagates() {
185 let cached =
186 CachedKeyStore::new(MokaKeyCache::new(Duration::from_secs(300)), FailingKeyStore);
187
188 let err = cached.get_keyset().await.unwrap_err();
189 assert!(matches!(err, Error::Other(_)));
190 }
191
192 #[tokio::test]
193 async fn cached_key_store_get_and_invalidate() {
194 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
195 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
196
197 let cached = CachedKeyStore::new(MokaKeyCache::with_default_ttl(), static_source);
198
199 let key = cached.get_key("test-key").await.unwrap();
200 assert!(key.is_some());
201
202 let cached_keyset = cached.cache().get().await.unwrap();
203 assert!(cached_keyset.is_some());
204
205 cached.cache().clear().await.unwrap();
206 let cleared = cached.cache().get().await.unwrap();
207 assert!(cleared.is_none());
208 }
209
210 #[tokio::test]
211 async fn cached_key_store_get_keyset() {
212 let json = r#"{"keys": [
213 {"kty": "oct", "kid": "key1", "k": "AQAB"},
214 {"kty": "oct", "kid": "key2", "k": "AQAB"}
215 ]}"#;
216 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
217
218 let cached = CachedKeyStore::new(MokaKeyCache::with_default_ttl(), static_source);
219
220 let keyset = cached.get_keyset().await.unwrap();
221 assert_eq!(keyset.len(), 2);
222
223 let cached_keyset = cached.cache().get().await.unwrap();
224 assert!(cached_keyset.is_some());
225 assert_eq!(cached_keyset.unwrap().len(), 2);
226 }
227
228 #[tokio::test]
229 async fn cached_key_store_get_key_miss() {
230 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
231 let static_source = serde_json::from_str::<KeySet>(json).unwrap();
232
233 let cached = CachedKeyStore::new(MokaKeyCache::with_default_ttl(), static_source);
234
235 let key: Option<Key> = cached.get_key("nonexistent").await.unwrap();
236 assert!(key.is_none());
237 }
238}