Skip to main content

moq_token/
set.rs

1use crate::error::KeyError;
2use crate::{Claims, Key, KeyOperation};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::path::Path;
5use std::sync::Arc;
6
7#[cfg(feature = "jwks-loader")]
8use std::time::Duration;
9
10/// JWK Set to spec <https://datatracker.ietf.org/doc/html/rfc7517#section-5>
11#[derive(Default, Clone)]
12pub struct KeySet {
13	/// Vec of an arbitrary number of Json Web Keys
14	pub keys: Vec<Arc<Key>>,
15}
16
17impl Serialize for KeySet {
18	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19	where
20		S: Serializer,
21	{
22		// Serialize as a struct with a `keys` field
23		use serde::ser::SerializeStruct;
24
25		let mut state = serializer.serialize_struct("KeySet", 1)?;
26		state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>())?;
27		state.end()
28	}
29}
30
31impl<'de> Deserialize<'de> for KeySet {
32	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33	where
34		D: Deserializer<'de>,
35	{
36		// Deserialize into a temporary Vec<Key>
37		#[derive(Deserialize)]
38		struct RawKeySet {
39			keys: Vec<Key>,
40		}
41
42		let raw = RawKeySet::deserialize(deserializer)?;
43		Ok(KeySet {
44			keys: raw.keys.into_iter().map(Arc::new).collect(),
45		})
46	}
47}
48
49impl KeySet {
50	#[allow(clippy::should_implement_trait)]
51	pub fn from_str(s: &str) -> crate::Result<Self> {
52		Ok(serde_json::from_str(s)?)
53	}
54
55	pub fn from_file<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
56		let json = std::fs::read_to_string(&path)?;
57		Ok(serde_json::from_str(&json)?)
58	}
59
60	pub fn to_str(&self) -> crate::Result<String> {
61		Ok(serde_json::to_string(&self)?)
62	}
63
64	pub fn to_file<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
65		let json = serde_json::to_string(&self)?;
66		std::fs::write(path, json)?;
67		Ok(())
68	}
69
70	pub fn to_public_set(&self) -> crate::Result<KeySet> {
71		Ok(KeySet {
72			keys: self
73				.keys
74				.iter()
75				.map(|key| key.as_ref().to_public().map(Arc::new))
76				.collect::<Result<Vec<Arc<Key>>, _>>()?,
77		})
78	}
79
80	pub fn find_key(&self, kid: &str) -> Option<Arc<Key>> {
81		self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned()
82	}
83
84	pub fn find_supported_key(&self, operation: &KeyOperation) -> Option<Arc<Key>> {
85		self.keys.iter().find(|key| key.operations.contains(operation)).cloned()
86	}
87
88	pub fn encode(&self, payload: &Claims) -> crate::Result<String> {
89		let key = self
90			.find_supported_key(&KeyOperation::Sign)
91			.ok_or(KeyError::NoSigningKey)?;
92		key.encode(payload)
93	}
94
95	pub fn decode(&self, token: &str) -> crate::Result<Claims> {
96		let header = jsonwebtoken::decode_header(token)?;
97
98		let key = match header.kid {
99			Some(kid) => self
100				.find_key(kid.as_str())
101				.ok_or_else(|| crate::Error::from(KeyError::KeyNotFound(kid))),
102			None => {
103				// If we only have one key we can use it without a kid
104				if self.keys.len() == 1 {
105					Ok(self.keys[0].clone())
106				} else {
107					Err(KeyError::MissingKid.into())
108				}
109			}
110		}?;
111
112		key.decode(token)
113	}
114}
115
116#[cfg(feature = "jwks-loader")]
117pub async fn load_keys(jwks_uri: &str) -> crate::Result<KeySet> {
118	let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()?;
119
120	let jwks_json = client.get(jwks_uri).send().await?.error_for_status()?.text().await?;
121
122	KeySet::from_str(&jwks_json)
123}
124
125#[cfg(test)]
126mod tests {
127	use super::*;
128	use crate::Algorithm;
129	use std::time::{Duration, SystemTime};
130
131	fn create_test_claims() -> Claims {
132		Claims {
133			root: "test-path".to_string(),
134			publish: vec!["test-pub".into()],
135			subscribe: vec!["test-sub".into()],
136			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
137			issued: Some(SystemTime::now()),
138		}
139	}
140
141	fn create_test_key(kid: Option<&str>) -> Key {
142		let kid = kid.map(|s| crate::KeyId::decode(s).unwrap());
143		Key::generate(Algorithm::ES256, kid).expect("failed to generate key")
144	}
145
146	#[test]
147	fn test_keyset_from_str_valid() {
148		let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#;
149		let set = KeySet::from_str(json);
150		assert!(set.is_ok());
151		let set = set.unwrap();
152		assert_eq!(set.keys.len(), 1);
153		assert_eq!(set.keys[0].kid.as_deref(), Some("1"));
154		assert!(set.find_key("1").is_some());
155	}
156
157	#[test]
158	fn test_keyset_from_str_invalid_json() {
159		let result = KeySet::from_str("invalid json");
160		assert!(result.is_err());
161	}
162
163	#[test]
164	fn test_keyset_from_str_empty() {
165		let json = r#"{"keys":[]}"#;
166		let set = KeySet::from_str(json).unwrap();
167		assert!(set.keys.is_empty());
168	}
169
170	#[test]
171	fn test_keyset_to_str() {
172		let key = create_test_key(Some("1"));
173		let set = KeySet {
174			keys: vec![Arc::new(key)],
175		};
176
177		let json = set.to_str().unwrap();
178		assert!(json.contains("\"keys\""));
179		assert!(json.contains("\"kid\":\"1\""));
180	}
181
182	#[test]
183	fn test_keyset_serde_round_trip() {
184		let key1 = create_test_key(Some("1"));
185		let key2 = create_test_key(Some("2"));
186		let set = KeySet {
187			keys: vec![Arc::new(key1), Arc::new(key2)],
188		};
189
190		let json = set.to_str().unwrap();
191		let deserialized = KeySet::from_str(&json).unwrap();
192
193		assert_eq!(deserialized.keys.len(), 2);
194		assert!(deserialized.find_key("1").is_some());
195		assert!(deserialized.find_key("2").is_some());
196	}
197
198	#[test]
199	fn test_find_key_success() {
200		let key = create_test_key(Some("my-key"));
201		let set = KeySet {
202			keys: vec![Arc::new(key)],
203		};
204
205		let found = set.find_key("my-key");
206		assert!(found.is_some());
207		assert_eq!(found.unwrap().kid.as_deref(), Some("my-key"));
208	}
209
210	#[test]
211	fn test_find_key_missing() {
212		let key = create_test_key(Some("my-key"));
213		let set = KeySet {
214			keys: vec![Arc::new(key)],
215		};
216
217		let found = set.find_key("other-key");
218		assert!(found.is_none());
219	}
220
221	#[test]
222	fn test_find_key_no_kid() {
223		let key = create_test_key(None);
224		let set = KeySet {
225			keys: vec![Arc::new(key)],
226		};
227
228		let found = set.find_key("any-key");
229		assert!(found.is_none());
230	}
231
232	#[test]
233	fn test_find_supported_key() {
234		let mut sign_key = create_test_key(Some("sign"));
235		sign_key.operations = [KeyOperation::Sign].into();
236
237		let mut verify_key = create_test_key(Some("verify"));
238		verify_key.operations = [KeyOperation::Verify].into();
239
240		let set = KeySet {
241			keys: vec![Arc::new(sign_key), Arc::new(verify_key)],
242		};
243
244		let found_sign = set.find_supported_key(&KeyOperation::Sign);
245		assert!(found_sign.is_some());
246		assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign"));
247
248		let found_verify = set.find_supported_key(&KeyOperation::Verify);
249		assert!(found_verify.is_some());
250		assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify"));
251	}
252
253	#[test]
254	fn test_to_public_set() {
255		// Use asymmetric key (ES256) so we can separate public/private
256		let key = create_test_key(Some("1"));
257
258		let set = KeySet {
259			keys: vec![Arc::new(key)],
260		};
261
262		let public_set = set.to_public_set().expect("failed to convert to public set");
263		assert_eq!(public_set.keys.len(), 1);
264
265		let public_key = &public_set.keys[0];
266		assert_eq!(public_key.kid.as_deref(), Some("1"));
267		assert!(public_key.operations.contains(&KeyOperation::Verify));
268		assert!(!public_key.operations.contains(&KeyOperation::Sign));
269	}
270
271	#[test]
272	fn test_to_public_set_fails_for_symmetric() {
273		let key = Key::generate(Algorithm::HS256, Some(crate::KeyId::decode("sym").unwrap())).unwrap();
274		let set = KeySet {
275			keys: vec![Arc::new(key)],
276		};
277
278		let result = set.to_public_set();
279		assert!(result.is_err());
280	}
281
282	#[test]
283	fn test_encode_success() {
284		let key = create_test_key(Some("1"));
285		let set = KeySet {
286			keys: vec![Arc::new(key)],
287		};
288		let claims = create_test_claims();
289
290		let token = set.encode(&claims).unwrap();
291		assert!(!token.is_empty());
292	}
293
294	#[test]
295	fn test_encode_no_signing_key() {
296		let mut key = create_test_key(Some("1"));
297		key.operations = [KeyOperation::Verify].into();
298		let set = KeySet {
299			keys: vec![Arc::new(key)],
300		};
301		let claims = create_test_claims();
302
303		let result = set.encode(&claims);
304		assert!(result.is_err());
305		assert!(result.unwrap_err().to_string().contains("cannot find signing key"));
306	}
307
308	#[test]
309	fn test_decode_success_with_kid() {
310		let key = create_test_key(Some("1"));
311		let set = KeySet {
312			keys: vec![Arc::new(key)],
313		};
314		let claims = create_test_claims();
315
316		let token = set.encode(&claims).unwrap();
317		let decoded = set.decode(&token).unwrap();
318
319		assert_eq!(decoded.root, claims.root);
320	}
321
322	#[test]
323	fn test_decode_success_single_key_no_kid() {
324		// Create a key without KID
325		let key = create_test_key(None);
326		let claims = create_test_claims();
327
328		// Encode using the key directly
329		let token = key.encode(&claims).unwrap();
330
331		let set = KeySet {
332			keys: vec![Arc::new(key)],
333		};
334
335		// Decode using the set
336		let decoded = set.decode(&token).unwrap();
337		assert_eq!(decoded.root, claims.root);
338	}
339
340	#[test]
341	fn test_decode_fail_multiple_keys_no_kid() {
342		let key1 = create_test_key(None);
343		let key2 = create_test_key(None);
344
345		let set = KeySet {
346			keys: vec![Arc::new(key1), Arc::new(key2)],
347		};
348
349		let claims = create_test_claims();
350		// Encode with one of the keys directly
351		let token = set.keys[0].encode(&claims).unwrap();
352
353		let result = set.decode(&token);
354		assert!(result.is_err());
355		assert!(result.unwrap_err().to_string().contains("missing kid"));
356	}
357
358	#[test]
359	fn test_decode_fail_unknown_kid() {
360		let key1 = create_test_key(Some("1"));
361		let key2 = create_test_key(Some("2"));
362
363		let set1 = KeySet {
364			keys: vec![Arc::new(key1)],
365		};
366		let set2 = KeySet {
367			keys: vec![Arc::new(key2)],
368		};
369
370		let claims = create_test_claims();
371		let token = set1.encode(&claims).unwrap();
372
373		let result = set2.decode(&token);
374		assert!(result.is_err());
375		assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1"));
376	}
377
378	#[test]
379	fn test_file_io() {
380		let key = create_test_key(Some("1"));
381		let set = KeySet {
382			keys: vec![Arc::new(key)],
383		};
384
385		let dir = std::env::temp_dir();
386		// Use a random-ish name to avoid collisions
387		let filename = format!(
388			"test_keyset_{}.json",
389			SystemTime::now()
390				.duration_since(SystemTime::UNIX_EPOCH)
391				.unwrap()
392				.as_nanos()
393		);
394		let path = dir.join(filename);
395
396		set.to_file(&path).expect("failed to write to file");
397
398		let loaded = KeySet::from_file(&path).expect("failed to read from file");
399		assert_eq!(loaded.keys.len(), 1);
400		assert_eq!(loaded.keys[0].kid.as_deref(), Some("1"));
401
402		// Clean up
403		let _ = std::fs::remove_file(path);
404	}
405}