1use crate::{Claims, Key, KeyOperation};
2use anyhow::Context;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Default, Clone)]
10pub struct KeySet {
11 pub keys: Vec<Arc<Key>>,
13}
14
15impl Serialize for KeySet {
16 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
17 where
18 S: Serializer,
19 {
20 use serde::ser::SerializeStruct;
22
23 let mut state = serializer.serialize_struct("KeySet", 1)?;
24 state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>())?;
25 state.end()
26 }
27}
28
29impl<'de> Deserialize<'de> for KeySet {
30 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
31 where
32 D: Deserializer<'de>,
33 {
34 #[derive(Deserialize)]
36 struct RawKeySet {
37 keys: Vec<Key>,
38 }
39
40 let raw = RawKeySet::deserialize(deserializer)?;
41 Ok(KeySet {
42 keys: raw.keys.into_iter().map(Arc::new).collect(),
43 })
44 }
45}
46
47impl KeySet {
48 #[allow(clippy::should_implement_trait)]
49 pub fn from_str(s: &str) -> anyhow::Result<Self> {
50 Ok(serde_json::from_str(s)?)
51 }
52
53 pub fn from_file<P: AsRef<Path>>(path: P) -> anyhow::Result<Self> {
54 let json = std::fs::read_to_string(&path)?;
55 Ok(serde_json::from_str(&json)?)
56 }
57
58 pub fn to_str(&self) -> anyhow::Result<String> {
59 Ok(serde_json::to_string(&self)?)
60 }
61
62 pub fn to_file<P: AsRef<Path>>(&self, path: P) -> anyhow::Result<()> {
63 let json = serde_json::to_string(&self)?;
64 std::fs::write(path, json)?;
65 Ok(())
66 }
67
68 pub fn to_public_set(&self) -> anyhow::Result<KeySet> {
69 Ok(KeySet {
70 keys: self
71 .keys
72 .iter()
73 .map(|key| {
74 key.as_ref()
75 .to_public()
76 .map(Arc::new)
77 .map_err(|e| anyhow::anyhow!("failed to get public key from jwks: {:?}", e))
78 })
79 .collect::<Result<Vec<Arc<Key>>, _>>()?,
80 })
81 }
82
83 pub fn find_key(&self, kid: &str) -> Option<Arc<Key>> {
84 self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned()
85 }
86
87 pub fn find_supported_key(&self, operation: &KeyOperation) -> Option<Arc<Key>> {
88 self.keys.iter().find(|key| key.operations.contains(operation)).cloned()
89 }
90
91 pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
92 let key = self
93 .find_supported_key(&KeyOperation::Sign)
94 .context("cannot find signing key")?;
95 key.encode(payload)
96 }
97
98 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
99 let header = jsonwebtoken::decode_header(token).context("failed to decode JWT header")?;
100
101 let key = match header.kid {
102 Some(kid) => self
103 .find_key(kid.as_str())
104 .ok_or_else(|| anyhow::anyhow!("cannot find key with kid {kid}")),
105 None => {
106 if self.keys.len() == 1 {
108 Ok(self.keys[0].clone())
109 } else {
110 anyhow::bail!("missing kid in JWT header")
111 }
112 }
113 }?;
114
115 key.decode(token)
116 }
117}
118
119#[cfg(feature = "jwks-loader")]
120pub async fn load_keys(jwks_uri: &str) -> anyhow::Result<KeySet> {
121 let client = reqwest::Client::builder()
122 .timeout(Duration::from_secs(10))
123 .build()
124 .context("failed to build reqwest client")?;
125
126 let jwks_json = client
127 .get(jwks_uri)
128 .send()
129 .await
130 .with_context(|| format!("failed to GET JWKS from {}", jwks_uri))?
131 .error_for_status()
132 .with_context(|| format!("JWKS endpoint returned error: {}", jwks_uri))?
133 .text()
134 .await
135 .context("failed to read JWKS response body")?;
136
137 KeySet::from_str(&jwks_json).context("Failed to parse JWKS into KeySet")
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::Algorithm;
145 use std::time::{Duration, SystemTime};
146
147 fn create_test_claims() -> Claims {
148 Claims {
149 root: "test-path".to_string(),
150 publish: vec!["test-pub".into()],
151 cluster: false,
152 subscribe: vec!["test-sub".into()],
153 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
154 issued: Some(SystemTime::now()),
155 }
156 }
157
158 fn create_test_key(kid: Option<String>) -> Key {
159 Key::generate(Algorithm::ES256, kid).expect("failed to generate key")
160 }
161
162 #[test]
163 fn test_keyset_from_str_valid() {
164 let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#;
165 let set = KeySet::from_str(json);
166 assert!(set.is_ok());
167 let set = set.unwrap();
168 assert_eq!(set.keys.len(), 1);
169 assert_eq!(set.keys[0].kid.as_deref(), Some("1"));
170 assert!(set.find_key("1").is_some());
171 }
172
173 #[test]
174 fn test_keyset_from_str_invalid_json() {
175 let result = KeySet::from_str("invalid json");
176 assert!(result.is_err());
177 }
178
179 #[test]
180 fn test_keyset_from_str_empty() {
181 let json = r#"{"keys":[]}"#;
182 let set = KeySet::from_str(json).unwrap();
183 assert!(set.keys.is_empty());
184 }
185
186 #[test]
187 fn test_keyset_to_str() {
188 let key = create_test_key(Some("1".to_string()));
189 let set = KeySet {
190 keys: vec![Arc::new(key)],
191 };
192
193 let json = set.to_str().unwrap();
194 assert!(json.contains("\"keys\""));
195 assert!(json.contains("\"kid\":\"1\""));
196 }
197
198 #[test]
199 fn test_keyset_serde_round_trip() {
200 let key1 = create_test_key(Some("1".to_string()));
201 let key2 = create_test_key(Some("2".to_string()));
202 let set = KeySet {
203 keys: vec![Arc::new(key1), Arc::new(key2)],
204 };
205
206 let json = set.to_str().unwrap();
207 let deserialized = KeySet::from_str(&json).unwrap();
208
209 assert_eq!(deserialized.keys.len(), 2);
210 assert!(deserialized.find_key("1").is_some());
211 assert!(deserialized.find_key("2").is_some());
212 }
213
214 #[test]
215 fn test_find_key_success() {
216 let key = create_test_key(Some("my-key".to_string()));
217 let set = KeySet {
218 keys: vec![Arc::new(key)],
219 };
220
221 let found = set.find_key("my-key");
222 assert!(found.is_some());
223 assert_eq!(found.unwrap().kid.as_deref(), Some("my-key"));
224 }
225
226 #[test]
227 fn test_find_key_missing() {
228 let key = create_test_key(Some("my-key".to_string()));
229 let set = KeySet {
230 keys: vec![Arc::new(key)],
231 };
232
233 let found = set.find_key("other-key");
234 assert!(found.is_none());
235 }
236
237 #[test]
238 fn test_find_key_no_kid() {
239 let key = create_test_key(None);
240 let set = KeySet {
241 keys: vec![Arc::new(key)],
242 };
243
244 let found = set.find_key("any-key");
245 assert!(found.is_none());
246 }
247
248 #[test]
249 fn test_find_supported_key() {
250 let mut sign_key = create_test_key(Some("sign".to_string()));
251 sign_key.operations = [KeyOperation::Sign].into();
252
253 let mut verify_key = create_test_key(Some("verify".to_string()));
254 verify_key.operations = [KeyOperation::Verify].into();
255
256 let set = KeySet {
257 keys: vec![Arc::new(sign_key), Arc::new(verify_key)],
258 };
259
260 let found_sign = set.find_supported_key(&KeyOperation::Sign);
261 assert!(found_sign.is_some());
262 assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign"));
263
264 let found_verify = set.find_supported_key(&KeyOperation::Verify);
265 assert!(found_verify.is_some());
266 assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify"));
267 }
268
269 #[test]
270 fn test_to_public_set() {
271 let key = create_test_key(Some("1".to_string()));
273
274 let set = KeySet {
275 keys: vec![Arc::new(key)],
276 };
277
278 let public_set = set.to_public_set().expect("failed to convert to public set");
279 assert_eq!(public_set.keys.len(), 1);
280
281 let public_key = &public_set.keys[0];
282 assert_eq!(public_key.kid.as_deref(), Some("1"));
283 assert!(public_key.operations.contains(&KeyOperation::Verify));
284 assert!(!public_key.operations.contains(&KeyOperation::Sign));
285 }
286
287 #[test]
288 fn test_to_public_set_fails_for_symmetric() {
289 let key = Key::generate(Algorithm::HS256, Some("sym".to_string())).unwrap();
290 let set = KeySet {
291 keys: vec![Arc::new(key)],
292 };
293
294 let result = set.to_public_set();
295 assert!(result.is_err());
296 }
297
298 #[test]
299 fn test_encode_success() {
300 let key = create_test_key(Some("1".to_string()));
301 let set = KeySet {
302 keys: vec![Arc::new(key)],
303 };
304 let claims = create_test_claims();
305
306 let token = set.encode(&claims).unwrap();
307 assert!(!token.is_empty());
308 }
309
310 #[test]
311 fn test_encode_no_signing_key() {
312 let mut key = create_test_key(Some("1".to_string()));
313 key.operations = [KeyOperation::Verify].into();
314 let set = KeySet {
315 keys: vec![Arc::new(key)],
316 };
317 let claims = create_test_claims();
318
319 let result = set.encode(&claims);
320 assert!(result.is_err());
321 assert!(result.unwrap_err().to_string().contains("cannot find signing key"));
322 }
323
324 #[test]
325 fn test_decode_success_with_kid() {
326 let key = create_test_key(Some("1".to_string()));
327 let set = KeySet {
328 keys: vec![Arc::new(key)],
329 };
330 let claims = create_test_claims();
331
332 let token = set.encode(&claims).unwrap();
333 let decoded = set.decode(&token).unwrap();
334
335 assert_eq!(decoded.root, claims.root);
336 }
337
338 #[test]
339 fn test_decode_success_single_key_no_kid() {
340 let key = create_test_key(None);
342 let claims = create_test_claims();
343
344 let token = key.encode(&claims).unwrap();
346
347 let set = KeySet {
348 keys: vec![Arc::new(key)],
349 };
350
351 let decoded = set.decode(&token).unwrap();
353 assert_eq!(decoded.root, claims.root);
354 }
355
356 #[test]
357 fn test_decode_fail_multiple_keys_no_kid() {
358 let key1 = create_test_key(None);
359 let key2 = create_test_key(None);
360
361 let set = KeySet {
362 keys: vec![Arc::new(key1), Arc::new(key2)],
363 };
364
365 let claims = create_test_claims();
366 let token = set.keys[0].encode(&claims).unwrap();
368
369 let result = set.decode(&token);
370 assert!(result.is_err());
371 assert!(result.unwrap_err().to_string().contains("missing kid"));
372 }
373
374 #[test]
375 fn test_decode_fail_unknown_kid() {
376 let key1 = create_test_key(Some("1".to_string()));
377 let key2 = create_test_key(Some("2".to_string()));
378
379 let set1 = KeySet {
380 keys: vec![Arc::new(key1)],
381 };
382 let set2 = KeySet {
383 keys: vec![Arc::new(key2)],
384 };
385
386 let claims = create_test_claims();
387 let token = set1.encode(&claims).unwrap();
388
389 let result = set2.decode(&token);
390 assert!(result.is_err());
391 assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1"));
392 }
393
394 #[test]
395 fn test_file_io() {
396 let key = create_test_key(Some("1".to_string()));
397 let set = KeySet {
398 keys: vec![Arc::new(key)],
399 };
400
401 let dir = std::env::temp_dir();
402 let filename = format!(
404 "test_keyset_{}.json",
405 SystemTime::now()
406 .duration_since(SystemTime::UNIX_EPOCH)
407 .unwrap()
408 .as_nanos()
409 );
410 let path = dir.join(filename);
411
412 set.to_file(&path).expect("failed to write to file");
413
414 let loaded = KeySet::from_file(&path).expect("failed to read from file");
415 assert_eq!(loaded.keys.len(), 1);
416 assert_eq!(loaded.keys[0].kid.as_deref(), Some("1"));
417
418 let _ = std::fs::remove_file(path);
420 }
421}