1use crate::error::KeyError;
2use crate::{Claims, Key, KeyOperation};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::path::Path;
5use std::sync::Arc;
6
7#[cfg(feature = "jwks-loader")]
8use std::time::Duration;
9
10#[derive(Default, Clone)]
12pub struct KeySet {
13 pub keys: Vec<Arc<Key>>,
15}
16
17impl Serialize for KeySet {
18 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19 where
20 S: Serializer,
21 {
22 use serde::ser::SerializeStruct;
24
25 let mut state = serializer.serialize_struct("KeySet", 1)?;
26 state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>())?;
27 state.end()
28 }
29}
30
31impl<'de> Deserialize<'de> for KeySet {
32 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33 where
34 D: Deserializer<'de>,
35 {
36 #[derive(Deserialize)]
38 struct RawKeySet {
39 keys: Vec<Key>,
40 }
41
42 let raw = RawKeySet::deserialize(deserializer)?;
43 Ok(KeySet {
44 keys: raw.keys.into_iter().map(Arc::new).collect(),
45 })
46 }
47}
48
49impl KeySet {
50 #[allow(clippy::should_implement_trait)]
51 pub fn from_str(s: &str) -> crate::Result<Self> {
52 Ok(serde_json::from_str(s)?)
53 }
54
55 pub fn from_file<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
56 let json = std::fs::read_to_string(&path)?;
57 Ok(serde_json::from_str(&json)?)
58 }
59
60 pub fn to_str(&self) -> crate::Result<String> {
61 Ok(serde_json::to_string(&self)?)
62 }
63
64 pub fn to_file<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
65 let json = serde_json::to_string(&self)?;
66 std::fs::write(path, json)?;
67 Ok(())
68 }
69
70 pub fn to_public_set(&self) -> crate::Result<KeySet> {
71 Ok(KeySet {
72 keys: self
73 .keys
74 .iter()
75 .map(|key| key.as_ref().to_public().map(Arc::new))
76 .collect::<Result<Vec<Arc<Key>>, _>>()?,
77 })
78 }
79
80 pub fn find_key(&self, kid: &str) -> Option<Arc<Key>> {
81 self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned()
82 }
83
84 pub fn find_supported_key(&self, operation: &KeyOperation) -> Option<Arc<Key>> {
85 self.keys.iter().find(|key| key.operations.contains(operation)).cloned()
86 }
87
88 pub fn encode(&self, payload: &Claims) -> crate::Result<String> {
89 let key = self
90 .find_supported_key(&KeyOperation::Sign)
91 .ok_or(KeyError::NoSigningKey)?;
92 key.encode(payload)
93 }
94
95 pub fn decode(&self, token: &str) -> crate::Result<Claims> {
96 let header = jsonwebtoken::decode_header(token)?;
97
98 let key = match header.kid {
99 Some(kid) => self
100 .find_key(kid.as_str())
101 .ok_or_else(|| crate::Error::from(KeyError::KeyNotFound(kid))),
102 None => {
103 if self.keys.len() == 1 {
105 Ok(self.keys[0].clone())
106 } else {
107 Err(KeyError::MissingKid.into())
108 }
109 }
110 }?;
111
112 key.decode(token)
113 }
114}
115
116#[cfg(feature = "jwks-loader")]
117pub async fn load_keys(jwks_uri: &str) -> crate::Result<KeySet> {
118 let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()?;
119
120 let jwks_json = client.get(jwks_uri).send().await?.error_for_status()?.text().await?;
121
122 KeySet::from_str(&jwks_json)
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::Algorithm;
129 use std::time::{Duration, SystemTime};
130
131 fn create_test_claims() -> Claims {
132 Claims {
133 root: "test-path".to_string(),
134 publish: vec!["test-pub".into()],
135 subscribe: vec!["test-sub".into()],
136 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
137 issued: Some(SystemTime::now()),
138 }
139 }
140
141 fn create_test_key(kid: Option<&str>) -> Key {
142 let kid = kid.map(|s| crate::KeyId::decode(s).unwrap());
143 Key::generate(Algorithm::ES256, kid).expect("failed to generate key")
144 }
145
146 #[test]
147 fn test_keyset_from_str_valid() {
148 let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#;
149 let set = KeySet::from_str(json);
150 assert!(set.is_ok());
151 let set = set.unwrap();
152 assert_eq!(set.keys.len(), 1);
153 assert_eq!(set.keys[0].kid.as_deref(), Some("1"));
154 assert!(set.find_key("1").is_some());
155 }
156
157 #[test]
158 fn test_keyset_from_str_invalid_json() {
159 let result = KeySet::from_str("invalid json");
160 assert!(result.is_err());
161 }
162
163 #[test]
164 fn test_keyset_from_str_empty() {
165 let json = r#"{"keys":[]}"#;
166 let set = KeySet::from_str(json).unwrap();
167 assert!(set.keys.is_empty());
168 }
169
170 #[test]
171 fn test_keyset_to_str() {
172 let key = create_test_key(Some("1"));
173 let set = KeySet {
174 keys: vec![Arc::new(key)],
175 };
176
177 let json = set.to_str().unwrap();
178 assert!(json.contains("\"keys\""));
179 assert!(json.contains("\"kid\":\"1\""));
180 }
181
182 #[test]
183 fn test_keyset_serde_round_trip() {
184 let key1 = create_test_key(Some("1"));
185 let key2 = create_test_key(Some("2"));
186 let set = KeySet {
187 keys: vec![Arc::new(key1), Arc::new(key2)],
188 };
189
190 let json = set.to_str().unwrap();
191 let deserialized = KeySet::from_str(&json).unwrap();
192
193 assert_eq!(deserialized.keys.len(), 2);
194 assert!(deserialized.find_key("1").is_some());
195 assert!(deserialized.find_key("2").is_some());
196 }
197
198 #[test]
199 fn test_find_key_success() {
200 let key = create_test_key(Some("my-key"));
201 let set = KeySet {
202 keys: vec![Arc::new(key)],
203 };
204
205 let found = set.find_key("my-key");
206 assert!(found.is_some());
207 assert_eq!(found.unwrap().kid.as_deref(), Some("my-key"));
208 }
209
210 #[test]
211 fn test_find_key_missing() {
212 let key = create_test_key(Some("my-key"));
213 let set = KeySet {
214 keys: vec![Arc::new(key)],
215 };
216
217 let found = set.find_key("other-key");
218 assert!(found.is_none());
219 }
220
221 #[test]
222 fn test_find_key_no_kid() {
223 let key = create_test_key(None);
224 let set = KeySet {
225 keys: vec![Arc::new(key)],
226 };
227
228 let found = set.find_key("any-key");
229 assert!(found.is_none());
230 }
231
232 #[test]
233 fn test_find_supported_key() {
234 let mut sign_key = create_test_key(Some("sign"));
235 sign_key.operations = [KeyOperation::Sign].into();
236
237 let mut verify_key = create_test_key(Some("verify"));
238 verify_key.operations = [KeyOperation::Verify].into();
239
240 let set = KeySet {
241 keys: vec![Arc::new(sign_key), Arc::new(verify_key)],
242 };
243
244 let found_sign = set.find_supported_key(&KeyOperation::Sign);
245 assert!(found_sign.is_some());
246 assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign"));
247
248 let found_verify = set.find_supported_key(&KeyOperation::Verify);
249 assert!(found_verify.is_some());
250 assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify"));
251 }
252
253 #[test]
254 fn test_to_public_set() {
255 let key = create_test_key(Some("1"));
257
258 let set = KeySet {
259 keys: vec![Arc::new(key)],
260 };
261
262 let public_set = set.to_public_set().expect("failed to convert to public set");
263 assert_eq!(public_set.keys.len(), 1);
264
265 let public_key = &public_set.keys[0];
266 assert_eq!(public_key.kid.as_deref(), Some("1"));
267 assert!(public_key.operations.contains(&KeyOperation::Verify));
268 assert!(!public_key.operations.contains(&KeyOperation::Sign));
269 }
270
271 #[test]
272 fn test_to_public_set_fails_for_symmetric() {
273 let key = Key::generate(Algorithm::HS256, Some(crate::KeyId::decode("sym").unwrap())).unwrap();
274 let set = KeySet {
275 keys: vec![Arc::new(key)],
276 };
277
278 let result = set.to_public_set();
279 assert!(result.is_err());
280 }
281
282 #[test]
283 fn test_encode_success() {
284 let key = create_test_key(Some("1"));
285 let set = KeySet {
286 keys: vec![Arc::new(key)],
287 };
288 let claims = create_test_claims();
289
290 let token = set.encode(&claims).unwrap();
291 assert!(!token.is_empty());
292 }
293
294 #[test]
295 fn test_encode_no_signing_key() {
296 let mut key = create_test_key(Some("1"));
297 key.operations = [KeyOperation::Verify].into();
298 let set = KeySet {
299 keys: vec![Arc::new(key)],
300 };
301 let claims = create_test_claims();
302
303 let result = set.encode(&claims);
304 assert!(result.is_err());
305 assert!(result.unwrap_err().to_string().contains("cannot find signing key"));
306 }
307
308 #[test]
309 fn test_decode_success_with_kid() {
310 let key = create_test_key(Some("1"));
311 let set = KeySet {
312 keys: vec![Arc::new(key)],
313 };
314 let claims = create_test_claims();
315
316 let token = set.encode(&claims).unwrap();
317 let decoded = set.decode(&token).unwrap();
318
319 assert_eq!(decoded.root, claims.root);
320 }
321
322 #[test]
323 fn test_decode_success_single_key_no_kid() {
324 let key = create_test_key(None);
326 let claims = create_test_claims();
327
328 let token = key.encode(&claims).unwrap();
330
331 let set = KeySet {
332 keys: vec![Arc::new(key)],
333 };
334
335 let decoded = set.decode(&token).unwrap();
337 assert_eq!(decoded.root, claims.root);
338 }
339
340 #[test]
341 fn test_decode_fail_multiple_keys_no_kid() {
342 let key1 = create_test_key(None);
343 let key2 = create_test_key(None);
344
345 let set = KeySet {
346 keys: vec![Arc::new(key1), Arc::new(key2)],
347 };
348
349 let claims = create_test_claims();
350 let token = set.keys[0].encode(&claims).unwrap();
352
353 let result = set.decode(&token);
354 assert!(result.is_err());
355 assert!(result.unwrap_err().to_string().contains("missing kid"));
356 }
357
358 #[test]
359 fn test_decode_fail_unknown_kid() {
360 let key1 = create_test_key(Some("1"));
361 let key2 = create_test_key(Some("2"));
362
363 let set1 = KeySet {
364 keys: vec![Arc::new(key1)],
365 };
366 let set2 = KeySet {
367 keys: vec![Arc::new(key2)],
368 };
369
370 let claims = create_test_claims();
371 let token = set1.encode(&claims).unwrap();
372
373 let result = set2.decode(&token);
374 assert!(result.is_err());
375 assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1"));
376 }
377
378 #[test]
379 fn test_file_io() {
380 let key = create_test_key(Some("1"));
381 let set = KeySet {
382 keys: vec![Arc::new(key)],
383 };
384
385 let dir = std::env::temp_dir();
386 let filename = format!(
388 "test_keyset_{}.json",
389 SystemTime::now()
390 .duration_since(SystemTime::UNIX_EPOCH)
391 .unwrap()
392 .as_nanos()
393 );
394 let path = dir.join(filename);
395
396 set.to_file(&path).expect("failed to write to file");
397
398 let loaded = KeySet::from_file(&path).expect("failed to read from file");
399 assert_eq!(loaded.keys.len(), 1);
400 assert_eq!(loaded.keys[0].kid.as_deref(), Some("1"));
401
402 let _ = std::fs::remove_file(path);
404 }
405}