jwk_simple/jwks.rs
1//! JSON Web Key Set (JWKS) as defined in RFC 7517 Section 5.
2//!
3//! A JWKS is a collection of JWK objects, typically used for key distribution
4//! and discovery. This module also defines the [`KeySource`] trait, which abstracts
5//! over different sources of JWK keys.
6
7use serde::{Deserialize, Deserializer, Serialize};
8
9use crate::error::Result;
10use crate::jwk::{Algorithm, Key, KeyType, KeyUse};
11
12mod cache;
13#[cfg(feature = "http")]
14mod remote;
15#[cfg(feature = "cache-inmemory")]
16mod inmemory_cache;
17#[cfg(all(feature = "cloudflare", target_arch = "wasm32"))]
18pub mod cloudflare;
19
20pub use cache::{CachedKeySet, KeyCache};
21#[cfg(feature = "http")]
22pub use remote::{RemoteKeySet, DEFAULT_TIMEOUT};
23#[cfg(feature = "cache-inmemory")]
24pub use inmemory_cache::{InMemoryCachedKeySet, InMemoryKeyCache, DEFAULT_CACHE_TTL};
25
26/// A trait for types that can provide JWK keys.
27///
28/// This trait abstracts over different sources of keys, whether from
29/// a static set, a remote HTTP endpoint, or a cached source.
30///
31/// # Naming
32///
33/// This trait is called `KeySource` (not `Jwks`) to distinguish it from:
34/// - [`KeySet`] - The data structure holding keys
35/// - A potential `Jwks` type for serving keys via HTTP endpoints
36///
37/// # Async and Send Bounds
38///
39/// On native targets, the trait requires `Send + Sync` and futures are `Send`.
40/// On WASM targets, these bounds are relaxed since everything is single-threaded.
41///
42/// # Examples
43///
44/// Using a static key set:
45///
46/// ```
47/// use jwk_simple::{KeySource, KeySet};
48///
49/// # async fn example() -> jwk_simple::Result<()> {
50/// let source: KeySet = serde_json::from_str(r#"{"keys": []}"#)?;
51/// let key = source.get_key("some-kid").await?;
52/// # Ok(())
53/// # }
54/// ```
55///
56/// Generic code that works with any source:
57///
58/// ```ignore
59/// async fn verify_token<S: KeySource>(source: &S, kid: &str) -> Result<()> {
60/// let key = source.get_key(kid).await?.ok_or(Error::KeyNotFound)?;
61/// // ... verify with key
62/// Ok(())
63/// }
64/// ```
65#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
66#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
67pub trait KeySource {
68 /// Gets a key by its key ID (`kid`).
69 ///
70 /// Returns `Ok(None)` if no key with the given ID exists.
71 /// Returns `Err` if the lookup failed (e.g., network error for remote sources).
72 ///
73 /// # Arguments
74 ///
75 /// * `kid` - The key ID to look up.
76 async fn get_key(&self, kid: &str) -> Result<Option<Key>>;
77
78 /// Gets all available keys as a [`KeySet`].
79 ///
80 /// For remote sources, this may trigger a fetch if the cache is empty or expired.
81 async fn get_keyset(&self) -> Result<KeySet>;
82}
83
84// Implement KeySource for KeySet (static, immediate)
85#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
86#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
87impl KeySource for KeySet {
88 async fn get_key(&self, kid: &str) -> Result<Option<Key>> {
89 Ok(self.find_by_kid(kid).cloned())
90 }
91
92 async fn get_keyset(&self) -> Result<KeySet> {
93 Ok(self.clone())
94 }
95}
96
97/// A JSON Web Key Set (RFC 7517 Section 5).
98///
99/// A KeySet contains a collection of keys that can be looked up by various
100/// criteria such as key ID (`kid`), algorithm, or key use.
101///
102/// # RFC Compliance
103///
104/// Per RFC 7517 Section 5:
105/// > "Implementations SHOULD ignore JWKs within a JWK Set that use 'kty'
106/// > (key type) values that are not understood by them"
107///
108/// This implementation follows this guidance by silently skipping keys with
109/// unknown `kty` values during deserialization rather than failing.
110///
111/// # Examples
112///
113/// Parse a JWKS from JSON:
114///
115/// ```
116/// use jwk_simple::KeySet;
117///
118/// let json = r#"{
119/// "keys": [
120/// {
121/// "kty": "RSA",
122/// "kid": "key-1",
123/// "use": "sig",
124/// "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
125/// "e": "AQAB"
126/// }
127/// ]
128/// }"#;
129///
130/// let jwks: KeySet = serde_json::from_str(json).unwrap();
131/// assert_eq!(jwks.len(), 1);
132/// ```
133///
134/// Keys with unknown `kty` values are silently skipped:
135///
136/// ```
137/// use jwk_simple::KeySet;
138///
139/// let json = r#"{
140/// "keys": [
141/// {"kty": "UNKNOWN", "data": "ignored"},
142/// {"kty": "oct", "k": "AQAB"}
143/// ]
144/// }"#;
145///
146/// let jwks: KeySet = serde_json::from_str(json).unwrap();
147/// assert_eq!(jwks.len(), 1); // Only the "oct" key is included
148/// ```
149#[derive(Debug, Clone, Serialize, Default)]
150pub struct KeySet {
151 /// The collection of keys.
152 pub keys: Vec<Key>,
153}
154
155impl<'de> Deserialize<'de> for KeySet {
156 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
157 where
158 D: Deserializer<'de>,
159 {
160 // Helper struct for raw deserialization
161 #[derive(Deserialize)]
162 struct RawJwkSet {
163 keys: Vec<serde_json::Value>,
164 }
165
166 let raw = RawJwkSet::deserialize(deserializer)?;
167
168 // Try to parse each key, silently skipping those with unknown kty values
169 // per RFC 7517 Section 5
170 let keys: Vec<Key> = raw
171 .keys
172 .into_iter()
173 .filter_map(|value| {
174 // Try to deserialize as a Key
175 serde_json::from_value::<Key>(value).ok()
176 })
177 .collect();
178
179 Ok(KeySet { keys })
180 }
181}
182
183impl KeySet {
184 /// Creates a new empty KeySet.
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// use jwk_simple::KeySet;
190 ///
191 /// let jwks = KeySet::new();
192 /// assert!(jwks.is_empty());
193 /// ```
194 pub fn new() -> Self {
195 Self { keys: Vec::new() }
196 }
197
198 /// Creates a KeySet from a vector of keys.
199 ///
200 /// # Examples
201 ///
202 /// ```
203 /// use jwk_simple::{KeySet, Key};
204 ///
205 /// let keys = vec![]; // Would contain Key instances
206 /// let jwks = KeySet::from_keys(keys);
207 /// ```
208 pub fn from_keys(keys: Vec<Key>) -> Self {
209 Self { keys }
210 }
211 /// Returns the number of keys in the set.
212 pub fn len(&self) -> usize {
213 self.keys.len()
214 }
215
216 /// Returns `true` if the set contains no keys.
217 pub fn is_empty(&self) -> bool {
218 self.keys.is_empty()
219 }
220
221 /// Adds a key to the set.
222 ///
223 /// # Examples
224 ///
225 /// ```
226 /// use jwk_simple::{KeySet, Key};
227 ///
228 /// let mut jwks = KeySet::new();
229 /// // jwks.add_key(some_jwk);
230 /// ```
231 pub fn add_key(&mut self, key: Key) {
232 self.keys.push(key);
233 }
234
235 /// Removes and returns a key by its ID.
236 ///
237 /// # Arguments
238 ///
239 /// * `kid` - The key ID to look for.
240 ///
241 /// # Returns
242 ///
243 /// The removed key, or `None` if not found.
244 pub fn remove_by_kid(&mut self, kid: &str) -> Option<Key> {
245 if let Some(pos) = self.keys.iter().position(|k| k.kid.as_deref() == Some(kid)) {
246 Some(self.keys.remove(pos))
247 } else {
248 None
249 }
250 }
251
252 /// Finds a key by its ID (`kid`).
253 ///
254 /// # Arguments
255 ///
256 /// * `kid` - The key ID to look for.
257 ///
258 /// # Returns
259 ///
260 /// A reference to the key, or `None` if not found.
261 ///
262 /// # Examples
263 ///
264 /// ```
265 /// use jwk_simple::KeySet;
266 ///
267 /// let json = r#"{"keys": [{"kty": "oct", "kid": "my-key", "k": "AQAB"}]}"#;
268 /// let jwks: KeySet = serde_json::from_str(json).unwrap();
269 ///
270 /// let key = jwks.find_by_kid("my-key");
271 /// assert!(key.is_some());
272 ///
273 /// let missing = jwks.find_by_kid("unknown");
274 /// assert!(missing.is_none());
275 /// ```
276 pub fn find_by_kid(&self, kid: &str) -> Option<&Key> {
277 self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
278 }
279
280 /// Finds all keys with the specified algorithm.
281 ///
282 /// # Arguments
283 ///
284 /// * `alg` - The algorithm to filter by.
285 ///
286 /// # Returns
287 ///
288 /// A vector of references to matching keys.
289 ///
290 /// # Examples
291 ///
292 /// ```
293 /// use jwk_simple::{KeySet, Algorithm};
294 ///
295 /// let json = r#"{"keys": [{"kty": "RSA", "alg": "RS256", "n": "AQAB", "e": "AQAB"}]}"#;
296 /// let jwks: KeySet = serde_json::from_str(json).unwrap();
297 ///
298 /// let rs256_keys = jwks.find_by_alg(&Algorithm::Rs256);
299 /// assert_eq!(rs256_keys.len(), 1);
300 /// ```
301 pub fn find_by_alg(&self, alg: &Algorithm) -> Vec<&Key> {
302 self.keys
303 .iter()
304 .filter(|k| k.alg.as_ref() == Some(alg))
305 .collect()
306 }
307
308 /// Finds all keys with the specified key type.
309 ///
310 /// # Arguments
311 ///
312 /// * `kty` - The key type to filter by.
313 ///
314 /// # Returns
315 ///
316 /// A vector of references to matching keys.
317 pub fn find_by_kty(&self, kty: KeyType) -> Vec<&Key> {
318 self.keys.iter().filter(|k| k.kty == kty).collect()
319 }
320
321 /// Finds all keys with the specified use.
322 ///
323 /// # Arguments
324 ///
325 /// * `key_use` - The key use to filter by.
326 ///
327 /// # Returns
328 ///
329 /// A vector of references to matching keys.
330 ///
331 /// # Examples
332 ///
333 /// ```
334 /// use jwk_simple::{KeySet, KeyUse};
335 ///
336 /// let json = r#"{"keys": [{"kty": "RSA", "use": "sig", "n": "AQAB", "e": "AQAB"}]}"#;
337 /// let jwks: KeySet = serde_json::from_str(json).unwrap();
338 ///
339 /// let signing_keys = jwks.find_by_use(KeyUse::Signature);
340 /// assert_eq!(signing_keys.len(), 1);
341 /// ```
342 pub fn find_by_use(&self, key_use: KeyUse) -> Vec<&Key> {
343 self.keys
344 .iter()
345 .filter(|k| k.key_use.as_ref() == Some(&key_use))
346 .collect()
347 }
348
349 /// Finds all signing keys.
350 ///
351 /// A key is considered a signing key if:
352 /// - It has `use: "sig"`, OR
353 /// - It has no `use` specified (default behavior for signature keys)
354 ///
355 /// # Returns
356 ///
357 /// A vector of references to signing keys.
358 pub fn signing_keys(&self) -> Vec<&Key> {
359 self.keys
360 .iter()
361 .filter(|k| k.key_use.is_none() || k.key_use.as_ref() == Some(&KeyUse::Signature))
362 .collect()
363 }
364
365 /// Finds all encryption keys.
366 ///
367 /// # Returns
368 ///
369 /// A vector of references to encryption keys.
370 pub fn encryption_keys(&self) -> Vec<&Key> {
371 self.keys
372 .iter()
373 .filter(|k| k.key_use.as_ref() == Some(&KeyUse::Encryption))
374 .collect()
375 }
376
377 /// Returns the first signing key, if any.
378 ///
379 /// This is a convenience method for cases where only one signing key is expected.
380 ///
381 /// # Examples
382 ///
383 /// ```
384 /// use jwk_simple::KeySet;
385 ///
386 /// let json = r#"{"keys": [{"kty": "RSA", "use": "sig", "n": "AQAB", "e": "AQAB"}]}"#;
387 /// let jwks: KeySet = serde_json::from_str(json).unwrap();
388 ///
389 /// let key = jwks.first_signing_key().expect("expected a signing key");
390 /// ```
391 pub fn first_signing_key(&self) -> Option<&Key> {
392 self.signing_keys().into_iter().next()
393 }
394
395 /// Returns the first key, if any.
396 ///
397 /// # Examples
398 ///
399 /// ```
400 /// use jwk_simple::KeySet;
401 ///
402 /// let jwks = KeySet::new();
403 /// assert!(jwks.first().is_none());
404 /// ```
405 pub fn first(&self) -> Option<&Key> {
406 self.keys.first()
407 }
408
409 /// Returns an iterator over the keys.
410 pub fn iter(&self) -> impl Iterator<Item = &Key> {
411 self.keys.iter()
412 }
413
414 /// Validates all keys in the set.
415 ///
416 /// # Errors
417 ///
418 /// Returns the first validation error encountered, if any.
419 pub fn validate(&self) -> Result<()> {
420 for key in &self.keys {
421 key.validate()?;
422 }
423 Ok(())
424 }
425
426 /// Finds a key by its JWK thumbprint.
427 ///
428 /// # Arguments
429 ///
430 /// * `thumbprint` - The base64url-encoded SHA-256 thumbprint.
431 ///
432 /// # Returns
433 ///
434 /// A reference to the key, or `None` if not found.
435 pub fn find_by_thumbprint(&self, thumbprint: &str) -> Option<&Key> {
436 self.keys
437 .iter()
438 .find(|k| k.thumbprint() == thumbprint)
439 }
440}
441
442impl IntoIterator for KeySet {
443 type Item = Key;
444 type IntoIter = std::vec::IntoIter<Key>;
445
446 fn into_iter(self) -> Self::IntoIter {
447 self.keys.into_iter()
448 }
449}
450
451impl<'a> IntoIterator for &'a KeySet {
452 type Item = &'a Key;
453 type IntoIter = std::slice::Iter<'a, Key>;
454
455 fn into_iter(self) -> Self::IntoIter {
456 self.keys.iter()
457 }
458}
459
460impl FromIterator<Key> for KeySet {
461 fn from_iter<I: IntoIterator<Item = Key>>(iter: I) -> Self {
462 Self {
463 keys: iter.into_iter().collect(),
464 }
465 }
466}
467
468impl std::ops::Index<usize> for KeySet {
469 type Output = Key;
470
471 fn index(&self, index: usize) -> &Self::Output {
472 &self.keys[index]
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 const SAMPLE_JWKS: &str = r#"{
481 "keys": [
482 {
483 "kty": "RSA",
484 "kid": "rsa-key-1",
485 "use": "sig",
486 "alg": "RS256",
487 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
488 "e": "AQAB"
489 },
490 {
491 "kty": "EC",
492 "kid": "ec-key-1",
493 "use": "sig",
494 "alg": "ES256",
495 "crv": "P-256",
496 "x": "MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4",
497 "y": "4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"
498 },
499 {
500 "kty": "RSA",
501 "kid": "rsa-enc-1",
502 "use": "enc",
503 "n": "sXchDaQebSXKcvL0vwlG",
504 "e": "AQAB"
505 }
506 ]
507 }"#;
508
509 #[test]
510 fn test_parse_jwks() {
511 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
512 assert_eq!(jwks.len(), 3);
513 }
514
515 #[test]
516 fn test_find_by_kid() {
517 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
518
519 assert!(jwks.find_by_kid("rsa-key-1").is_some());
520 assert!(jwks.find_by_kid("ec-key-1").is_some());
521 assert!(jwks.find_by_kid("unknown").is_none());
522 }
523
524 #[test]
525 fn test_find_by_alg() {
526 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
527
528 let rs256_keys = jwks.find_by_alg(&Algorithm::Rs256);
529 assert_eq!(rs256_keys.len(), 1);
530
531 let es256_keys = jwks.find_by_alg(&Algorithm::Es256);
532 assert_eq!(es256_keys.len(), 1);
533 }
534
535 #[test]
536 fn test_find_by_use() {
537 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
538
539 let sig_keys = jwks.find_by_use(KeyUse::Signature);
540 assert_eq!(sig_keys.len(), 2);
541
542 let enc_keys = jwks.find_by_use(KeyUse::Encryption);
543 assert_eq!(enc_keys.len(), 1);
544 }
545
546 #[test]
547 fn test_signing_keys() {
548 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
549
550 let signing = jwks.signing_keys();
551 assert_eq!(signing.len(), 2);
552 }
553
554 #[test]
555 fn test_encryption_keys() {
556 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
557
558 let encryption = jwks.encryption_keys();
559 assert_eq!(encryption.len(), 1);
560 }
561
562 #[test]
563 fn test_first_signing_key() {
564 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
565
566 let first = jwks.first_signing_key().unwrap();
567 assert!(first.key_use == Some(KeyUse::Signature) || first.key_use.is_none());
568 }
569
570 #[test]
571 fn test_empty_jwks() {
572 let jwks = KeySet::new();
573 assert!(jwks.is_empty());
574 assert_eq!(jwks.len(), 0);
575 assert!(jwks.first().is_none());
576 assert!(jwks.first_signing_key().is_none());
577 }
578
579 #[test]
580 fn test_serde_roundtrip() {
581 let original: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
582 let json = serde_json::to_string(&original).unwrap();
583 let parsed: KeySet = serde_json::from_str(&json).unwrap();
584 assert_eq!(original.len(), parsed.len());
585 }
586
587 #[test]
588 fn test_iterator() {
589 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
590
591 let count = jwks.iter().count();
592 assert_eq!(count, 3);
593
594 let kids: Vec<_> = jwks
595 .iter()
596 .filter_map(|k| k.kid.as_deref())
597 .collect();
598 assert!(kids.contains(&"rsa-key-1"));
599 }
600
601 #[test]
602 fn test_from_iterator() {
603 let keys = vec![];
604 let jwks: KeySet = keys.into_iter().collect();
605 assert!(jwks.is_empty());
606 }
607
608 #[test]
609 fn test_index() {
610 let jwks: KeySet = serde_json::from_str(SAMPLE_JWKS).unwrap();
611 let first = &jwks[0];
612 assert_eq!(first.kid, Some("rsa-key-1".to_string()));
613 }
614
615 #[tokio::test]
616 async fn test_jwkset_implements_source() {
617 let json = r#"{"keys": [{"kty": "oct", "kid": "test-key", "k": "AQAB"}]}"#;
618 let source: KeySet = serde_json::from_str(json).unwrap();
619
620 // Test get_key
621 let key = source.get_key("test-key").await.unwrap();
622 assert!(key.is_some());
623 assert_eq!(key.unwrap().kid, Some("test-key".to_string()));
624
625 // Test missing key
626 let missing = source.get_key("nonexistent").await.unwrap();
627 assert!(missing.is_none());
628
629 // Test get_keyset
630 let keyset = source.get_keyset().await.unwrap();
631 assert_eq!(keyset.len(), 1);
632 }
633}