moq_token/
key.rs

1use crate::generate::generate;
2use crate::{Algorithm, Claims};
3use anyhow::{bail, Context};
4use base64::Engine;
5use elliptic_curve::pkcs8::EncodePrivateKey;
6use elliptic_curve::SecretKey;
7use jsonwebtoken::{DecodingKey, EncodingKey, Header};
8use rsa::pkcs1::EncodeRsaPrivateKey;
9use rsa::BigUint;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::sync::OnceLock;
12use std::{collections::HashSet, fmt, path::Path as StdPath};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
15#[serde(rename_all = "camelCase")]
16pub enum KeyOperation {
17	Sign,
18	Verify,
19	Decrypt,
20	Encrypt,
21}
22
23/// https://datatracker.ietf.org/doc/html/rfc7518#section-6
24#[derive(Clone, Serialize, Deserialize)]
25#[serde(tag = "kty")]
26pub enum KeyType {
27	/// https://datatracker.ietf.org/doc/html/rfc7518#section-6.2
28	EC {
29		#[serde(rename = "crv")]
30		curve: EllipticCurve,
31		/// The X-coordinate of an EC key
32		#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
33		x: Vec<u8>,
34		/// The Y-coordinate of an EC key
35		#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
36		y: Vec<u8>,
37		/// The private value of an EC key
38		#[serde(
39			default,
40			skip_serializing_if = "Option::is_none",
41			serialize_with = "serialize_base64url_optional",
42			deserialize_with = "deserialize_base64url_optional"
43		)]
44		d: Option<Vec<u8>>,
45	},
46	/// https://datatracker.ietf.org/doc/html/rfc7518#section-6.3
47	RSA {
48		#[serde(flatten)]
49		public: RsaPublicKey,
50		#[serde(flatten, skip_serializing_if = "Option::is_none")]
51		private: Option<RsaPrivateKey>,
52	},
53	/// https://datatracker.ietf.org/doc/html/rfc7518#section-6.4
54	#[serde(rename = "oct")]
55	OCT {
56		/// The secret key as base64url (unpadded).
57		#[serde(
58			rename = "k",
59			default,
60			serialize_with = "serialize_base64url",
61			deserialize_with = "deserialize_base64url"
62		)]
63		secret: Vec<u8>,
64	},
65	/// https://datatracker.ietf.org/doc/html/rfc8037#section-2
66	OKP {
67		#[serde(rename = "crv")]
68		curve: EllipticCurve,
69		#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
70		x: Vec<u8>,
71		#[serde(
72			rename = "d",
73			default,
74			skip_serializing_if = "Option::is_none",
75			serialize_with = "serialize_base64url_optional",
76			deserialize_with = "deserialize_base64url_optional"
77		)]
78		d: Option<Vec<u8>>,
79	},
80}
81
82/// https://datatracker.ietf.org/doc/html/rfc7518#section-6.2.1.1
83#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
84pub enum EllipticCurve {
85	#[serde(rename = "P-256")]
86	P256,
87	#[serde(rename = "P-384")]
88	P384,
89	// jsonwebtoken doesn't support the ES512 algorithm, so we can't implement this
90	// #[serde(rename = "P-521")]
91	// P521,
92	#[serde(rename = "Ed25519")]
93	Ed25519,
94}
95
96/// https://datatracker.ietf.org/doc/html/rfc7518#section-6.3.1
97#[derive(Clone, Serialize, Deserialize)]
98pub struct RsaPublicKey {
99	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
100	pub n: Vec<u8>,
101	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
102	pub e: Vec<u8>,
103}
104
105/// https://datatracker.ietf.org/doc/html/rfc7518#section-6.3.2
106#[derive(Clone, Serialize, Deserialize)]
107pub struct RsaPrivateKey {
108	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
109	pub d: Vec<u8>,
110	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
111	pub p: Vec<u8>,
112	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
113	pub q: Vec<u8>,
114	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
115	pub dp: Vec<u8>,
116	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
117	pub dq: Vec<u8>,
118	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
119	pub qi: Vec<u8>,
120	#[serde(skip_serializing_if = "Option::is_none")]
121	pub oth: Option<Vec<RsaAdditionalPrime>>,
122}
123
124#[derive(Clone, Serialize, Deserialize)]
125pub struct RsaAdditionalPrime {
126	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
127	pub r: Vec<u8>,
128	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
129	pub d: Vec<u8>,
130	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
131	pub t: Vec<u8>,
132}
133
134/// JWK, almost to spec (https://datatracker.ietf.org/doc/html/rfc7517) but not quite the same
135/// because it's annoying to implement.
136#[derive(Clone, Serialize, Deserialize)]
137#[serde(remote = "Self")]
138pub struct Key {
139	/// The algorithm used by the key.
140	#[serde(rename = "alg")]
141	pub algorithm: Algorithm,
142
143	/// The operations that the key can perform.
144	#[serde(rename = "key_ops")]
145	pub operations: HashSet<KeyOperation>,
146
147	/// Defaults to KeyType::OCT
148	#[serde(flatten)]
149	pub key: KeyType,
150
151	/// The key ID, useful for rotating keys.
152	#[serde(skip_serializing_if = "Option::is_none")]
153	pub kid: Option<String>,
154
155	// Cached for performance reasons, unfortunately.
156	#[serde(skip)]
157	pub(crate) decode: OnceLock<DecodingKey>,
158
159	#[serde(skip)]
160	pub(crate) encode: OnceLock<EncodingKey>,
161}
162
163impl<'de> Deserialize<'de> for Key {
164	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165	where
166		D: Deserializer<'de>,
167	{
168		let mut value = serde_json::Value::deserialize(deserializer)?;
169
170		// Normally the "kty" parameter is required in a JWK: https://datatracker.ietf.org/doc/html/rfc7517#section-4.1
171		// But for backwards compatibility we need to default to "oct" because in a previous
172		// implementation the parameter was omitted, and we want to keep previously generated tokens valid
173		if let Some(obj) = value.as_object_mut() {
174			if !obj.contains_key("kty") {
175				obj.insert("kty".to_string(), serde_json::Value::String("oct".to_string()));
176			}
177		}
178
179		Self::deserialize(value).map_err(serde::de::Error::custom)
180	}
181}
182
183impl Serialize for Key {
184	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
185	where
186		S: Serializer,
187	{
188		Self::serialize(self, serializer)
189	}
190}
191
192impl fmt::Debug for Key {
193	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194		f.debug_struct("Key")
195			.field("algorithm", &self.algorithm)
196			.field("operations", &self.operations)
197			.field("kid", &self.kid)
198			.finish()
199	}
200}
201
202impl Key {
203	#[allow(clippy::should_implement_trait)]
204	pub fn from_str(s: &str) -> anyhow::Result<Self> {
205		Ok(serde_json::from_str(s)?)
206	}
207
208	pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
209		let contents = std::fs::read_to_string(&path)?;
210		// It's base64url encoded
211		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
212		let json = String::from_utf8(decoded)?;
213		Ok(serde_json::from_str(&json)?)
214	}
215
216	pub fn to_str(&self) -> anyhow::Result<String> {
217		Ok(serde_json::to_string(self)?)
218	}
219
220	pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
221		// Serialize to JSON first
222		let json = serde_json::to_string(self)?;
223		// Then encode as base64url
224		let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
225		std::fs::write(path, encoded)?;
226		Ok(())
227	}
228
229	pub fn to_public(&self) -> anyhow::Result<Self> {
230		if !self.operations.contains(&KeyOperation::Verify) {
231			return Err(anyhow::anyhow!("This key doesn't support the Verify operation"));
232		}
233
234		let key = match self.key {
235			KeyType::RSA { ref public, .. } => Ok(KeyType::RSA {
236				public: public.clone(),
237				private: None,
238			}),
239			KeyType::EC {
240				ref x,
241				ref y,
242				ref curve,
243				..
244			} => Ok(KeyType::EC {
245				x: x.clone(),
246				y: y.clone(),
247				curve: curve.clone(),
248				d: None,
249			}),
250			KeyType::OCT { .. } => Err(anyhow::anyhow!("OCT key cannot be converted to public key")),
251			KeyType::OKP { ref x, ref curve, .. } => Ok(KeyType::OKP {
252				x: x.clone(),
253				curve: curve.clone(),
254				d: None,
255			}),
256		};
257
258		match key {
259			Ok(key) => Ok(Self {
260				algorithm: self.algorithm,
261				operations: [KeyOperation::Verify].into(),
262				key,
263				kid: self.kid.clone(),
264				decode: Default::default(),
265				encode: Default::default(),
266			}),
267			Err(err) => Err(anyhow::anyhow!("Failed to convert key: {}", err)),
268		}
269	}
270
271	fn to_decoding_key(&self) -> anyhow::Result<&DecodingKey> {
272		if let Some(key) = self.decode.get() {
273			return Ok(key);
274		}
275
276		let decoding_key = match self.key {
277			KeyType::OCT { ref secret } => match self.algorithm {
278				Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret),
279				_ => bail!("Invalid algorithm for key type OCT"),
280			},
281			KeyType::EC {
282				ref curve,
283				ref x,
284				ref y,
285				..
286			} => match curve {
287				EllipticCurve::P256 => {
288					if self.algorithm != Algorithm::ES256 {
289						bail!("Invalid algorithm for P-256 curve");
290					}
291					if x.len() != 32 || y.len() != 32 {
292						bail!("Invalid coordinate length for P-256");
293					}
294
295					DecodingKey::from_ec_components(
296						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
297						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
298					)?
299				}
300				EllipticCurve::P384 => {
301					if self.algorithm != Algorithm::ES384 {
302						bail!("Invalid algorithm for P-384 curve");
303					}
304					if x.len() != 48 || y.len() != 48 {
305						bail!("Invalid coordinate length for P-384");
306					}
307
308					DecodingKey::from_ec_components(
309						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
310						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
311					)?
312				}
313				_ => bail!("Invalid curve for EC key"),
314			},
315			KeyType::OKP { ref curve, ref x, .. } => match curve {
316				EllipticCurve::Ed25519 => {
317					if self.algorithm != Algorithm::EdDSA {
318						bail!("Invalid algorithm for Ed25519 curve");
319					}
320
321					DecodingKey::from_ed_components(
322						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
323					)?
324				}
325				_ => bail!("Invalid curve for OKP key"),
326			},
327			KeyType::RSA { ref public, .. } => {
328				DecodingKey::from_rsa_raw_components(public.n.as_ref(), public.e.as_ref())
329			}
330		};
331
332		Ok(self.decode.get_or_init(|| decoding_key))
333	}
334
335	fn to_encoding_key(&self) -> anyhow::Result<&EncodingKey> {
336		if let Some(key) = self.encode.get() {
337			return Ok(key);
338		}
339
340		let encoding_key = match self.key {
341			KeyType::OCT { ref secret } => match self.algorithm {
342				Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret),
343				_ => bail!("Invalid algorithm for key type OCT"),
344			},
345			KeyType::EC { ref curve, ref d, .. } => {
346				let d = d.as_ref().context("Missing private key")?;
347
348				match curve {
349					EllipticCurve::P256 => {
350						let secret_key = SecretKey::<p256::NistP256>::from_slice(d)?;
351						let doc = secret_key.to_pkcs8_der()?;
352						EncodingKey::from_ec_der(doc.as_bytes())
353					}
354					EllipticCurve::P384 => {
355						let secret_key = SecretKey::<p384::NistP384>::from_slice(d)?;
356						let doc = secret_key.to_pkcs8_der()?;
357						EncodingKey::from_ec_der(doc.as_bytes())
358					}
359					_ => bail!("Invalid curve for EC key"),
360				}
361			}
362			KeyType::OKP {
363				ref curve,
364				ref d,
365				ref x,
366			} => {
367				let d = d.as_ref().context("Missing private key")?;
368
369				let key_pair =
370					aws_lc_rs::signature::Ed25519KeyPair::from_seed_and_public_key(d.as_slice(), x.as_slice())?;
371
372				match curve {
373					EllipticCurve::Ed25519 => EncodingKey::from_ed_der(key_pair.to_pkcs8()?.as_ref()),
374					_ => bail!("Invalid curve for OKP key"),
375				}
376			}
377			KeyType::RSA {
378				ref public,
379				ref private,
380			} => {
381				let n = BigUint::from_bytes_be(&public.n);
382				let e = BigUint::from_bytes_be(&public.e);
383				let private = private.as_ref().context("Missing private key")?;
384				let d = BigUint::from_bytes_be(&private.d);
385				let p = BigUint::from_bytes_be(&private.p);
386				let q = BigUint::from_bytes_be(&private.q);
387
388				let rsa = rsa::RsaPrivateKey::from_components(n, e, d, vec![p, q]);
389				let pem = rsa?.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF);
390
391				EncodingKey::from_rsa_pem(pem?.as_bytes())?
392			}
393		};
394
395		Ok(self.encode.get_or_init(|| encoding_key))
396	}
397
398	pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
399		if !self.operations.contains(&KeyOperation::Verify) {
400			bail!("key does not support verification");
401		}
402
403		let decode: anyhow::Result<&DecodingKey> = self.to_decoding_key();
404
405		match decode {
406			Ok(decode) => {
407				let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
408				validation.required_spec_claims = Default::default(); // Don't require exp, but still validate it if present
409
410				let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
411				token.claims.validate()?;
412
413				Ok(token.claims)
414			}
415			Err(e) => Err(anyhow::anyhow!("Failed to decode key: {}", e)),
416		}
417	}
418
419	pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
420		if !self.operations.contains(&KeyOperation::Sign) {
421			bail!("key does not support signing");
422		}
423
424		payload.validate()?;
425
426		let encode: anyhow::Result<&EncodingKey> = self.to_encoding_key();
427
428		match encode {
429			Ok(encode) => {
430				let mut header = Header::new(self.algorithm.into());
431				header.kid = self.kid.clone();
432				let token = jsonwebtoken::encode(&header, &payload, encode)?;
433				Ok(token)
434			}
435			Err(e) => Err(anyhow::anyhow!("Failed to encode key: {}", e)),
436		}
437	}
438
439	/// Generate a key pair for the given algorithm, returning the private and public keys.
440	pub fn generate(algorithm: Algorithm, id: Option<String>) -> anyhow::Result<Self> {
441		generate(algorithm, id)
442	}
443}
444
445/// Serialize bytes as base64url without padding
446fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
447where
448	S: Serializer,
449{
450	let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
451	serializer.serialize_str(&encoded)
452}
453
454fn serialize_base64url_optional<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
455where
456	S: Serializer,
457{
458	match bytes {
459		Some(b) => serialize_base64url(b, serializer),
460		None => serializer.serialize_none(),
461	}
462}
463
464/// Deserialize base64url string to bytes, supporting both padded and unpadded formats for backwards compatibility
465fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
466where
467	D: Deserializer<'de>,
468{
469	let s = String::deserialize(deserializer)?;
470
471	// Try to decode as unpadded base64url first (preferred format)
472	base64::engine::general_purpose::URL_SAFE_NO_PAD
473		.decode(&s)
474		.or_else(|_| {
475			// Fall back to padded base64url for backwards compatibility
476			base64::engine::general_purpose::URL_SAFE.decode(&s)
477		})
478		.map_err(serde::de::Error::custom)
479}
480
481fn deserialize_base64url_optional<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
482where
483	D: Deserializer<'de>,
484{
485	let s: Option<String> = Option::deserialize(deserializer)?;
486	match s {
487		Some(s) => {
488			let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
489				.decode(&s)
490				.or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(&s))
491				.map_err(serde::de::Error::custom)?;
492			Ok(Some(decoded))
493		}
494		None => Ok(None),
495	}
496}
497
498#[cfg(test)]
499mod tests {
500	use super::*;
501	use std::time::{Duration, SystemTime};
502
503	fn create_test_key() -> Key {
504		Key {
505			algorithm: Algorithm::HS256,
506			operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
507			key: KeyType::OCT {
508				secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
509			},
510			kid: Some("test-key-1".to_string()),
511			decode: Default::default(),
512			encode: Default::default(),
513		}
514	}
515
516	fn create_test_claims() -> Claims {
517		Claims {
518			root: "test-path".to_string(),
519			publish: vec!["test-pub".into()],
520			cluster: false,
521			subscribe: vec!["test-sub".into()],
522			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
523			issued: Some(SystemTime::now()),
524		}
525	}
526
527	#[test]
528	fn test_key_from_str_valid() {
529		let key = create_test_key();
530		let json = key.to_str().unwrap();
531		let loaded_key = Key::from_str(&json).unwrap();
532
533		assert_eq!(loaded_key.algorithm, key.algorithm);
534		assert_eq!(loaded_key.operations, key.operations);
535		match (loaded_key.key, key.key) {
536			(KeyType::OCT { secret: loaded_secret }, KeyType::OCT { secret }) => {
537				assert_eq!(loaded_secret, secret);
538			}
539			_ => panic!("Expected OCT key"),
540		}
541		assert_eq!(loaded_key.kid, key.kid);
542	}
543
544	/// Tests whether Key::from_str() works for keys without a kty value to fall back to OCT
545	#[test]
546	fn test_key_oct_backwards_compatibility() {
547		let json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68"}"#;
548		let key = Key::from_str(json);
549
550		assert!(key.is_ok());
551		let key = key.unwrap();
552
553		if let KeyType::OCT { ref secret, .. } = key.key {
554			let base64_key = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(secret);
555			assert_eq!(base64_key, "Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68");
556		} else {
557			panic!("Expected OCT key");
558		}
559
560		let key_str = key.to_str();
561		assert!(key_str.is_ok());
562		let key_str = key_str.unwrap();
563
564		// After serializing again it must contain the kty
565		assert!(key_str.contains("\"alg\":\"HS256\""));
566		assert!(key_str.contains("\"key_ops\""));
567		assert!(key_str.contains("\"sign\""));
568		assert!(key_str.contains("\"verify\""));
569		assert!(key_str.contains("\"kty\":\"oct\""));
570	}
571
572	#[test]
573	fn test_key_from_str_invalid_json() {
574		let result = Key::from_str("invalid json");
575		assert!(result.is_err());
576	}
577
578	#[test]
579	fn test_key_to_str() {
580		let key = create_test_key();
581		let json = key.to_str().unwrap();
582		assert!(json.contains("\"alg\":\"HS256\""));
583		assert!(json.contains("\"key_ops\""));
584		assert!(json.contains("\"sign\""));
585		assert!(json.contains("\"verify\""));
586		assert!(json.contains("\"kid\":\"test-key-1\""));
587		assert!(json.contains("\"kty\":\"oct\""));
588	}
589
590	#[test]
591	fn test_key_sign_success() {
592		let key = create_test_key();
593		let claims = create_test_claims();
594		let token = key.encode(&claims).unwrap();
595
596		assert!(!token.is_empty());
597		assert_eq!(token.matches('.').count(), 2); // JWT format: header.payload.signature
598	}
599
600	#[test]
601	fn test_key_sign_no_permission() {
602		let mut key = create_test_key();
603		key.operations = [KeyOperation::Verify].into();
604		let claims = create_test_claims();
605
606		let result = key.encode(&claims);
607		assert!(result.is_err());
608		assert!(result.unwrap_err().to_string().contains("key does not support signing"));
609	}
610
611	#[test]
612	fn test_key_sign_invalid_claims() {
613		let key = create_test_key();
614		let invalid_claims = Claims {
615			root: "test-path".to_string(),
616			publish: vec![],
617			subscribe: vec![],
618			cluster: false,
619			expires: None,
620			issued: None,
621		};
622
623		let result = key.encode(&invalid_claims);
624		assert!(result.is_err());
625		assert!(result
626			.unwrap_err()
627			.to_string()
628			.contains("no publish or subscribe allowed; token is useless"));
629	}
630
631	#[test]
632	fn test_key_verify_success() {
633		let key = create_test_key();
634		let claims = create_test_claims();
635		let token = key.encode(&claims).unwrap();
636
637		let verified_claims = key.decode(&token).unwrap();
638		assert_eq!(verified_claims.root, claims.root);
639		assert_eq!(verified_claims.publish, claims.publish);
640		assert_eq!(verified_claims.subscribe, claims.subscribe);
641		assert_eq!(verified_claims.cluster, claims.cluster);
642	}
643
644	#[test]
645	fn test_key_verify_no_permission() {
646		let mut key = create_test_key();
647		key.operations = [KeyOperation::Sign].into();
648
649		let result = key.decode("some.jwt.token");
650		assert!(result.is_err());
651		assert!(result
652			.unwrap_err()
653			.to_string()
654			.contains("key does not support verification"));
655	}
656
657	#[test]
658	fn test_key_verify_invalid_token() {
659		let key = create_test_key();
660		let result = key.decode("invalid-token");
661		assert!(result.is_err());
662	}
663
664	#[test]
665	fn test_key_verify_path_mismatch() {
666		let key = create_test_key();
667		let claims = create_test_claims();
668		let token = key.encode(&claims).unwrap();
669
670		// This test was expecting a path mismatch error, but now decode succeeds
671		let result = key.decode(&token);
672		assert!(result.is_ok());
673	}
674
675	#[test]
676	fn test_key_verify_expired_token() {
677		let key = create_test_key();
678		let mut claims = create_test_claims();
679		claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); // 1 hour ago
680		let token = key.encode(&claims).unwrap();
681
682		let result = key.decode(&token);
683		assert!(result.is_err());
684	}
685
686	#[test]
687	fn test_key_verify_token_without_exp() {
688		let key = create_test_key();
689		let claims = Claims {
690			root: "test-path".to_string(),
691			publish: vec!["".to_string()],
692			subscribe: vec!["".to_string()],
693			cluster: false,
694			expires: None,
695			issued: None,
696		};
697		let token = key.encode(&claims).unwrap();
698
699		let verified_claims = key.decode(&token).unwrap();
700		assert_eq!(verified_claims.root, claims.root);
701		assert_eq!(verified_claims.publish, claims.publish);
702		assert_eq!(verified_claims.subscribe, claims.subscribe);
703		assert_eq!(verified_claims.expires, None);
704	}
705
706	#[test]
707	fn test_key_round_trip() {
708		let key = create_test_key();
709		let original_claims = Claims {
710			root: "test-path".to_string(),
711			publish: vec!["test-pub".into()],
712			subscribe: vec!["test-sub".into()],
713			cluster: true,
714			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
715			issued: Some(SystemTime::now()),
716		};
717
718		let token = key.encode(&original_claims).unwrap();
719		let verified_claims = key.decode(&token).unwrap();
720
721		assert_eq!(verified_claims.root, original_claims.root);
722		assert_eq!(verified_claims.publish, original_claims.publish);
723		assert_eq!(verified_claims.subscribe, original_claims.subscribe);
724		assert_eq!(verified_claims.cluster, original_claims.cluster);
725	}
726
727	#[test]
728	fn test_key_generate_hs256() {
729		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
730		assert!(key.is_ok());
731		let key = key.unwrap();
732
733		assert_eq!(key.algorithm, Algorithm::HS256);
734		assert_eq!(key.kid, Some("test-id".to_string()));
735		assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
736
737		match key.key {
738			KeyType::OCT { ref secret } => assert_eq!(secret.len(), 32),
739			_ => panic!("Expected OCT key"),
740		}
741	}
742
743	#[test]
744	fn test_key_generate_hs384() {
745		let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
746		assert!(key.is_ok());
747		let key = key.unwrap();
748
749		assert_eq!(key.algorithm, Algorithm::HS384);
750
751		match key.key {
752			KeyType::OCT { ref secret } => assert_eq!(secret.len(), 48),
753			_ => panic!("Expected OCT key"),
754		}
755	}
756
757	#[test]
758	fn test_key_generate_hs512() {
759		let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
760		assert!(key.is_ok());
761		let key = key.unwrap();
762
763		assert_eq!(key.algorithm, Algorithm::HS512);
764
765		match key.key {
766			KeyType::OCT { ref secret } => assert_eq!(secret.len(), 64),
767			_ => panic!("Expected OCT key"),
768		}
769	}
770
771	#[test]
772	fn test_key_generate_rs512() {
773		let key = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
774		assert!(key.is_ok());
775		let key = key.unwrap();
776
777		assert_eq!(key.algorithm, Algorithm::RS512);
778		assert!(matches!(key.key, KeyType::RSA { .. }));
779		match key.key {
780			KeyType::RSA {
781				ref public,
782				ref private,
783			} => {
784				assert!(private.is_some());
785				assert_eq!(public.n.len(), 256);
786				assert_eq!(public.e.len(), 3);
787			}
788			_ => panic!("Expected RSA key"),
789		}
790	}
791
792	#[test]
793	fn test_key_generate_es256() {
794		let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
795		assert!(key.is_ok());
796		let key = key.unwrap();
797
798		assert_eq!(key.algorithm, Algorithm::ES256);
799		assert!(matches!(key.key, KeyType::EC { .. }))
800	}
801
802	#[test]
803	fn test_key_generate_ps512() {
804		let key = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
805		assert!(key.is_ok());
806		let key = key.unwrap();
807
808		assert_eq!(key.algorithm, Algorithm::PS512);
809		assert!(matches!(key.key, KeyType::RSA { .. }));
810	}
811
812	#[test]
813	fn test_key_generate_eddsa() {
814		let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
815		assert!(key.is_ok());
816		let key = key.unwrap();
817
818		assert_eq!(key.algorithm, Algorithm::EdDSA);
819		assert!(matches!(key.key, KeyType::OKP { .. }));
820	}
821
822	#[test]
823	fn test_key_generate_without_id() {
824		let key = Key::generate(Algorithm::HS256, None);
825		assert!(key.is_ok());
826		let key = key.unwrap();
827
828		assert_eq!(key.algorithm, Algorithm::HS256);
829		assert_eq!(key.kid, None);
830		assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
831	}
832
833	#[test]
834	fn test_public_key_conversion_hmac() {
835		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string())).expect("HMAC key generation failed");
836
837		assert!(key.to_public().is_err());
838	}
839
840	#[test]
841	fn test_public_key_conversion_rsa() {
842		let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
843		assert!(key.is_ok());
844		let key = key.unwrap();
845
846		let public_key = key.to_public().unwrap();
847		assert_eq!(key.kid, public_key.kid);
848		assert_eq!(public_key.operations, [KeyOperation::Verify].into());
849		assert!(public_key.encode.get().is_none());
850		assert!(public_key.decode.get().is_none());
851		assert!(matches!(public_key.key, KeyType::RSA { .. }));
852
853		if let KeyType::RSA { public, private } = &public_key.key {
854			assert!(private.is_none());
855
856			if let KeyType::RSA { public: src_public, .. } = &key.key {
857				assert_eq!(public.e, src_public.e);
858				assert_eq!(public.n, src_public.n);
859			} else {
860				unreachable!("Expected RSA key")
861			}
862		} else {
863			unreachable!("Expected RSA key");
864		}
865	}
866
867	#[test]
868	fn test_public_key_conversion_es() {
869		let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
870		assert!(key.is_ok());
871		let key = key.unwrap();
872
873		let public_key = key.to_public().unwrap();
874		assert_eq!(key.kid, public_key.kid);
875		assert_eq!(public_key.operations, [KeyOperation::Verify].into());
876		assert!(public_key.encode.get().is_none());
877		assert!(public_key.decode.get().is_none());
878		assert!(matches!(public_key.key, KeyType::EC { .. }));
879
880		if let KeyType::EC { x, y, d, curve } = &public_key.key {
881			assert!(d.is_none());
882
883			if let KeyType::EC {
884				x: src_x,
885				y: src_y,
886				curve: src_curve,
887				..
888			} = &key.key
889			{
890				assert_eq!(x, src_x);
891				assert_eq!(y, src_y);
892				assert_eq!(curve, src_curve);
893			} else {
894				unreachable!("Expected EC key")
895			}
896		} else {
897			unreachable!("Expected EC key");
898		}
899	}
900
901	#[test]
902	fn test_public_key_conversion_ed() {
903		let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
904		assert!(key.is_ok());
905		let key = key.unwrap();
906
907		let public_key = key.to_public().unwrap();
908		assert_eq!(key.kid, public_key.kid);
909		assert_eq!(public_key.operations, [KeyOperation::Verify].into());
910		assert!(public_key.encode.get().is_none());
911		assert!(public_key.decode.get().is_none());
912		assert!(matches!(public_key.key, KeyType::OKP { .. }));
913
914		if let KeyType::OKP { x, d, curve } = &public_key.key {
915			assert!(d.is_none());
916
917			if let KeyType::OKP {
918				x: src_x,
919				curve: src_curve,
920				..
921			} = &key.key
922			{
923				assert_eq!(x, src_x);
924				assert_eq!(curve, src_curve);
925			} else {
926				unreachable!("Expected OKP key")
927			}
928		} else {
929			unreachable!("Expected OKP key");
930		}
931	}
932
933	#[test]
934	fn test_key_generate_sign_verify_cycle() {
935		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
936		assert!(key.is_ok());
937		let key = key.unwrap();
938
939		let claims = create_test_claims();
940
941		let token = key.encode(&claims).unwrap();
942		let verified_claims = key.decode(&token).unwrap();
943
944		assert_eq!(verified_claims.root, claims.root);
945		assert_eq!(verified_claims.publish, claims.publish);
946		assert_eq!(verified_claims.subscribe, claims.subscribe);
947		assert_eq!(verified_claims.cluster, claims.cluster);
948	}
949
950	#[test]
951	fn test_key_debug_no_secret() {
952		let key = create_test_key();
953		let debug_str = format!("{key:?}");
954
955		assert!(debug_str.contains("algorithm: HS256"));
956		assert!(debug_str.contains("operations"));
957		assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
958		assert!(!debug_str.contains("secret")); // Should not contain secret
959	}
960
961	#[test]
962	fn test_key_operations_enum() {
963		let sign_op = KeyOperation::Sign;
964		let verify_op = KeyOperation::Verify;
965		let decrypt_op = KeyOperation::Decrypt;
966		let encrypt_op = KeyOperation::Encrypt;
967
968		assert_eq!(sign_op, KeyOperation::Sign);
969		assert_eq!(verify_op, KeyOperation::Verify);
970		assert_eq!(decrypt_op, KeyOperation::Decrypt);
971		assert_eq!(encrypt_op, KeyOperation::Encrypt);
972
973		assert_ne!(sign_op, verify_op);
974		assert_ne!(decrypt_op, encrypt_op);
975	}
976
977	#[test]
978	fn test_key_operations_serde() {
979		let operations = [KeyOperation::Sign, KeyOperation::Verify];
980		let json = serde_json::to_string(&operations).unwrap();
981		assert!(json.contains("\"sign\""));
982		assert!(json.contains("\"verify\""));
983
984		let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
985		assert_eq!(deserialized, operations);
986	}
987
988	#[test]
989	fn test_key_serde() {
990		let key = create_test_key();
991		let json = serde_json::to_string(&key).unwrap();
992		let deserialized: Key = serde_json::from_str(&json).unwrap();
993
994		assert_eq!(deserialized.algorithm, key.algorithm);
995		assert_eq!(deserialized.operations, key.operations);
996		assert_eq!(deserialized.kid, key.kid);
997
998		if let (
999			KeyType::OCT {
1000				secret: original_secret,
1001			},
1002			KeyType::OCT {
1003				secret: deserialized_secret,
1004			},
1005		) = (&key.key, &deserialized.key)
1006		{
1007			assert_eq!(deserialized_secret, original_secret);
1008		} else {
1009			panic!("Expected both keys to be OCT variant");
1010		}
1011	}
1012
1013	#[test]
1014	fn test_key_clone() {
1015		let key = create_test_key();
1016		let cloned = key.clone();
1017
1018		assert_eq!(cloned.algorithm, key.algorithm);
1019		assert_eq!(cloned.operations, key.operations);
1020		assert_eq!(cloned.kid, key.kid);
1021
1022		if let (
1023			KeyType::OCT {
1024				secret: original_secret,
1025			},
1026			KeyType::OCT { secret: cloned_secret },
1027		) = (&key.key, &cloned.key)
1028		{
1029			assert_eq!(cloned_secret, original_secret);
1030		} else {
1031			panic!("Expected both keys to be OCT variant");
1032		}
1033	}
1034
1035	#[test]
1036	fn test_hmac_algorithms() {
1037		let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1038		let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1039		let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
1040
1041		let claims = create_test_claims();
1042
1043		// Test that each algorithm can sign and verify
1044		for key in [key_256, key_384, key_512] {
1045			assert!(key.is_ok());
1046			let key = key.unwrap();
1047
1048			let token = key.encode(&claims).unwrap();
1049			let verified_claims = key.decode(&token).unwrap();
1050			assert_eq!(verified_claims.root, claims.root);
1051		}
1052	}
1053
1054	#[test]
1055	fn test_rsa_pkcs1_asymmetric_algorithms() {
1056		let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1057		let key_rs384 = Key::generate(Algorithm::RS384, Some("test-id".to_string()));
1058		let key_rs512 = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
1059
1060		for key in [key_rs256, key_rs384, key_rs512] {
1061			test_asymmetric_key(key);
1062		}
1063	}
1064
1065	#[test]
1066	fn test_rsa_pss_asymmetric_algorithms() {
1067		let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1068		let key_ps384 = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1069		let key_ps512 = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
1070
1071		for key in [key_ps256, key_ps384, key_ps512] {
1072			test_asymmetric_key(key);
1073		}
1074	}
1075
1076	#[test]
1077	fn test_ec_asymmetric_algorithms() {
1078		let key_es256 = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
1079		let key_es384 = Key::generate(Algorithm::ES384, Some("test-id".to_string()));
1080
1081		for key in [key_es256, key_es384] {
1082			test_asymmetric_key(key);
1083		}
1084	}
1085
1086	#[test]
1087	fn test_ed_asymmetric_algorithms() {
1088		let key_eddsa = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
1089
1090		test_asymmetric_key(key_eddsa);
1091	}
1092
1093	fn test_asymmetric_key(key: anyhow::Result<Key>) {
1094		assert!(key.is_ok());
1095		let key = key.unwrap();
1096
1097		let claims = create_test_claims();
1098		let token = key.encode(&claims).unwrap();
1099
1100		let private_verified_claims = key.decode(&token).unwrap();
1101		assert_eq!(
1102			private_verified_claims.root, claims.root,
1103			"validation using private key"
1104		);
1105
1106		let public_verified_claims = key.to_public().unwrap().decode(&token).unwrap();
1107		assert_eq!(public_verified_claims.root, claims.root, "validation using public key");
1108	}
1109
1110	#[test]
1111	fn test_cross_algorithm_verification_fails() {
1112		let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1113		assert!(key_256.is_ok());
1114		let key_256 = key_256.unwrap();
1115
1116		let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1117		assert!(key_384.is_ok());
1118		let key_384 = key_384.unwrap();
1119
1120		let claims = create_test_claims();
1121		let token = key_256.encode(&claims).unwrap();
1122
1123		// Different algorithm should fail verification
1124		let result = key_384.decode(&token);
1125		assert!(result.is_err());
1126	}
1127
1128	#[test]
1129	fn test_asymmetric_cross_algorithm_verification_fails() {
1130		let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1131		assert!(key_rs256.is_ok());
1132		let key_rs256 = key_rs256.unwrap();
1133
1134		let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1135		assert!(key_ps256.is_ok());
1136		let key_ps256 = key_ps256.unwrap();
1137
1138		let claims = create_test_claims();
1139		let token = key_rs256.encode(&claims).unwrap();
1140
1141		// Different algorithm should fail verification
1142		let private_result = key_ps256.decode(&token);
1143		let public_result = key_ps256.to_public().unwrap().decode(&token);
1144		assert!(private_result.is_err());
1145		assert!(public_result.is_err());
1146	}
1147
1148	#[test]
1149	fn test_rsa_pkcs1_public_key_conversion() {
1150		let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1151		assert!(key.is_ok());
1152		let key = key.unwrap();
1153
1154		assert!(key.operations.contains(&KeyOperation::Sign));
1155		assert!(key.operations.contains(&KeyOperation::Verify));
1156
1157		let public_key = key.to_public().unwrap();
1158		assert!(!public_key.operations.contains(&KeyOperation::Sign));
1159		assert!(public_key.operations.contains(&KeyOperation::Verify));
1160
1161		match key.key {
1162			KeyType::RSA {
1163				ref public,
1164				ref private,
1165			} => {
1166				assert!(private.is_some());
1167				assert_eq!(public.n.len(), 256);
1168				assert_eq!(public.e.len(), 3);
1169
1170				match public_key.key {
1171					KeyType::RSA {
1172						public: ref public_public,
1173						private: ref public_private,
1174					} => {
1175						assert!(public_private.is_none());
1176						assert_eq!(public.n, public_public.n);
1177						assert_eq!(public.e, public_public.e);
1178					}
1179					_ => panic!("Expected public key to be an RSA key"),
1180				}
1181			}
1182			_ => panic!("Expected private key to be an RSA key"),
1183		}
1184	}
1185
1186	#[test]
1187	fn test_rsa_pss_public_key_conversion() {
1188		let key = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1189		assert!(key.is_ok());
1190		let key = key.unwrap();
1191
1192		assert!(key.operations.contains(&KeyOperation::Sign));
1193		assert!(key.operations.contains(&KeyOperation::Verify));
1194
1195		let public_key = key.to_public().unwrap();
1196		assert!(!public_key.operations.contains(&KeyOperation::Sign));
1197		assert!(public_key.operations.contains(&KeyOperation::Verify));
1198
1199		match key.key {
1200			KeyType::RSA {
1201				ref public,
1202				ref private,
1203			} => {
1204				assert!(private.is_some());
1205				assert_eq!(public.n.len(), 256);
1206				assert_eq!(public.e.len(), 3);
1207
1208				match public_key.key {
1209					KeyType::RSA {
1210						public: ref public_public,
1211						private: ref public_private,
1212					} => {
1213						assert!(public_private.is_none());
1214						assert_eq!(public.n, public_public.n);
1215						assert_eq!(public.e, public_public.e);
1216					}
1217					_ => panic!("Expected public key to be an RSA key"),
1218				}
1219			}
1220			_ => panic!("Expected private key to be an RSA key"),
1221		}
1222	}
1223
1224	#[test]
1225	fn test_base64url_serialization() {
1226		let key = create_test_key();
1227		let json = serde_json::to_string(&key).unwrap();
1228
1229		// Check that the secret is base64url encoded without padding
1230		let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1231		let k_value = parsed["k"].as_str().unwrap();
1232
1233		// Base64url should not contain padding characters
1234		assert!(!k_value.contains('='));
1235		assert!(!k_value.contains('+'));
1236		assert!(!k_value.contains('/'));
1237
1238		// Verify it decodes correctly
1239		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1240			.decode(k_value)
1241			.unwrap();
1242
1243		if let KeyType::OCT {
1244			secret: original_secret,
1245		} = &key.key
1246		{
1247			assert_eq!(decoded, *original_secret);
1248		} else {
1249			panic!("Expected both keys to be OCT variant");
1250		}
1251	}
1252
1253	#[test]
1254	fn test_backwards_compatibility_unpadded_base64url() {
1255		// Create a JSON with unpadded base64url (new format)
1256		let unpadded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
1257
1258		// Should be able to deserialize new format
1259		let key: Key = serde_json::from_str(unpadded_json).unwrap();
1260		assert_eq!(key.algorithm, Algorithm::HS256);
1261		assert_eq!(key.kid, Some("test-key-1".to_string()));
1262
1263		if let KeyType::OCT { secret } = &key.key {
1264			assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1265		} else {
1266			panic!("Expected key to be OCT variant");
1267		}
1268	}
1269
1270	#[test]
1271	fn test_backwards_compatibility_padded_base64url() {
1272		// Create a JSON with padded base64url (old format) - same secret but with padding
1273		let padded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
1274
1275		// Should be able to deserialize old format for backwards compatibility
1276		let key: Key = serde_json::from_str(padded_json).unwrap();
1277		assert_eq!(key.algorithm, Algorithm::HS256);
1278		assert_eq!(key.kid, Some("test-key-1".to_string()));
1279
1280		if let KeyType::OCT { secret } = &key.key {
1281			assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1282		} else {
1283			panic!("Expected key to be OCT variant");
1284		}
1285	}
1286
1287	#[test]
1288	fn test_file_io_base64url() {
1289		let key = create_test_key();
1290		let temp_dir = std::env::temp_dir();
1291		let temp_path = temp_dir.join("test_jwk.key");
1292
1293		// Write key to file
1294		key.to_file(&temp_path).unwrap();
1295
1296		// Read file contents
1297		let contents = std::fs::read_to_string(&temp_path).unwrap();
1298
1299		// Should be base64url encoded
1300		assert!(!contents.contains('{'));
1301		assert!(!contents.contains('}'));
1302		assert!(!contents.contains('"'));
1303
1304		// Decode and verify it's valid JSON
1305		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1306			.decode(&contents)
1307			.unwrap();
1308		let json_str = String::from_utf8(decoded).unwrap();
1309		let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1310
1311		// Read key back from file
1312		let loaded_key = Key::from_file(&temp_path).unwrap();
1313		assert_eq!(loaded_key.algorithm, key.algorithm);
1314		assert_eq!(loaded_key.operations, key.operations);
1315		assert_eq!(loaded_key.kid, key.kid);
1316
1317		if let (
1318			KeyType::OCT {
1319				secret: original_secret,
1320			},
1321			KeyType::OCT { secret: loaded_secret },
1322		) = (&key.key, &loaded_key.key)
1323		{
1324			assert_eq!(loaded_secret, original_secret);
1325		} else {
1326			panic!("Expected both keys to be OCT variant");
1327		}
1328
1329		// Clean up
1330		std::fs::remove_file(temp_path).ok();
1331	}
1332}