1use crate::generate::generate;
2use crate::{Algorithm, Claims};
3use anyhow::{Context, bail};
4use base64::Engine;
5use elliptic_curve::SecretKey;
6use elliptic_curve::pkcs8::EncodePrivateKey;
7use jsonwebtoken::{DecodingKey, EncodingKey, Header};
8use rsa::BigUint;
9use rsa::pkcs1::EncodeRsaPrivateKey;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::sync::OnceLock;
12use std::{collections::HashSet, fmt, path::Path as StdPath};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
16#[serde(rename_all = "camelCase")]
17pub enum KeyOperation {
18 Sign,
19 Verify,
20 Decrypt,
21 Encrypt,
22}
23
24#[derive(Clone, Serialize, Deserialize)]
26#[serde(tag = "kty")]
27pub enum KeyType {
28 EC {
30 #[serde(rename = "crv")]
31 curve: EllipticCurve,
32 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
34 x: Vec<u8>,
35 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
37 y: Vec<u8>,
38 #[serde(
40 default,
41 skip_serializing_if = "Option::is_none",
42 serialize_with = "serialize_base64url_optional",
43 deserialize_with = "deserialize_base64url_optional"
44 )]
45 d: Option<Vec<u8>>,
46 },
47 RSA {
49 #[serde(flatten)]
50 public: RsaPublicKey,
51 #[serde(flatten, skip_serializing_if = "Option::is_none")]
52 private: Option<RsaPrivateKey>,
53 },
54 #[serde(rename = "oct")]
56 OCT {
57 #[serde(
59 rename = "k",
60 default,
61 serialize_with = "serialize_base64url",
62 deserialize_with = "deserialize_base64url"
63 )]
64 secret: Vec<u8>,
65 },
66 OKP {
68 #[serde(rename = "crv")]
69 curve: EllipticCurve,
70 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
71 x: Vec<u8>,
72 #[serde(
73 rename = "d",
74 default,
75 skip_serializing_if = "Option::is_none",
76 serialize_with = "serialize_base64url_optional",
77 deserialize_with = "deserialize_base64url_optional"
78 )]
79 d: Option<Vec<u8>>,
80 },
81}
82
83#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
87pub enum EllipticCurve {
88 #[serde(rename = "P-256")]
89 P256,
90 #[serde(rename = "P-384")]
91 P384,
92 #[serde(rename = "Ed25519")]
96 Ed25519,
97}
98
99#[derive(Clone, Serialize, Deserialize)]
103pub struct RsaPublicKey {
104 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
105 pub n: Vec<u8>,
106 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
107 pub e: Vec<u8>,
108}
109
110#[derive(Clone, Serialize, Deserialize)]
114pub struct RsaPrivateKey {
115 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
116 pub d: Vec<u8>,
117 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
118 pub p: Vec<u8>,
119 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
120 pub q: Vec<u8>,
121 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
122 pub dp: Vec<u8>,
123 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
124 pub dq: Vec<u8>,
125 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
126 pub qi: Vec<u8>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub oth: Option<Vec<RsaAdditionalPrime>>,
129}
130
131#[derive(Clone, Serialize, Deserialize)]
133pub struct RsaAdditionalPrime {
134 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
135 pub r: Vec<u8>,
136 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
137 pub d: Vec<u8>,
138 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
139 pub t: Vec<u8>,
140}
141
142#[derive(Clone, Serialize, Deserialize)]
145#[serde(remote = "Self")]
146pub struct Key {
147 #[serde(rename = "alg")]
149 pub algorithm: Algorithm,
150
151 #[serde(rename = "key_ops")]
153 pub operations: HashSet<KeyOperation>,
154
155 #[serde(flatten)]
157 pub key: KeyType,
158
159 #[serde(skip_serializing_if = "Option::is_none")]
161 pub kid: Option<String>,
162
163 #[serde(skip)]
165 pub(crate) decode: OnceLock<DecodingKey>,
166
167 #[serde(skip)]
168 pub(crate) encode: OnceLock<EncodingKey>,
169}
170
171impl<'de> Deserialize<'de> for Key {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 let mut value = serde_json::Value::deserialize(deserializer)?;
177
178 if let Some(obj) = value.as_object_mut()
182 && !obj.contains_key("kty")
183 {
184 obj.insert("kty".to_string(), serde_json::Value::String("oct".to_string()));
185 }
186
187 Self::deserialize(value).map_err(serde::de::Error::custom)
188 }
189}
190
191impl Serialize for Key {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: Serializer,
195 {
196 Self::serialize(self, serializer)
197 }
198}
199
200impl fmt::Debug for Key {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("Key")
203 .field("algorithm", &self.algorithm)
204 .field("operations", &self.operations)
205 .field("kid", &self.kid)
206 .finish()
207 }
208}
209
210impl Key {
211 #[allow(clippy::should_implement_trait)]
212 pub fn from_str(s: &str) -> anyhow::Result<Self> {
213 Ok(serde_json::from_str(s)?)
214 }
215
216 pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
217 let contents = std::fs::read_to_string(&path)?;
218 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
220 let json = String::from_utf8(decoded)?;
221 Ok(serde_json::from_str(&json)?)
222 }
223
224 pub fn to_str(&self) -> anyhow::Result<String> {
225 Ok(serde_json::to_string(self)?)
226 }
227
228 pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
229 let json = serde_json::to_string(self)?;
231 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
233 std::fs::write(path, encoded)?;
234 Ok(())
235 }
236
237 pub fn to_public(&self) -> anyhow::Result<Self> {
238 if !self.operations.contains(&KeyOperation::Verify) {
239 return Err(anyhow::anyhow!("This key doesn't support the Verify operation"));
240 }
241
242 let key = match self.key {
243 KeyType::RSA { ref public, .. } => Ok(KeyType::RSA {
244 public: public.clone(),
245 private: None,
246 }),
247 KeyType::EC {
248 ref x,
249 ref y,
250 ref curve,
251 ..
252 } => Ok(KeyType::EC {
253 x: x.clone(),
254 y: y.clone(),
255 curve: curve.clone(),
256 d: None,
257 }),
258 KeyType::OCT { .. } => Err(anyhow::anyhow!("OCT key cannot be converted to public key")),
259 KeyType::OKP { ref x, ref curve, .. } => Ok(KeyType::OKP {
260 x: x.clone(),
261 curve: curve.clone(),
262 d: None,
263 }),
264 };
265
266 match key {
267 Ok(key) => Ok(Self {
268 algorithm: self.algorithm,
269 operations: [KeyOperation::Verify].into(),
270 key,
271 kid: self.kid.clone(),
272 decode: Default::default(),
273 encode: Default::default(),
274 }),
275 Err(err) => Err(anyhow::anyhow!("Failed to convert key: {}", err)),
276 }
277 }
278
279 fn to_decoding_key(&self) -> anyhow::Result<&DecodingKey> {
280 if let Some(key) = self.decode.get() {
281 return Ok(key);
282 }
283
284 let decoding_key = match self.key {
285 KeyType::OCT { ref secret } => match self.algorithm {
286 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret),
287 _ => bail!("Invalid algorithm for key type OCT"),
288 },
289 KeyType::EC {
290 ref curve,
291 ref x,
292 ref y,
293 ..
294 } => match curve {
295 EllipticCurve::P256 => {
296 if self.algorithm != Algorithm::ES256 {
297 bail!("Invalid algorithm for P-256 curve");
298 }
299 if x.len() != 32 || y.len() != 32 {
300 bail!("Invalid coordinate length for P-256");
301 }
302
303 DecodingKey::from_ec_components(
304 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
305 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
306 )?
307 }
308 EllipticCurve::P384 => {
309 if self.algorithm != Algorithm::ES384 {
310 bail!("Invalid algorithm for P-384 curve");
311 }
312 if x.len() != 48 || y.len() != 48 {
313 bail!("Invalid coordinate length for P-384");
314 }
315
316 DecodingKey::from_ec_components(
317 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
318 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
319 )?
320 }
321 _ => bail!("Invalid curve for EC key"),
322 },
323 KeyType::OKP { ref curve, ref x, .. } => match curve {
324 EllipticCurve::Ed25519 => {
325 if self.algorithm != Algorithm::EdDSA {
326 bail!("Invalid algorithm for Ed25519 curve");
327 }
328
329 DecodingKey::from_ed_components(
330 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
331 )?
332 }
333 _ => bail!("Invalid curve for OKP key"),
334 },
335 KeyType::RSA { ref public, .. } => {
336 DecodingKey::from_rsa_raw_components(public.n.as_ref(), public.e.as_ref())
337 }
338 };
339
340 Ok(self.decode.get_or_init(|| decoding_key))
341 }
342
343 fn to_encoding_key(&self) -> anyhow::Result<&EncodingKey> {
344 if let Some(key) = self.encode.get() {
345 return Ok(key);
346 }
347
348 let encoding_key = match self.key {
349 KeyType::OCT { ref secret } => match self.algorithm {
350 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret),
351 _ => bail!("Invalid algorithm for key type OCT"),
352 },
353 KeyType::EC { ref curve, ref d, .. } => {
354 let d = d.as_ref().context("Missing private key")?;
355
356 match curve {
357 EllipticCurve::P256 => {
358 let secret_key = SecretKey::<p256::NistP256>::from_slice(d)?;
359 let doc = secret_key.to_pkcs8_der()?;
360 EncodingKey::from_ec_der(doc.as_bytes())
361 }
362 EllipticCurve::P384 => {
363 let secret_key = SecretKey::<p384::NistP384>::from_slice(d)?;
364 let doc = secret_key.to_pkcs8_der()?;
365 EncodingKey::from_ec_der(doc.as_bytes())
366 }
367 _ => bail!("Invalid curve for EC key"),
368 }
369 }
370 KeyType::OKP {
371 ref curve,
372 ref d,
373 ref x,
374 } => {
375 let d = d.as_ref().context("Missing private key")?;
376
377 let key_pair =
378 aws_lc_rs::signature::Ed25519KeyPair::from_seed_and_public_key(d.as_slice(), x.as_slice())?;
379
380 match curve {
381 EllipticCurve::Ed25519 => EncodingKey::from_ed_der(key_pair.to_pkcs8()?.as_ref()),
382 _ => bail!("Invalid curve for OKP key"),
383 }
384 }
385 KeyType::RSA {
386 ref public,
387 ref private,
388 } => {
389 let n = BigUint::from_bytes_be(&public.n);
390 let e = BigUint::from_bytes_be(&public.e);
391 let private = private.as_ref().context("Missing private key")?;
392 let d = BigUint::from_bytes_be(&private.d);
393 let p = BigUint::from_bytes_be(&private.p);
394 let q = BigUint::from_bytes_be(&private.q);
395
396 let rsa = rsa::RsaPrivateKey::from_components(n, e, d, vec![p, q]);
397 let pem = rsa?.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF);
398
399 EncodingKey::from_rsa_pem(pem?.as_bytes())?
400 }
401 };
402
403 Ok(self.encode.get_or_init(|| encoding_key))
404 }
405
406 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
407 if !self.operations.contains(&KeyOperation::Verify) {
408 bail!("key does not support verification");
409 }
410
411 let decode: anyhow::Result<&DecodingKey> = self.to_decoding_key();
412
413 match decode {
414 Ok(decode) => {
415 let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
416 validation.required_spec_claims = Default::default(); validation.validate_exp = false; let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
420
421 if let Some(exp) = token.claims.expires
422 && exp < std::time::SystemTime::now()
423 {
424 anyhow::bail!("token has expired");
425 }
426
427 token.claims.validate()?;
428
429 Ok(token.claims)
430 }
431 Err(e) => Err(anyhow::anyhow!("Failed to decode key: {}", e)),
432 }
433 }
434
435 pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
436 if !self.operations.contains(&KeyOperation::Sign) {
437 bail!("key does not support signing");
438 }
439
440 payload.validate()?;
441
442 let encode: anyhow::Result<&EncodingKey> = self.to_encoding_key();
443
444 match encode {
445 Ok(encode) => {
446 let mut header = Header::new(self.algorithm.into());
447 header.kid = self.kid.clone();
448 let token = jsonwebtoken::encode(&header, &payload, encode)?;
449 Ok(token)
450 }
451 Err(e) => Err(anyhow::anyhow!("Failed to encode key: {}", e)),
452 }
453 }
454
455 pub fn generate(algorithm: Algorithm, id: Option<String>) -> anyhow::Result<Self> {
457 generate(algorithm, id)
458 }
459}
460
461fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
463where
464 S: Serializer,
465{
466 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
467 serializer.serialize_str(&encoded)
468}
469
470fn serialize_base64url_optional<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
471where
472 S: Serializer,
473{
474 match bytes {
475 Some(b) => serialize_base64url(b, serializer),
476 None => serializer.serialize_none(),
477 }
478}
479
480fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
482where
483 D: Deserializer<'de>,
484{
485 let s = String::deserialize(deserializer)?;
486
487 base64::engine::general_purpose::URL_SAFE_NO_PAD
489 .decode(&s)
490 .or_else(|_| {
491 base64::engine::general_purpose::URL_SAFE.decode(&s)
493 })
494 .map_err(serde::de::Error::custom)
495}
496
497fn deserialize_base64url_optional<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
498where
499 D: Deserializer<'de>,
500{
501 let s: Option<String> = Option::deserialize(deserializer)?;
502 match s {
503 Some(s) => {
504 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
505 .decode(&s)
506 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(&s))
507 .map_err(serde::de::Error::custom)?;
508 Ok(Some(decoded))
509 }
510 None => Ok(None),
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use std::time::{Duration, SystemTime};
518
519 fn create_test_key() -> Key {
520 Key {
521 algorithm: Algorithm::HS256,
522 operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
523 key: KeyType::OCT {
524 secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
525 },
526 kid: Some("test-key-1".to_string()),
527 decode: Default::default(),
528 encode: Default::default(),
529 }
530 }
531
532 fn create_test_claims() -> Claims {
533 Claims {
534 root: "test-path".to_string(),
535 publish: vec!["test-pub".into()],
536 cluster: false,
537 subscribe: vec!["test-sub".into()],
538 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
539 issued: Some(SystemTime::now()),
540 }
541 }
542
543 #[test]
544 fn test_key_from_str_valid() {
545 let key = create_test_key();
546 let json = key.to_str().unwrap();
547 let loaded_key = Key::from_str(&json).unwrap();
548
549 assert_eq!(loaded_key.algorithm, key.algorithm);
550 assert_eq!(loaded_key.operations, key.operations);
551 match (loaded_key.key, key.key) {
552 (KeyType::OCT { secret: loaded_secret }, KeyType::OCT { secret }) => {
553 assert_eq!(loaded_secret, secret);
554 }
555 _ => panic!("Expected OCT key"),
556 }
557 assert_eq!(loaded_key.kid, key.kid);
558 }
559
560 #[test]
562 fn test_key_oct_backwards_compatibility() {
563 let json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68"}"#;
564 let key = Key::from_str(json);
565
566 assert!(key.is_ok());
567 let key = key.unwrap();
568
569 if let KeyType::OCT { ref secret, .. } = key.key {
570 let base64_key = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(secret);
571 assert_eq!(base64_key, "Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68");
572 } else {
573 panic!("Expected OCT key");
574 }
575
576 let key_str = key.to_str();
577 assert!(key_str.is_ok());
578 let key_str = key_str.unwrap();
579
580 assert!(key_str.contains("\"alg\":\"HS256\""));
582 assert!(key_str.contains("\"key_ops\""));
583 assert!(key_str.contains("\"sign\""));
584 assert!(key_str.contains("\"verify\""));
585 assert!(key_str.contains("\"kty\":\"oct\""));
586 }
587
588 #[test]
589 fn test_key_from_str_invalid_json() {
590 let result = Key::from_str("invalid json");
591 assert!(result.is_err());
592 }
593
594 #[test]
595 fn test_key_to_str() {
596 let key = create_test_key();
597 let json = key.to_str().unwrap();
598 assert!(json.contains("\"alg\":\"HS256\""));
599 assert!(json.contains("\"key_ops\""));
600 assert!(json.contains("\"sign\""));
601 assert!(json.contains("\"verify\""));
602 assert!(json.contains("\"kid\":\"test-key-1\""));
603 assert!(json.contains("\"kty\":\"oct\""));
604 }
605
606 #[test]
607 fn test_key_sign_success() {
608 let key = create_test_key();
609 let claims = create_test_claims();
610 let token = key.encode(&claims).unwrap();
611
612 assert!(!token.is_empty());
613 assert_eq!(token.matches('.').count(), 2); }
615
616 #[test]
617 fn test_key_sign_no_permission() {
618 let mut key = create_test_key();
619 key.operations = [KeyOperation::Verify].into();
620 let claims = create_test_claims();
621
622 let result = key.encode(&claims);
623 assert!(result.is_err());
624 assert!(result.unwrap_err().to_string().contains("key does not support signing"));
625 }
626
627 #[test]
628 fn test_key_sign_invalid_claims() {
629 let key = create_test_key();
630 let invalid_claims = Claims {
631 root: "test-path".to_string(),
632 publish: vec![],
633 subscribe: vec![],
634 cluster: false,
635 expires: None,
636 issued: None,
637 };
638
639 let result = key.encode(&invalid_claims);
640 assert!(result.is_err());
641 assert!(
642 result
643 .unwrap_err()
644 .to_string()
645 .contains("no publish or subscribe allowed; token is useless")
646 );
647 }
648
649 #[test]
650 fn test_key_verify_success() {
651 let key = create_test_key();
652 let claims = create_test_claims();
653 let token = key.encode(&claims).unwrap();
654
655 let verified_claims = key.decode(&token).unwrap();
656 assert_eq!(verified_claims.root, claims.root);
657 assert_eq!(verified_claims.publish, claims.publish);
658 assert_eq!(verified_claims.subscribe, claims.subscribe);
659 assert_eq!(verified_claims.cluster, claims.cluster);
660 }
661
662 #[test]
663 fn test_key_verify_no_permission() {
664 let mut key = create_test_key();
665 key.operations = [KeyOperation::Sign].into();
666
667 let result = key.decode("some.jwt.token");
668 assert!(result.is_err());
669 assert!(
670 result
671 .unwrap_err()
672 .to_string()
673 .contains("key does not support verification")
674 );
675 }
676
677 #[test]
678 fn test_key_verify_invalid_token() {
679 let key = create_test_key();
680 let result = key.decode("invalid-token");
681 assert!(result.is_err());
682 }
683
684 #[test]
685 fn test_key_verify_path_mismatch() {
686 let key = create_test_key();
687 let claims = create_test_claims();
688 let token = key.encode(&claims).unwrap();
689
690 let result = key.decode(&token);
692 assert!(result.is_ok());
693 }
694
695 #[test]
696 fn test_key_verify_expired_token() {
697 let key = create_test_key();
698 let mut claims = create_test_claims();
699 claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); let token = key.encode(&claims).unwrap();
701
702 let result = key.decode(&token);
703 assert!(result.is_err());
704 }
705
706 #[test]
707 fn test_key_verify_token_without_exp() {
708 let key = create_test_key();
709 let claims = Claims {
710 root: "test-path".to_string(),
711 publish: vec!["".to_string()],
712 subscribe: vec!["".to_string()],
713 cluster: false,
714 expires: None,
715 issued: None,
716 };
717 let token = key.encode(&claims).unwrap();
718
719 let verified_claims = key.decode(&token).unwrap();
720 assert_eq!(verified_claims.root, claims.root);
721 assert_eq!(verified_claims.publish, claims.publish);
722 assert_eq!(verified_claims.subscribe, claims.subscribe);
723 assert_eq!(verified_claims.expires, None);
724 }
725
726 #[test]
727 fn test_key_round_trip() {
728 let key = create_test_key();
729 let original_claims = Claims {
730 root: "test-path".to_string(),
731 publish: vec!["test-pub".into()],
732 subscribe: vec!["test-sub".into()],
733 cluster: true,
734 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
735 issued: Some(SystemTime::now()),
736 };
737
738 let token = key.encode(&original_claims).unwrap();
739 let verified_claims = key.decode(&token).unwrap();
740
741 assert_eq!(verified_claims.root, original_claims.root);
742 assert_eq!(verified_claims.publish, original_claims.publish);
743 assert_eq!(verified_claims.subscribe, original_claims.subscribe);
744 assert_eq!(verified_claims.cluster, original_claims.cluster);
745 }
746
747 #[test]
748 fn test_key_generate_hs256() {
749 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
750 assert!(key.is_ok());
751 let key = key.unwrap();
752
753 assert_eq!(key.algorithm, Algorithm::HS256);
754 assert_eq!(key.kid, Some("test-id".to_string()));
755 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
756
757 match key.key {
758 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 32),
759 _ => panic!("Expected OCT key"),
760 }
761 }
762
763 #[test]
764 fn test_key_generate_hs384() {
765 let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
766 assert!(key.is_ok());
767 let key = key.unwrap();
768
769 assert_eq!(key.algorithm, Algorithm::HS384);
770
771 match key.key {
772 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 48),
773 _ => panic!("Expected OCT key"),
774 }
775 }
776
777 #[test]
778 fn test_key_generate_hs512() {
779 let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
780 assert!(key.is_ok());
781 let key = key.unwrap();
782
783 assert_eq!(key.algorithm, Algorithm::HS512);
784
785 match key.key {
786 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 64),
787 _ => panic!("Expected OCT key"),
788 }
789 }
790
791 #[test]
792 fn test_key_generate_rs512() {
793 let key = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
794 assert!(key.is_ok());
795 let key = key.unwrap();
796
797 assert_eq!(key.algorithm, Algorithm::RS512);
798 assert!(matches!(key.key, KeyType::RSA { .. }));
799 match key.key {
800 KeyType::RSA {
801 ref public,
802 ref private,
803 } => {
804 assert!(private.is_some());
805 assert_eq!(public.n.len(), 256);
806 assert_eq!(public.e.len(), 3);
807 }
808 _ => panic!("Expected RSA key"),
809 }
810 }
811
812 #[test]
813 fn test_key_generate_es256() {
814 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
815 assert!(key.is_ok());
816 let key = key.unwrap();
817
818 assert_eq!(key.algorithm, Algorithm::ES256);
819 assert!(matches!(key.key, KeyType::EC { .. }))
820 }
821
822 #[test]
823 fn test_key_generate_ps512() {
824 let key = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
825 assert!(key.is_ok());
826 let key = key.unwrap();
827
828 assert_eq!(key.algorithm, Algorithm::PS512);
829 assert!(matches!(key.key, KeyType::RSA { .. }));
830 }
831
832 #[test]
833 fn test_key_generate_eddsa() {
834 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
835 assert!(key.is_ok());
836 let key = key.unwrap();
837
838 assert_eq!(key.algorithm, Algorithm::EdDSA);
839 assert!(matches!(key.key, KeyType::OKP { .. }));
840 }
841
842 #[test]
843 fn test_key_generate_without_id() {
844 let key = Key::generate(Algorithm::HS256, None);
845 assert!(key.is_ok());
846 let key = key.unwrap();
847
848 assert_eq!(key.algorithm, Algorithm::HS256);
849 assert_eq!(key.kid, None);
850 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
851 }
852
853 #[test]
854 fn test_public_key_conversion_hmac() {
855 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string())).expect("HMAC key generation failed");
856
857 assert!(key.to_public().is_err());
858 }
859
860 #[test]
861 fn test_public_key_conversion_rsa() {
862 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
863 assert!(key.is_ok());
864 let key = key.unwrap();
865
866 let public_key = key.to_public().unwrap();
867 assert_eq!(key.kid, public_key.kid);
868 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
869 assert!(public_key.encode.get().is_none());
870 assert!(public_key.decode.get().is_none());
871 assert!(matches!(public_key.key, KeyType::RSA { .. }));
872
873 if let KeyType::RSA { public, private } = &public_key.key {
874 assert!(private.is_none());
875
876 if let KeyType::RSA { public: src_public, .. } = &key.key {
877 assert_eq!(public.e, src_public.e);
878 assert_eq!(public.n, src_public.n);
879 } else {
880 unreachable!("Expected RSA key")
881 }
882 } else {
883 unreachable!("Expected RSA key");
884 }
885 }
886
887 #[test]
888 fn test_public_key_conversion_es() {
889 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
890 assert!(key.is_ok());
891 let key = key.unwrap();
892
893 let public_key = key.to_public().unwrap();
894 assert_eq!(key.kid, public_key.kid);
895 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
896 assert!(public_key.encode.get().is_none());
897 assert!(public_key.decode.get().is_none());
898 assert!(matches!(public_key.key, KeyType::EC { .. }));
899
900 if let KeyType::EC { x, y, d, curve } = &public_key.key {
901 assert!(d.is_none());
902
903 if let KeyType::EC {
904 x: src_x,
905 y: src_y,
906 curve: src_curve,
907 ..
908 } = &key.key
909 {
910 assert_eq!(x, src_x);
911 assert_eq!(y, src_y);
912 assert_eq!(curve, src_curve);
913 } else {
914 unreachable!("Expected EC key")
915 }
916 } else {
917 unreachable!("Expected EC key");
918 }
919 }
920
921 #[test]
922 fn test_public_key_conversion_ed() {
923 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
924 assert!(key.is_ok());
925 let key = key.unwrap();
926
927 let public_key = key.to_public().unwrap();
928 assert_eq!(key.kid, public_key.kid);
929 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
930 assert!(public_key.encode.get().is_none());
931 assert!(public_key.decode.get().is_none());
932 assert!(matches!(public_key.key, KeyType::OKP { .. }));
933
934 if let KeyType::OKP { x, d, curve } = &public_key.key {
935 assert!(d.is_none());
936
937 if let KeyType::OKP {
938 x: src_x,
939 curve: src_curve,
940 ..
941 } = &key.key
942 {
943 assert_eq!(x, src_x);
944 assert_eq!(curve, src_curve);
945 } else {
946 unreachable!("Expected OKP key")
947 }
948 } else {
949 unreachable!("Expected OKP key");
950 }
951 }
952
953 #[test]
954 fn test_key_generate_sign_verify_cycle() {
955 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
956 assert!(key.is_ok());
957 let key = key.unwrap();
958
959 let claims = create_test_claims();
960
961 let token = key.encode(&claims).unwrap();
962 let verified_claims = key.decode(&token).unwrap();
963
964 assert_eq!(verified_claims.root, claims.root);
965 assert_eq!(verified_claims.publish, claims.publish);
966 assert_eq!(verified_claims.subscribe, claims.subscribe);
967 assert_eq!(verified_claims.cluster, claims.cluster);
968 }
969
970 #[test]
971 fn test_key_debug_no_secret() {
972 let key = create_test_key();
973 let debug_str = format!("{key:?}");
974
975 assert!(debug_str.contains("algorithm: HS256"));
976 assert!(debug_str.contains("operations"));
977 assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
978 assert!(!debug_str.contains("secret")); }
980
981 #[test]
982 fn test_key_operations_enum() {
983 let sign_op = KeyOperation::Sign;
984 let verify_op = KeyOperation::Verify;
985 let decrypt_op = KeyOperation::Decrypt;
986 let encrypt_op = KeyOperation::Encrypt;
987
988 assert_eq!(sign_op, KeyOperation::Sign);
989 assert_eq!(verify_op, KeyOperation::Verify);
990 assert_eq!(decrypt_op, KeyOperation::Decrypt);
991 assert_eq!(encrypt_op, KeyOperation::Encrypt);
992
993 assert_ne!(sign_op, verify_op);
994 assert_ne!(decrypt_op, encrypt_op);
995 }
996
997 #[test]
998 fn test_key_operations_serde() {
999 let operations = [KeyOperation::Sign, KeyOperation::Verify];
1000 let json = serde_json::to_string(&operations).unwrap();
1001 assert!(json.contains("\"sign\""));
1002 assert!(json.contains("\"verify\""));
1003
1004 let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
1005 assert_eq!(deserialized, operations);
1006 }
1007
1008 #[test]
1009 fn test_key_serde() {
1010 let key = create_test_key();
1011 let json = serde_json::to_string(&key).unwrap();
1012 let deserialized: Key = serde_json::from_str(&json).unwrap();
1013
1014 assert_eq!(deserialized.algorithm, key.algorithm);
1015 assert_eq!(deserialized.operations, key.operations);
1016 assert_eq!(deserialized.kid, key.kid);
1017
1018 if let (
1019 KeyType::OCT {
1020 secret: original_secret,
1021 },
1022 KeyType::OCT {
1023 secret: deserialized_secret,
1024 },
1025 ) = (&key.key, &deserialized.key)
1026 {
1027 assert_eq!(deserialized_secret, original_secret);
1028 } else {
1029 panic!("Expected both keys to be OCT variant");
1030 }
1031 }
1032
1033 #[test]
1034 fn test_key_clone() {
1035 let key = create_test_key();
1036 let cloned = key.clone();
1037
1038 assert_eq!(cloned.algorithm, key.algorithm);
1039 assert_eq!(cloned.operations, key.operations);
1040 assert_eq!(cloned.kid, key.kid);
1041
1042 if let (
1043 KeyType::OCT {
1044 secret: original_secret,
1045 },
1046 KeyType::OCT { secret: cloned_secret },
1047 ) = (&key.key, &cloned.key)
1048 {
1049 assert_eq!(cloned_secret, original_secret);
1050 } else {
1051 panic!("Expected both keys to be OCT variant");
1052 }
1053 }
1054
1055 #[test]
1056 fn test_hmac_algorithms() {
1057 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1058 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1059 let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
1060
1061 let claims = create_test_claims();
1062
1063 for key in [key_256, key_384, key_512] {
1065 assert!(key.is_ok());
1066 let key = key.unwrap();
1067
1068 let token = key.encode(&claims).unwrap();
1069 let verified_claims = key.decode(&token).unwrap();
1070 assert_eq!(verified_claims.root, claims.root);
1071 }
1072 }
1073
1074 #[test]
1075 fn test_rsa_pkcs1_asymmetric_algorithms() {
1076 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1077 let key_rs384 = Key::generate(Algorithm::RS384, Some("test-id".to_string()));
1078 let key_rs512 = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
1079
1080 for key in [key_rs256, key_rs384, key_rs512] {
1081 test_asymmetric_key(key);
1082 }
1083 }
1084
1085 #[test]
1086 fn test_rsa_pss_asymmetric_algorithms() {
1087 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1088 let key_ps384 = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1089 let key_ps512 = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
1090
1091 for key in [key_ps256, key_ps384, key_ps512] {
1092 test_asymmetric_key(key);
1093 }
1094 }
1095
1096 #[test]
1097 fn test_ec_asymmetric_algorithms() {
1098 let key_es256 = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
1099 let key_es384 = Key::generate(Algorithm::ES384, Some("test-id".to_string()));
1100
1101 for key in [key_es256, key_es384] {
1102 test_asymmetric_key(key);
1103 }
1104 }
1105
1106 #[test]
1107 fn test_ed_asymmetric_algorithms() {
1108 let key_eddsa = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
1109
1110 test_asymmetric_key(key_eddsa);
1111 }
1112
1113 fn test_asymmetric_key(key: anyhow::Result<Key>) {
1114 assert!(key.is_ok());
1115 let key = key.unwrap();
1116
1117 let claims = create_test_claims();
1118 let token = key.encode(&claims).unwrap();
1119
1120 let private_verified_claims = key.decode(&token).unwrap();
1121 assert_eq!(
1122 private_verified_claims.root, claims.root,
1123 "validation using private key"
1124 );
1125
1126 let public_verified_claims = key.to_public().unwrap().decode(&token).unwrap();
1127 assert_eq!(public_verified_claims.root, claims.root, "validation using public key");
1128 }
1129
1130 #[test]
1131 fn test_cross_algorithm_verification_fails() {
1132 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1133 assert!(key_256.is_ok());
1134 let key_256 = key_256.unwrap();
1135
1136 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1137 assert!(key_384.is_ok());
1138 let key_384 = key_384.unwrap();
1139
1140 let claims = create_test_claims();
1141 let token = key_256.encode(&claims).unwrap();
1142
1143 let result = key_384.decode(&token);
1145 assert!(result.is_err());
1146 }
1147
1148 #[test]
1149 fn test_asymmetric_cross_algorithm_verification_fails() {
1150 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1151 assert!(key_rs256.is_ok());
1152 let key_rs256 = key_rs256.unwrap();
1153
1154 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1155 assert!(key_ps256.is_ok());
1156 let key_ps256 = key_ps256.unwrap();
1157
1158 let claims = create_test_claims();
1159 let token = key_rs256.encode(&claims).unwrap();
1160
1161 let private_result = key_ps256.decode(&token);
1163 let public_result = key_ps256.to_public().unwrap().decode(&token);
1164 assert!(private_result.is_err());
1165 assert!(public_result.is_err());
1166 }
1167
1168 #[test]
1169 fn test_rsa_pkcs1_public_key_conversion() {
1170 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1171 assert!(key.is_ok());
1172 let key = key.unwrap();
1173
1174 assert!(key.operations.contains(&KeyOperation::Sign));
1175 assert!(key.operations.contains(&KeyOperation::Verify));
1176
1177 let public_key = key.to_public().unwrap();
1178 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1179 assert!(public_key.operations.contains(&KeyOperation::Verify));
1180
1181 match key.key {
1182 KeyType::RSA {
1183 ref public,
1184 ref private,
1185 } => {
1186 assert!(private.is_some());
1187 assert_eq!(public.n.len(), 256);
1188 assert_eq!(public.e.len(), 3);
1189
1190 match public_key.key {
1191 KeyType::RSA {
1192 public: ref public_public,
1193 private: ref public_private,
1194 } => {
1195 assert!(public_private.is_none());
1196 assert_eq!(public.n, public_public.n);
1197 assert_eq!(public.e, public_public.e);
1198 }
1199 _ => panic!("Expected public key to be an RSA key"),
1200 }
1201 }
1202 _ => panic!("Expected private key to be an RSA key"),
1203 }
1204 }
1205
1206 #[test]
1207 fn test_rsa_pss_public_key_conversion() {
1208 let key = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1209 assert!(key.is_ok());
1210 let key = key.unwrap();
1211
1212 assert!(key.operations.contains(&KeyOperation::Sign));
1213 assert!(key.operations.contains(&KeyOperation::Verify));
1214
1215 let public_key = key.to_public().unwrap();
1216 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1217 assert!(public_key.operations.contains(&KeyOperation::Verify));
1218
1219 match key.key {
1220 KeyType::RSA {
1221 ref public,
1222 ref private,
1223 } => {
1224 assert!(private.is_some());
1225 assert_eq!(public.n.len(), 256);
1226 assert_eq!(public.e.len(), 3);
1227
1228 match public_key.key {
1229 KeyType::RSA {
1230 public: ref public_public,
1231 private: ref public_private,
1232 } => {
1233 assert!(public_private.is_none());
1234 assert_eq!(public.n, public_public.n);
1235 assert_eq!(public.e, public_public.e);
1236 }
1237 _ => panic!("Expected public key to be an RSA key"),
1238 }
1239 }
1240 _ => panic!("Expected private key to be an RSA key"),
1241 }
1242 }
1243
1244 #[test]
1245 fn test_base64url_serialization() {
1246 let key = create_test_key();
1247 let json = serde_json::to_string(&key).unwrap();
1248
1249 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1251 let k_value = parsed["k"].as_str().unwrap();
1252
1253 assert!(!k_value.contains('='));
1255 assert!(!k_value.contains('+'));
1256 assert!(!k_value.contains('/'));
1257
1258 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1260 .decode(k_value)
1261 .unwrap();
1262
1263 if let KeyType::OCT {
1264 secret: original_secret,
1265 } = &key.key
1266 {
1267 assert_eq!(decoded, *original_secret);
1268 } else {
1269 panic!("Expected both keys to be OCT variant");
1270 }
1271 }
1272
1273 #[test]
1274 fn test_backwards_compatibility_unpadded_base64url() {
1275 let unpadded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
1277
1278 let key: Key = serde_json::from_str(unpadded_json).unwrap();
1280 assert_eq!(key.algorithm, Algorithm::HS256);
1281 assert_eq!(key.kid, Some("test-key-1".to_string()));
1282
1283 if let KeyType::OCT { secret } = &key.key {
1284 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1285 } else {
1286 panic!("Expected key to be OCT variant");
1287 }
1288 }
1289
1290 #[test]
1291 fn test_backwards_compatibility_padded_base64url() {
1292 let padded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
1294
1295 let key: Key = serde_json::from_str(padded_json).unwrap();
1297 assert_eq!(key.algorithm, Algorithm::HS256);
1298 assert_eq!(key.kid, Some("test-key-1".to_string()));
1299
1300 if let KeyType::OCT { secret } = &key.key {
1301 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1302 } else {
1303 panic!("Expected key to be OCT variant");
1304 }
1305 }
1306
1307 #[test]
1308 fn test_file_io_base64url() {
1309 let key = create_test_key();
1310 let temp_dir = std::env::temp_dir();
1311 let temp_path = temp_dir.join("test_jwk.key");
1312
1313 key.to_file(&temp_path).unwrap();
1315
1316 let contents = std::fs::read_to_string(&temp_path).unwrap();
1318
1319 assert!(!contents.contains('{'));
1321 assert!(!contents.contains('}'));
1322 assert!(!contents.contains('"'));
1323
1324 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1326 .decode(&contents)
1327 .unwrap();
1328 let json_str = String::from_utf8(decoded).unwrap();
1329 let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1330
1331 let loaded_key = Key::from_file(&temp_path).unwrap();
1333 assert_eq!(loaded_key.algorithm, key.algorithm);
1334 assert_eq!(loaded_key.operations, key.operations);
1335 assert_eq!(loaded_key.kid, key.kid);
1336
1337 if let (
1338 KeyType::OCT {
1339 secret: original_secret,
1340 },
1341 KeyType::OCT { secret: loaded_secret },
1342 ) = (&key.key, &loaded_key.key)
1343 {
1344 assert_eq!(loaded_secret, original_secret);
1345 } else {
1346 panic!("Expected both keys to be OCT variant");
1347 }
1348
1349 std::fs::remove_file(temp_path).ok();
1351 }
1352}