Skip to main content

jwk_simple/jwks/cache/
moka.rs

1use std::time::Duration;
2
3use moka::future::Cache;
4
5use crate::error::Result;
6use crate::jwks::KeySet;
7
8use super::KeyCache;
9
10const KEYSET_CACHE_KEY: &str = "jwks";
11
12/// Default cache TTL for [`MokaKeyCache`] (5 minutes).
13pub const DEFAULT_MOKA_CACHE_TTL: Duration = Duration::from_secs(300);
14
15/// A Moka-backed in-memory key cache with TTL-based expiration.
16#[derive(Debug)]
17pub struct MokaKeyCache {
18    cache: Cache<&'static str, KeySet>,
19    ttl: Duration,
20}
21
22impl MokaKeyCache {
23    /// Creates a new Moka-backed cache with the specified TTL.
24    pub fn new(ttl: Duration) -> Self {
25        let cache = Cache::builder().max_capacity(1).time_to_live(ttl).build();
26
27        Self { cache, ttl }
28    }
29
30    /// Creates a new Moka-backed cache with the default TTL (5 minutes).
31    pub fn with_default_ttl() -> Self {
32        Self::new(DEFAULT_MOKA_CACHE_TTL)
33    }
34
35    /// Returns the configured TTL.
36    pub fn ttl(&self) -> Duration {
37        self.ttl
38    }
39}
40
41impl Default for MokaKeyCache {
42    fn default() -> Self {
43        Self::with_default_ttl()
44    }
45}
46
47#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
48#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
49impl KeyCache for MokaKeyCache {
50    async fn get(&self) -> Result<Option<KeySet>> {
51        Ok(self.cache.get(&KEYSET_CACHE_KEY).await)
52    }
53
54    async fn set(&self, keyset: KeySet) -> Result<()> {
55        self.cache.insert(KEYSET_CACHE_KEY, keyset).await;
56
57        Ok(())
58    }
59
60    async fn clear(&self) -> Result<()> {
61        self.cache.invalidate_all();
62
63        Ok(())
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use crate::Key;
71    use crate::error::Error;
72    use crate::jwks::{CachedKeyStore, KeyStore};
73    use std::sync::atomic::{AtomicUsize, Ordering};
74
75    #[tokio::test]
76    async fn moka_cache_basic() {
77        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
78        let keyset: KeySet = serde_json::from_str(json).unwrap();
79
80        let cache = MokaKeyCache::new(Duration::from_secs(300));
81
82        assert!(cache.get().await.unwrap().is_none());
83
84        cache.set(keyset.clone()).await.unwrap();
85        let cached = cache.get().await.unwrap();
86        assert!(cached.is_some());
87        assert_eq!(cached.unwrap().len(), 1);
88
89        cache.clear().await.unwrap();
90        assert!(cache.get().await.unwrap().is_none());
91    }
92
93    #[tokio::test]
94    async fn moka_cache_expiration() {
95        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
96        let keyset: KeySet = serde_json::from_str(json).unwrap();
97
98        let cache = MokaKeyCache::new(Duration::from_millis(20));
99
100        cache.set(keyset).await.unwrap();
101        assert!(cache.get().await.unwrap().is_some());
102
103        tokio::time::sleep(Duration::from_millis(40)).await;
104        cache.cache.run_pending_tasks().await;
105
106        assert!(cache.get().await.unwrap().is_none());
107    }
108
109    struct RotatingKeyStore {
110        keysets: Vec<KeySet>,
111        call_count: AtomicUsize,
112    }
113
114    impl RotatingKeyStore {
115        fn new(keysets: Vec<KeySet>) -> Self {
116            Self {
117                keysets,
118                call_count: AtomicUsize::new(0),
119            }
120        }
121
122        fn fetch_count(&self) -> usize {
123            self.call_count.load(Ordering::SeqCst)
124        }
125    }
126
127    #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
128    #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
129    impl KeyStore for RotatingKeyStore {
130        async fn get_keyset(&self) -> crate::error::Result<KeySet> {
131            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
132            let keyset = self
133                .keysets
134                .get(idx)
135                .unwrap_or_else(|| self.keysets.last().unwrap());
136            Ok(keyset.clone())
137        }
138    }
139
140    struct FailingKeyStore;
141
142    #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
143    #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
144    impl KeyStore for FailingKeyStore {
145        async fn get_keyset(&self) -> crate::error::Result<KeySet> {
146            Err(Error::Other("mock source failure".to_string()))
147        }
148    }
149
150    #[tokio::test]
151    async fn cached_key_store_refetches_on_unknown_kid() {
152        let initial: KeySet =
153            serde_json::from_str(r#"{"keys": [{"kty": "oct", "kid": "old-key", "k": "AQAB"}]}"#)
154                .unwrap();
155        let rotated: KeySet = serde_json::from_str(
156            r#"{"keys": [
157                {"kty": "oct", "kid": "old-key", "k": "AQAB"},
158                {"kty": "oct", "kid": "new-key", "k": "AQAB"}
159            ]}"#,
160        )
161        .unwrap();
162
163        let source = RotatingKeyStore::new(vec![initial, rotated]);
164        let cached = CachedKeyStore::new(MokaKeyCache::new(Duration::from_secs(300)), source);
165
166        let key = cached.get_key("old-key").await.unwrap();
167        assert!(key.is_some());
168        assert_eq!(cached.store().fetch_count(), 1);
169
170        let key = cached.get_key("old-key").await.unwrap();
171        assert!(key.is_some());
172        assert_eq!(cached.store().fetch_count(), 1);
173
174        let key = cached.get_key("new-key").await.unwrap();
175        assert!(key.is_some());
176        assert_eq!(cached.store().fetch_count(), 2);
177
178        let key = cached.get_key("new-key").await.unwrap();
179        assert!(key.is_some());
180        assert_eq!(cached.store().fetch_count(), 2);
181    }
182
183    #[tokio::test]
184    async fn cached_key_store_source_error_propagates() {
185        let cached =
186            CachedKeyStore::new(MokaKeyCache::new(Duration::from_secs(300)), FailingKeyStore);
187
188        let err = cached.get_keyset().await.unwrap_err();
189        assert!(matches!(err, Error::Other(_)));
190    }
191
192    #[tokio::test]
193    async fn cached_key_store_get_and_invalidate() {
194        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
195        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
196
197        let cached = CachedKeyStore::new(MokaKeyCache::with_default_ttl(), static_source);
198
199        let key = cached.get_key("test-key").await.unwrap();
200        assert!(key.is_some());
201
202        let cached_keyset = cached.cache().get().await.unwrap();
203        assert!(cached_keyset.is_some());
204
205        cached.cache().clear().await.unwrap();
206        let cleared = cached.cache().get().await.unwrap();
207        assert!(cleared.is_none());
208    }
209
210    #[tokio::test]
211    async fn cached_key_store_get_keyset() {
212        let json = r#"{"keys": [
213            {"kty": "oct", "kid": "key1", "k": "AQAB"},
214            {"kty": "oct", "kid": "key2", "k": "AQAB"}
215        ]}"#;
216        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
217
218        let cached = CachedKeyStore::new(MokaKeyCache::with_default_ttl(), static_source);
219
220        let keyset = cached.get_keyset().await.unwrap();
221        assert_eq!(keyset.len(), 2);
222
223        let cached_keyset = cached.cache().get().await.unwrap();
224        assert!(cached_keyset.is_some());
225        assert_eq!(cached_keyset.unwrap().len(), 2);
226    }
227
228    #[tokio::test]
229    async fn cached_key_store_get_key_miss() {
230        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
231        let static_source = serde_json::from_str::<KeySet>(json).unwrap();
232
233        let cached = CachedKeyStore::new(MokaKeyCache::with_default_ttl(), static_source);
234
235        let key: Option<Key> = cached.get_key("nonexistent").await.unwrap();
236        assert!(key.is_none());
237    }
238}