1use std::{collections::HashSet, fmt, path::Path as StdPath, sync::OnceLock};
2
3use base64::Engine;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7use crate::{Algorithm, Claims};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
10#[serde(rename_all = "camelCase")]
11pub enum KeyOperation {
12 Sign,
13 Verify,
14 Decrypt,
15 Encrypt,
16}
17
18#[derive(Clone, Serialize, Deserialize)]
20pub struct Key {
21 #[serde(rename = "alg")]
23 pub algorithm: Algorithm,
24
25 #[serde(rename = "key_ops")]
27 pub operations: HashSet<KeyOperation>,
28
29 #[serde(
31 rename = "k",
32 serialize_with = "serialize_base64url",
33 deserialize_with = "deserialize_base64url"
34 )]
35 pub secret: Vec<u8>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub kid: Option<String>,
40
41 #[serde(skip)]
43 pub(crate) decode: OnceLock<DecodingKey>,
44
45 #[serde(skip)]
46 pub(crate) encode: OnceLock<EncodingKey>,
47}
48
49impl fmt::Debug for Key {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("Key")
52 .field("algorithm", &self.algorithm)
53 .field("operations", &self.operations)
54 .field("kid", &self.kid)
55 .finish()
56 }
57}
58
59impl Key {
60 #[allow(clippy::should_implement_trait)]
61 pub fn from_str(s: &str) -> anyhow::Result<Self> {
62 Ok(serde_json::from_str(s)?)
63 }
64
65 pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
66 let contents = std::fs::read_to_string(&path)?;
67 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
69 let json = String::from_utf8(decoded)?;
70 Ok(serde_json::from_str(&json)?)
71 }
72
73 pub fn to_str(&self) -> anyhow::Result<String> {
74 Ok(serde_json::to_string(self)?)
75 }
76
77 pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
78 let json = serde_json::to_string(self)?;
80 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
82 std::fs::write(path, encoded)?;
83 Ok(())
84 }
85
86 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
87 if !self.operations.contains(&KeyOperation::Verify) {
88 anyhow::bail!("key does not support verification");
89 }
90
91 let decode = self.decode.get_or_init(|| match self.algorithm {
92 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(&self.secret),
93 });
100
101 let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
102 validation.required_spec_claims = Default::default(); let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
105 token.claims.validate()?;
106
107 Ok(token.claims)
108 }
109
110 pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
111 if !self.operations.contains(&KeyOperation::Sign) {
112 anyhow::bail!("key does not support signing");
113 }
114
115 payload.validate()?;
116
117 let encode = self.encode.get_or_init(|| match self.algorithm {
118 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(&self.secret),
119 });
126
127 let mut header = Header::new(self.algorithm.into());
128 header.kid = self.kid.clone();
129 let token = jsonwebtoken::encode(&header, &payload, encode)?;
130 Ok(token)
131 }
132
133 pub fn generate(algorithm: Algorithm, id: Option<String>) -> Self {
135 let private_key = match algorithm {
136 Algorithm::HS256 => generate_hmac_key::<32>(),
137 Algorithm::HS384 => generate_hmac_key::<48>(),
138 Algorithm::HS512 => generate_hmac_key::<64>(),
139 };
151
152 Key {
153 kid: id.clone(),
154 operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
155 algorithm,
156 secret: private_key,
157 decode: Default::default(),
158 encode: Default::default(),
159 }
160
161 }
174}
175
176fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
178where
179 S: Serializer,
180{
181 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
182 serializer.serialize_str(&encoded)
183}
184
185fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
187where
188 D: Deserializer<'de>,
189{
190 let s = String::deserialize(deserializer)?;
191
192 base64::engine::general_purpose::URL_SAFE_NO_PAD
194 .decode(&s)
195 .or_else(|_| {
196 base64::engine::general_purpose::URL_SAFE.decode(&s)
198 })
199 .map_err(serde::de::Error::custom)
200}
201
202fn generate_hmac_key<const SIZE: usize>() -> Vec<u8> {
203 let mut key = [0u8; SIZE];
204 aws_lc_rs::rand::fill(&mut key).unwrap();
205 key.to_vec()
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::time::{Duration, SystemTime};
212
213 fn create_test_key() -> Key {
214 Key {
215 algorithm: Algorithm::HS256,
216 operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
217 secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
218 kid: Some("test-key-1".to_string()),
219 decode: Default::default(),
220 encode: Default::default(),
221 }
222 }
223
224 fn create_test_claims() -> Claims {
225 Claims {
226 root: "test-path".to_string(),
227 publish: vec!["test-pub".into()],
228 cluster: false,
229 subscribe: vec!["test-sub".into()],
230 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
231 issued: Some(SystemTime::now()),
232 }
233 }
234
235 #[test]
236 fn test_key_from_str_valid() {
237 let key = create_test_key();
238 let json = key.to_str().unwrap();
239 let loaded_key = Key::from_str(&json).unwrap();
240
241 assert_eq!(loaded_key.algorithm, key.algorithm);
242 assert_eq!(loaded_key.operations, key.operations);
243 assert_eq!(loaded_key.secret, key.secret);
244 assert_eq!(loaded_key.kid, key.kid);
245 }
246
247 #[test]
248 fn test_key_from_str_invalid_json() {
249 let result = Key::from_str("invalid json");
250 assert!(result.is_err());
251 }
252
253 #[test]
254 fn test_key_to_str() {
255 let key = create_test_key();
256 let json = key.to_str().unwrap();
257 assert!(json.contains("\"alg\":\"HS256\""));
258 assert!(json.contains("\"key_ops\""));
259 assert!(json.contains("\"sign\""));
260 assert!(json.contains("\"verify\""));
261 assert!(json.contains("\"kid\":\"test-key-1\""));
262 }
263
264 #[test]
265 fn test_key_sign_success() {
266 let key = create_test_key();
267 let claims = create_test_claims();
268 let token = key.encode(&claims).unwrap();
269
270 assert!(!token.is_empty());
271 assert_eq!(token.matches('.').count(), 2); }
273
274 #[test]
275 fn test_key_sign_no_permission() {
276 let mut key = create_test_key();
277 key.operations = [KeyOperation::Verify].into();
278 let claims = create_test_claims();
279
280 let result = key.encode(&claims);
281 assert!(result.is_err());
282 assert!(result.unwrap_err().to_string().contains("key does not support signing"));
283 }
284
285 #[test]
286 fn test_key_sign_invalid_claims() {
287 let key = create_test_key();
288 let invalid_claims = Claims {
289 root: "test-path".to_string(),
290 publish: vec![],
291 subscribe: vec![],
292 cluster: false,
293 expires: None,
294 issued: None,
295 };
296
297 let result = key.encode(&invalid_claims);
298 assert!(result.is_err());
299 assert!(result
300 .unwrap_err()
301 .to_string()
302 .contains("no publish or subscribe allowed; token is useless"));
303 }
304
305 #[test]
306 fn test_key_verify_success() {
307 let key = create_test_key();
308 let claims = create_test_claims();
309 let token = key.encode(&claims).unwrap();
310
311 let verified_claims = key.decode(&token).unwrap();
312 assert_eq!(verified_claims.root, claims.root);
313 assert_eq!(verified_claims.publish, claims.publish);
314 assert_eq!(verified_claims.subscribe, claims.subscribe);
315 assert_eq!(verified_claims.cluster, claims.cluster);
316 }
317
318 #[test]
319 fn test_key_verify_no_permission() {
320 let mut key = create_test_key();
321 key.operations = [KeyOperation::Sign].into();
322
323 let result = key.decode("some.jwt.token");
324 assert!(result.is_err());
325 assert!(result
326 .unwrap_err()
327 .to_string()
328 .contains("key does not support verification"));
329 }
330
331 #[test]
332 fn test_key_verify_invalid_token() {
333 let key = create_test_key();
334 let result = key.decode("invalid-token");
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn test_key_verify_path_mismatch() {
340 let key = create_test_key();
341 let claims = create_test_claims();
342 let token = key.encode(&claims).unwrap();
343
344 let result = key.decode(&token);
346 assert!(result.is_ok());
347 }
348
349 #[test]
350 fn test_key_verify_expired_token() {
351 let key = create_test_key();
352 let mut claims = create_test_claims();
353 claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); let token = key.encode(&claims).unwrap();
355
356 let result = key.decode(&token);
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_key_verify_token_without_exp() {
362 let key = create_test_key();
363 let claims = Claims {
364 root: "test-path".to_string(),
365 publish: vec!["".to_string()],
366 subscribe: vec!["".to_string()],
367 cluster: false,
368 expires: None,
369 issued: None,
370 };
371 let token = key.encode(&claims).unwrap();
372
373 let verified_claims = key.decode(&token).unwrap();
374 assert_eq!(verified_claims.root, claims.root);
375 assert_eq!(verified_claims.publish, claims.publish);
376 assert_eq!(verified_claims.subscribe, claims.subscribe);
377 assert_eq!(verified_claims.expires, None);
378 }
379
380 #[test]
381 fn test_key_round_trip() {
382 let key = create_test_key();
383 let original_claims = Claims {
384 root: "test-path".to_string(),
385 publish: vec!["test-pub".into()],
386 subscribe: vec!["test-sub".into()],
387 cluster: true,
388 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
389 issued: Some(SystemTime::now()),
390 };
391
392 let token = key.encode(&original_claims).unwrap();
393 let verified_claims = key.decode(&token).unwrap();
394
395 assert_eq!(verified_claims.root, original_claims.root);
396 assert_eq!(verified_claims.publish, original_claims.publish);
397 assert_eq!(verified_claims.subscribe, original_claims.subscribe);
398 assert_eq!(verified_claims.cluster, original_claims.cluster);
399 }
400
401 #[test]
402 fn test_key_generate_hs256() {
403 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
404 assert_eq!(key.algorithm, Algorithm::HS256);
405 assert_eq!(key.kid, Some("test-id".to_string()));
406 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
407 assert_eq!(key.secret.len(), 32);
408 }
409
410 #[test]
411 fn test_key_generate_hs384() {
412 let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
413 assert_eq!(key.algorithm, Algorithm::HS384);
414 assert_eq!(key.secret.len(), 48);
415 }
416
417 #[test]
418 fn test_key_generate_hs512() {
419 let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
420 assert_eq!(key.algorithm, Algorithm::HS512);
421 assert_eq!(key.secret.len(), 64);
422 }
423
424 #[test]
425 fn test_key_generate_without_id() {
426 let key = Key::generate(Algorithm::HS256, None);
427 assert_eq!(key.algorithm, Algorithm::HS256);
428 assert_eq!(key.kid, None);
429 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
430 }
431
432 #[test]
433 fn test_key_generate_sign_verify_cycle() {
434 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
435 let claims = create_test_claims();
436
437 let token = key.encode(&claims).unwrap();
438 let verified_claims = key.decode(&token).unwrap();
439
440 assert_eq!(verified_claims.root, claims.root);
441 assert_eq!(verified_claims.publish, claims.publish);
442 assert_eq!(verified_claims.subscribe, claims.subscribe);
443 assert_eq!(verified_claims.cluster, claims.cluster);
444 }
445
446 #[test]
447 fn test_key_debug_no_secret() {
448 let key = create_test_key();
449 let debug_str = format!("{key:?}");
450
451 assert!(debug_str.contains("algorithm: HS256"));
452 assert!(debug_str.contains("operations"));
453 assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
454 assert!(!debug_str.contains("secret")); }
456
457 #[test]
458 fn test_key_operations_enum() {
459 let sign_op = KeyOperation::Sign;
460 let verify_op = KeyOperation::Verify;
461 let decrypt_op = KeyOperation::Decrypt;
462 let encrypt_op = KeyOperation::Encrypt;
463
464 assert_eq!(sign_op, KeyOperation::Sign);
465 assert_eq!(verify_op, KeyOperation::Verify);
466 assert_eq!(decrypt_op, KeyOperation::Decrypt);
467 assert_eq!(encrypt_op, KeyOperation::Encrypt);
468
469 assert_ne!(sign_op, verify_op);
470 assert_ne!(decrypt_op, encrypt_op);
471 }
472
473 #[test]
474 fn test_key_operations_serde() {
475 let operations = [KeyOperation::Sign, KeyOperation::Verify];
476 let json = serde_json::to_string(&operations).unwrap();
477 assert!(json.contains("\"sign\""));
478 assert!(json.contains("\"verify\""));
479
480 let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
481 assert_eq!(deserialized, operations);
482 }
483
484 #[test]
485 fn test_key_serde() {
486 let key = create_test_key();
487 let json = serde_json::to_string(&key).unwrap();
488 let deserialized: Key = serde_json::from_str(&json).unwrap();
489
490 assert_eq!(deserialized.algorithm, key.algorithm);
491 assert_eq!(deserialized.operations, key.operations);
492 assert_eq!(deserialized.secret, key.secret);
493 assert_eq!(deserialized.kid, key.kid);
494 }
495
496 #[test]
497 fn test_key_clone() {
498 let key = create_test_key();
499 let cloned = key.clone();
500
501 assert_eq!(cloned.algorithm, key.algorithm);
502 assert_eq!(cloned.operations, key.operations);
503 assert_eq!(cloned.secret, key.secret);
504 assert_eq!(cloned.kid, key.kid);
505 }
506
507 #[test]
508 fn test_different_algorithms() {
509 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
510 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
511 let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
512
513 let claims = create_test_claims();
514
515 for key in [key_256, key_384, key_512] {
517 let token = key.encode(&claims).unwrap();
518 let verified_claims = key.decode(&token).unwrap();
519 assert_eq!(verified_claims.root, claims.root);
520 }
521 }
522
523 #[test]
524 fn test_cross_algorithm_verification_fails() {
525 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
526 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
527
528 let claims = create_test_claims();
529 let token = key_256.encode(&claims).unwrap();
530
531 let result = key_384.decode(&token);
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn test_base64url_serialization() {
538 let key = create_test_key();
539 let json = serde_json::to_string(&key).unwrap();
540
541 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
543 let k_value = parsed["k"].as_str().unwrap();
544
545 assert!(!k_value.contains('='));
547 assert!(!k_value.contains('+'));
548 assert!(!k_value.contains('/'));
549
550 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
552 .decode(k_value)
553 .unwrap();
554 assert_eq!(decoded, key.secret);
555 }
556
557 #[test]
558 fn test_backwards_compatibility_unpadded_base64url() {
559 let unpadded_json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
561
562 let key: Key = serde_json::from_str(unpadded_json).unwrap();
564 assert_eq!(key.secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
565 assert_eq!(key.algorithm, Algorithm::HS256);
566 assert_eq!(key.kid, Some("test-key-1".to_string()));
567 }
568
569 #[test]
570 fn test_backwards_compatibility_padded_base64url() {
571 let padded_json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
573
574 let key: Key = serde_json::from_str(padded_json).unwrap();
576 assert_eq!(key.secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
577 assert_eq!(key.algorithm, Algorithm::HS256);
578 assert_eq!(key.kid, Some("test-key-1".to_string()));
579 }
580
581 #[test]
582 fn test_file_io_base64url() {
583 let key = create_test_key();
584 let temp_dir = std::env::temp_dir();
585 let temp_path = temp_dir.join("test_jwk.key");
586
587 key.to_file(&temp_path).unwrap();
589
590 let contents = std::fs::read_to_string(&temp_path).unwrap();
592
593 assert!(!contents.contains('{'));
595 assert!(!contents.contains('}'));
596 assert!(!contents.contains('"'));
597
598 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
600 .decode(&contents)
601 .unwrap();
602 let json_str = String::from_utf8(decoded).unwrap();
603 let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
604
605 let loaded_key = Key::from_file(&temp_path).unwrap();
607 assert_eq!(loaded_key.algorithm, key.algorithm);
608 assert_eq!(loaded_key.operations, key.operations);
609 assert_eq!(loaded_key.secret, key.secret);
610 assert_eq!(loaded_key.kid, key.kid);
611
612 std::fs::remove_file(temp_path).ok();
614 }
615}