jwk_simple/jwks/cache.rs
1//! Caching traits and wrappers for JWKS key sources.
2//!
3//! This module provides a [`KeyCache`] trait for caching keys by their ID,
4//! and a [`CachedKeySet`] wrapper that combines any cache with any key source.
5//!
6//! For a ready-to-use in-memory implementation, enable the `inmemory-cache` feature
7//! and use [`InMemoryKeyCache`](super::InMemoryKeyCache).
8
9use crate::error::Result;
10use crate::jwk::Key;
11
12use super::{KeySet, KeySource};
13
14/// A trait for caching keys by their ID.
15///
16/// Implementations can provide different caching strategies (in-memory, Redis, etc.)
17/// while the [`CachedKeySet`] handles the cache-aside pattern.
18///
19/// # Examples
20///
21/// ```ignore
22/// use jwk_simple::{KeyCache, InMemoryKeyCache};
23/// use std::time::Duration;
24///
25/// let cache = InMemoryKeyCache::new(Duration::from_secs(300));
26///
27/// // Store a key
28/// cache.set("my-kid", key).await;
29///
30/// // Retrieve a key
31/// if let Some(key) = cache.get("my-kid").await {
32/// // Use the cached key
33/// }
34/// ```
35#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
36#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
37pub trait KeyCache {
38 /// Gets a key from the cache by its ID.
39 ///
40 /// Returns `None` if the key is not in the cache or has expired.
41 async fn get(&self, kid: &str) -> Option<Key>;
42
43 /// Stores a key in the cache.
44 ///
45 /// The key ID is extracted from the key's `kid` field if not provided explicitly.
46 async fn set(&self, kid: &str, key: Key);
47
48 /// Removes a key from the cache.
49 async fn remove(&self, kid: &str);
50
51 /// Clears all keys from the cache.
52 async fn clear(&self);
53}
54
55/// A caching wrapper for any [`KeySource`] implementation.
56///
57/// This wrapper uses the cache-aside pattern: it first checks the cache for a key,
58/// and only fetches from the underlying source on a cache miss. Retrieved keys
59/// are then stored in the cache for future requests.
60///
61/// # Type Parameters
62///
63/// * `C` - The cache implementation (must implement [`KeyCache`])
64/// * `S` - The underlying key source (must implement [`KeySource`])
65///
66/// # Examples
67///
68/// ```ignore
69/// use jwk_simple::{CachedKeySet, InMemoryKeyCache, RemoteKeySet, KeySource};
70/// use std::time::Duration;
71/// use std::sync::Arc;
72///
73/// // Create a cached remote JWKS
74/// let cache = InMemoryKeyCache::new(Duration::from_secs(300));
75/// let remote = RemoteKeySet::new("https://example.com/.well-known/jwks.json");
76/// let cached = CachedKeySet::new(cache, remote);
77///
78/// // First call fetches from remote, caches the key
79/// let key = cached.get_key("kid").await?;
80///
81/// // Subsequent calls use the cache
82/// let key = cached.get_key("kid").await?;
83/// ```
84pub struct CachedKeySet<C, S> {
85 cache: C,
86 source: S,
87}
88
89impl<C, S> CachedKeySet<C, S> {
90 /// Creates a new cached key source.
91 ///
92 /// # Arguments
93 ///
94 /// * `cache` - The cache implementation to use.
95 /// * `source` - The underlying key source to fetch from on cache misses.
96 pub fn new(cache: C, source: S) -> Self {
97 Self { cache, source }
98 }
99
100 /// Returns a reference to the cache.
101 pub fn cache(&self) -> &C {
102 &self.cache
103 }
104
105 /// Returns a reference to the underlying source.
106 pub fn source(&self) -> &S {
107 &self.source
108 }
109}
110
111impl<C: std::fmt::Debug, S: std::fmt::Debug> std::fmt::Debug for CachedKeySet<C, S> {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("CachedKeySet")
114 .field("cache", &self.cache)
115 .field("source", &self.source)
116 .finish()
117 }
118}
119
120#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
121#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
122impl<C, S> KeySource for CachedKeySet<C, S>
123where
124 C: KeyCache + Send + Sync,
125 S: KeySource + Send + Sync,
126{
127 async fn get_key(&self, kid: &str) -> Result<Option<Key>> {
128 // Try cache first
129 if let Some(key) = self.cache.get(kid).await {
130 return Ok(Some(key));
131 }
132
133 // Cache miss: fetch from underlying source
134 if let Some(key) = self.source.get_key(kid).await? {
135 self.cache.set(kid, key.clone()).await;
136 return Ok(Some(key));
137 }
138
139 Ok(None)
140 }
141
142 async fn get_keyset(&self) -> Result<KeySet> {
143 // Fetch from source
144 let keyset = self.source.get_keyset().await?;
145
146 // Cache all keys that have a kid
147 for key in &keyset.keys {
148 if let Some(kid) = &key.kid {
149 self.cache.set(kid, key.clone()).await;
150 }
151 }
152
153 Ok(keyset)
154 }
155}