1use crate::generate::generate;
2use crate::{Algorithm, Claims};
3use anyhow::{Context, bail};
4use base64::Engine;
5use elliptic_curve::SecretKey;
6use elliptic_curve::pkcs8::EncodePrivateKey;
7use jsonwebtoken::{DecodingKey, EncodingKey, Header};
8use rsa::BigUint;
9use rsa::pkcs1::EncodeRsaPrivateKey;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::sync::OnceLock;
12use std::{collections::HashSet, fmt, path::Path as StdPath};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
16#[serde(rename_all = "camelCase")]
17pub enum KeyOperation {
18 Sign,
19 Verify,
20 Decrypt,
21 Encrypt,
22}
23
24#[derive(Clone, Serialize, Deserialize)]
26#[serde(tag = "kty")]
27pub enum KeyType {
28 EC {
30 #[serde(rename = "crv")]
31 curve: EllipticCurve,
32 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
34 x: Vec<u8>,
35 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
37 y: Vec<u8>,
38 #[serde(
40 default,
41 skip_serializing_if = "Option::is_none",
42 serialize_with = "serialize_base64url_optional",
43 deserialize_with = "deserialize_base64url_optional"
44 )]
45 d: Option<Vec<u8>>,
46 },
47 RSA {
49 #[serde(flatten)]
50 public: RsaPublicKey,
51 #[serde(flatten, skip_serializing_if = "Option::is_none")]
52 private: Option<RsaPrivateKey>,
53 },
54 #[serde(rename = "oct")]
56 OCT {
57 #[serde(
59 rename = "k",
60 default,
61 serialize_with = "serialize_base64url",
62 deserialize_with = "deserialize_base64url"
63 )]
64 secret: Vec<u8>,
65 },
66 OKP {
68 #[serde(rename = "crv")]
69 curve: EllipticCurve,
70 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
71 x: Vec<u8>,
72 #[serde(
73 rename = "d",
74 default,
75 skip_serializing_if = "Option::is_none",
76 serialize_with = "serialize_base64url_optional",
77 deserialize_with = "deserialize_base64url_optional"
78 )]
79 d: Option<Vec<u8>>,
80 },
81}
82
83#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
87pub enum EllipticCurve {
88 #[serde(rename = "P-256")]
89 P256,
90 #[serde(rename = "P-384")]
91 P384,
92 #[serde(rename = "Ed25519")]
96 Ed25519,
97}
98
99#[derive(Clone, Serialize, Deserialize)]
103pub struct RsaPublicKey {
104 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
105 pub n: Vec<u8>,
106 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
107 pub e: Vec<u8>,
108}
109
110#[derive(Clone, Serialize, Deserialize)]
114pub struct RsaPrivateKey {
115 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
116 pub d: Vec<u8>,
117 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
118 pub p: Vec<u8>,
119 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
120 pub q: Vec<u8>,
121 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
122 pub dp: Vec<u8>,
123 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
124 pub dq: Vec<u8>,
125 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
126 pub qi: Vec<u8>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub oth: Option<Vec<RsaAdditionalPrime>>,
129}
130
131#[derive(Clone, Serialize, Deserialize)]
133pub struct RsaAdditionalPrime {
134 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
135 pub r: Vec<u8>,
136 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
137 pub d: Vec<u8>,
138 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
139 pub t: Vec<u8>,
140}
141
142#[derive(Clone, Serialize, Deserialize)]
145#[serde(remote = "Self")]
146pub struct Key {
147 #[serde(rename = "alg")]
149 pub algorithm: Algorithm,
150
151 #[serde(rename = "key_ops")]
153 pub operations: HashSet<KeyOperation>,
154
155 #[serde(flatten)]
157 pub key: KeyType,
158
159 #[serde(skip_serializing_if = "Option::is_none")]
161 pub kid: Option<String>,
162
163 #[serde(skip)]
165 pub(crate) decode: OnceLock<DecodingKey>,
166
167 #[serde(skip)]
168 pub(crate) encode: OnceLock<EncodingKey>,
169}
170
171impl<'de> Deserialize<'de> for Key {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 let mut value = serde_json::Value::deserialize(deserializer)?;
177
178 if let Some(obj) = value.as_object_mut()
182 && !obj.contains_key("kty")
183 {
184 obj.insert("kty".to_string(), serde_json::Value::String("oct".to_string()));
185 }
186
187 Self::deserialize(value).map_err(serde::de::Error::custom)
188 }
189}
190
191impl Serialize for Key {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: Serializer,
195 {
196 Self::serialize(self, serializer)
197 }
198}
199
200impl fmt::Debug for Key {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("Key")
203 .field("algorithm", &self.algorithm)
204 .field("operations", &self.operations)
205 .field("kid", &self.kid)
206 .finish()
207 }
208}
209
210impl Key {
211 #[allow(clippy::should_implement_trait)]
212 pub fn from_str(s: &str) -> anyhow::Result<Self> {
213 Ok(serde_json::from_str(s)?)
214 }
215
216 pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
217 let contents = std::fs::read_to_string(&path)?;
218 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
220 let json = String::from_utf8(decoded)?;
221 Ok(serde_json::from_str(&json)?)
222 }
223
224 pub fn to_str(&self) -> anyhow::Result<String> {
225 Ok(serde_json::to_string(self)?)
226 }
227
228 pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
229 let json = serde_json::to_string(self)?;
231 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
233 std::fs::write(path, encoded)?;
234 Ok(())
235 }
236
237 pub fn to_public(&self) -> anyhow::Result<Self> {
238 if !self.operations.contains(&KeyOperation::Verify) {
239 return Err(anyhow::anyhow!("This key doesn't support the Verify operation"));
240 }
241
242 let key = match self.key {
243 KeyType::RSA { ref public, .. } => Ok(KeyType::RSA {
244 public: public.clone(),
245 private: None,
246 }),
247 KeyType::EC {
248 ref x,
249 ref y,
250 ref curve,
251 ..
252 } => Ok(KeyType::EC {
253 x: x.clone(),
254 y: y.clone(),
255 curve: curve.clone(),
256 d: None,
257 }),
258 KeyType::OCT { .. } => Err(anyhow::anyhow!("OCT key cannot be converted to public key")),
259 KeyType::OKP { ref x, ref curve, .. } => Ok(KeyType::OKP {
260 x: x.clone(),
261 curve: curve.clone(),
262 d: None,
263 }),
264 };
265
266 match key {
267 Ok(key) => Ok(Self {
268 algorithm: self.algorithm,
269 operations: [KeyOperation::Verify].into(),
270 key,
271 kid: self.kid.clone(),
272 decode: Default::default(),
273 encode: Default::default(),
274 }),
275 Err(err) => Err(anyhow::anyhow!("Failed to convert key: {}", err)),
276 }
277 }
278
279 fn to_decoding_key(&self) -> anyhow::Result<&DecodingKey> {
280 if let Some(key) = self.decode.get() {
281 return Ok(key);
282 }
283
284 let decoding_key = match self.key {
285 KeyType::OCT { ref secret } => match self.algorithm {
286 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret),
287 _ => bail!("Invalid algorithm for key type OCT"),
288 },
289 KeyType::EC {
290 ref curve,
291 ref x,
292 ref y,
293 ..
294 } => match curve {
295 EllipticCurve::P256 => {
296 if self.algorithm != Algorithm::ES256 {
297 bail!("Invalid algorithm for P-256 curve");
298 }
299 if x.len() != 32 || y.len() != 32 {
300 bail!("Invalid coordinate length for P-256");
301 }
302
303 DecodingKey::from_ec_components(
304 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
305 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
306 )?
307 }
308 EllipticCurve::P384 => {
309 if self.algorithm != Algorithm::ES384 {
310 bail!("Invalid algorithm for P-384 curve");
311 }
312 if x.len() != 48 || y.len() != 48 {
313 bail!("Invalid coordinate length for P-384");
314 }
315
316 DecodingKey::from_ec_components(
317 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
318 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
319 )?
320 }
321 _ => bail!("Invalid curve for EC key"),
322 },
323 KeyType::OKP { ref curve, ref x, .. } => match curve {
324 EllipticCurve::Ed25519 => {
325 if self.algorithm != Algorithm::EdDSA {
326 bail!("Invalid algorithm for Ed25519 curve");
327 }
328
329 DecodingKey::from_ed_components(
330 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
331 )?
332 }
333 _ => bail!("Invalid curve for OKP key"),
334 },
335 KeyType::RSA { ref public, .. } => {
336 DecodingKey::from_rsa_raw_components(public.n.as_ref(), public.e.as_ref())
337 }
338 };
339
340 Ok(self.decode.get_or_init(|| decoding_key))
341 }
342
343 fn to_encoding_key(&self) -> anyhow::Result<&EncodingKey> {
344 if let Some(key) = self.encode.get() {
345 return Ok(key);
346 }
347
348 let encoding_key = match self.key {
349 KeyType::OCT { ref secret } => match self.algorithm {
350 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret),
351 _ => bail!("Invalid algorithm for key type OCT"),
352 },
353 KeyType::EC { ref curve, ref d, .. } => {
354 let d = d.as_ref().context("Missing private key")?;
355
356 match curve {
357 EllipticCurve::P256 => {
358 let secret_key = SecretKey::<p256::NistP256>::from_slice(d)?;
359 let doc = secret_key.to_pkcs8_der()?;
360 EncodingKey::from_ec_der(doc.as_bytes())
361 }
362 EllipticCurve::P384 => {
363 let secret_key = SecretKey::<p384::NistP384>::from_slice(d)?;
364 let doc = secret_key.to_pkcs8_der()?;
365 EncodingKey::from_ec_der(doc.as_bytes())
366 }
367 _ => bail!("Invalid curve for EC key"),
368 }
369 }
370 KeyType::OKP {
371 ref curve,
372 ref d,
373 ref x,
374 } => {
375 let d = d.as_ref().context("Missing private key")?;
376
377 let key_pair =
378 aws_lc_rs::signature::Ed25519KeyPair::from_seed_and_public_key(d.as_slice(), x.as_slice())?;
379
380 match curve {
381 EllipticCurve::Ed25519 => EncodingKey::from_ed_der(key_pair.to_pkcs8()?.as_ref()),
382 _ => bail!("Invalid curve for OKP key"),
383 }
384 }
385 KeyType::RSA {
386 ref public,
387 ref private,
388 } => {
389 let n = BigUint::from_bytes_be(&public.n);
390 let e = BigUint::from_bytes_be(&public.e);
391 let private = private.as_ref().context("Missing private key")?;
392 let d = BigUint::from_bytes_be(&private.d);
393 let p = BigUint::from_bytes_be(&private.p);
394 let q = BigUint::from_bytes_be(&private.q);
395
396 let rsa = rsa::RsaPrivateKey::from_components(n, e, d, vec![p, q]);
397 let pem = rsa?.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF);
398
399 EncodingKey::from_rsa_pem(pem?.as_bytes())?
400 }
401 };
402
403 Ok(self.encode.get_or_init(|| encoding_key))
404 }
405
406 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
407 if !self.operations.contains(&KeyOperation::Verify) {
408 bail!("key does not support verification");
409 }
410
411 let decode: anyhow::Result<&DecodingKey> = self.to_decoding_key();
412
413 match decode {
414 Ok(decode) => {
415 let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
416 validation.required_spec_claims = Default::default(); let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
419 token.claims.validate()?;
420
421 Ok(token.claims)
422 }
423 Err(e) => Err(anyhow::anyhow!("Failed to decode key: {}", e)),
424 }
425 }
426
427 pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
428 if !self.operations.contains(&KeyOperation::Sign) {
429 bail!("key does not support signing");
430 }
431
432 payload.validate()?;
433
434 let encode: anyhow::Result<&EncodingKey> = self.to_encoding_key();
435
436 match encode {
437 Ok(encode) => {
438 let mut header = Header::new(self.algorithm.into());
439 header.kid = self.kid.clone();
440 let token = jsonwebtoken::encode(&header, &payload, encode)?;
441 Ok(token)
442 }
443 Err(e) => Err(anyhow::anyhow!("Failed to encode key: {}", e)),
444 }
445 }
446
447 pub fn generate(algorithm: Algorithm, id: Option<String>) -> anyhow::Result<Self> {
449 generate(algorithm, id)
450 }
451}
452
453fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
455where
456 S: Serializer,
457{
458 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
459 serializer.serialize_str(&encoded)
460}
461
462fn serialize_base64url_optional<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
463where
464 S: Serializer,
465{
466 match bytes {
467 Some(b) => serialize_base64url(b, serializer),
468 None => serializer.serialize_none(),
469 }
470}
471
472fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
474where
475 D: Deserializer<'de>,
476{
477 let s = String::deserialize(deserializer)?;
478
479 base64::engine::general_purpose::URL_SAFE_NO_PAD
481 .decode(&s)
482 .or_else(|_| {
483 base64::engine::general_purpose::URL_SAFE.decode(&s)
485 })
486 .map_err(serde::de::Error::custom)
487}
488
489fn deserialize_base64url_optional<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
490where
491 D: Deserializer<'de>,
492{
493 let s: Option<String> = Option::deserialize(deserializer)?;
494 match s {
495 Some(s) => {
496 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
497 .decode(&s)
498 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(&s))
499 .map_err(serde::de::Error::custom)?;
500 Ok(Some(decoded))
501 }
502 None => Ok(None),
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use std::time::{Duration, SystemTime};
510
511 fn create_test_key() -> Key {
512 Key {
513 algorithm: Algorithm::HS256,
514 operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
515 key: KeyType::OCT {
516 secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
517 },
518 kid: Some("test-key-1".to_string()),
519 decode: Default::default(),
520 encode: Default::default(),
521 }
522 }
523
524 fn create_test_claims() -> Claims {
525 Claims {
526 root: "test-path".to_string(),
527 publish: vec!["test-pub".into()],
528 cluster: false,
529 subscribe: vec!["test-sub".into()],
530 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
531 issued: Some(SystemTime::now()),
532 }
533 }
534
535 #[test]
536 fn test_key_from_str_valid() {
537 let key = create_test_key();
538 let json = key.to_str().unwrap();
539 let loaded_key = Key::from_str(&json).unwrap();
540
541 assert_eq!(loaded_key.algorithm, key.algorithm);
542 assert_eq!(loaded_key.operations, key.operations);
543 match (loaded_key.key, key.key) {
544 (KeyType::OCT { secret: loaded_secret }, KeyType::OCT { secret }) => {
545 assert_eq!(loaded_secret, secret);
546 }
547 _ => panic!("Expected OCT key"),
548 }
549 assert_eq!(loaded_key.kid, key.kid);
550 }
551
552 #[test]
554 fn test_key_oct_backwards_compatibility() {
555 let json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68"}"#;
556 let key = Key::from_str(json);
557
558 assert!(key.is_ok());
559 let key = key.unwrap();
560
561 if let KeyType::OCT { ref secret, .. } = key.key {
562 let base64_key = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(secret);
563 assert_eq!(base64_key, "Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68");
564 } else {
565 panic!("Expected OCT key");
566 }
567
568 let key_str = key.to_str();
569 assert!(key_str.is_ok());
570 let key_str = key_str.unwrap();
571
572 assert!(key_str.contains("\"alg\":\"HS256\""));
574 assert!(key_str.contains("\"key_ops\""));
575 assert!(key_str.contains("\"sign\""));
576 assert!(key_str.contains("\"verify\""));
577 assert!(key_str.contains("\"kty\":\"oct\""));
578 }
579
580 #[test]
581 fn test_key_from_str_invalid_json() {
582 let result = Key::from_str("invalid json");
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_key_to_str() {
588 let key = create_test_key();
589 let json = key.to_str().unwrap();
590 assert!(json.contains("\"alg\":\"HS256\""));
591 assert!(json.contains("\"key_ops\""));
592 assert!(json.contains("\"sign\""));
593 assert!(json.contains("\"verify\""));
594 assert!(json.contains("\"kid\":\"test-key-1\""));
595 assert!(json.contains("\"kty\":\"oct\""));
596 }
597
598 #[test]
599 fn test_key_sign_success() {
600 let key = create_test_key();
601 let claims = create_test_claims();
602 let token = key.encode(&claims).unwrap();
603
604 assert!(!token.is_empty());
605 assert_eq!(token.matches('.').count(), 2); }
607
608 #[test]
609 fn test_key_sign_no_permission() {
610 let mut key = create_test_key();
611 key.operations = [KeyOperation::Verify].into();
612 let claims = create_test_claims();
613
614 let result = key.encode(&claims);
615 assert!(result.is_err());
616 assert!(result.unwrap_err().to_string().contains("key does not support signing"));
617 }
618
619 #[test]
620 fn test_key_sign_invalid_claims() {
621 let key = create_test_key();
622 let invalid_claims = Claims {
623 root: "test-path".to_string(),
624 publish: vec![],
625 subscribe: vec![],
626 cluster: false,
627 expires: None,
628 issued: None,
629 };
630
631 let result = key.encode(&invalid_claims);
632 assert!(result.is_err());
633 assert!(
634 result
635 .unwrap_err()
636 .to_string()
637 .contains("no publish or subscribe allowed; token is useless")
638 );
639 }
640
641 #[test]
642 fn test_key_verify_success() {
643 let key = create_test_key();
644 let claims = create_test_claims();
645 let token = key.encode(&claims).unwrap();
646
647 let verified_claims = key.decode(&token).unwrap();
648 assert_eq!(verified_claims.root, claims.root);
649 assert_eq!(verified_claims.publish, claims.publish);
650 assert_eq!(verified_claims.subscribe, claims.subscribe);
651 assert_eq!(verified_claims.cluster, claims.cluster);
652 }
653
654 #[test]
655 fn test_key_verify_no_permission() {
656 let mut key = create_test_key();
657 key.operations = [KeyOperation::Sign].into();
658
659 let result = key.decode("some.jwt.token");
660 assert!(result.is_err());
661 assert!(
662 result
663 .unwrap_err()
664 .to_string()
665 .contains("key does not support verification")
666 );
667 }
668
669 #[test]
670 fn test_key_verify_invalid_token() {
671 let key = create_test_key();
672 let result = key.decode("invalid-token");
673 assert!(result.is_err());
674 }
675
676 #[test]
677 fn test_key_verify_path_mismatch() {
678 let key = create_test_key();
679 let claims = create_test_claims();
680 let token = key.encode(&claims).unwrap();
681
682 let result = key.decode(&token);
684 assert!(result.is_ok());
685 }
686
687 #[test]
688 fn test_key_verify_expired_token() {
689 let key = create_test_key();
690 let mut claims = create_test_claims();
691 claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); let token = key.encode(&claims).unwrap();
693
694 let result = key.decode(&token);
695 assert!(result.is_err());
696 }
697
698 #[test]
699 fn test_key_verify_token_without_exp() {
700 let key = create_test_key();
701 let claims = Claims {
702 root: "test-path".to_string(),
703 publish: vec!["".to_string()],
704 subscribe: vec!["".to_string()],
705 cluster: false,
706 expires: None,
707 issued: None,
708 };
709 let token = key.encode(&claims).unwrap();
710
711 let verified_claims = key.decode(&token).unwrap();
712 assert_eq!(verified_claims.root, claims.root);
713 assert_eq!(verified_claims.publish, claims.publish);
714 assert_eq!(verified_claims.subscribe, claims.subscribe);
715 assert_eq!(verified_claims.expires, None);
716 }
717
718 #[test]
719 fn test_key_round_trip() {
720 let key = create_test_key();
721 let original_claims = Claims {
722 root: "test-path".to_string(),
723 publish: vec!["test-pub".into()],
724 subscribe: vec!["test-sub".into()],
725 cluster: true,
726 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
727 issued: Some(SystemTime::now()),
728 };
729
730 let token = key.encode(&original_claims).unwrap();
731 let verified_claims = key.decode(&token).unwrap();
732
733 assert_eq!(verified_claims.root, original_claims.root);
734 assert_eq!(verified_claims.publish, original_claims.publish);
735 assert_eq!(verified_claims.subscribe, original_claims.subscribe);
736 assert_eq!(verified_claims.cluster, original_claims.cluster);
737 }
738
739 #[test]
740 fn test_key_generate_hs256() {
741 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
742 assert!(key.is_ok());
743 let key = key.unwrap();
744
745 assert_eq!(key.algorithm, Algorithm::HS256);
746 assert_eq!(key.kid, Some("test-id".to_string()));
747 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
748
749 match key.key {
750 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 32),
751 _ => panic!("Expected OCT key"),
752 }
753 }
754
755 #[test]
756 fn test_key_generate_hs384() {
757 let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
758 assert!(key.is_ok());
759 let key = key.unwrap();
760
761 assert_eq!(key.algorithm, Algorithm::HS384);
762
763 match key.key {
764 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 48),
765 _ => panic!("Expected OCT key"),
766 }
767 }
768
769 #[test]
770 fn test_key_generate_hs512() {
771 let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
772 assert!(key.is_ok());
773 let key = key.unwrap();
774
775 assert_eq!(key.algorithm, Algorithm::HS512);
776
777 match key.key {
778 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 64),
779 _ => panic!("Expected OCT key"),
780 }
781 }
782
783 #[test]
784 fn test_key_generate_rs512() {
785 let key = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
786 assert!(key.is_ok());
787 let key = key.unwrap();
788
789 assert_eq!(key.algorithm, Algorithm::RS512);
790 assert!(matches!(key.key, KeyType::RSA { .. }));
791 match key.key {
792 KeyType::RSA {
793 ref public,
794 ref private,
795 } => {
796 assert!(private.is_some());
797 assert_eq!(public.n.len(), 256);
798 assert_eq!(public.e.len(), 3);
799 }
800 _ => panic!("Expected RSA key"),
801 }
802 }
803
804 #[test]
805 fn test_key_generate_es256() {
806 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
807 assert!(key.is_ok());
808 let key = key.unwrap();
809
810 assert_eq!(key.algorithm, Algorithm::ES256);
811 assert!(matches!(key.key, KeyType::EC { .. }))
812 }
813
814 #[test]
815 fn test_key_generate_ps512() {
816 let key = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
817 assert!(key.is_ok());
818 let key = key.unwrap();
819
820 assert_eq!(key.algorithm, Algorithm::PS512);
821 assert!(matches!(key.key, KeyType::RSA { .. }));
822 }
823
824 #[test]
825 fn test_key_generate_eddsa() {
826 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
827 assert!(key.is_ok());
828 let key = key.unwrap();
829
830 assert_eq!(key.algorithm, Algorithm::EdDSA);
831 assert!(matches!(key.key, KeyType::OKP { .. }));
832 }
833
834 #[test]
835 fn test_key_generate_without_id() {
836 let key = Key::generate(Algorithm::HS256, None);
837 assert!(key.is_ok());
838 let key = key.unwrap();
839
840 assert_eq!(key.algorithm, Algorithm::HS256);
841 assert_eq!(key.kid, None);
842 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
843 }
844
845 #[test]
846 fn test_public_key_conversion_hmac() {
847 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string())).expect("HMAC key generation failed");
848
849 assert!(key.to_public().is_err());
850 }
851
852 #[test]
853 fn test_public_key_conversion_rsa() {
854 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
855 assert!(key.is_ok());
856 let key = key.unwrap();
857
858 let public_key = key.to_public().unwrap();
859 assert_eq!(key.kid, public_key.kid);
860 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
861 assert!(public_key.encode.get().is_none());
862 assert!(public_key.decode.get().is_none());
863 assert!(matches!(public_key.key, KeyType::RSA { .. }));
864
865 if let KeyType::RSA { public, private } = &public_key.key {
866 assert!(private.is_none());
867
868 if let KeyType::RSA { public: src_public, .. } = &key.key {
869 assert_eq!(public.e, src_public.e);
870 assert_eq!(public.n, src_public.n);
871 } else {
872 unreachable!("Expected RSA key")
873 }
874 } else {
875 unreachable!("Expected RSA key");
876 }
877 }
878
879 #[test]
880 fn test_public_key_conversion_es() {
881 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
882 assert!(key.is_ok());
883 let key = key.unwrap();
884
885 let public_key = key.to_public().unwrap();
886 assert_eq!(key.kid, public_key.kid);
887 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
888 assert!(public_key.encode.get().is_none());
889 assert!(public_key.decode.get().is_none());
890 assert!(matches!(public_key.key, KeyType::EC { .. }));
891
892 if let KeyType::EC { x, y, d, curve } = &public_key.key {
893 assert!(d.is_none());
894
895 if let KeyType::EC {
896 x: src_x,
897 y: src_y,
898 curve: src_curve,
899 ..
900 } = &key.key
901 {
902 assert_eq!(x, src_x);
903 assert_eq!(y, src_y);
904 assert_eq!(curve, src_curve);
905 } else {
906 unreachable!("Expected EC key")
907 }
908 } else {
909 unreachable!("Expected EC key");
910 }
911 }
912
913 #[test]
914 fn test_public_key_conversion_ed() {
915 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
916 assert!(key.is_ok());
917 let key = key.unwrap();
918
919 let public_key = key.to_public().unwrap();
920 assert_eq!(key.kid, public_key.kid);
921 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
922 assert!(public_key.encode.get().is_none());
923 assert!(public_key.decode.get().is_none());
924 assert!(matches!(public_key.key, KeyType::OKP { .. }));
925
926 if let KeyType::OKP { x, d, curve } = &public_key.key {
927 assert!(d.is_none());
928
929 if let KeyType::OKP {
930 x: src_x,
931 curve: src_curve,
932 ..
933 } = &key.key
934 {
935 assert_eq!(x, src_x);
936 assert_eq!(curve, src_curve);
937 } else {
938 unreachable!("Expected OKP key")
939 }
940 } else {
941 unreachable!("Expected OKP key");
942 }
943 }
944
945 #[test]
946 fn test_key_generate_sign_verify_cycle() {
947 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
948 assert!(key.is_ok());
949 let key = key.unwrap();
950
951 let claims = create_test_claims();
952
953 let token = key.encode(&claims).unwrap();
954 let verified_claims = key.decode(&token).unwrap();
955
956 assert_eq!(verified_claims.root, claims.root);
957 assert_eq!(verified_claims.publish, claims.publish);
958 assert_eq!(verified_claims.subscribe, claims.subscribe);
959 assert_eq!(verified_claims.cluster, claims.cluster);
960 }
961
962 #[test]
963 fn test_key_debug_no_secret() {
964 let key = create_test_key();
965 let debug_str = format!("{key:?}");
966
967 assert!(debug_str.contains("algorithm: HS256"));
968 assert!(debug_str.contains("operations"));
969 assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
970 assert!(!debug_str.contains("secret")); }
972
973 #[test]
974 fn test_key_operations_enum() {
975 let sign_op = KeyOperation::Sign;
976 let verify_op = KeyOperation::Verify;
977 let decrypt_op = KeyOperation::Decrypt;
978 let encrypt_op = KeyOperation::Encrypt;
979
980 assert_eq!(sign_op, KeyOperation::Sign);
981 assert_eq!(verify_op, KeyOperation::Verify);
982 assert_eq!(decrypt_op, KeyOperation::Decrypt);
983 assert_eq!(encrypt_op, KeyOperation::Encrypt);
984
985 assert_ne!(sign_op, verify_op);
986 assert_ne!(decrypt_op, encrypt_op);
987 }
988
989 #[test]
990 fn test_key_operations_serde() {
991 let operations = [KeyOperation::Sign, KeyOperation::Verify];
992 let json = serde_json::to_string(&operations).unwrap();
993 assert!(json.contains("\"sign\""));
994 assert!(json.contains("\"verify\""));
995
996 let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
997 assert_eq!(deserialized, operations);
998 }
999
1000 #[test]
1001 fn test_key_serde() {
1002 let key = create_test_key();
1003 let json = serde_json::to_string(&key).unwrap();
1004 let deserialized: Key = serde_json::from_str(&json).unwrap();
1005
1006 assert_eq!(deserialized.algorithm, key.algorithm);
1007 assert_eq!(deserialized.operations, key.operations);
1008 assert_eq!(deserialized.kid, key.kid);
1009
1010 if let (
1011 KeyType::OCT {
1012 secret: original_secret,
1013 },
1014 KeyType::OCT {
1015 secret: deserialized_secret,
1016 },
1017 ) = (&key.key, &deserialized.key)
1018 {
1019 assert_eq!(deserialized_secret, original_secret);
1020 } else {
1021 panic!("Expected both keys to be OCT variant");
1022 }
1023 }
1024
1025 #[test]
1026 fn test_key_clone() {
1027 let key = create_test_key();
1028 let cloned = key.clone();
1029
1030 assert_eq!(cloned.algorithm, key.algorithm);
1031 assert_eq!(cloned.operations, key.operations);
1032 assert_eq!(cloned.kid, key.kid);
1033
1034 if let (
1035 KeyType::OCT {
1036 secret: original_secret,
1037 },
1038 KeyType::OCT { secret: cloned_secret },
1039 ) = (&key.key, &cloned.key)
1040 {
1041 assert_eq!(cloned_secret, original_secret);
1042 } else {
1043 panic!("Expected both keys to be OCT variant");
1044 }
1045 }
1046
1047 #[test]
1048 fn test_hmac_algorithms() {
1049 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1050 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1051 let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
1052
1053 let claims = create_test_claims();
1054
1055 for key in [key_256, key_384, key_512] {
1057 assert!(key.is_ok());
1058 let key = key.unwrap();
1059
1060 let token = key.encode(&claims).unwrap();
1061 let verified_claims = key.decode(&token).unwrap();
1062 assert_eq!(verified_claims.root, claims.root);
1063 }
1064 }
1065
1066 #[test]
1067 fn test_rsa_pkcs1_asymmetric_algorithms() {
1068 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1069 let key_rs384 = Key::generate(Algorithm::RS384, Some("test-id".to_string()));
1070 let key_rs512 = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
1071
1072 for key in [key_rs256, key_rs384, key_rs512] {
1073 test_asymmetric_key(key);
1074 }
1075 }
1076
1077 #[test]
1078 fn test_rsa_pss_asymmetric_algorithms() {
1079 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1080 let key_ps384 = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1081 let key_ps512 = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
1082
1083 for key in [key_ps256, key_ps384, key_ps512] {
1084 test_asymmetric_key(key);
1085 }
1086 }
1087
1088 #[test]
1089 fn test_ec_asymmetric_algorithms() {
1090 let key_es256 = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
1091 let key_es384 = Key::generate(Algorithm::ES384, Some("test-id".to_string()));
1092
1093 for key in [key_es256, key_es384] {
1094 test_asymmetric_key(key);
1095 }
1096 }
1097
1098 #[test]
1099 fn test_ed_asymmetric_algorithms() {
1100 let key_eddsa = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
1101
1102 test_asymmetric_key(key_eddsa);
1103 }
1104
1105 fn test_asymmetric_key(key: anyhow::Result<Key>) {
1106 assert!(key.is_ok());
1107 let key = key.unwrap();
1108
1109 let claims = create_test_claims();
1110 let token = key.encode(&claims).unwrap();
1111
1112 let private_verified_claims = key.decode(&token).unwrap();
1113 assert_eq!(
1114 private_verified_claims.root, claims.root,
1115 "validation using private key"
1116 );
1117
1118 let public_verified_claims = key.to_public().unwrap().decode(&token).unwrap();
1119 assert_eq!(public_verified_claims.root, claims.root, "validation using public key");
1120 }
1121
1122 #[test]
1123 fn test_cross_algorithm_verification_fails() {
1124 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1125 assert!(key_256.is_ok());
1126 let key_256 = key_256.unwrap();
1127
1128 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1129 assert!(key_384.is_ok());
1130 let key_384 = key_384.unwrap();
1131
1132 let claims = create_test_claims();
1133 let token = key_256.encode(&claims).unwrap();
1134
1135 let result = key_384.decode(&token);
1137 assert!(result.is_err());
1138 }
1139
1140 #[test]
1141 fn test_asymmetric_cross_algorithm_verification_fails() {
1142 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1143 assert!(key_rs256.is_ok());
1144 let key_rs256 = key_rs256.unwrap();
1145
1146 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1147 assert!(key_ps256.is_ok());
1148 let key_ps256 = key_ps256.unwrap();
1149
1150 let claims = create_test_claims();
1151 let token = key_rs256.encode(&claims).unwrap();
1152
1153 let private_result = key_ps256.decode(&token);
1155 let public_result = key_ps256.to_public().unwrap().decode(&token);
1156 assert!(private_result.is_err());
1157 assert!(public_result.is_err());
1158 }
1159
1160 #[test]
1161 fn test_rsa_pkcs1_public_key_conversion() {
1162 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1163 assert!(key.is_ok());
1164 let key = key.unwrap();
1165
1166 assert!(key.operations.contains(&KeyOperation::Sign));
1167 assert!(key.operations.contains(&KeyOperation::Verify));
1168
1169 let public_key = key.to_public().unwrap();
1170 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1171 assert!(public_key.operations.contains(&KeyOperation::Verify));
1172
1173 match key.key {
1174 KeyType::RSA {
1175 ref public,
1176 ref private,
1177 } => {
1178 assert!(private.is_some());
1179 assert_eq!(public.n.len(), 256);
1180 assert_eq!(public.e.len(), 3);
1181
1182 match public_key.key {
1183 KeyType::RSA {
1184 public: ref public_public,
1185 private: ref public_private,
1186 } => {
1187 assert!(public_private.is_none());
1188 assert_eq!(public.n, public_public.n);
1189 assert_eq!(public.e, public_public.e);
1190 }
1191 _ => panic!("Expected public key to be an RSA key"),
1192 }
1193 }
1194 _ => panic!("Expected private key to be an RSA key"),
1195 }
1196 }
1197
1198 #[test]
1199 fn test_rsa_pss_public_key_conversion() {
1200 let key = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1201 assert!(key.is_ok());
1202 let key = key.unwrap();
1203
1204 assert!(key.operations.contains(&KeyOperation::Sign));
1205 assert!(key.operations.contains(&KeyOperation::Verify));
1206
1207 let public_key = key.to_public().unwrap();
1208 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1209 assert!(public_key.operations.contains(&KeyOperation::Verify));
1210
1211 match key.key {
1212 KeyType::RSA {
1213 ref public,
1214 ref private,
1215 } => {
1216 assert!(private.is_some());
1217 assert_eq!(public.n.len(), 256);
1218 assert_eq!(public.e.len(), 3);
1219
1220 match public_key.key {
1221 KeyType::RSA {
1222 public: ref public_public,
1223 private: ref public_private,
1224 } => {
1225 assert!(public_private.is_none());
1226 assert_eq!(public.n, public_public.n);
1227 assert_eq!(public.e, public_public.e);
1228 }
1229 _ => panic!("Expected public key to be an RSA key"),
1230 }
1231 }
1232 _ => panic!("Expected private key to be an RSA key"),
1233 }
1234 }
1235
1236 #[test]
1237 fn test_base64url_serialization() {
1238 let key = create_test_key();
1239 let json = serde_json::to_string(&key).unwrap();
1240
1241 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1243 let k_value = parsed["k"].as_str().unwrap();
1244
1245 assert!(!k_value.contains('='));
1247 assert!(!k_value.contains('+'));
1248 assert!(!k_value.contains('/'));
1249
1250 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1252 .decode(k_value)
1253 .unwrap();
1254
1255 if let KeyType::OCT {
1256 secret: original_secret,
1257 } = &key.key
1258 {
1259 assert_eq!(decoded, *original_secret);
1260 } else {
1261 panic!("Expected both keys to be OCT variant");
1262 }
1263 }
1264
1265 #[test]
1266 fn test_backwards_compatibility_unpadded_base64url() {
1267 let unpadded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
1269
1270 let key: Key = serde_json::from_str(unpadded_json).unwrap();
1272 assert_eq!(key.algorithm, Algorithm::HS256);
1273 assert_eq!(key.kid, Some("test-key-1".to_string()));
1274
1275 if let KeyType::OCT { secret } = &key.key {
1276 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1277 } else {
1278 panic!("Expected key to be OCT variant");
1279 }
1280 }
1281
1282 #[test]
1283 fn test_backwards_compatibility_padded_base64url() {
1284 let padded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
1286
1287 let key: Key = serde_json::from_str(padded_json).unwrap();
1289 assert_eq!(key.algorithm, Algorithm::HS256);
1290 assert_eq!(key.kid, Some("test-key-1".to_string()));
1291
1292 if let KeyType::OCT { secret } = &key.key {
1293 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1294 } else {
1295 panic!("Expected key to be OCT variant");
1296 }
1297 }
1298
1299 #[test]
1300 fn test_file_io_base64url() {
1301 let key = create_test_key();
1302 let temp_dir = std::env::temp_dir();
1303 let temp_path = temp_dir.join("test_jwk.key");
1304
1305 key.to_file(&temp_path).unwrap();
1307
1308 let contents = std::fs::read_to_string(&temp_path).unwrap();
1310
1311 assert!(!contents.contains('{'));
1313 assert!(!contents.contains('}'));
1314 assert!(!contents.contains('"'));
1315
1316 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1318 .decode(&contents)
1319 .unwrap();
1320 let json_str = String::from_utf8(decoded).unwrap();
1321 let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1322
1323 let loaded_key = Key::from_file(&temp_path).unwrap();
1325 assert_eq!(loaded_key.algorithm, key.algorithm);
1326 assert_eq!(loaded_key.operations, key.operations);
1327 assert_eq!(loaded_key.kid, key.kid);
1328
1329 if let (
1330 KeyType::OCT {
1331 secret: original_secret,
1332 },
1333 KeyType::OCT { secret: loaded_secret },
1334 ) = (&key.key, &loaded_key.key)
1335 {
1336 assert_eq!(loaded_secret, original_secret);
1337 } else {
1338 panic!("Expected both keys to be OCT variant");
1339 }
1340
1341 std::fs::remove_file(temp_path).ok();
1343 }
1344}