1use crate::generate::generate;
2use crate::{Algorithm, Claims};
3use anyhow::{bail, Context};
4use base64::Engine;
5use elliptic_curve::pkcs8::EncodePrivateKey;
6use elliptic_curve::SecretKey;
7use jsonwebtoken::{DecodingKey, EncodingKey, Header};
8use rsa::pkcs1::EncodeRsaPrivateKey;
9use rsa::BigUint;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::sync::OnceLock;
12use std::{collections::HashSet, fmt, path::Path as StdPath};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
15#[serde(rename_all = "camelCase")]
16pub enum KeyOperation {
17 Sign,
18 Verify,
19 Decrypt,
20 Encrypt,
21}
22
23#[derive(Clone, Serialize, Deserialize)]
25#[serde(tag = "kty")]
26pub enum KeyType {
27 EC {
29 #[serde(rename = "crv")]
30 curve: EllipticCurve,
31 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
33 x: Vec<u8>,
34 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
36 y: Vec<u8>,
37 #[serde(
39 default,
40 skip_serializing_if = "Option::is_none",
41 serialize_with = "serialize_base64url_optional",
42 deserialize_with = "deserialize_base64url_optional"
43 )]
44 d: Option<Vec<u8>>,
45 },
46 RSA {
48 #[serde(flatten)]
49 public: RsaPublicKey,
50 #[serde(flatten, skip_serializing_if = "Option::is_none")]
51 private: Option<RsaPrivateKey>,
52 },
53 #[serde(rename = "oct")]
55 OCT {
56 #[serde(
58 rename = "k",
59 default,
60 serialize_with = "serialize_base64url",
61 deserialize_with = "deserialize_base64url"
62 )]
63 secret: Vec<u8>,
64 },
65 OKP {
67 #[serde(rename = "crv")]
68 curve: EllipticCurve,
69 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
70 x: Vec<u8>,
71 #[serde(
72 rename = "d",
73 default,
74 skip_serializing_if = "Option::is_none",
75 serialize_with = "serialize_base64url_optional",
76 deserialize_with = "deserialize_base64url_optional"
77 )]
78 d: Option<Vec<u8>>,
79 },
80}
81
82#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
84pub enum EllipticCurve {
85 #[serde(rename = "P-256")]
86 P256,
87 #[serde(rename = "P-384")]
88 P384,
89 #[serde(rename = "Ed25519")]
93 Ed25519,
94}
95
96#[derive(Clone, Serialize, Deserialize)]
98pub struct RsaPublicKey {
99 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
100 pub n: Vec<u8>,
101 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
102 pub e: Vec<u8>,
103}
104
105#[derive(Clone, Serialize, Deserialize)]
107pub struct RsaPrivateKey {
108 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
109 pub d: Vec<u8>,
110 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
111 pub p: Vec<u8>,
112 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
113 pub q: Vec<u8>,
114 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
115 pub dp: Vec<u8>,
116 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
117 pub dq: Vec<u8>,
118 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
119 pub qi: Vec<u8>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 pub oth: Option<Vec<RsaAdditionalPrime>>,
122}
123
124#[derive(Clone, Serialize, Deserialize)]
125pub struct RsaAdditionalPrime {
126 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
127 pub r: Vec<u8>,
128 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
129 pub d: Vec<u8>,
130 #[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
131 pub t: Vec<u8>,
132}
133
134#[derive(Clone, Serialize, Deserialize)]
137#[serde(remote = "Self")]
138pub struct Key {
139 #[serde(rename = "alg")]
141 pub algorithm: Algorithm,
142
143 #[serde(rename = "key_ops")]
145 pub operations: HashSet<KeyOperation>,
146
147 #[serde(flatten)]
149 pub key: KeyType,
150
151 #[serde(skip_serializing_if = "Option::is_none")]
153 pub kid: Option<String>,
154
155 #[serde(skip)]
157 pub(crate) decode: OnceLock<DecodingKey>,
158
159 #[serde(skip)]
160 pub(crate) encode: OnceLock<EncodingKey>,
161}
162
163impl<'de> Deserialize<'de> for Key {
164 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165 where
166 D: Deserializer<'de>,
167 {
168 let mut value = serde_json::Value::deserialize(deserializer)?;
169
170 if let Some(obj) = value.as_object_mut() {
174 if !obj.contains_key("kty") {
175 obj.insert("kty".to_string(), serde_json::Value::String("oct".to_string()));
176 }
177 }
178
179 Self::deserialize(value).map_err(serde::de::Error::custom)
180 }
181}
182
183impl Serialize for Key {
184 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185 where
186 S: Serializer,
187 {
188 Self::serialize(self, serializer)
189 }
190}
191
192impl fmt::Debug for Key {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("Key")
195 .field("algorithm", &self.algorithm)
196 .field("operations", &self.operations)
197 .field("kid", &self.kid)
198 .finish()
199 }
200}
201
202impl Key {
203 #[allow(clippy::should_implement_trait)]
204 pub fn from_str(s: &str) -> anyhow::Result<Self> {
205 Ok(serde_json::from_str(s)?)
206 }
207
208 pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
209 let contents = std::fs::read_to_string(&path)?;
210 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
212 let json = String::from_utf8(decoded)?;
213 Ok(serde_json::from_str(&json)?)
214 }
215
216 pub fn to_str(&self) -> anyhow::Result<String> {
217 Ok(serde_json::to_string(self)?)
218 }
219
220 pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
221 let json = serde_json::to_string(self)?;
223 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
225 std::fs::write(path, encoded)?;
226 Ok(())
227 }
228
229 pub fn to_public(&self) -> anyhow::Result<Self> {
230 if !self.operations.contains(&KeyOperation::Verify) {
231 return Err(anyhow::anyhow!("This key doesn't support the Verify operation"));
232 }
233
234 let key = match self.key {
235 KeyType::RSA { ref public, .. } => Ok(KeyType::RSA {
236 public: public.clone(),
237 private: None,
238 }),
239 KeyType::EC {
240 ref x,
241 ref y,
242 ref curve,
243 ..
244 } => Ok(KeyType::EC {
245 x: x.clone(),
246 y: y.clone(),
247 curve: curve.clone(),
248 d: None,
249 }),
250 KeyType::OCT { .. } => Err(anyhow::anyhow!("OCT key cannot be converted to public key")),
251 KeyType::OKP { ref x, ref curve, .. } => Ok(KeyType::OKP {
252 x: x.clone(),
253 curve: curve.clone(),
254 d: None,
255 }),
256 };
257
258 match key {
259 Ok(key) => Ok(Self {
260 algorithm: self.algorithm,
261 operations: [KeyOperation::Verify].into(),
262 key,
263 kid: self.kid.clone(),
264 decode: Default::default(),
265 encode: Default::default(),
266 }),
267 Err(err) => Err(anyhow::anyhow!("Failed to convert key: {}", err)),
268 }
269 }
270
271 fn to_decoding_key(&self) -> anyhow::Result<&DecodingKey> {
272 if let Some(key) = self.decode.get() {
273 return Ok(key);
274 }
275
276 let decoding_key = match self.key {
277 KeyType::OCT { ref secret } => match self.algorithm {
278 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret),
279 _ => bail!("Invalid algorithm for key type OCT"),
280 },
281 KeyType::EC {
282 ref curve,
283 ref x,
284 ref y,
285 ..
286 } => match curve {
287 EllipticCurve::P256 => {
288 if self.algorithm != Algorithm::ES256 {
289 bail!("Invalid algorithm for P-256 curve");
290 }
291 if x.len() != 32 || y.len() != 32 {
292 bail!("Invalid coordinate length for P-256");
293 }
294
295 DecodingKey::from_ec_components(
296 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
297 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
298 )?
299 }
300 EllipticCurve::P384 => {
301 if self.algorithm != Algorithm::ES384 {
302 bail!("Invalid algorithm for P-384 curve");
303 }
304 if x.len() != 48 || y.len() != 48 {
305 bail!("Invalid coordinate length for P-384");
306 }
307
308 DecodingKey::from_ec_components(
309 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
310 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
311 )?
312 }
313 _ => bail!("Invalid curve for EC key"),
314 },
315 KeyType::OKP { ref curve, ref x, .. } => match curve {
316 EllipticCurve::Ed25519 => {
317 if self.algorithm != Algorithm::EdDSA {
318 bail!("Invalid algorithm for Ed25519 curve");
319 }
320
321 DecodingKey::from_ed_components(
322 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
323 )?
324 }
325 _ => bail!("Invalid curve for OKP key"),
326 },
327 KeyType::RSA { ref public, .. } => {
328 DecodingKey::from_rsa_raw_components(public.n.as_ref(), public.e.as_ref())
329 }
330 };
331
332 Ok(self.decode.get_or_init(|| decoding_key))
333 }
334
335 fn to_encoding_key(&self) -> anyhow::Result<&EncodingKey> {
336 if let Some(key) = self.encode.get() {
337 return Ok(key);
338 }
339
340 let encoding_key = match self.key {
341 KeyType::OCT { ref secret } => match self.algorithm {
342 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret),
343 _ => bail!("Invalid algorithm for key type OCT"),
344 },
345 KeyType::EC { ref curve, ref d, .. } => {
346 let d = d.as_ref().context("Missing private key")?;
347
348 match curve {
349 EllipticCurve::P256 => {
350 let secret_key = SecretKey::<p256::NistP256>::from_slice(d)?;
351 let doc = secret_key.to_pkcs8_der()?;
352 EncodingKey::from_ec_der(doc.as_bytes())
353 }
354 EllipticCurve::P384 => {
355 let secret_key = SecretKey::<p384::NistP384>::from_slice(d)?;
356 let doc = secret_key.to_pkcs8_der()?;
357 EncodingKey::from_ec_der(doc.as_bytes())
358 }
359 _ => bail!("Invalid curve for EC key"),
360 }
361 }
362 KeyType::OKP {
363 ref curve,
364 ref d,
365 ref x,
366 } => {
367 let d = d.as_ref().context("Missing private key")?;
368
369 let key_pair =
370 aws_lc_rs::signature::Ed25519KeyPair::from_seed_and_public_key(d.as_slice(), x.as_slice())?;
371
372 match curve {
373 EllipticCurve::Ed25519 => EncodingKey::from_ed_der(key_pair.to_pkcs8()?.as_ref()),
374 _ => bail!("Invalid curve for OKP key"),
375 }
376 }
377 KeyType::RSA {
378 ref public,
379 ref private,
380 } => {
381 let n = BigUint::from_bytes_be(&public.n);
382 let e = BigUint::from_bytes_be(&public.e);
383 let private = private.as_ref().context("Missing private key")?;
384 let d = BigUint::from_bytes_be(&private.d);
385 let p = BigUint::from_bytes_be(&private.p);
386 let q = BigUint::from_bytes_be(&private.q);
387
388 let rsa = rsa::RsaPrivateKey::from_components(n, e, d, vec![p, q]);
389 let pem = rsa?.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF);
390
391 EncodingKey::from_rsa_pem(pem?.as_bytes())?
392 }
393 };
394
395 Ok(self.encode.get_or_init(|| encoding_key))
396 }
397
398 pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
399 if !self.operations.contains(&KeyOperation::Verify) {
400 bail!("key does not support verification");
401 }
402
403 let decode: anyhow::Result<&DecodingKey> = self.to_decoding_key();
404
405 match decode {
406 Ok(decode) => {
407 let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
408 validation.required_spec_claims = Default::default(); let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
411 token.claims.validate()?;
412
413 Ok(token.claims)
414 }
415 Err(e) => Err(anyhow::anyhow!("Failed to decode key: {}", e)),
416 }
417 }
418
419 pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
420 if !self.operations.contains(&KeyOperation::Sign) {
421 bail!("key does not support signing");
422 }
423
424 payload.validate()?;
425
426 let encode: anyhow::Result<&EncodingKey> = self.to_encoding_key();
427
428 match encode {
429 Ok(encode) => {
430 let mut header = Header::new(self.algorithm.into());
431 header.kid = self.kid.clone();
432 let token = jsonwebtoken::encode(&header, &payload, encode)?;
433 Ok(token)
434 }
435 Err(e) => Err(anyhow::anyhow!("Failed to encode key: {}", e)),
436 }
437 }
438
439 pub fn generate(algorithm: Algorithm, id: Option<String>) -> anyhow::Result<Self> {
441 generate(algorithm, id)
442 }
443}
444
445fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
447where
448 S: Serializer,
449{
450 let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
451 serializer.serialize_str(&encoded)
452}
453
454fn serialize_base64url_optional<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
455where
456 S: Serializer,
457{
458 match bytes {
459 Some(b) => serialize_base64url(b, serializer),
460 None => serializer.serialize_none(),
461 }
462}
463
464fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
466where
467 D: Deserializer<'de>,
468{
469 let s = String::deserialize(deserializer)?;
470
471 base64::engine::general_purpose::URL_SAFE_NO_PAD
473 .decode(&s)
474 .or_else(|_| {
475 base64::engine::general_purpose::URL_SAFE.decode(&s)
477 })
478 .map_err(serde::de::Error::custom)
479}
480
481fn deserialize_base64url_optional<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
482where
483 D: Deserializer<'de>,
484{
485 let s: Option<String> = Option::deserialize(deserializer)?;
486 match s {
487 Some(s) => {
488 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
489 .decode(&s)
490 .or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(&s))
491 .map_err(serde::de::Error::custom)?;
492 Ok(Some(decoded))
493 }
494 None => Ok(None),
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use std::time::{Duration, SystemTime};
502
503 fn create_test_key() -> Key {
504 Key {
505 algorithm: Algorithm::HS256,
506 operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
507 key: KeyType::OCT {
508 secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
509 },
510 kid: Some("test-key-1".to_string()),
511 decode: Default::default(),
512 encode: Default::default(),
513 }
514 }
515
516 fn create_test_claims() -> Claims {
517 Claims {
518 root: "test-path".to_string(),
519 publish: vec!["test-pub".into()],
520 cluster: false,
521 subscribe: vec!["test-sub".into()],
522 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
523 issued: Some(SystemTime::now()),
524 }
525 }
526
527 #[test]
528 fn test_key_from_str_valid() {
529 let key = create_test_key();
530 let json = key.to_str().unwrap();
531 let loaded_key = Key::from_str(&json).unwrap();
532
533 assert_eq!(loaded_key.algorithm, key.algorithm);
534 assert_eq!(loaded_key.operations, key.operations);
535 match (loaded_key.key, key.key) {
536 (KeyType::OCT { secret: loaded_secret }, KeyType::OCT { secret }) => {
537 assert_eq!(loaded_secret, secret);
538 }
539 _ => panic!("Expected OCT key"),
540 }
541 assert_eq!(loaded_key.kid, key.kid);
542 }
543
544 #[test]
546 fn test_key_oct_backwards_compatibility() {
547 let json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68"}"#;
548 let key = Key::from_str(json);
549
550 assert!(key.is_ok());
551 let key = key.unwrap();
552
553 if let KeyType::OCT { ref secret, .. } = key.key {
554 let base64_key = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(secret);
555 assert_eq!(base64_key, "Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68");
556 } else {
557 panic!("Expected OCT key");
558 }
559
560 let key_str = key.to_str();
561 assert!(key_str.is_ok());
562 let key_str = key_str.unwrap();
563
564 assert!(key_str.contains("\"alg\":\"HS256\""));
566 assert!(key_str.contains("\"key_ops\""));
567 assert!(key_str.contains("\"sign\""));
568 assert!(key_str.contains("\"verify\""));
569 assert!(key_str.contains("\"kty\":\"oct\""));
570 }
571
572 #[test]
573 fn test_key_from_str_invalid_json() {
574 let result = Key::from_str("invalid json");
575 assert!(result.is_err());
576 }
577
578 #[test]
579 fn test_key_to_str() {
580 let key = create_test_key();
581 let json = key.to_str().unwrap();
582 assert!(json.contains("\"alg\":\"HS256\""));
583 assert!(json.contains("\"key_ops\""));
584 assert!(json.contains("\"sign\""));
585 assert!(json.contains("\"verify\""));
586 assert!(json.contains("\"kid\":\"test-key-1\""));
587 assert!(json.contains("\"kty\":\"oct\""));
588 }
589
590 #[test]
591 fn test_key_sign_success() {
592 let key = create_test_key();
593 let claims = create_test_claims();
594 let token = key.encode(&claims).unwrap();
595
596 assert!(!token.is_empty());
597 assert_eq!(token.matches('.').count(), 2); }
599
600 #[test]
601 fn test_key_sign_no_permission() {
602 let mut key = create_test_key();
603 key.operations = [KeyOperation::Verify].into();
604 let claims = create_test_claims();
605
606 let result = key.encode(&claims);
607 assert!(result.is_err());
608 assert!(result.unwrap_err().to_string().contains("key does not support signing"));
609 }
610
611 #[test]
612 fn test_key_sign_invalid_claims() {
613 let key = create_test_key();
614 let invalid_claims = Claims {
615 root: "test-path".to_string(),
616 publish: vec![],
617 subscribe: vec![],
618 cluster: false,
619 expires: None,
620 issued: None,
621 };
622
623 let result = key.encode(&invalid_claims);
624 assert!(result.is_err());
625 assert!(result
626 .unwrap_err()
627 .to_string()
628 .contains("no publish or subscribe allowed; token is useless"));
629 }
630
631 #[test]
632 fn test_key_verify_success() {
633 let key = create_test_key();
634 let claims = create_test_claims();
635 let token = key.encode(&claims).unwrap();
636
637 let verified_claims = key.decode(&token).unwrap();
638 assert_eq!(verified_claims.root, claims.root);
639 assert_eq!(verified_claims.publish, claims.publish);
640 assert_eq!(verified_claims.subscribe, claims.subscribe);
641 assert_eq!(verified_claims.cluster, claims.cluster);
642 }
643
644 #[test]
645 fn test_key_verify_no_permission() {
646 let mut key = create_test_key();
647 key.operations = [KeyOperation::Sign].into();
648
649 let result = key.decode("some.jwt.token");
650 assert!(result.is_err());
651 assert!(result
652 .unwrap_err()
653 .to_string()
654 .contains("key does not support verification"));
655 }
656
657 #[test]
658 fn test_key_verify_invalid_token() {
659 let key = create_test_key();
660 let result = key.decode("invalid-token");
661 assert!(result.is_err());
662 }
663
664 #[test]
665 fn test_key_verify_path_mismatch() {
666 let key = create_test_key();
667 let claims = create_test_claims();
668 let token = key.encode(&claims).unwrap();
669
670 let result = key.decode(&token);
672 assert!(result.is_ok());
673 }
674
675 #[test]
676 fn test_key_verify_expired_token() {
677 let key = create_test_key();
678 let mut claims = create_test_claims();
679 claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); let token = key.encode(&claims).unwrap();
681
682 let result = key.decode(&token);
683 assert!(result.is_err());
684 }
685
686 #[test]
687 fn test_key_verify_token_without_exp() {
688 let key = create_test_key();
689 let claims = Claims {
690 root: "test-path".to_string(),
691 publish: vec!["".to_string()],
692 subscribe: vec!["".to_string()],
693 cluster: false,
694 expires: None,
695 issued: None,
696 };
697 let token = key.encode(&claims).unwrap();
698
699 let verified_claims = key.decode(&token).unwrap();
700 assert_eq!(verified_claims.root, claims.root);
701 assert_eq!(verified_claims.publish, claims.publish);
702 assert_eq!(verified_claims.subscribe, claims.subscribe);
703 assert_eq!(verified_claims.expires, None);
704 }
705
706 #[test]
707 fn test_key_round_trip() {
708 let key = create_test_key();
709 let original_claims = Claims {
710 root: "test-path".to_string(),
711 publish: vec!["test-pub".into()],
712 subscribe: vec!["test-sub".into()],
713 cluster: true,
714 expires: Some(SystemTime::now() + Duration::from_secs(3600)),
715 issued: Some(SystemTime::now()),
716 };
717
718 let token = key.encode(&original_claims).unwrap();
719 let verified_claims = key.decode(&token).unwrap();
720
721 assert_eq!(verified_claims.root, original_claims.root);
722 assert_eq!(verified_claims.publish, original_claims.publish);
723 assert_eq!(verified_claims.subscribe, original_claims.subscribe);
724 assert_eq!(verified_claims.cluster, original_claims.cluster);
725 }
726
727 #[test]
728 fn test_key_generate_hs256() {
729 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
730 assert!(key.is_ok());
731 let key = key.unwrap();
732
733 assert_eq!(key.algorithm, Algorithm::HS256);
734 assert_eq!(key.kid, Some("test-id".to_string()));
735 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
736
737 match key.key {
738 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 32),
739 _ => panic!("Expected OCT key"),
740 }
741 }
742
743 #[test]
744 fn test_key_generate_hs384() {
745 let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
746 assert!(key.is_ok());
747 let key = key.unwrap();
748
749 assert_eq!(key.algorithm, Algorithm::HS384);
750
751 match key.key {
752 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 48),
753 _ => panic!("Expected OCT key"),
754 }
755 }
756
757 #[test]
758 fn test_key_generate_hs512() {
759 let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
760 assert!(key.is_ok());
761 let key = key.unwrap();
762
763 assert_eq!(key.algorithm, Algorithm::HS512);
764
765 match key.key {
766 KeyType::OCT { ref secret } => assert_eq!(secret.len(), 64),
767 _ => panic!("Expected OCT key"),
768 }
769 }
770
771 #[test]
772 fn test_key_generate_rs512() {
773 let key = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
774 assert!(key.is_ok());
775 let key = key.unwrap();
776
777 assert_eq!(key.algorithm, Algorithm::RS512);
778 assert!(matches!(key.key, KeyType::RSA { .. }));
779 match key.key {
780 KeyType::RSA {
781 ref public,
782 ref private,
783 } => {
784 assert!(private.is_some());
785 assert_eq!(public.n.len(), 256);
786 assert_eq!(public.e.len(), 3);
787 }
788 _ => panic!("Expected RSA key"),
789 }
790 }
791
792 #[test]
793 fn test_key_generate_es256() {
794 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
795 assert!(key.is_ok());
796 let key = key.unwrap();
797
798 assert_eq!(key.algorithm, Algorithm::ES256);
799 assert!(matches!(key.key, KeyType::EC { .. }))
800 }
801
802 #[test]
803 fn test_key_generate_ps512() {
804 let key = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
805 assert!(key.is_ok());
806 let key = key.unwrap();
807
808 assert_eq!(key.algorithm, Algorithm::PS512);
809 assert!(matches!(key.key, KeyType::RSA { .. }));
810 }
811
812 #[test]
813 fn test_key_generate_eddsa() {
814 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
815 assert!(key.is_ok());
816 let key = key.unwrap();
817
818 assert_eq!(key.algorithm, Algorithm::EdDSA);
819 assert!(matches!(key.key, KeyType::OKP { .. }));
820 }
821
822 #[test]
823 fn test_key_generate_without_id() {
824 let key = Key::generate(Algorithm::HS256, None);
825 assert!(key.is_ok());
826 let key = key.unwrap();
827
828 assert_eq!(key.algorithm, Algorithm::HS256);
829 assert_eq!(key.kid, None);
830 assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
831 }
832
833 #[test]
834 fn test_public_key_conversion_hmac() {
835 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string())).expect("HMAC key generation failed");
836
837 assert!(key.to_public().is_err());
838 }
839
840 #[test]
841 fn test_public_key_conversion_rsa() {
842 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
843 assert!(key.is_ok());
844 let key = key.unwrap();
845
846 let public_key = key.to_public().unwrap();
847 assert_eq!(key.kid, public_key.kid);
848 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
849 assert!(public_key.encode.get().is_none());
850 assert!(public_key.decode.get().is_none());
851 assert!(matches!(public_key.key, KeyType::RSA { .. }));
852
853 if let KeyType::RSA { public, private } = &public_key.key {
854 assert!(private.is_none());
855
856 if let KeyType::RSA { public: src_public, .. } = &key.key {
857 assert_eq!(public.e, src_public.e);
858 assert_eq!(public.n, src_public.n);
859 } else {
860 unreachable!("Expected RSA key")
861 }
862 } else {
863 unreachable!("Expected RSA key");
864 }
865 }
866
867 #[test]
868 fn test_public_key_conversion_es() {
869 let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
870 assert!(key.is_ok());
871 let key = key.unwrap();
872
873 let public_key = key.to_public().unwrap();
874 assert_eq!(key.kid, public_key.kid);
875 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
876 assert!(public_key.encode.get().is_none());
877 assert!(public_key.decode.get().is_none());
878 assert!(matches!(public_key.key, KeyType::EC { .. }));
879
880 if let KeyType::EC { x, y, d, curve } = &public_key.key {
881 assert!(d.is_none());
882
883 if let KeyType::EC {
884 x: src_x,
885 y: src_y,
886 curve: src_curve,
887 ..
888 } = &key.key
889 {
890 assert_eq!(x, src_x);
891 assert_eq!(y, src_y);
892 assert_eq!(curve, src_curve);
893 } else {
894 unreachable!("Expected EC key")
895 }
896 } else {
897 unreachable!("Expected EC key");
898 }
899 }
900
901 #[test]
902 fn test_public_key_conversion_ed() {
903 let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
904 assert!(key.is_ok());
905 let key = key.unwrap();
906
907 let public_key = key.to_public().unwrap();
908 assert_eq!(key.kid, public_key.kid);
909 assert_eq!(public_key.operations, [KeyOperation::Verify].into());
910 assert!(public_key.encode.get().is_none());
911 assert!(public_key.decode.get().is_none());
912 assert!(matches!(public_key.key, KeyType::OKP { .. }));
913
914 if let KeyType::OKP { x, d, curve } = &public_key.key {
915 assert!(d.is_none());
916
917 if let KeyType::OKP {
918 x: src_x,
919 curve: src_curve,
920 ..
921 } = &key.key
922 {
923 assert_eq!(x, src_x);
924 assert_eq!(curve, src_curve);
925 } else {
926 unreachable!("Expected OKP key")
927 }
928 } else {
929 unreachable!("Expected OKP key");
930 }
931 }
932
933 #[test]
934 fn test_key_generate_sign_verify_cycle() {
935 let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
936 assert!(key.is_ok());
937 let key = key.unwrap();
938
939 let claims = create_test_claims();
940
941 let token = key.encode(&claims).unwrap();
942 let verified_claims = key.decode(&token).unwrap();
943
944 assert_eq!(verified_claims.root, claims.root);
945 assert_eq!(verified_claims.publish, claims.publish);
946 assert_eq!(verified_claims.subscribe, claims.subscribe);
947 assert_eq!(verified_claims.cluster, claims.cluster);
948 }
949
950 #[test]
951 fn test_key_debug_no_secret() {
952 let key = create_test_key();
953 let debug_str = format!("{key:?}");
954
955 assert!(debug_str.contains("algorithm: HS256"));
956 assert!(debug_str.contains("operations"));
957 assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
958 assert!(!debug_str.contains("secret")); }
960
961 #[test]
962 fn test_key_operations_enum() {
963 let sign_op = KeyOperation::Sign;
964 let verify_op = KeyOperation::Verify;
965 let decrypt_op = KeyOperation::Decrypt;
966 let encrypt_op = KeyOperation::Encrypt;
967
968 assert_eq!(sign_op, KeyOperation::Sign);
969 assert_eq!(verify_op, KeyOperation::Verify);
970 assert_eq!(decrypt_op, KeyOperation::Decrypt);
971 assert_eq!(encrypt_op, KeyOperation::Encrypt);
972
973 assert_ne!(sign_op, verify_op);
974 assert_ne!(decrypt_op, encrypt_op);
975 }
976
977 #[test]
978 fn test_key_operations_serde() {
979 let operations = [KeyOperation::Sign, KeyOperation::Verify];
980 let json = serde_json::to_string(&operations).unwrap();
981 assert!(json.contains("\"sign\""));
982 assert!(json.contains("\"verify\""));
983
984 let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
985 assert_eq!(deserialized, operations);
986 }
987
988 #[test]
989 fn test_key_serde() {
990 let key = create_test_key();
991 let json = serde_json::to_string(&key).unwrap();
992 let deserialized: Key = serde_json::from_str(&json).unwrap();
993
994 assert_eq!(deserialized.algorithm, key.algorithm);
995 assert_eq!(deserialized.operations, key.operations);
996 assert_eq!(deserialized.kid, key.kid);
997
998 if let (
999 KeyType::OCT {
1000 secret: original_secret,
1001 },
1002 KeyType::OCT {
1003 secret: deserialized_secret,
1004 },
1005 ) = (&key.key, &deserialized.key)
1006 {
1007 assert_eq!(deserialized_secret, original_secret);
1008 } else {
1009 panic!("Expected both keys to be OCT variant");
1010 }
1011 }
1012
1013 #[test]
1014 fn test_key_clone() {
1015 let key = create_test_key();
1016 let cloned = key.clone();
1017
1018 assert_eq!(cloned.algorithm, key.algorithm);
1019 assert_eq!(cloned.operations, key.operations);
1020 assert_eq!(cloned.kid, key.kid);
1021
1022 if let (
1023 KeyType::OCT {
1024 secret: original_secret,
1025 },
1026 KeyType::OCT { secret: cloned_secret },
1027 ) = (&key.key, &cloned.key)
1028 {
1029 assert_eq!(cloned_secret, original_secret);
1030 } else {
1031 panic!("Expected both keys to be OCT variant");
1032 }
1033 }
1034
1035 #[test]
1036 fn test_hmac_algorithms() {
1037 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1038 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1039 let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
1040
1041 let claims = create_test_claims();
1042
1043 for key in [key_256, key_384, key_512] {
1045 assert!(key.is_ok());
1046 let key = key.unwrap();
1047
1048 let token = key.encode(&claims).unwrap();
1049 let verified_claims = key.decode(&token).unwrap();
1050 assert_eq!(verified_claims.root, claims.root);
1051 }
1052 }
1053
1054 #[test]
1055 fn test_rsa_pkcs1_asymmetric_algorithms() {
1056 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1057 let key_rs384 = Key::generate(Algorithm::RS384, Some("test-id".to_string()));
1058 let key_rs512 = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
1059
1060 for key in [key_rs256, key_rs384, key_rs512] {
1061 test_asymmetric_key(key);
1062 }
1063 }
1064
1065 #[test]
1066 fn test_rsa_pss_asymmetric_algorithms() {
1067 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1068 let key_ps384 = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1069 let key_ps512 = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
1070
1071 for key in [key_ps256, key_ps384, key_ps512] {
1072 test_asymmetric_key(key);
1073 }
1074 }
1075
1076 #[test]
1077 fn test_ec_asymmetric_algorithms() {
1078 let key_es256 = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
1079 let key_es384 = Key::generate(Algorithm::ES384, Some("test-id".to_string()));
1080
1081 for key in [key_es256, key_es384] {
1082 test_asymmetric_key(key);
1083 }
1084 }
1085
1086 #[test]
1087 fn test_ed_asymmetric_algorithms() {
1088 let key_eddsa = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
1089
1090 test_asymmetric_key(key_eddsa);
1091 }
1092
1093 fn test_asymmetric_key(key: anyhow::Result<Key>) {
1094 assert!(key.is_ok());
1095 let key = key.unwrap();
1096
1097 let claims = create_test_claims();
1098 let token = key.encode(&claims).unwrap();
1099
1100 let private_verified_claims = key.decode(&token).unwrap();
1101 assert_eq!(
1102 private_verified_claims.root, claims.root,
1103 "validation using private key"
1104 );
1105
1106 let public_verified_claims = key.to_public().unwrap().decode(&token).unwrap();
1107 assert_eq!(public_verified_claims.root, claims.root, "validation using public key");
1108 }
1109
1110 #[test]
1111 fn test_cross_algorithm_verification_fails() {
1112 let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1113 assert!(key_256.is_ok());
1114 let key_256 = key_256.unwrap();
1115
1116 let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1117 assert!(key_384.is_ok());
1118 let key_384 = key_384.unwrap();
1119
1120 let claims = create_test_claims();
1121 let token = key_256.encode(&claims).unwrap();
1122
1123 let result = key_384.decode(&token);
1125 assert!(result.is_err());
1126 }
1127
1128 #[test]
1129 fn test_asymmetric_cross_algorithm_verification_fails() {
1130 let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1131 assert!(key_rs256.is_ok());
1132 let key_rs256 = key_rs256.unwrap();
1133
1134 let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1135 assert!(key_ps256.is_ok());
1136 let key_ps256 = key_ps256.unwrap();
1137
1138 let claims = create_test_claims();
1139 let token = key_rs256.encode(&claims).unwrap();
1140
1141 let private_result = key_ps256.decode(&token);
1143 let public_result = key_ps256.to_public().unwrap().decode(&token);
1144 assert!(private_result.is_err());
1145 assert!(public_result.is_err());
1146 }
1147
1148 #[test]
1149 fn test_rsa_pkcs1_public_key_conversion() {
1150 let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1151 assert!(key.is_ok());
1152 let key = key.unwrap();
1153
1154 assert!(key.operations.contains(&KeyOperation::Sign));
1155 assert!(key.operations.contains(&KeyOperation::Verify));
1156
1157 let public_key = key.to_public().unwrap();
1158 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1159 assert!(public_key.operations.contains(&KeyOperation::Verify));
1160
1161 match key.key {
1162 KeyType::RSA {
1163 ref public,
1164 ref private,
1165 } => {
1166 assert!(private.is_some());
1167 assert_eq!(public.n.len(), 256);
1168 assert_eq!(public.e.len(), 3);
1169
1170 match public_key.key {
1171 KeyType::RSA {
1172 public: ref public_public,
1173 private: ref public_private,
1174 } => {
1175 assert!(public_private.is_none());
1176 assert_eq!(public.n, public_public.n);
1177 assert_eq!(public.e, public_public.e);
1178 }
1179 _ => panic!("Expected public key to be an RSA key"),
1180 }
1181 }
1182 _ => panic!("Expected private key to be an RSA key"),
1183 }
1184 }
1185
1186 #[test]
1187 fn test_rsa_pss_public_key_conversion() {
1188 let key = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1189 assert!(key.is_ok());
1190 let key = key.unwrap();
1191
1192 assert!(key.operations.contains(&KeyOperation::Sign));
1193 assert!(key.operations.contains(&KeyOperation::Verify));
1194
1195 let public_key = key.to_public().unwrap();
1196 assert!(!public_key.operations.contains(&KeyOperation::Sign));
1197 assert!(public_key.operations.contains(&KeyOperation::Verify));
1198
1199 match key.key {
1200 KeyType::RSA {
1201 ref public,
1202 ref private,
1203 } => {
1204 assert!(private.is_some());
1205 assert_eq!(public.n.len(), 256);
1206 assert_eq!(public.e.len(), 3);
1207
1208 match public_key.key {
1209 KeyType::RSA {
1210 public: ref public_public,
1211 private: ref public_private,
1212 } => {
1213 assert!(public_private.is_none());
1214 assert_eq!(public.n, public_public.n);
1215 assert_eq!(public.e, public_public.e);
1216 }
1217 _ => panic!("Expected public key to be an RSA key"),
1218 }
1219 }
1220 _ => panic!("Expected private key to be an RSA key"),
1221 }
1222 }
1223
1224 #[test]
1225 fn test_base64url_serialization() {
1226 let key = create_test_key();
1227 let json = serde_json::to_string(&key).unwrap();
1228
1229 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1231 let k_value = parsed["k"].as_str().unwrap();
1232
1233 assert!(!k_value.contains('='));
1235 assert!(!k_value.contains('+'));
1236 assert!(!k_value.contains('/'));
1237
1238 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1240 .decode(k_value)
1241 .unwrap();
1242
1243 if let KeyType::OCT {
1244 secret: original_secret,
1245 } = &key.key
1246 {
1247 assert_eq!(decoded, *original_secret);
1248 } else {
1249 panic!("Expected both keys to be OCT variant");
1250 }
1251 }
1252
1253 #[test]
1254 fn test_backwards_compatibility_unpadded_base64url() {
1255 let unpadded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
1257
1258 let key: Key = serde_json::from_str(unpadded_json).unwrap();
1260 assert_eq!(key.algorithm, Algorithm::HS256);
1261 assert_eq!(key.kid, Some("test-key-1".to_string()));
1262
1263 if let KeyType::OCT { secret } = &key.key {
1264 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1265 } else {
1266 panic!("Expected key to be OCT variant");
1267 }
1268 }
1269
1270 #[test]
1271 fn test_backwards_compatibility_padded_base64url() {
1272 let padded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
1274
1275 let key: Key = serde_json::from_str(padded_json).unwrap();
1277 assert_eq!(key.algorithm, Algorithm::HS256);
1278 assert_eq!(key.kid, Some("test-key-1".to_string()));
1279
1280 if let KeyType::OCT { secret } = &key.key {
1281 assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1282 } else {
1283 panic!("Expected key to be OCT variant");
1284 }
1285 }
1286
1287 #[test]
1288 fn test_file_io_base64url() {
1289 let key = create_test_key();
1290 let temp_dir = std::env::temp_dir();
1291 let temp_path = temp_dir.join("test_jwk.key");
1292
1293 key.to_file(&temp_path).unwrap();
1295
1296 let contents = std::fs::read_to_string(&temp_path).unwrap();
1298
1299 assert!(!contents.contains('{'));
1301 assert!(!contents.contains('}'));
1302 assert!(!contents.contains('"'));
1303
1304 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1306 .decode(&contents)
1307 .unwrap();
1308 let json_str = String::from_utf8(decoded).unwrap();
1309 let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1310
1311 let loaded_key = Key::from_file(&temp_path).unwrap();
1313 assert_eq!(loaded_key.algorithm, key.algorithm);
1314 assert_eq!(loaded_key.operations, key.operations);
1315 assert_eq!(loaded_key.kid, key.kid);
1316
1317 if let (
1318 KeyType::OCT {
1319 secret: original_secret,
1320 },
1321 KeyType::OCT { secret: loaded_secret },
1322 ) = (&key.key, &loaded_key.key)
1323 {
1324 assert_eq!(loaded_secret, original_secret);
1325 } else {
1326 panic!("Expected both keys to be OCT variant");
1327 }
1328
1329 std::fs::remove_file(temp_path).ok();
1331 }
1332}