Skip to main content

moq_token/
set.rs

1use crate::error::KeyError;
2use crate::{Claims, Key, KeyOperation};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::path::Path;
5use std::sync::Arc;
6
7#[cfg(feature = "jwks-loader")]
8use std::time::Duration;
9
10/// JWK Set to spec <https://datatracker.ietf.org/doc/html/rfc7517#section-5>
11#[derive(Default, Clone)]
12pub struct KeySet {
13	/// Vec of an arbitrary number of Json Web Keys
14	pub keys: Vec<Arc<Key>>,
15}
16
17impl Serialize for KeySet {
18	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19	where
20		S: Serializer,
21	{
22		// Serialize as a struct with a `keys` field
23		use serde::ser::SerializeStruct;
24
25		let mut state = serializer.serialize_struct("KeySet", 1)?;
26		state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>())?;
27		state.end()
28	}
29}
30
31impl<'de> Deserialize<'de> for KeySet {
32	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33	where
34		D: Deserializer<'de>,
35	{
36		// Deserialize into a temporary Vec<Key>
37		#[derive(Deserialize)]
38		struct RawKeySet {
39			keys: Vec<Key>,
40		}
41
42		let raw = RawKeySet::deserialize(deserializer)?;
43		Ok(KeySet {
44			keys: raw.keys.into_iter().map(Arc::new).collect(),
45		})
46	}
47}
48
49impl KeySet {
50	#[allow(clippy::should_implement_trait)]
51	pub fn from_str(s: &str) -> crate::Result<Self> {
52		Ok(serde_json::from_str(s)?)
53	}
54
55	pub fn from_file<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
56		let json = std::fs::read_to_string(&path)?;
57		Ok(serde_json::from_str(&json)?)
58	}
59
60	pub fn to_str(&self) -> crate::Result<String> {
61		Ok(serde_json::to_string(&self)?)
62	}
63
64	pub fn to_file<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
65		let json = serde_json::to_string(&self)?;
66		std::fs::write(path, json)?;
67		Ok(())
68	}
69
70	pub fn to_public_set(&self) -> crate::Result<KeySet> {
71		Ok(KeySet {
72			keys: self
73				.keys
74				.iter()
75				.map(|key| key.as_ref().to_public().map(Arc::new))
76				.collect::<Result<Vec<Arc<Key>>, _>>()?,
77		})
78	}
79
80	pub fn find_key(&self, kid: &str) -> Option<Arc<Key>> {
81		self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned()
82	}
83
84	pub fn find_supported_key(&self, operation: &KeyOperation) -> Option<Arc<Key>> {
85		self.keys.iter().find(|key| key.operations.contains(operation)).cloned()
86	}
87
88	pub fn encode(&self, payload: &Claims) -> crate::Result<String> {
89		let key = self
90			.find_supported_key(&KeyOperation::Sign)
91			.ok_or(KeyError::NoSigningKey)?;
92		key.encode(payload)
93	}
94
95	pub fn decode(&self, token: &str) -> crate::Result<Claims> {
96		let header = jsonwebtoken::decode_header(token)?;
97
98		let key = match header.kid {
99			Some(kid) => self
100				.find_key(kid.as_str())
101				.ok_or_else(|| crate::Error::from(KeyError::KeyNotFound(kid))),
102			None => {
103				// If we only have one key we can use it without a kid
104				if self.keys.len() == 1 {
105					Ok(self.keys[0].clone())
106				} else {
107					Err(KeyError::MissingKid.into())
108				}
109			}
110		}?;
111
112		key.decode(token)
113	}
114}
115
116#[cfg(feature = "jwks-loader")]
117pub async fn load_keys(jwks_uri: &str) -> crate::Result<KeySet> {
118	let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()?;
119
120	let jwks_json = client.get(jwks_uri).send().await?.error_for_status()?.text().await?;
121
122	KeySet::from_str(&jwks_json)
123}
124
125#[cfg(test)]
126mod tests {
127	use super::*;
128	use crate::Algorithm;
129	use std::time::{Duration, SystemTime};
130
131	fn create_test_claims() -> Claims {
132		Claims {
133			root: "test-path".to_string(),
134			publish: vec!["test-pub".into()],
135			cluster: false,
136			subscribe: vec!["test-sub".into()],
137			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
138			issued: Some(SystemTime::now()),
139		}
140	}
141
142	fn create_test_key(kid: Option<&str>) -> Key {
143		let kid = kid.map(|s| crate::KeyId::decode(s).unwrap());
144		Key::generate(Algorithm::ES256, kid).expect("failed to generate key")
145	}
146
147	#[test]
148	fn test_keyset_from_str_valid() {
149		let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#;
150		let set = KeySet::from_str(json);
151		assert!(set.is_ok());
152		let set = set.unwrap();
153		assert_eq!(set.keys.len(), 1);
154		assert_eq!(set.keys[0].kid.as_deref(), Some("1"));
155		assert!(set.find_key("1").is_some());
156	}
157
158	#[test]
159	fn test_keyset_from_str_invalid_json() {
160		let result = KeySet::from_str("invalid json");
161		assert!(result.is_err());
162	}
163
164	#[test]
165	fn test_keyset_from_str_empty() {
166		let json = r#"{"keys":[]}"#;
167		let set = KeySet::from_str(json).unwrap();
168		assert!(set.keys.is_empty());
169	}
170
171	#[test]
172	fn test_keyset_to_str() {
173		let key = create_test_key(Some("1"));
174		let set = KeySet {
175			keys: vec![Arc::new(key)],
176		};
177
178		let json = set.to_str().unwrap();
179		assert!(json.contains("\"keys\""));
180		assert!(json.contains("\"kid\":\"1\""));
181	}
182
183	#[test]
184	fn test_keyset_serde_round_trip() {
185		let key1 = create_test_key(Some("1"));
186		let key2 = create_test_key(Some("2"));
187		let set = KeySet {
188			keys: vec![Arc::new(key1), Arc::new(key2)],
189		};
190
191		let json = set.to_str().unwrap();
192		let deserialized = KeySet::from_str(&json).unwrap();
193
194		assert_eq!(deserialized.keys.len(), 2);
195		assert!(deserialized.find_key("1").is_some());
196		assert!(deserialized.find_key("2").is_some());
197	}
198
199	#[test]
200	fn test_find_key_success() {
201		let key = create_test_key(Some("my-key"));
202		let set = KeySet {
203			keys: vec![Arc::new(key)],
204		};
205
206		let found = set.find_key("my-key");
207		assert!(found.is_some());
208		assert_eq!(found.unwrap().kid.as_deref(), Some("my-key"));
209	}
210
211	#[test]
212	fn test_find_key_missing() {
213		let key = create_test_key(Some("my-key"));
214		let set = KeySet {
215			keys: vec![Arc::new(key)],
216		};
217
218		let found = set.find_key("other-key");
219		assert!(found.is_none());
220	}
221
222	#[test]
223	fn test_find_key_no_kid() {
224		let key = create_test_key(None);
225		let set = KeySet {
226			keys: vec![Arc::new(key)],
227		};
228
229		let found = set.find_key("any-key");
230		assert!(found.is_none());
231	}
232
233	#[test]
234	fn test_find_supported_key() {
235		let mut sign_key = create_test_key(Some("sign"));
236		sign_key.operations = [KeyOperation::Sign].into();
237
238		let mut verify_key = create_test_key(Some("verify"));
239		verify_key.operations = [KeyOperation::Verify].into();
240
241		let set = KeySet {
242			keys: vec![Arc::new(sign_key), Arc::new(verify_key)],
243		};
244
245		let found_sign = set.find_supported_key(&KeyOperation::Sign);
246		assert!(found_sign.is_some());
247		assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign"));
248
249		let found_verify = set.find_supported_key(&KeyOperation::Verify);
250		assert!(found_verify.is_some());
251		assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify"));
252	}
253
254	#[test]
255	fn test_to_public_set() {
256		// Use asymmetric key (ES256) so we can separate public/private
257		let key = create_test_key(Some("1"));
258
259		let set = KeySet {
260			keys: vec![Arc::new(key)],
261		};
262
263		let public_set = set.to_public_set().expect("failed to convert to public set");
264		assert_eq!(public_set.keys.len(), 1);
265
266		let public_key = &public_set.keys[0];
267		assert_eq!(public_key.kid.as_deref(), Some("1"));
268		assert!(public_key.operations.contains(&KeyOperation::Verify));
269		assert!(!public_key.operations.contains(&KeyOperation::Sign));
270	}
271
272	#[test]
273	fn test_to_public_set_fails_for_symmetric() {
274		let key = Key::generate(Algorithm::HS256, Some(crate::KeyId::decode("sym").unwrap())).unwrap();
275		let set = KeySet {
276			keys: vec![Arc::new(key)],
277		};
278
279		let result = set.to_public_set();
280		assert!(result.is_err());
281	}
282
283	#[test]
284	fn test_encode_success() {
285		let key = create_test_key(Some("1"));
286		let set = KeySet {
287			keys: vec![Arc::new(key)],
288		};
289		let claims = create_test_claims();
290
291		let token = set.encode(&claims).unwrap();
292		assert!(!token.is_empty());
293	}
294
295	#[test]
296	fn test_encode_no_signing_key() {
297		let mut key = create_test_key(Some("1"));
298		key.operations = [KeyOperation::Verify].into();
299		let set = KeySet {
300			keys: vec![Arc::new(key)],
301		};
302		let claims = create_test_claims();
303
304		let result = set.encode(&claims);
305		assert!(result.is_err());
306		assert!(result.unwrap_err().to_string().contains("cannot find signing key"));
307	}
308
309	#[test]
310	fn test_decode_success_with_kid() {
311		let key = create_test_key(Some("1"));
312		let set = KeySet {
313			keys: vec![Arc::new(key)],
314		};
315		let claims = create_test_claims();
316
317		let token = set.encode(&claims).unwrap();
318		let decoded = set.decode(&token).unwrap();
319
320		assert_eq!(decoded.root, claims.root);
321	}
322
323	#[test]
324	fn test_decode_success_single_key_no_kid() {
325		// Create a key without KID
326		let key = create_test_key(None);
327		let claims = create_test_claims();
328
329		// Encode using the key directly
330		let token = key.encode(&claims).unwrap();
331
332		let set = KeySet {
333			keys: vec![Arc::new(key)],
334		};
335
336		// Decode using the set
337		let decoded = set.decode(&token).unwrap();
338		assert_eq!(decoded.root, claims.root);
339	}
340
341	#[test]
342	fn test_decode_fail_multiple_keys_no_kid() {
343		let key1 = create_test_key(None);
344		let key2 = create_test_key(None);
345
346		let set = KeySet {
347			keys: vec![Arc::new(key1), Arc::new(key2)],
348		};
349
350		let claims = create_test_claims();
351		// Encode with one of the keys directly
352		let token = set.keys[0].encode(&claims).unwrap();
353
354		let result = set.decode(&token);
355		assert!(result.is_err());
356		assert!(result.unwrap_err().to_string().contains("missing kid"));
357	}
358
359	#[test]
360	fn test_decode_fail_unknown_kid() {
361		let key1 = create_test_key(Some("1"));
362		let key2 = create_test_key(Some("2"));
363
364		let set1 = KeySet {
365			keys: vec![Arc::new(key1)],
366		};
367		let set2 = KeySet {
368			keys: vec![Arc::new(key2)],
369		};
370
371		let claims = create_test_claims();
372		let token = set1.encode(&claims).unwrap();
373
374		let result = set2.decode(&token);
375		assert!(result.is_err());
376		assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1"));
377	}
378
379	#[test]
380	fn test_file_io() {
381		let key = create_test_key(Some("1"));
382		let set = KeySet {
383			keys: vec![Arc::new(key)],
384		};
385
386		let dir = std::env::temp_dir();
387		// Use a random-ish name to avoid collisions
388		let filename = format!(
389			"test_keyset_{}.json",
390			SystemTime::now()
391				.duration_since(SystemTime::UNIX_EPOCH)
392				.unwrap()
393				.as_nanos()
394		);
395		let path = dir.join(filename);
396
397		set.to_file(&path).expect("failed to write to file");
398
399		let loaded = KeySet::from_file(&path).expect("failed to read from file");
400		assert_eq!(loaded.keys.len(), 1);
401		assert_eq!(loaded.keys[0].kid.as_deref(), Some("1"));
402
403		// Clean up
404		let _ = std::fs::remove_file(path);
405	}
406}