1use crate::error::KeyError;
2use crate::{Claims, Key, KeyOperation};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::path::Path;
5use std::sync::Arc;
6
7#[cfg(feature = "jwks-loader")]
8use std::time::Duration;
9
10#[derive(Default, Clone)]
12pub struct KeySet {
13 pub keys: Vec<Arc<Key>>,
15}
16
17impl Serialize for KeySet {
18 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19 where
20 S: Serializer,
21 {
22 use serde::ser::SerializeStruct;
24
25 let mut state = serializer.serialize_struct("KeySet", 1)?;
26 state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>())?;
27 state.end()
28 }
29}
30
31impl<'de> Deserialize<'de> for KeySet {
32 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33 where
34 D: Deserializer<'de>,
35 {
36 #[derive(Deserialize)]
38 struct RawKeySet {
39 keys: Vec<Key>,
40 }
41
42 let raw = RawKeySet::deserialize(deserializer)?;
43 Ok(KeySet {
44 keys: raw.keys.into_iter().map(Arc::new).collect(),
45 })
46 }
47}
48
49impl KeySet {
50 #[allow(clippy::should_implement_trait)]
51 pub fn from_str(s: &str) -> crate::Result<Self> {
52 Ok(serde_json::from_str(s)?)
53 }
54
55 pub fn from_file<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
56 let json = std::fs::read_to_string(&path)?;
57 Ok(serde_json::from_str(&json)?)
58 }
59
60 pub fn to_str(&self) -> crate::Result<String> {
61 Ok(serde_json::to_string(&self)?)
62 }
63
64 pub fn to_file<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
65 let json = serde_json::to_string(&self)?;
66 std::fs::write(path, json)?;
67 Ok(())
68 }
69
70 pub fn to_public_set(&self) -> crate::Result<KeySet> {
71 Ok(KeySet {
72 keys: self
73 .keys
74 .iter()
75 .map(|key| key.as_ref().to_public().map(Arc::new))
76 .collect::<Result<Vec<Arc<Key>>, _>>()?,
77 })
78 }
79
80 pub fn find_key(&self, kid: &str) -> Option<Arc<Key>> {
81 self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned()
82 }
83
84 pub fn find_supported_key(&self, operation: &KeyOperation) -> Option<Arc<Key>> {
85 self.keys.iter().find(|key| key.operations.contains(operation)).cloned()
86 }
87
88 pub fn encode(&self, payload: &Claims) -> crate::Result<String> {
89 let key = self
90 .find_supported_key(&KeyOperation::Sign)
91 .ok_or(KeyError::NoSigningKey)?;
92 key.encode(payload)
93 }
94
95 pub fn decode(&self, token: &str) -> crate::Result<Claims> {
96 let header = jsonwebtoken::decode_header(token)?;
97
98 let key = match header.kid {
99 Some(kid) => self
100 .find_key(kid.as_str())
101 .ok_or_else(|| crate::Error::from(KeyError::KeyNotFound(kid))),
102 None => {
103 if self.keys.len() == 1 {
105 Ok(self.keys[0].clone())
106 } else {
107 Err(KeyError::MissingKid.into())
108 }
109 }
110 }?;
111
112 key.decode(token)
113 }
114}
115
116#[cfg(feature = "jwks-loader")]
117pub async fn load_keys(jwks_uri: &str) -> crate::Result<KeySet> {
118 let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()?;
119
120 let jwks_json = client.get(jwks_uri).send().await?.error_for_status()?.text().await?;
121
122 KeySet::from_str(&jwks_json)
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::Algorithm;
129 use std::time::{Duration, SystemTime};
130
131 fn create_test_claims() -> Claims {
132 Claims {
133 root: "test-path".to_string(),
134 publish: vec!["test-pub".into()],
135 cluster: false,
136 subscribe: vec!["test-sub".into()],
137 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
138 issued: Some(SystemTime::now()),
139 }
140 }
141
142 fn create_test_key(kid: Option<&str>) -> Key {
143 let kid = kid.map(|s| crate::KeyId::decode(s).unwrap());
144 Key::generate(Algorithm::ES256, kid).expect("failed to generate key")
145 }
146
147 #[test]
148 fn test_keyset_from_str_valid() {
149 let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#;
150 let set = KeySet::from_str(json);
151 assert!(set.is_ok());
152 let set = set.unwrap();
153 assert_eq!(set.keys.len(), 1);
154 assert_eq!(set.keys[0].kid.as_deref(), Some("1"));
155 assert!(set.find_key("1").is_some());
156 }
157
158 #[test]
159 fn test_keyset_from_str_invalid_json() {
160 let result = KeySet::from_str("invalid json");
161 assert!(result.is_err());
162 }
163
164 #[test]
165 fn test_keyset_from_str_empty() {
166 let json = r#"{"keys":[]}"#;
167 let set = KeySet::from_str(json).unwrap();
168 assert!(set.keys.is_empty());
169 }
170
171 #[test]
172 fn test_keyset_to_str() {
173 let key = create_test_key(Some("1"));
174 let set = KeySet {
175 keys: vec![Arc::new(key)],
176 };
177
178 let json = set.to_str().unwrap();
179 assert!(json.contains("\"keys\""));
180 assert!(json.contains("\"kid\":\"1\""));
181 }
182
183 #[test]
184 fn test_keyset_serde_round_trip() {
185 let key1 = create_test_key(Some("1"));
186 let key2 = create_test_key(Some("2"));
187 let set = KeySet {
188 keys: vec![Arc::new(key1), Arc::new(key2)],
189 };
190
191 let json = set.to_str().unwrap();
192 let deserialized = KeySet::from_str(&json).unwrap();
193
194 assert_eq!(deserialized.keys.len(), 2);
195 assert!(deserialized.find_key("1").is_some());
196 assert!(deserialized.find_key("2").is_some());
197 }
198
199 #[test]
200 fn test_find_key_success() {
201 let key = create_test_key(Some("my-key"));
202 let set = KeySet {
203 keys: vec![Arc::new(key)],
204 };
205
206 let found = set.find_key("my-key");
207 assert!(found.is_some());
208 assert_eq!(found.unwrap().kid.as_deref(), Some("my-key"));
209 }
210
211 #[test]
212 fn test_find_key_missing() {
213 let key = create_test_key(Some("my-key"));
214 let set = KeySet {
215 keys: vec![Arc::new(key)],
216 };
217
218 let found = set.find_key("other-key");
219 assert!(found.is_none());
220 }
221
222 #[test]
223 fn test_find_key_no_kid() {
224 let key = create_test_key(None);
225 let set = KeySet {
226 keys: vec![Arc::new(key)],
227 };
228
229 let found = set.find_key("any-key");
230 assert!(found.is_none());
231 }
232
233 #[test]
234 fn test_find_supported_key() {
235 let mut sign_key = create_test_key(Some("sign"));
236 sign_key.operations = [KeyOperation::Sign].into();
237
238 let mut verify_key = create_test_key(Some("verify"));
239 verify_key.operations = [KeyOperation::Verify].into();
240
241 let set = KeySet {
242 keys: vec![Arc::new(sign_key), Arc::new(verify_key)],
243 };
244
245 let found_sign = set.find_supported_key(&KeyOperation::Sign);
246 assert!(found_sign.is_some());
247 assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign"));
248
249 let found_verify = set.find_supported_key(&KeyOperation::Verify);
250 assert!(found_verify.is_some());
251 assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify"));
252 }
253
254 #[test]
255 fn test_to_public_set() {
256 let key = create_test_key(Some("1"));
258
259 let set = KeySet {
260 keys: vec![Arc::new(key)],
261 };
262
263 let public_set = set.to_public_set().expect("failed to convert to public set");
264 assert_eq!(public_set.keys.len(), 1);
265
266 let public_key = &public_set.keys[0];
267 assert_eq!(public_key.kid.as_deref(), Some("1"));
268 assert!(public_key.operations.contains(&KeyOperation::Verify));
269 assert!(!public_key.operations.contains(&KeyOperation::Sign));
270 }
271
272 #[test]
273 fn test_to_public_set_fails_for_symmetric() {
274 let key = Key::generate(Algorithm::HS256, Some(crate::KeyId::decode("sym").unwrap())).unwrap();
275 let set = KeySet {
276 keys: vec![Arc::new(key)],
277 };
278
279 let result = set.to_public_set();
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_encode_success() {
285 let key = create_test_key(Some("1"));
286 let set = KeySet {
287 keys: vec![Arc::new(key)],
288 };
289 let claims = create_test_claims();
290
291 let token = set.encode(&claims).unwrap();
292 assert!(!token.is_empty());
293 }
294
295 #[test]
296 fn test_encode_no_signing_key() {
297 let mut key = create_test_key(Some("1"));
298 key.operations = [KeyOperation::Verify].into();
299 let set = KeySet {
300 keys: vec![Arc::new(key)],
301 };
302 let claims = create_test_claims();
303
304 let result = set.encode(&claims);
305 assert!(result.is_err());
306 assert!(result.unwrap_err().to_string().contains("cannot find signing key"));
307 }
308
309 #[test]
310 fn test_decode_success_with_kid() {
311 let key = create_test_key(Some("1"));
312 let set = KeySet {
313 keys: vec![Arc::new(key)],
314 };
315 let claims = create_test_claims();
316
317 let token = set.encode(&claims).unwrap();
318 let decoded = set.decode(&token).unwrap();
319
320 assert_eq!(decoded.root, claims.root);
321 }
322
323 #[test]
324 fn test_decode_success_single_key_no_kid() {
325 let key = create_test_key(None);
327 let claims = create_test_claims();
328
329 let token = key.encode(&claims).unwrap();
331
332 let set = KeySet {
333 keys: vec![Arc::new(key)],
334 };
335
336 let decoded = set.decode(&token).unwrap();
338 assert_eq!(decoded.root, claims.root);
339 }
340
341 #[test]
342 fn test_decode_fail_multiple_keys_no_kid() {
343 let key1 = create_test_key(None);
344 let key2 = create_test_key(None);
345
346 let set = KeySet {
347 keys: vec![Arc::new(key1), Arc::new(key2)],
348 };
349
350 let claims = create_test_claims();
351 let token = set.keys[0].encode(&claims).unwrap();
353
354 let result = set.decode(&token);
355 assert!(result.is_err());
356 assert!(result.unwrap_err().to_string().contains("missing kid"));
357 }
358
359 #[test]
360 fn test_decode_fail_unknown_kid() {
361 let key1 = create_test_key(Some("1"));
362 let key2 = create_test_key(Some("2"));
363
364 let set1 = KeySet {
365 keys: vec![Arc::new(key1)],
366 };
367 let set2 = KeySet {
368 keys: vec![Arc::new(key2)],
369 };
370
371 let claims = create_test_claims();
372 let token = set1.encode(&claims).unwrap();
373
374 let result = set2.decode(&token);
375 assert!(result.is_err());
376 assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1"));
377 }
378
379 #[test]
380 fn test_file_io() {
381 let key = create_test_key(Some("1"));
382 let set = KeySet {
383 keys: vec![Arc::new(key)],
384 };
385
386 let dir = std::env::temp_dir();
387 let filename = format!(
389 "test_keyset_{}.json",
390 SystemTime::now()
391 .duration_since(SystemTime::UNIX_EPOCH)
392 .unwrap()
393 .as_nanos()
394 );
395 let path = dir.join(filename);
396
397 set.to_file(&path).expect("failed to write to file");
398
399 let loaded = KeySet::from_file(&path).expect("failed to read from file");
400 assert_eq!(loaded.keys.len(), 1);
401 assert_eq!(loaded.keys[0].kid.as_deref(), Some("1"));
402
403 let _ = std::fs::remove_file(path);
405 }
406}