jwk_simple/jwks/
cache.rs

1//! Caching traits and wrappers for JWKS key sources.
2//!
3//! This module provides a [`KeyCache`] trait for caching keys by their ID,
4//! and a [`CachedKeySet`] wrapper that combines any cache with any key source.
5//!
6//! For a ready-to-use in-memory implementation, enable the `inmemory-cache` feature
7//! and use [`InMemoryKeyCache`](super::InMemoryKeyCache).
8
9use crate::error::Result;
10use crate::jwk::Key;
11
12use super::{KeySet, KeySource};
13
14/// A trait for caching keys by their ID.
15///
16/// Implementations can provide different caching strategies (in-memory, Redis, etc.)
17/// while the [`CachedKeySet`] handles the cache-aside pattern.
18///
19/// # Examples
20///
21/// ```ignore
22/// use jwk_simple::{KeyCache, InMemoryKeyCache};
23/// use std::time::Duration;
24///
25/// let cache = InMemoryKeyCache::new(Duration::from_secs(300));
26///
27/// // Store a key
28/// cache.set("my-kid", key).await;
29///
30/// // Retrieve a key
31/// if let Some(key) = cache.get("my-kid").await {
32///     // Use the cached key
33/// }
34/// ```
35#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
36#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
37pub trait KeyCache {
38    /// Gets a key from the cache by its ID.
39    ///
40    /// Returns `None` if the key is not in the cache or has expired.
41    async fn get(&self, kid: &str) -> Option<Key>;
42
43    /// Stores a key in the cache.
44    ///
45    /// The key ID is extracted from the key's `kid` field if not provided explicitly.
46    async fn set(&self, kid: &str, key: Key);
47
48    /// Removes a key from the cache.
49    async fn remove(&self, kid: &str);
50
51    /// Clears all keys from the cache.
52    async fn clear(&self);
53}
54
55/// A caching wrapper for any [`KeySource`] implementation.
56///
57/// This wrapper uses the cache-aside pattern: it first checks the cache for a key,
58/// and only fetches from the underlying source on a cache miss. Retrieved keys
59/// are then stored in the cache for future requests.
60///
61/// # Type Parameters
62///
63/// * `C` - The cache implementation (must implement [`KeyCache`])
64/// * `S` - The underlying key source (must implement [`KeySource`])
65///
66/// # Examples
67///
68/// ```ignore
69/// use jwk_simple::{CachedKeySet, InMemoryKeyCache, RemoteKeySet, KeySource};
70/// use std::time::Duration;
71/// use std::sync::Arc;
72///
73/// // Create a cached remote JWKS
74/// let cache = InMemoryKeyCache::new(Duration::from_secs(300));
75/// let remote = RemoteKeySet::new("https://example.com/.well-known/jwks.json");
76/// let cached = CachedKeySet::new(cache, remote);
77///
78/// // First call fetches from remote, caches the key
79/// let key = cached.get_key("kid").await?;
80///
81/// // Subsequent calls use the cache
82/// let key = cached.get_key("kid").await?;
83/// ```
84pub struct CachedKeySet<C, S> {
85    cache: C,
86    source: S,
87}
88
89impl<C, S> CachedKeySet<C, S> {
90    /// Creates a new cached key source.
91    ///
92    /// # Arguments
93    ///
94    /// * `cache` - The cache implementation to use.
95    /// * `source` - The underlying key source to fetch from on cache misses.
96    pub fn new(cache: C, source: S) -> Self {
97        Self { cache, source }
98    }
99
100    /// Returns a reference to the cache.
101    pub fn cache(&self) -> &C {
102        &self.cache
103    }
104
105    /// Returns a reference to the underlying source.
106    pub fn source(&self) -> &S {
107        &self.source
108    }
109}
110
111impl<C: std::fmt::Debug, S: std::fmt::Debug> std::fmt::Debug for CachedKeySet<C, S> {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("CachedKeySet")
114            .field("cache", &self.cache)
115            .field("source", &self.source)
116            .finish()
117    }
118}
119
120#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
121#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
122impl<C, S> KeySource for CachedKeySet<C, S>
123where
124    C: KeyCache + Send + Sync,
125    S: KeySource + Send + Sync,
126{
127    async fn get_key(&self, kid: &str) -> Result<Option<Key>> {
128        // Try cache first
129        if let Some(key) = self.cache.get(kid).await {
130            return Ok(Some(key));
131        }
132
133        // Cache miss: fetch from underlying source
134        if let Some(key) = self.source.get_key(kid).await? {
135            self.cache.set(kid, key.clone()).await;
136            return Ok(Some(key));
137        }
138
139        Ok(None)
140    }
141
142    async fn get_keyset(&self) -> Result<KeySet> {
143        // Fetch from source
144        let keyset = self.source.get_keyset().await?;
145
146        // Cache all keys that have a kid
147        for key in &keyset.keys {
148            if let Some(kid) = &key.kid {
149                self.cache.set(kid, key.clone()).await;
150            }
151        }
152
153        Ok(keyset)
154    }
155}