1use crate::generate::generate;
2use crate::{Algorithm, Claims};
3use anyhow::{bail, Context};
4use base64::Engine;
5use elliptic_curve::pkcs8::EncodePrivateKey;
6use elliptic_curve::SecretKey;
7use jsonwebtoken::{DecodingKey, EncodingKey, Header};
8use rsa::pkcs1::EncodeRsaPrivateKey;
9use rsa::BigUint;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::sync::OnceLock;
12use std::{collections::HashSet, fmt, path::Path as StdPath};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
16#[serde(rename_all = "camelCase")]
17pub enum KeyOperation {
18 Sign,
19 Verify,
20 Decrypt,
21 Encrypt,
22}
23
24#[derive(Clone, Serialize, Deserialize)]
26#[serde(tag = "kty")]
27pub enum KeyType {
28 EC {
30 #[serde(rename = "crv")]
31 curve: EllipticCurve,
32 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
34 x: Vec<u8>,
35 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
37 y: Vec<u8>,
38 #[serde(
40 default,
41 skip_serializing_if = "Option::is_none",
42 serialize_with = "serialize_base64url_optional",
43 deserialize_with = "deserialize_base64url_optional"
44 )]
45 d: Option<Vec<u8>>,
46 },
47 RSA {
49 #[serde(flatten)]
50 public: RsaPublicKey,
51 #[serde(flatten, skip_serializing_if = "Option::is_none")]
52 private: Option<RsaPrivateKey>,
53 },
54 #[serde(rename = "oct")]
56 OCT {
57 #[serde(
59 rename = "k",
60 default,
61 serialize_with = "serialize_base64url",
62 deserialize_with = "deserialize_base64url"
63 )]
64 secret: Vec<u8>,
65 },
66 OKP {
68 #[serde(rename = "crv")]
69 curve: EllipticCurve,
70 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
71 x: Vec<u8>,
72 #[serde(
73 rename = "d",
74 default,
75 skip_serializing_if = "Option::is_none",
76 serialize_with = "serialize_base64url_optional",
77 deserialize_with = "deserialize_base64url_optional"
78 )]
79 d: Option<Vec<u8>>,
80 },
81}
82
83#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
87pub enum EllipticCurve {
88 #[serde(rename = "P-256")]
89 P256,
90 #[serde(rename = "P-384")]
91 P384,
92 #[serde(rename = "Ed25519")]
96 Ed25519,
97}
98
99#[derive(Clone, Serialize, Deserialize)]
103pub struct RsaPublicKey {
104 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
105 pub n: Vec<u8>,
106 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
107 pub e: Vec<u8>,
108}
109
110#[derive(Clone, Serialize, Deserialize)]
114pub struct RsaPrivateKey {
115 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
116 pub d: Vec<u8>,
117 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
118 pub p: Vec<u8>,
119 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
120 pub q: Vec<u8>,
121 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
122 pub dp: Vec<u8>,
123 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
124 pub dq: Vec<u8>,
125 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
126 pub qi: Vec<u8>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub oth: Option<Vec<RsaAdditionalPrime>>,
129}
130
131#[derive(Clone, Serialize, Deserialize)]
133pub struct RsaAdditionalPrime {
134 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
135 pub r: Vec<u8>,
136 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
137 pub d: Vec<u8>,
138 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
139 pub t: Vec<u8>,
140}
141
142#[derive(Clone, Serialize, Deserialize)]
145#[serde(remote = "Self")]
146pub struct Key {
147 #[serde(rename = "alg")]
149 pub algorithm: Algorithm,
150
151 #[serde(rename = "key_ops")]
153 pub operations: HashSet<KeyOperation>,
154
155 #[serde(flatten)]
157 pub key: KeyType,
158
159 #[serde(skip_serializing_if = "Option::is_none")]
161 pub kid: Option<String>,
162
163 #[serde(skip)]
165 pub(crate) decode: OnceLock<DecodingKey>,
166
167 #[serde(skip)]
168 pub(crate) encode: OnceLock<EncodingKey>,
169}
170
171impl<'de> Deserialize<'de> for Key {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 let mut value = serde_json::Value::deserialize(deserializer)?;
177
178 if let Some(obj) = value.as_object_mut() {
182 if !obj.contains_key("kty") {
183 obj.insert("kty".to_string(), serde_json::Value::String("oct".to_string()));
184 }
185 }
186
187 Self::deserialize(value).map_err(serde::de::Error::custom)
188 }
189}
190
191impl Serialize for Key {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: Serializer,
195 {
196 Self::serialize(self, serializer)
197 }
198}
199
200impl fmt::Debug for Key {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("Key")
203 .field("algorithm", &self.algorithm)
204 .field("operations", &self.operations)
205 .field("kid", &self.kid)
206 .finish()
207 }
208}
209
210impl Key {
211 #[allow(clippy::should_implement_trait)]
212 pub fn from_str(s: &str) -> anyhow::Result<Self> {
213 Ok(serde_json::from_str(s)?)
214 }
215
216 pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
217 let contents = std::fs::read_to_string(&path)?;
218 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
220 let json = String::from_utf8(decoded)?;
221 Ok(serde_json::from_str(&json)?)
222 }
223
224 pub fn to_str(&self) -> anyhow::Result<String> {
225 Ok(serde_json::to_string(self)?)
226 }
227
228 pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
229 let json = serde_json::to_string(self)?;
231 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
233 std::fs::write(path, encoded)?;
234 Ok(())
235 }
236
237 pub fn to_public(&self) -> anyhow::Result<Self> {
238 if !self.operations.contains(&KeyOperation::Verify) {
239 return Err(anyhow::anyhow!("This key doesn't support the Verify operation"));
240 }
241
242 let key = match self.key {
243 KeyType::RSA { ref public, .. } => Ok(KeyType::RSA {
244 public: public.clone(),
245 private: None,
246 }),
247 KeyType::EC {
248 ref x,
249 ref y,
250 ref curve,
251 ..
252 } => Ok(KeyType::EC {
253 x: x.clone(),
254 y: y.clone(),
255 curve: curve.clone(),
256 d: None,
257 }),
258 KeyType::OCT { .. } => Err(anyhow::anyhow!("OCT key cannot be converted to public key")),
259 KeyType::OKP { ref x, ref curve, .. } => Ok(KeyType::OKP {
260 x: x.clone(),
261 curve: curve.clone(),
262 d: None,
263 }),
264 };
265
266 match key {
267 Ok(key) => Ok(Self {
268 algorithm: self.algorithm,
269 operations: [KeyOperation::Verify].into(),
270 key,
271 kid: self.kid.clone(),
272 decode: Default::default(),
273 encode: Default::default(),
274 }),
275 Err(err) => Err(anyhow::anyhow!("Failed to convert key: {}", err)),
276 }
277 }
278
279 fn to_decoding_key(&self) -> anyhow::Result<&DecodingKey> {
280 if let Some(key) = self.decode.get() {
281 return Ok(key);
282 }
283
284 let decoding_key = match self.key {
285 KeyType::OCT { ref secret } => match self.algorithm {
286 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret),
287 _ => bail!("Invalid algorithm for key type OCT"),
288 },
289 KeyType::EC {
290 ref curve,
291 ref x,
292 ref y,
293 ..
294 } => match curve {
295 EllipticCurve::P256 => {
296 if self.algorithm != Algorithm::ES256 {
297 bail!("Invalid algorithm for P-256 curve");
298 }
299 if x.len() != 32 || y.len() != 32 {
300 bail!("Invalid coordinate length for P-256");
301 }
302
303 DecodingKey::from_ec_components(
304 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
305 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
306 )?
307 }
308 EllipticCurve::P384 => {
309 if self.algorithm != Algorithm::ES384 {
310 bail!("Invalid algorithm for P-384 curve");
311 }
312 if x.len() != 48 || y.len() != 48 {
313 bail!("Invalid coordinate length for P-384");
314 }
315
316 DecodingKey::from_ec_components(
317 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
318 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
319 )?
320 }
321 _ => bail!("Invalid curve for EC key"),
322 },
323 KeyType::OKP { ref curve, ref x, .. } => match curve {
324 EllipticCurve::Ed25519 => {
325 if self.algorithm != Algorithm::EdDSA {
326 bail!("Invalid algorithm for Ed25519 curve");
327 }
328
329 DecodingKey::from_ed_components(
330 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
331 )?
332 }
333 _ => bail!("Invalid curve for OKP key"),
334 },
335 KeyType::RSA { ref public, .. } => {
336 DecodingKey::from_rsa_raw_components(public.n.as_ref(), public.e.as_ref())
337 }
338 };
339
340 Ok(self.decode.get_or_init(|| decoding_key))
341 }
342
343 fn to_encoding_key(&self) -> anyhow::Result<&EncodingKey> {
344 if let Some(key) = self.encode.get() {
345 return Ok(key);
346 }
347
348 let encoding_key = match self.key {
349 KeyType::OCT { ref secret } => match self.algorithm {
350 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret),
351 _ => bail!("Invalid algorithm for key type OCT"),
352 },
353 KeyType::EC { ref curve, ref d, .. } => {
354 let d = d.as_ref().context("Missing private key")?;
355
356 match curve {
357 EllipticCurve::P256 => {
358 let secret_key = SecretKey::<p256::NistP256>::from_slice(d)?;
359 let doc = secret_key.to_pkcs8_der()?;
360 EncodingKey::from_ec_der(doc.as_bytes())
361 }
362 EllipticCurve::P384 => {
363 let secret_key = SecretKey::<p384::NistP384>::from_slice(d)?;
364 let doc = secret_key.to_pkcs8_der()?;
365 EncodingKey::from_ec_der(doc.as_bytes())
366 }
367 _ => bail!("Invalid curve for EC key"),
368 }
369 }
370 KeyType::OKP {
371 ref curve,
372 ref d,
373 ref x,
374 } => {
375 let d = d.as_ref().context("Missing private key")?;
376
377 let key_pair =
378 aws_lc_rs::signature::Ed25519KeyPair::from_seed_and_public_key(d.as_slice(), x.as_slice())?;
379
380 match curve {
381 EllipticCurve::Ed25519 => EncodingKey::from_ed_der(key_pair.to_pkcs8()?.as_ref()),
382 _ => bail!("Invalid curve for OKP key"),
383 }
384 }
385 KeyType::RSA {
386 ref public,
387 ref private,
388 } => {
389 let n = BigUint::from_bytes_be(&public.n);
390 let e = BigUint::from_bytes_be(&public.e);
391 let private = private.as_ref().context("Missing private key")?;
392 let d = BigUint::from_bytes_be(&private.d);
393 let p = BigUint::from_bytes_be(&private.p);
394 let q = BigUint::from_bytes_be(&private.q);
395
396 let rsa = rsa::RsaPrivateKey::from_components(n, e, d, vec![p, q]);
397 let pem = rsa?.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF);
398
399 EncodingKey::from_rsa_pem(pem?.as_bytes())?
400 }
401 };
402
403 Ok(self.encode.get_or_init(|| encoding_key))
404 }
405
406 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
407 if !self.operations.contains(&KeyOperation::Verify) {
408 bail!("key does not support verification");
409 }
410
411 let decode: anyhow::Result<&DecodingKey> = self.to_decoding_key();
412
413 match decode {
414 Ok(decode) => {
415 let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
416 validation.required_spec_claims = Default::default(); let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
419 token.claims.validate()?;
420
421 Ok(token.claims)
422 }
423 Err(e) => Err(anyhow::anyhow!("Failed to decode key: {}", e)),
424 }
425 }
426
427 pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
428 if !self.operations.contains(&KeyOperation::Sign) {
429 bail!("key does not support signing");
430 }
431
432 payload.validate()?;
433
434 let encode: anyhow::Result<&EncodingKey> = self.to_encoding_key();
435
436 match encode {
437 Ok(encode) => {
438 let mut header = Header::new(self.algorithm.into());
439 header.kid = self.kid.clone();
440 let token = jsonwebtoken::encode(&header, &payload, encode)?;
441 Ok(token)
442 }
443 Err(e) => Err(anyhow::anyhow!("Failed to encode key: {}", e)),
444 }
445 }
446
447 pub fn generate(algorithm: Algorithm, id: Option<String>) -> anyhow::Result<Self> {
449 generate(algorithm, id)
450 }
451}
452
453fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
455where
456 S: Serializer,
457{
458 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
459 serializer.serialize_str(&encoded)
460}
461
462fn serialize_base64url_optional<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
463where
464 S: Serializer,
465{
466 match bytes {
467 Some(b) => serialize_base64url(b, serializer),
468 None => serializer.serialize_none(),
469 }
470}
471
472fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
474where
475 D: Deserializer<'de>,
476{
477 let s = String::deserialize(deserializer)?;
478
479 base64::engine::general_purpose::URL_SAFE_NO_PAD
481 .decode(&s)
482 .or_else(|_| {
483 base64::engine::general_purpose::URL_SAFE.decode(&s)
485 })
486 .map_err(serde::de::Error::custom)
487}
488
489fn deserialize_base64url_optional<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
490where
491 D: Deserializer<'de>,
492{
493 let s: Option<String> = Option::deserialize(deserializer)?;
494 match s {
495 Some(s) => {
496 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
497 .decode(&s)
498 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(&s))
499 .map_err(serde::de::Error::custom)?;
500 Ok(Some(decoded))
501 }
502 None => Ok(None),
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use std::time::{Duration, SystemTime};
510
511 fn create_test_key() -> Key {
512 Key {
513 algorithm: Algorithm::HS256,
514 operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
515 key: KeyType::OCT {
516 secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
517 },
518 kid: Some("test-key-1".to_string()),
519 decode: Default::default(),
520 encode: Default::default(),
521 }
522 }
523
524 fn create_test_claims() -> Claims {
525 Claims {
526 root: "test-path".to_string(),
527 publish: vec!["test-pub".into()],
528 cluster: false,
529 subscribe: vec!["test-sub".into()],
530 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
531 issued: Some(SystemTime::now()),
532 }
533 }
534
535 #[test]
536 fn test_key_from_str_valid() {
537 let key = create_test_key();
538 let json = key.to_str().unwrap();
539 let loaded_key = Key::from_str(&json).unwrap();
540
541 assert_eq!(loaded_key.algorithm, key.algorithm);
542 assert_eq!(loaded_key.operations, key.operations);
543 match (loaded_key.key, key.key) {
544 (KeyType::OCT { secret: loaded_secret }, KeyType::OCT { secret }) => {
545 assert_eq!(loaded_secret, secret);
546 }
547 _ => panic!("Expected OCT key"),
548 }
549 assert_eq!(loaded_key.kid, key.kid);
550 }
551
552 #[test]
554 fn test_key_oct_backwards_compatibility() {
555 let json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68"}"#;
556 let key = Key::from_str(json);
557
558 assert!(key.is_ok());
559 let key = key.unwrap();
560
561 if let KeyType::OCT { ref secret, .. } = key.key {
562 let base64_key = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(secret);
563 assert_eq!(base64_key, "Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68");
564 } else {
565 panic!("Expected OCT key");
566 }
567
568 let key_str = key.to_str();
569 assert!(key_str.is_ok());
570 let key_str = key_str.unwrap();
571
572 assert!(key_str.contains("\"alg\":\"HS256\""));
574 assert!(key_str.contains("\"key_ops\""));
575 assert!(key_str.contains("\"sign\""));
576 assert!(key_str.contains("\"verify\""));
577 assert!(key_str.contains("\"kty\":\"oct\""));
578 }
579
580 #[test]
581 fn test_key_from_str_invalid_json() {
582 let result = Key::from_str("invalid json");
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_key_to_str() {
588 let key = create_test_key();
589 let json = key.to_str().unwrap();
590 assert!(json.contains("\"alg\":\"HS256\""));
591 assert!(json.contains("\"key_ops\""));
592 assert!(json.contains("\"sign\""));
593 assert!(json.contains("\"verify\""));
594 assert!(json.contains("\"kid\":\"test-key-1\""));
595 assert!(json.contains("\"kty\":\"oct\""));
596 }
597
598 #[test]
599 fn test_key_sign_success() {
600 let key = create_test_key();
601 let claims = create_test_claims();
602 let token = key.encode(&claims).unwrap();
603
604 assert!(!token.is_empty());
605 assert_eq!(token.matches('.').count(), 2); }
607
608 #[test]
609 fn test_key_sign_no_permission() {
610 let mut key = create_test_key();
611 key.operations = [KeyOperation::Verify].into();
612 let claims = create_test_claims();
613
614 let result = key.encode(&claims);
615 assert!(result.is_err());
616 assert!(result.unwrap_err().to_string().contains("key does not support signing"));
617 }
618
619 #[test]
620 fn test_key_sign_invalid_claims() {
621 let key = create_test_key();
622 let invalid_claims = Claims {
623 root: "test-path".to_string(),
624 publish: vec![],
625 subscribe: vec![],
626 cluster: false,
627 expires: None,
628 issued: None,
629 };
630
631 let result = key.encode(&invalid_claims);
632 assert!(result.is_err());
633 assert!(result
634 .unwrap_err()
635 .to_string()
636 .contains("no publish or subscribe allowed; token is useless"));
637 }
638
639 #[test]
640 fn test_key_verify_success() {
641 let key = create_test_key();
642 let claims = create_test_claims();
643 let token = key.encode(&claims).unwrap();
644
645 let verified_claims = key.decode(&token).unwrap();
646 assert_eq!(verified_claims.root, claims.root);
647 assert_eq!(verified_claims.publish, claims.publish);
648 assert_eq!(verified_claims.subscribe, claims.subscribe);
649 assert_eq!(verified_claims.cluster, claims.cluster);
650 }
651
652 #[test]
653 fn test_key_verify_no_permission() {
654 let mut key = create_test_key();
655 key.operations = [KeyOperation::Sign].into();
656
657 let result = key.decode("some.jwt.token");
658 assert!(result.is_err());
659 assert!(result
660 .unwrap_err()
661 .to_string()
662 .contains("key does not support verification"));
663 }
664
665 #[test]
666 fn test_key_verify_invalid_token() {
667 let key = create_test_key();
668 let result = key.decode("invalid-token");
669 assert!(result.is_err());
670 }
671
672 #[test]
673 fn test_key_verify_path_mismatch() {
674 let key = create_test_key();
675 let claims = create_test_claims();
676 let token = key.encode(&claims).unwrap();
677
678 let result = key.decode(&token);
680 assert!(result.is_ok());
681 }
682
683 #[test]
684 fn test_key_verify_expired_token() {
685 let key = create_test_key();
686 let mut claims = create_test_claims();
687 claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); let token = key.encode(&claims).unwrap();
689
690 let result = key.decode(&token);
691 assert!(result.is_err());
692 }
693
694 #[test]
695 fn test_key_verify_token_without_exp() {
696 let key = create_test_key();
697 let claims = Claims {
698 root: "test-path".to_string(),
699 publish: vec!["".to_string()],
700 subscribe: vec!["".to_string()],
701 cluster: false,
702 expires: None,
703 issued: None,
704 };
705 let token = key.encode(&claims).unwrap();
706
707 let verified_claims = key.decode(&token).unwrap();
708 assert_eq!(verified_claims.root, claims.root);
709 assert_eq!(verified_claims.publish, claims.publish);
710 assert_eq!(verified_claims.subscribe, claims.subscribe);
711 assert_eq!(verified_claims.expires, None);
712 }
713
714 #[test]
715 fn test_key_round_trip() {
716 let key = create_test_key();
717 let original_claims = Claims {
718 root: "test-path".to_string(),
719 publish: vec!["test-pub".into()],
720 subscribe: vec!["test-sub".into()],
721 cluster: true,
722 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
723 issued: Some(SystemTime::now()),
724 };
725
726 let token = key.encode(&original_claims).unwrap();
727 let verified_claims = key.decode(&token).unwrap();
728
729 assert_eq!(verified_claims.root, original_claims.root);
730 assert_eq!(verified_claims.publish, original_claims.publish);
731 assert_eq!(verified_claims.subscribe, original_claims.subscribe);
732 assert_eq!(verified_claims.cluster, original_claims.cluster);
733 }
734
735 #[test]
736 fn test_key_generate_hs256() {
737 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
738 assert!(key.is_ok());
739 let key = key.unwrap();
740
741 assert_eq!(key.algorithm, Algorithm::HS256);
742 assert_eq!(key.kid, Some("test-id".to_string()));
743 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
744
745 match key.key {
746 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 32),
747 _ => panic!("Expected OCT key"),
748 }
749 }
750
751 #[test]
752 fn test_key_generate_hs384() {
753 let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
754 assert!(key.is_ok());
755 let key = key.unwrap();
756
757 assert_eq!(key.algorithm, Algorithm::HS384);
758
759 match key.key {
760 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 48),
761 _ => panic!("Expected OCT key"),
762 }
763 }
764
765 #[test]
766 fn test_key_generate_hs512() {
767 let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
768 assert!(key.is_ok());
769 let key = key.unwrap();
770
771 assert_eq!(key.algorithm, Algorithm::HS512);
772
773 match key.key {
774 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 64),
775 _ => panic!("Expected OCT key"),
776 }
777 }
778
779 #[test]
780 fn test_key_generate_rs512() {
781 let key = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
782 assert!(key.is_ok());
783 let key = key.unwrap();
784
785 assert_eq!(key.algorithm, Algorithm::RS512);
786 assert!(matches!(key.key, KeyType::RSA { .. }));
787 match key.key {
788 KeyType::RSA {
789 ref public,
790 ref private,
791 } => {
792 assert!(private.is_some());
793 assert_eq!(public.n.len(), 256);
794 assert_eq!(public.e.len(), 3);
795 }
796 _ => panic!("Expected RSA key"),
797 }
798 }
799
800 #[test]
801 fn test_key_generate_es256() {
802 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
803 assert!(key.is_ok());
804 let key = key.unwrap();
805
806 assert_eq!(key.algorithm, Algorithm::ES256);
807 assert!(matches!(key.key, KeyType::EC { .. }))
808 }
809
810 #[test]
811 fn test_key_generate_ps512() {
812 let key = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
813 assert!(key.is_ok());
814 let key = key.unwrap();
815
816 assert_eq!(key.algorithm, Algorithm::PS512);
817 assert!(matches!(key.key, KeyType::RSA { .. }));
818 }
819
820 #[test]
821 fn test_key_generate_eddsa() {
822 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
823 assert!(key.is_ok());
824 let key = key.unwrap();
825
826 assert_eq!(key.algorithm, Algorithm::EdDSA);
827 assert!(matches!(key.key, KeyType::OKP { .. }));
828 }
829
830 #[test]
831 fn test_key_generate_without_id() {
832 let key = Key::generate(Algorithm::HS256, None);
833 assert!(key.is_ok());
834 let key = key.unwrap();
835
836 assert_eq!(key.algorithm, Algorithm::HS256);
837 assert_eq!(key.kid, None);
838 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
839 }
840
841 #[test]
842 fn test_public_key_conversion_hmac() {
843 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string())).expect("HMAC key generation failed");
844
845 assert!(key.to_public().is_err());
846 }
847
848 #[test]
849 fn test_public_key_conversion_rsa() {
850 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
851 assert!(key.is_ok());
852 let key = key.unwrap();
853
854 let public_key = key.to_public().unwrap();
855 assert_eq!(key.kid, public_key.kid);
856 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
857 assert!(public_key.encode.get().is_none());
858 assert!(public_key.decode.get().is_none());
859 assert!(matches!(public_key.key, KeyType::RSA { .. }));
860
861 if let KeyType::RSA { public, private } = &public_key.key {
862 assert!(private.is_none());
863
864 if let KeyType::RSA { public: src_public, .. } = &key.key {
865 assert_eq!(public.e, src_public.e);
866 assert_eq!(public.n, src_public.n);
867 } else {
868 unreachable!("Expected RSA key")
869 }
870 } else {
871 unreachable!("Expected RSA key");
872 }
873 }
874
875 #[test]
876 fn test_public_key_conversion_es() {
877 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
878 assert!(key.is_ok());
879 let key = key.unwrap();
880
881 let public_key = key.to_public().unwrap();
882 assert_eq!(key.kid, public_key.kid);
883 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
884 assert!(public_key.encode.get().is_none());
885 assert!(public_key.decode.get().is_none());
886 assert!(matches!(public_key.key, KeyType::EC { .. }));
887
888 if let KeyType::EC { x, y, d, curve } = &public_key.key {
889 assert!(d.is_none());
890
891 if let KeyType::EC {
892 x: src_x,
893 y: src_y,
894 curve: src_curve,
895 ..
896 } = &key.key
897 {
898 assert_eq!(x, src_x);
899 assert_eq!(y, src_y);
900 assert_eq!(curve, src_curve);
901 } else {
902 unreachable!("Expected EC key")
903 }
904 } else {
905 unreachable!("Expected EC key");
906 }
907 }
908
909 #[test]
910 fn test_public_key_conversion_ed() {
911 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
912 assert!(key.is_ok());
913 let key = key.unwrap();
914
915 let public_key = key.to_public().unwrap();
916 assert_eq!(key.kid, public_key.kid);
917 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
918 assert!(public_key.encode.get().is_none());
919 assert!(public_key.decode.get().is_none());
920 assert!(matches!(public_key.key, KeyType::OKP { .. }));
921
922 if let KeyType::OKP { x, d, curve } = &public_key.key {
923 assert!(d.is_none());
924
925 if let KeyType::OKP {
926 x: src_x,
927 curve: src_curve,
928 ..
929 } = &key.key
930 {
931 assert_eq!(x, src_x);
932 assert_eq!(curve, src_curve);
933 } else {
934 unreachable!("Expected OKP key")
935 }
936 } else {
937 unreachable!("Expected OKP key");
938 }
939 }
940
941 #[test]
942 fn test_key_generate_sign_verify_cycle() {
943 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
944 assert!(key.is_ok());
945 let key = key.unwrap();
946
947 let claims = create_test_claims();
948
949 let token = key.encode(&claims).unwrap();
950 let verified_claims = key.decode(&token).unwrap();
951
952 assert_eq!(verified_claims.root, claims.root);
953 assert_eq!(verified_claims.publish, claims.publish);
954 assert_eq!(verified_claims.subscribe, claims.subscribe);
955 assert_eq!(verified_claims.cluster, claims.cluster);
956 }
957
958 #[test]
959 fn test_key_debug_no_secret() {
960 let key = create_test_key();
961 let debug_str = format!("{key:?}");
962
963 assert!(debug_str.contains("algorithm: HS256"));
964 assert!(debug_str.contains("operations"));
965 assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
966 assert!(!debug_str.contains("secret")); }
968
969 #[test]
970 fn test_key_operations_enum() {
971 let sign_op = KeyOperation::Sign;
972 let verify_op = KeyOperation::Verify;
973 let decrypt_op = KeyOperation::Decrypt;
974 let encrypt_op = KeyOperation::Encrypt;
975
976 assert_eq!(sign_op, KeyOperation::Sign);
977 assert_eq!(verify_op, KeyOperation::Verify);
978 assert_eq!(decrypt_op, KeyOperation::Decrypt);
979 assert_eq!(encrypt_op, KeyOperation::Encrypt);
980
981 assert_ne!(sign_op, verify_op);
982 assert_ne!(decrypt_op, encrypt_op);
983 }
984
985 #[test]
986 fn test_key_operations_serde() {
987 let operations = [KeyOperation::Sign, KeyOperation::Verify];
988 let json = serde_json::to_string(&operations).unwrap();
989 assert!(json.contains("\"sign\""));
990 assert!(json.contains("\"verify\""));
991
992 let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
993 assert_eq!(deserialized, operations);
994 }
995
996 #[test]
997 fn test_key_serde() {
998 let key = create_test_key();
999 let json = serde_json::to_string(&key).unwrap();
1000 let deserialized: Key = serde_json::from_str(&json).unwrap();
1001
1002 assert_eq!(deserialized.algorithm, key.algorithm);
1003 assert_eq!(deserialized.operations, key.operations);
1004 assert_eq!(deserialized.kid, key.kid);
1005
1006 if let (
1007 KeyType::OCT {
1008 secret: original_secret,
1009 },
1010 KeyType::OCT {
1011 secret: deserialized_secret,
1012 },
1013 ) = (&key.key, &deserialized.key)
1014 {
1015 assert_eq!(deserialized_secret, original_secret);
1016 } else {
1017 panic!("Expected both keys to be OCT variant");
1018 }
1019 }
1020
1021 #[test]
1022 fn test_key_clone() {
1023 let key = create_test_key();
1024 let cloned = key.clone();
1025
1026 assert_eq!(cloned.algorithm, key.algorithm);
1027 assert_eq!(cloned.operations, key.operations);
1028 assert_eq!(cloned.kid, key.kid);
1029
1030 if let (
1031 KeyType::OCT {
1032 secret: original_secret,
1033 },
1034 KeyType::OCT { secret: cloned_secret },
1035 ) = (&key.key, &cloned.key)
1036 {
1037 assert_eq!(cloned_secret, original_secret);
1038 } else {
1039 panic!("Expected both keys to be OCT variant");
1040 }
1041 }
1042
1043 #[test]
1044 fn test_hmac_algorithms() {
1045 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1046 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1047 let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
1048
1049 let claims = create_test_claims();
1050
1051 for key in [key_256, key_384, key_512] {
1053 assert!(key.is_ok());
1054 let key = key.unwrap();
1055
1056 let token = key.encode(&claims).unwrap();
1057 let verified_claims = key.decode(&token).unwrap();
1058 assert_eq!(verified_claims.root, claims.root);
1059 }
1060 }
1061
1062 #[test]
1063 fn test_rsa_pkcs1_asymmetric_algorithms() {
1064 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1065 let key_rs384 = Key::generate(Algorithm::RS384, Some("test-id".to_string()));
1066 let key_rs512 = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
1067
1068 for key in [key_rs256, key_rs384, key_rs512] {
1069 test_asymmetric_key(key);
1070 }
1071 }
1072
1073 #[test]
1074 fn test_rsa_pss_asymmetric_algorithms() {
1075 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1076 let key_ps384 = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1077 let key_ps512 = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
1078
1079 for key in [key_ps256, key_ps384, key_ps512] {
1080 test_asymmetric_key(key);
1081 }
1082 }
1083
1084 #[test]
1085 fn test_ec_asymmetric_algorithms() {
1086 let key_es256 = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
1087 let key_es384 = Key::generate(Algorithm::ES384, Some("test-id".to_string()));
1088
1089 for key in [key_es256, key_es384] {
1090 test_asymmetric_key(key);
1091 }
1092 }
1093
1094 #[test]
1095 fn test_ed_asymmetric_algorithms() {
1096 let key_eddsa = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
1097
1098 test_asymmetric_key(key_eddsa);
1099 }
1100
1101 fn test_asymmetric_key(key: anyhow::Result<Key>) {
1102 assert!(key.is_ok());
1103 let key = key.unwrap();
1104
1105 let claims = create_test_claims();
1106 let token = key.encode(&claims).unwrap();
1107
1108 let private_verified_claims = key.decode(&token).unwrap();
1109 assert_eq!(
1110 private_verified_claims.root, claims.root,
1111 "validation using private key"
1112 );
1113
1114 let public_verified_claims = key.to_public().unwrap().decode(&token).unwrap();
1115 assert_eq!(public_verified_claims.root, claims.root, "validation using public key");
1116 }
1117
1118 #[test]
1119 fn test_cross_algorithm_verification_fails() {
1120 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1121 assert!(key_256.is_ok());
1122 let key_256 = key_256.unwrap();
1123
1124 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1125 assert!(key_384.is_ok());
1126 let key_384 = key_384.unwrap();
1127
1128 let claims = create_test_claims();
1129 let token = key_256.encode(&claims).unwrap();
1130
1131 let result = key_384.decode(&token);
1133 assert!(result.is_err());
1134 }
1135
1136 #[test]
1137 fn test_asymmetric_cross_algorithm_verification_fails() {
1138 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1139 assert!(key_rs256.is_ok());
1140 let key_rs256 = key_rs256.unwrap();
1141
1142 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1143 assert!(key_ps256.is_ok());
1144 let key_ps256 = key_ps256.unwrap();
1145
1146 let claims = create_test_claims();
1147 let token = key_rs256.encode(&claims).unwrap();
1148
1149 let private_result = key_ps256.decode(&token);
1151 let public_result = key_ps256.to_public().unwrap().decode(&token);
1152 assert!(private_result.is_err());
1153 assert!(public_result.is_err());
1154 }
1155
1156 #[test]
1157 fn test_rsa_pkcs1_public_key_conversion() {
1158 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1159 assert!(key.is_ok());
1160 let key = key.unwrap();
1161
1162 assert!(key.operations.contains(&KeyOperation::Sign));
1163 assert!(key.operations.contains(&KeyOperation::Verify));
1164
1165 let public_key = key.to_public().unwrap();
1166 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1167 assert!(public_key.operations.contains(&KeyOperation::Verify));
1168
1169 match key.key {
1170 KeyType::RSA {
1171 ref public,
1172 ref private,
1173 } => {
1174 assert!(private.is_some());
1175 assert_eq!(public.n.len(), 256);
1176 assert_eq!(public.e.len(), 3);
1177
1178 match public_key.key {
1179 KeyType::RSA {
1180 public: ref public_public,
1181 private: ref public_private,
1182 } => {
1183 assert!(public_private.is_none());
1184 assert_eq!(public.n, public_public.n);
1185 assert_eq!(public.e, public_public.e);
1186 }
1187 _ => panic!("Expected public key to be an RSA key"),
1188 }
1189 }
1190 _ => panic!("Expected private key to be an RSA key"),
1191 }
1192 }
1193
1194 #[test]
1195 fn test_rsa_pss_public_key_conversion() {
1196 let key = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1197 assert!(key.is_ok());
1198 let key = key.unwrap();
1199
1200 assert!(key.operations.contains(&KeyOperation::Sign));
1201 assert!(key.operations.contains(&KeyOperation::Verify));
1202
1203 let public_key = key.to_public().unwrap();
1204 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1205 assert!(public_key.operations.contains(&KeyOperation::Verify));
1206
1207 match key.key {
1208 KeyType::RSA {
1209 ref public,
1210 ref private,
1211 } => {
1212 assert!(private.is_some());
1213 assert_eq!(public.n.len(), 256);
1214 assert_eq!(public.e.len(), 3);
1215
1216 match public_key.key {
1217 KeyType::RSA {
1218 public: ref public_public,
1219 private: ref public_private,
1220 } => {
1221 assert!(public_private.is_none());
1222 assert_eq!(public.n, public_public.n);
1223 assert_eq!(public.e, public_public.e);
1224 }
1225 _ => panic!("Expected public key to be an RSA key"),
1226 }
1227 }
1228 _ => panic!("Expected private key to be an RSA key"),
1229 }
1230 }
1231
1232 #[test]
1233 fn test_base64url_serialization() {
1234 let key = create_test_key();
1235 let json = serde_json::to_string(&key).unwrap();
1236
1237 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1239 let k_value = parsed["k"].as_str().unwrap();
1240
1241 assert!(!k_value.contains('='));
1243 assert!(!k_value.contains('+'));
1244 assert!(!k_value.contains('/'));
1245
1246 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1248 .decode(k_value)
1249 .unwrap();
1250
1251 if let KeyType::OCT {
1252 secret: original_secret,
1253 } = &key.key
1254 {
1255 assert_eq!(decoded, *original_secret);
1256 } else {
1257 panic!("Expected both keys to be OCT variant");
1258 }
1259 }
1260
1261 #[test]
1262 fn test_backwards_compatibility_unpadded_base64url() {
1263 let unpadded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
1265
1266 let key: Key = serde_json::from_str(unpadded_json).unwrap();
1268 assert_eq!(key.algorithm, Algorithm::HS256);
1269 assert_eq!(key.kid, Some("test-key-1".to_string()));
1270
1271 if let KeyType::OCT { secret } = &key.key {
1272 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1273 } else {
1274 panic!("Expected key to be OCT variant");
1275 }
1276 }
1277
1278 #[test]
1279 fn test_backwards_compatibility_padded_base64url() {
1280 let padded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
1282
1283 let key: Key = serde_json::from_str(padded_json).unwrap();
1285 assert_eq!(key.algorithm, Algorithm::HS256);
1286 assert_eq!(key.kid, Some("test-key-1".to_string()));
1287
1288 if let KeyType::OCT { secret } = &key.key {
1289 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1290 } else {
1291 panic!("Expected key to be OCT variant");
1292 }
1293 }
1294
1295 #[test]
1296 fn test_file_io_base64url() {
1297 let key = create_test_key();
1298 let temp_dir = std::env::temp_dir();
1299 let temp_path = temp_dir.join("test_jwk.key");
1300
1301 key.to_file(&temp_path).unwrap();
1303
1304 let contents = std::fs::read_to_string(&temp_path).unwrap();
1306
1307 assert!(!contents.contains('{'));
1309 assert!(!contents.contains('}'));
1310 assert!(!contents.contains('"'));
1311
1312 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1314 .decode(&contents)
1315 .unwrap();
1316 let json_str = String::from_utf8(decoded).unwrap();
1317 let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1318
1319 let loaded_key = Key::from_file(&temp_path).unwrap();
1321 assert_eq!(loaded_key.algorithm, key.algorithm);
1322 assert_eq!(loaded_key.operations, key.operations);
1323 assert_eq!(loaded_key.kid, key.kid);
1324
1325 if let (
1326 KeyType::OCT {
1327 secret: original_secret,
1328 },
1329 KeyType::OCT { secret: loaded_secret },
1330 ) = (&key.key, &loaded_key.key)
1331 {
1332 assert_eq!(loaded_secret, original_secret);
1333 } else {
1334 panic!("Expected both keys to be OCT variant");
1335 }
1336
1337 std::fs::remove_file(temp_path).ok();
1339 }
1340}