jwk_simple/
jwks.rs

1//! JSON Web Key Set (JWKS) as defined in RFC 7517 Section 5.
2//!
3//! A JWKS is a collection of JWK objects, typically used for key distribution
4//! and discovery. This module also defines the [`KeySource`] trait, which abstracts
5//! over different sources of JWK keys.
6
7use serde::{Deserialize, Deserializer, Serialize};
8
9use crate::error::Result;
10use crate::jwk::{Algorithm, Key, KeyType, KeyUse};
11
12mod cache;
13#[cfg(feature = "http")]
14mod remote;
15#[cfg(feature = "cache-inmemory")]
16mod inmemory_cache;
17#[cfg(all(feature = "cloudflare", target_arch = "wasm32"))]
18pub mod cloudflare;
19
20pub use cache::{CachedKeySet, KeyCache};
21#[cfg(feature = "http")]
22pub use remote::{RemoteKeySet, DEFAULT_TIMEOUT};
23#[cfg(feature = "cache-inmemory")]
24pub use inmemory_cache::{InMemoryCachedKeySet, InMemoryKeyCache, DEFAULT_CACHE_TTL};
25
26/// A trait for types that can provide JWK keys.
27///
28/// This trait abstracts over different sources of keys, whether from
29/// a static set, a remote HTTP endpoint, or a cached source.
30///
31/// # Naming
32///
33/// This trait is called `KeySource` (not `Jwks`) to distinguish it from:
34/// - [`KeySet`] - The data structure holding keys
35/// - A potential `Jwks` type for serving keys via HTTP endpoints
36///
37/// # Async and Send Bounds
38///
39/// On native targets, the trait requires `Send + Sync` and futures are `Send`.
40/// On WASM targets, these bounds are relaxed since everything is single-threaded.
41///
42/// # Examples
43///
44/// Using a static key set:
45///
46/// ```
47/// use jwk_simple::{KeySource, KeySet};
48///
49/// # async fn example() -> jwk_simple::Result<()> {
50/// let source: KeySet = serde_json::from_str(r#"{"keys": []}"#)?;
51/// let key = source.get_key("some-kid").await?;
52/// # Ok(())
53/// # }
54/// ```
55///
56/// Generic code that works with any source:
57///
58/// ```ignore
59/// async fn verify_token<S: KeySource>(source: &S, kid: &str) -> Result<()> {
60///     let key = source.get_key(kid).await?.ok_or(Error::KeyNotFound)?;
61///     // ... verify with key
62///     Ok(())
63/// }
64/// ```
65#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
66#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
67pub trait KeySource {
68    /// Gets a key by its key ID (`kid`).
69    ///
70    /// Returns `Ok(None)` if no key with the given ID exists.
71    /// Returns `Err` if the lookup failed (e.g., network error for remote sources).
72    ///
73    /// # Arguments
74    ///
75    /// * `kid` - The key ID to look up.
76    async fn get_key(&self, kid: &str) -> Result<Option<Key>>;
77
78    /// Gets all available keys as a [`KeySet`].
79    ///
80    /// For remote sources, this may trigger a fetch if the cache is empty or expired.
81    async fn get_keyset(&self) -> Result<KeySet>;
82}
83
84// Implement KeySource for KeySet (static, immediate)
85#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
86#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
87impl KeySource for KeySet {
88    async fn get_key(&self, kid: &str) -> Result<Option<Key>> {
89        Ok(self.find_by_kid(kid).cloned())
90    }
91
92    async fn get_keyset(&self) -> Result<KeySet> {
93        Ok(self.clone())
94    }
95}
96
97/// A JSON Web Key Set (RFC 7517 Section 5).
98///
99/// A KeySet contains a collection of keys that can be looked up by various
100/// criteria such as key ID (`kid`), algorithm, or key use.
101///
102/// # RFC Compliance
103///
104/// Per RFC 7517 Section 5:
105/// > "Implementations SHOULD ignore JWKs within a JWK Set that use 'kty'
106/// > (key type) values that are not understood by them"
107///
108/// This implementation follows this guidance by silently skipping keys with
109/// unknown `kty` values during deserialization rather than failing.
110///
111/// # Examples
112///
113/// Parse a JWKS from JSON:
114///
115/// ```
116/// use jwk_simple::KeySet;
117///
118/// let json = r#"{
119///     "keys": [
120///         {
121///             "kty": "RSA",
122///             "kid": "key-1",
123///             "use": "sig",
124///             "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
125///             "e": "AQAB"
126///         }
127///     ]
128/// }"#;
129///
130/// let jwks: KeySet = serde_json::from_str(json).unwrap();
131/// assert_eq!(jwks.len(), 1);
132/// ```
133///
134/// Keys with unknown `kty` values are silently skipped:
135///
136/// ```
137/// use jwk_simple::KeySet;
138///
139/// let json = r#"{
140///     "keys": [
141///         {"kty": "UNKNOWN", "data": "ignored"},
142///         {"kty": "oct", "k": "AQAB"}
143///     ]
144/// }"#;
145///
146/// let jwks: KeySet = serde_json::from_str(json).unwrap();
147/// assert_eq!(jwks.len(), 1); // Only the "oct" key is included
148/// ```
149#[derive(Debug, Clone, Serialize, Default)]
150pub struct KeySet {
151    /// The collection of keys.
152    pub keys: Vec<Key>,
153}
154
155impl<'de> Deserialize<'de> for KeySet {
156    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
157    where
158        D: Deserializer<'de>,
159    {
160        // Helper struct for raw deserialization
161        #[derive(Deserialize)]
162        struct RawJwkSet {
163            keys: Vec<serde_json::Value>,
164        }
165
166        let raw = RawJwkSet::deserialize(deserializer)?;
167
168        // Try to parse each key, silently skipping those with unknown kty values
169        // per RFC 7517 Section 5
170        let keys: Vec<Key> = raw
171            .keys
172            .into_iter()
173            .filter_map(|value| {
174                // Try to deserialize as a Key
175                serde_json::from_value::<Key>(value).ok()
176            })
177            .collect();
178
179        Ok(KeySet { keys })
180    }
181}
182
183impl KeySet {
184    /// Creates a new empty KeySet.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use jwk_simple::KeySet;
190    ///
191    /// let jwks = KeySet::new();
192    /// assert!(jwks.is_empty());
193    /// ```
194    pub fn new() -> Self {
195        Self { keys: Vec::new() }
196    }
197
198    /// Creates a KeySet from a vector of keys.
199    ///
200    /// # Examples
201    ///
202    /// ```
203    /// use jwk_simple::{KeySet, Key};
204    ///
205    /// let keys = vec![]; // Would contain Key instances
206    /// let jwks = KeySet::from_keys(keys);
207    /// ```
208    pub fn from_keys(keys: Vec<Key>) -> Self {
209        Self { keys }
210    }
211    /// Returns the number of keys in the set.
212    pub fn len(&self) -> usize {
213        self.keys.len()
214    }
215
216    /// Returns `true` if the set contains no keys.
217    pub fn is_empty(&self) -> bool {
218        self.keys.is_empty()
219    }
220
221    /// Adds a key to the set.
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// use jwk_simple::{KeySet, Key};
227    ///
228    /// let mut jwks = KeySet::new();
229    /// // jwks.add_key(some_jwk);
230    /// ```
231    pub fn add_key(&mut self, key: Key) {
232        self.keys.push(key);
233    }
234
235    /// Removes and returns a key by its ID.
236    ///
237    /// # Arguments
238    ///
239    /// * `kid` - The key ID to look for.
240    ///
241    /// # Returns
242    ///
243    /// The removed key, or `None` if not found.
244    pub fn remove_by_kid(&mut self, kid: &str) -> Option<Key> {
245        if let Some(pos) = self.keys.iter().position(|k| k.kid.as_deref() == Some(kid)) {
246            Some(self.keys.remove(pos))
247        } else {
248            None
249        }
250    }
251
252    /// Finds a key by its ID (`kid`).
253    ///
254    /// # Arguments
255    ///
256    /// * `kid` - The key ID to look for.
257    ///
258    /// # Returns
259    ///
260    /// A reference to the key, or `None` if not found.
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use jwk_simple::KeySet;
266    ///
267    /// let json = r#"{"keys": [{"kty": "oct", "kid": "my-key", "k": "AQAB"}]}"#;
268    /// let jwks: KeySet = serde_json::from_str(json).unwrap();
269    ///
270    /// let key = jwks.find_by_kid("my-key");
271    /// assert!(key.is_some());
272    ///
273    /// let missing = jwks.find_by_kid("unknown");
274    /// assert!(missing.is_none());
275    /// ```
276    pub fn find_by_kid(&self, kid: &str) -> Option<&Key> {
277        self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
278    }
279
280    /// Finds all keys with the specified algorithm.
281    ///
282    /// # Arguments
283    ///
284    /// * `alg` - The algorithm to filter by.
285    ///
286    /// # Returns
287    ///
288    /// A vector of references to matching keys.
289    ///
290    /// # Examples
291    ///
292    /// ```
293    /// use jwk_simple::{KeySet, Algorithm};
294    ///
295    /// let json = r#"{"keys": [{"kty": "RSA", "alg": "RS256", "n": "AQAB", "e": "AQAB"}]}"#;
296    /// let jwks: KeySet = serde_json::from_str(json).unwrap();
297    ///
298    /// let rs256_keys = jwks.find_by_alg(&Algorithm::Rs256);
299    /// assert_eq!(rs256_keys.len(), 1);
300    /// ```
301    pub fn find_by_alg(&self, alg: &Algorithm) -> Vec<&Key> {
302        self.keys
303            .iter()
304            .filter(|k| k.alg.as_ref() == Some(alg))
305            .collect()
306    }
307
308    /// Finds all keys with the specified key type.
309    ///
310    /// # Arguments
311    ///
312    /// * `kty` - The key type to filter by.
313    ///
314    /// # Returns
315    ///
316    /// A vector of references to matching keys.
317    pub fn find_by_kty(&self, kty: KeyType) -> Vec<&Key> {
318        self.keys.iter().filter(|k| k.kty == kty).collect()
319    }
320
321    /// Finds all keys with the specified use.
322    ///
323    /// # Arguments
324    ///
325    /// * `key_use` - The key use to filter by.
326    ///
327    /// # Returns
328    ///
329    /// A vector of references to matching keys.
330    ///
331    /// # Examples
332    ///
333    /// ```
334    /// use jwk_simple::{KeySet, KeyUse};
335    ///
336    /// let json = r#"{"keys": [{"kty": "RSA", "use": "sig", "n": "AQAB", "e": "AQAB"}]}"#;
337    /// let jwks: KeySet = serde_json::from_str(json).unwrap();
338    ///
339    /// let signing_keys = jwks.find_by_use(KeyUse::Signature);
340    /// assert_eq!(signing_keys.len(), 1);
341    /// ```
342    pub fn find_by_use(&self, key_use: KeyUse) -> Vec<&Key> {
343        self.keys
344            .iter()
345            .filter(|k| k.key_use.as_ref() == Some(&key_use))
346            .collect()
347    }
348
349    /// Finds all signing keys.
350    ///
351    /// A key is considered a signing key if:
352    /// - It has `use: "sig"`, OR
353    /// - It has no `use` specified (default behavior for signature keys)
354    ///
355    /// # Returns
356    ///
357    /// A vector of references to signing keys.
358    pub fn signing_keys(&self) -> Vec<&Key> {
359        self.keys
360            .iter()
361            .filter(|k| k.key_use.is_none() || k.key_use.as_ref() == Some(&KeyUse::Signature))
362            .collect()
363    }
364
365    /// Finds all encryption keys.
366    ///
367    /// # Returns
368    ///
369    /// A vector of references to encryption keys.
370    pub fn encryption_keys(&self) -> Vec<&Key> {
371        self.keys
372            .iter()
373            .filter(|k| k.key_use.as_ref() == Some(&KeyUse::Encryption))
374            .collect()
375    }
376
377    /// Returns the first signing key, if any.
378    ///
379    /// This is a convenience method for cases where only one signing key is expected.
380    ///
381    /// # Examples
382    ///
383    /// ```
384    /// use jwk_simple::KeySet;
385    ///
386    /// let json = r#"{"keys": [{"kty": "RSA", "use": "sig", "n": "AQAB", "e": "AQAB"}]}"#;
387    /// let jwks: KeySet = serde_json::from_str(json).unwrap();
388    ///
389    /// let key = jwks.first_signing_key().expect("expected a signing key");
390    /// ```
391    pub fn first_signing_key(&self) -> Option<&Key> {
392        self.signing_keys().into_iter().next()
393    }
394
395    /// Returns the first key, if any.
396    ///
397    /// # Examples
398    ///
399    /// ```
400    /// use jwk_simple::KeySet;
401    ///
402    /// let jwks = KeySet::new();
403    /// assert!(jwks.first().is_none());
404    /// ```
405    pub fn first(&self) -> Option<&Key> {
406        self.keys.first()
407    }
408
409    /// Returns an iterator over the keys.
410    pub fn iter(&self) -> impl Iterator<Item = &Key> {
411        self.keys.iter()
412    }
413
414    /// Validates all keys in the set.
415    ///
416    /// # Errors
417    ///
418    /// Returns the first validation error encountered, if any.
419    pub fn validate(&self) -> Result<()> {
420        for key in &self.keys {
421            key.validate()?;
422        }
423        Ok(())
424    }
425
426    /// Finds a key by its JWK thumbprint.
427    ///
428    /// # Arguments
429    ///
430    /// * `thumbprint` - The base64url-encoded SHA-256 thumbprint.
431    ///
432    /// # Returns
433    ///
434    /// A reference to the key, or `None` if not found.
435    pub fn find_by_thumbprint(&self, thumbprint: &str) -> Option<&Key> {
436        self.keys
437            .iter()
438            .find(|k| k.thumbprint() == thumbprint)
439    }
440}
441
442impl IntoIterator for KeySet {
443    type Item = Key;
444    type IntoIter = std::vec::IntoIter<Key>;
445
446    fn into_iter(self) -> Self::IntoIter {
447        self.keys.into_iter()
448    }
449}
450
451impl<'a> IntoIterator for &'a KeySet {
452    type Item = &'a Key;
453    type IntoIter = std::slice::Iter<'a, Key>;
454
455    fn into_iter(self) -> Self::IntoIter {
456        self.keys.iter()
457    }
458}
459
460impl FromIterator<Key> for KeySet {
461    fn from_iter<I: IntoIterator<Item = Key>>(iter: I) -> Self {
462        Self {
463            keys: iter.into_iter().collect(),
464        }
465    }
466}
467
468impl std::ops::Index<usize> for KeySet {
469    type Output = Key;
470
471    fn index(&self, index: usize) -> &Self::Output {
472        &self.keys[index]
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    const SAMPLE_JWKS: &str = r#"{
481        "keys": [
482            {
483                "kty": "RSA",
484                "kid": "rsa-key-1",
485                "use": "sig",
486                "alg": "RS256",
487                "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
488                "e": "AQAB"
489            },
490            {
491                "kty": "EC",
492                "kid": "ec-key-1",
493                "use": "sig",
494                "alg": "ES256",
495                "crv": "P-256",
496                "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
497                "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"
498            },
499            {
500                "kty": "RSA",
501                "kid": "rsa-enc-1",
502                "use": "enc",
503                "n": "sXchDaQebSXKcvL0vwlG",
504                "e": "AQAB"
505            }
506        ]
507    }"#;
508
509    #[test]
510    fn test_parse_jwks() {
511        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
512        assert_eq!(jwks.len(), 3);
513    }
514
515    #[test]
516    fn test_find_by_kid() {
517        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
518
519        assert!(jwks.find_by_kid("rsa-key-1").is_some());
520        assert!(jwks.find_by_kid("ec-key-1").is_some());
521        assert!(jwks.find_by_kid("unknown").is_none());
522    }
523
524    #[test]
525    fn test_find_by_alg() {
526        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
527
528        let rs256_keys = jwks.find_by_alg(&Algorithm::Rs256);
529        assert_eq!(rs256_keys.len(), 1);
530
531        let es256_keys = jwks.find_by_alg(&Algorithm::Es256);
532        assert_eq!(es256_keys.len(), 1);
533    }
534
535    #[test]
536    fn test_find_by_use() {
537        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
538
539        let sig_keys = jwks.find_by_use(KeyUse::Signature);
540        assert_eq!(sig_keys.len(), 2);
541
542        let enc_keys = jwks.find_by_use(KeyUse::Encryption);
543        assert_eq!(enc_keys.len(), 1);
544    }
545
546    #[test]
547    fn test_signing_keys() {
548        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
549
550        let signing = jwks.signing_keys();
551        assert_eq!(signing.len(), 2);
552    }
553
554    #[test]
555    fn test_encryption_keys() {
556        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
557
558        let encryption = jwks.encryption_keys();
559        assert_eq!(encryption.len(), 1);
560    }
561
562    #[test]
563    fn test_first_signing_key() {
564        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
565
566        let first = jwks.first_signing_key().unwrap();
567        assert!(first.key_use == Some(KeyUse::Signature) || first.key_use.is_none());
568    }
569
570    #[test]
571    fn test_empty_jwks() {
572        let jwks = KeySet::new();
573        assert!(jwks.is_empty());
574        assert_eq!(jwks.len(), 0);
575        assert!(jwks.first().is_none());
576        assert!(jwks.first_signing_key().is_none());
577    }
578
579    #[test]
580    fn test_serde_roundtrip() {
581        let original: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
582        let json = serde_json::to_string(&original).unwrap();
583        let parsed: KeySet = serde_json::from_str(&json).unwrap();
584        assert_eq!(original.len(), parsed.len());
585    }
586
587    #[test]
588    fn test_iterator() {
589        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
590
591        let count = jwks.iter().count();
592        assert_eq!(count, 3);
593
594        let kids: Vec<_> = jwks
595            .iter()
596            .filter_map(|k| k.kid.as_deref())
597            .collect();
598        assert!(kids.contains(&"rsa-key-1"));
599    }
600
601    #[test]
602    fn test_from_iterator() {
603        let keys = vec![];
604        let jwks: KeySet = keys.into_iter().collect();
605        assert!(jwks.is_empty());
606    }
607
608    #[test]
609    fn test_index() {
610        let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
611        let first = &jwks[0];
612        assert_eq!(first.kid, Some("rsa-key-1".to_string()));
613    }
614
615    #[tokio::test]
616    async fn test_jwkset_implements_source() {
617        let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
618        let source: KeySet = serde_json::from_str(json).unwrap();
619
620        // Test get_key
621        let key = source.get_key("test-key").await.unwrap();
622        assert!(key.is_some());
623        assert_eq!(key.unwrap().kid, Some("test-key".to_string()));
624
625        // Test missing key
626        let missing = source.get_key("nonexistent").await.unwrap();
627        assert!(missing.is_none());
628
629        // Test get_keyset
630        let keyset = source.get_keyset().await.unwrap();
631        assert_eq!(keyset.len(), 1);
632    }
633}