Skip to main content

moq_token/
set.rs

1use crate::{Claims, Key, KeyOperation};
2use anyhow::Context;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8/// JWK Set to spec <https://datatracker.ietf.org/doc/html/rfc7517#section-5>
9#[derive(Default, Clone)]
10pub struct KeySet {
11	/// Vec of an arbitrary number of Json Web Keys
12	pub keys: Vec<Arc<Key>>,
13}
14
15impl Serialize for KeySet {
16	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
17	where
18		S: Serializer,
19	{
20		// Serialize as a struct with a `keys` field
21		use serde::ser::SerializeStruct;
22
23		let mut state = serializer.serialize_struct("KeySet", 1)?;
24		state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>())?;
25		state.end()
26	}
27}
28
29impl<'de> Deserialize<'de> for KeySet {
30	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
31	where
32		D: Deserializer<'de>,
33	{
34		// Deserialize into a temporary Vec<Key>
35		#[derive(Deserialize)]
36		struct RawKeySet {
37			keys: Vec<Key>,
38		}
39
40		let raw = RawKeySet::deserialize(deserializer)?;
41		Ok(KeySet {
42			keys: raw.keys.into_iter().map(Arc::new).collect(),
43		})
44	}
45}
46
47impl KeySet {
48	#[allow(clippy::should_implement_trait)]
49	pub fn from_str(s: &str) -> anyhow::Result<Self> {
50		Ok(serde_json::from_str(s)?)
51	}
52
53	pub fn from_file<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
54		let json = std::fs::read_to_string(&path)?;
55		Ok(serde_json::from_str(&json)?)
56	}
57
58	pub fn to_str(&self) -> anyhow::Result<String> {
59		Ok(serde_json::to_string(&self)?)
60	}
61
62	pub fn to_file<P: AsRef<Path>>(&self, path: P) -> anyhow::Result<()> {
63		let json = serde_json::to_string(&self)?;
64		std::fs::write(path, json)?;
65		Ok(())
66	}
67
68	pub fn to_public_set(&self) -> anyhow::Result<KeySet> {
69		Ok(KeySet {
70			keys: self
71				.keys
72				.iter()
73				.map(|key| {
74					key.as_ref()
75						.to_public()
76						.map(Arc::new)
77						.map_err(|e| anyhow::anyhow!("failed to get public key from jwks: {:?}", e))
78				})
79				.collect::<Result<Vec<Arc<Key>>, _>>()?,
80		})
81	}
82
83	pub fn find_key(&self, kid: &str) -> Option<Arc<Key>> {
84		self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned()
85	}
86
87	pub fn find_supported_key(&self, operation: &KeyOperation) -> Option<Arc<Key>> {
88		self.keys.iter().find(|key| key.operations.contains(operation)).cloned()
89	}
90
91	pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
92		let key = self
93			.find_supported_key(&KeyOperation::Sign)
94			.context("cannot find signing key")?;
95		key.encode(payload)
96	}
97
98	pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
99		let header = jsonwebtoken::decode_header(token).context("failed to decode JWT header")?;
100
101		let key = match header.kid {
102			Some(kid) => self
103				.find_key(kid.as_str())
104				.ok_or_else(|| anyhow::anyhow!("cannot find key with kid {kid}")),
105			None => {
106				// If we only have one key we can use it without a kid
107				if self.keys.len() == 1 {
108					Ok(self.keys[0].clone())
109				} else {
110					anyhow::bail!("missing kid in JWT header")
111				}
112			}
113		}?;
114
115		key.decode(token)
116	}
117}
118
119#[cfg(feature = "jwks-loader")]
120pub async fn load_keys(jwks_uri: &str) -> anyhow::Result<KeySet> {
121	let client = reqwest::Client::builder()
122		.timeout(Duration::from_secs(10))
123		.build()
124		.context("failed to build reqwest client")?;
125
126	let jwks_json = client
127		.get(jwks_uri)
128		.send()
129		.await
130		.with_context(|| format!("failed to GET JWKS from {}", jwks_uri))?
131		.error_for_status()
132		.with_context(|| format!("JWKS endpoint returned error: {}", jwks_uri))?
133		.text()
134		.await
135		.context("failed to read JWKS response body")?;
136
137	// Parse the JWKS into a KeySet
138	KeySet::from_str(&jwks_json).context("Failed to parse JWKS into KeySet")
139}
140
141#[cfg(test)]
142mod tests {
143	use super::*;
144	use crate::Algorithm;
145	use std::time::{Duration, SystemTime};
146
147	fn create_test_claims() -> Claims {
148		Claims {
149			root: "test-path".to_string(),
150			publish: vec!["test-pub".into()],
151			cluster: false,
152			subscribe: vec!["test-sub".into()],
153			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
154			issued: Some(SystemTime::now()),
155		}
156	}
157
158	fn create_test_key(kid: Option<String>) -> Key {
159		Key::generate(Algorithm::ES256, kid).expect("failed to generate key")
160	}
161
162	#[test]
163	fn test_keyset_from_str_valid() {
164		let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#;
165		let set = KeySet::from_str(json);
166		assert!(set.is_ok());
167		let set = set.unwrap();
168		assert_eq!(set.keys.len(), 1);
169		assert_eq!(set.keys[0].kid.as_deref(), Some("1"));
170		assert!(set.find_key("1").is_some());
171	}
172
173	#[test]
174	fn test_keyset_from_str_invalid_json() {
175		let result = KeySet::from_str("invalid json");
176		assert!(result.is_err());
177	}
178
179	#[test]
180	fn test_keyset_from_str_empty() {
181		let json = r#"{"keys":[]}"#;
182		let set = KeySet::from_str(json).unwrap();
183		assert!(set.keys.is_empty());
184	}
185
186	#[test]
187	fn test_keyset_to_str() {
188		let key = create_test_key(Some("1".to_string()));
189		let set = KeySet {
190			keys: vec![Arc::new(key)],
191		};
192
193		let json = set.to_str().unwrap();
194		assert!(json.contains("\"keys\""));
195		assert!(json.contains("\"kid\":\"1\""));
196	}
197
198	#[test]
199	fn test_keyset_serde_round_trip() {
200		let key1 = create_test_key(Some("1".to_string()));
201		let key2 = create_test_key(Some("2".to_string()));
202		let set = KeySet {
203			keys: vec![Arc::new(key1), Arc::new(key2)],
204		};
205
206		let json = set.to_str().unwrap();
207		let deserialized = KeySet::from_str(&json).unwrap();
208
209		assert_eq!(deserialized.keys.len(), 2);
210		assert!(deserialized.find_key("1").is_some());
211		assert!(deserialized.find_key("2").is_some());
212	}
213
214	#[test]
215	fn test_find_key_success() {
216		let key = create_test_key(Some("my-key".to_string()));
217		let set = KeySet {
218			keys: vec![Arc::new(key)],
219		};
220
221		let found = set.find_key("my-key");
222		assert!(found.is_some());
223		assert_eq!(found.unwrap().kid.as_deref(), Some("my-key"));
224	}
225
226	#[test]
227	fn test_find_key_missing() {
228		let key = create_test_key(Some("my-key".to_string()));
229		let set = KeySet {
230			keys: vec![Arc::new(key)],
231		};
232
233		let found = set.find_key("other-key");
234		assert!(found.is_none());
235	}
236
237	#[test]
238	fn test_find_key_no_kid() {
239		let key = create_test_key(None);
240		let set = KeySet {
241			keys: vec![Arc::new(key)],
242		};
243
244		let found = set.find_key("any-key");
245		assert!(found.is_none());
246	}
247
248	#[test]
249	fn test_find_supported_key() {
250		let mut sign_key = create_test_key(Some("sign".to_string()));
251		sign_key.operations = [KeyOperation::Sign].into();
252
253		let mut verify_key = create_test_key(Some("verify".to_string()));
254		verify_key.operations = [KeyOperation::Verify].into();
255
256		let set = KeySet {
257			keys: vec![Arc::new(sign_key), Arc::new(verify_key)],
258		};
259
260		let found_sign = set.find_supported_key(&KeyOperation::Sign);
261		assert!(found_sign.is_some());
262		assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign"));
263
264		let found_verify = set.find_supported_key(&KeyOperation::Verify);
265		assert!(found_verify.is_some());
266		assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify"));
267	}
268
269	#[test]
270	fn test_to_public_set() {
271		// Use asymmetric key (ES256) so we can separate public/private
272		let key = create_test_key(Some("1".to_string()));
273
274		let set = KeySet {
275			keys: vec![Arc::new(key)],
276		};
277
278		let public_set = set.to_public_set().expect("failed to convert to public set");
279		assert_eq!(public_set.keys.len(), 1);
280
281		let public_key = &public_set.keys[0];
282		assert_eq!(public_key.kid.as_deref(), Some("1"));
283		assert!(public_key.operations.contains(&KeyOperation::Verify));
284		assert!(!public_key.operations.contains(&KeyOperation::Sign));
285	}
286
287	#[test]
288	fn test_to_public_set_fails_for_symmetric() {
289		let key = Key::generate(Algorithm::HS256, Some("sym".to_string())).unwrap();
290		let set = KeySet {
291			keys: vec![Arc::new(key)],
292		};
293
294		let result = set.to_public_set();
295		assert!(result.is_err());
296	}
297
298	#[test]
299	fn test_encode_success() {
300		let key = create_test_key(Some("1".to_string()));
301		let set = KeySet {
302			keys: vec![Arc::new(key)],
303		};
304		let claims = create_test_claims();
305
306		let token = set.encode(&claims).unwrap();
307		assert!(!token.is_empty());
308	}
309
310	#[test]
311	fn test_encode_no_signing_key() {
312		let mut key = create_test_key(Some("1".to_string()));
313		key.operations = [KeyOperation::Verify].into();
314		let set = KeySet {
315			keys: vec![Arc::new(key)],
316		};
317		let claims = create_test_claims();
318
319		let result = set.encode(&claims);
320		assert!(result.is_err());
321		assert!(result.unwrap_err().to_string().contains("cannot find signing key"));
322	}
323
324	#[test]
325	fn test_decode_success_with_kid() {
326		let key = create_test_key(Some("1".to_string()));
327		let set = KeySet {
328			keys: vec![Arc::new(key)],
329		};
330		let claims = create_test_claims();
331
332		let token = set.encode(&claims).unwrap();
333		let decoded = set.decode(&token).unwrap();
334
335		assert_eq!(decoded.root, claims.root);
336	}
337
338	#[test]
339	fn test_decode_success_single_key_no_kid() {
340		// Create a key without KID
341		let key = create_test_key(None);
342		let claims = create_test_claims();
343
344		// Encode using the key directly
345		let token = key.encode(&claims).unwrap();
346
347		let set = KeySet {
348			keys: vec![Arc::new(key)],
349		};
350
351		// Decode using the set
352		let decoded = set.decode(&token).unwrap();
353		assert_eq!(decoded.root, claims.root);
354	}
355
356	#[test]
357	fn test_decode_fail_multiple_keys_no_kid() {
358		let key1 = create_test_key(None);
359		let key2 = create_test_key(None);
360
361		let set = KeySet {
362			keys: vec![Arc::new(key1), Arc::new(key2)],
363		};
364
365		let claims = create_test_claims();
366		// Encode with one of the keys directly
367		let token = set.keys[0].encode(&claims).unwrap();
368
369		let result = set.decode(&token);
370		assert!(result.is_err());
371		assert!(result.unwrap_err().to_string().contains("missing kid"));
372	}
373
374	#[test]
375	fn test_decode_fail_unknown_kid() {
376		let key1 = create_test_key(Some("1".to_string()));
377		let key2 = create_test_key(Some("2".to_string()));
378
379		let set1 = KeySet {
380			keys: vec![Arc::new(key1)],
381		};
382		let set2 = KeySet {
383			keys: vec![Arc::new(key2)],
384		};
385
386		let claims = create_test_claims();
387		let token = set1.encode(&claims).unwrap();
388
389		let result = set2.decode(&token);
390		assert!(result.is_err());
391		assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1"));
392	}
393
394	#[test]
395	fn test_file_io() {
396		let key = create_test_key(Some("1".to_string()));
397		let set = KeySet {
398			keys: vec![Arc::new(key)],
399		};
400
401		let dir = std::env::temp_dir();
402		// Use a random-ish name to avoid collisions
403		let filename = format!(
404			"test_keyset_{}.json",
405			SystemTime::now()
406				.duration_since(SystemTime::UNIX_EPOCH)
407				.unwrap()
408				.as_nanos()
409		);
410		let path = dir.join(filename);
411
412		set.to_file(&path).expect("failed to write to file");
413
414		let loaded = KeySet::from_file(&path).expect("failed to read from file");
415		assert_eq!(loaded.keys.len(), 1);
416		assert_eq!(loaded.keys[0].kid.as_deref(), Some("1"));
417
418		// Clean up
419		let _ = std::fs::remove_file(path);
420	}
421}