Skip to main content

moq_token/
key.rs

1use crate::generate::generate;
2use crate::{Algorithm, Claims};
3use anyhow::{Context, bail};
4use base64::Engine;
5use elliptic_curve::SecretKey;
6use elliptic_curve::pkcs8::EncodePrivateKey;
7use jsonwebtoken::{DecodingKey, EncodingKey, Header};
8use rsa::BigUint;
9use rsa::pkcs1::EncodeRsaPrivateKey;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use std::sync::OnceLock;
12use std::{collections::HashSet, fmt, path::Path as StdPath};
13
14/// Cryptographic operations that a key can perform.
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
16#[serde(rename_all = "camelCase")]
17pub enum KeyOperation {
18	Sign,
19	Verify,
20	Decrypt,
21	Encrypt,
22}
23
24/// <https://datatracker.ietf.org/doc/html/rfc7518#section-6>
25#[derive(Clone, Serialize, Deserialize)]
26#[serde(tag = "kty")]
27pub enum KeyType {
28	/// <https://datatracker.ietf.org/doc/html/rfc7518#section-6.2>
29	EC {
30		#[serde(rename = "crv")]
31		curve: EllipticCurve,
32		/// The X-coordinate of an EC key
33		#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
34		x: Vec<u8>,
35		/// The Y-coordinate of an EC key
36		#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
37		y: Vec<u8>,
38		/// The private value of an EC key
39		#[serde(
40			default,
41			skip_serializing_if = "Option::is_none",
42			serialize_with = "serialize_base64url_optional",
43			deserialize_with = "deserialize_base64url_optional"
44		)]
45		d: Option<Vec<u8>>,
46	},
47	/// <https://datatracker.ietf.org/doc/html/rfc7518#section-6.3>
48	RSA {
49		#[serde(flatten)]
50		public: RsaPublicKey,
51		#[serde(flatten, skip_serializing_if = "Option::is_none")]
52		private: Option<RsaPrivateKey>,
53	},
54	/// <https://datatracker.ietf.org/doc/html/rfc7518#section-6.4>
55	#[serde(rename = "oct")]
56	OCT {
57		/// The secret key as base64url (unpadded).
58		#[serde(
59			rename = "k",
60			default,
61			serialize_with = "serialize_base64url",
62			deserialize_with = "deserialize_base64url"
63		)]
64		secret: Vec<u8>,
65	},
66	/// <https://datatracker.ietf.org/doc/html/rfc8037#section-2>
67	OKP {
68		#[serde(rename = "crv")]
69		curve: EllipticCurve,
70		#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
71		x: Vec<u8>,
72		#[serde(
73			rename = "d",
74			default,
75			skip_serializing_if = "Option::is_none",
76			serialize_with = "serialize_base64url_optional",
77			deserialize_with = "deserialize_base64url_optional"
78		)]
79		d: Option<Vec<u8>>,
80	},
81}
82
83/// Supported elliptic curves for EC and OKP key types.
84///
85/// See <https://datatracker.ietf.org/doc/html/rfc7518#section-6.2.1.1>
86#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Debug)]
87pub enum EllipticCurve {
88	#[serde(rename = "P-256")]
89	P256,
90	#[serde(rename = "P-384")]
91	P384,
92	// jsonwebtoken doesn't support the ES512 algorithm, so we can't implement this
93	// #[serde(rename = "P-521")]
94	// P521,
95	#[serde(rename = "Ed25519")]
96	Ed25519,
97}
98
99/// RSA public key parameters.
100///
101/// See <https://datatracker.ietf.org/doc/html/rfc7518#section-6.3.1>
102#[derive(Clone, Serialize, Deserialize)]
103pub struct RsaPublicKey {
104	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
105	pub n: Vec<u8>,
106	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
107	pub e: Vec<u8>,
108}
109
110/// RSA private key parameters.
111///
112/// See <https://datatracker.ietf.org/doc/html/rfc7518#section-6.3.2>
113#[derive(Clone, Serialize, Deserialize)]
114pub struct RsaPrivateKey {
115	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
116	pub d: Vec<u8>,
117	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
118	pub p: Vec<u8>,
119	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
120	pub q: Vec<u8>,
121	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
122	pub dp: Vec<u8>,
123	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
124	pub dq: Vec<u8>,
125	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
126	pub qi: Vec<u8>,
127	#[serde(skip_serializing_if = "Option::is_none")]
128	pub oth: Option<Vec<RsaAdditionalPrime>>,
129}
130
131/// Additional prime information for multi-prime RSA keys.
132#[derive(Clone, Serialize, Deserialize)]
133pub struct RsaAdditionalPrime {
134	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
135	pub r: Vec<u8>,
136	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
137	pub d: Vec<u8>,
138	#[serde(serialize_with = "serialize_base64url", deserialize_with = "deserialize_base64url")]
139	pub t: Vec<u8>,
140}
141
142/// JWK, almost to spec (<https://datatracker.ietf.org/doc/html/rfc7517>) but not quite the same
143/// because it's annoying to implement.
144#[derive(Clone, Serialize, Deserialize)]
145#[serde(remote = "Self")]
146pub struct Key {
147	/// The algorithm used by the key.
148	#[serde(rename = "alg")]
149	pub algorithm: Algorithm,
150
151	/// The operations that the key can perform.
152	#[serde(rename = "key_ops")]
153	pub operations: HashSet<KeyOperation>,
154
155	/// Defaults to KeyType::OCT
156	#[serde(flatten)]
157	pub key: KeyType,
158
159	/// The key ID, useful for rotating keys.
160	#[serde(skip_serializing_if = "Option::is_none")]
161	pub kid: Option<String>,
162
163	// Cached for performance reasons, unfortunately.
164	#[serde(skip)]
165	pub(crate) decode: OnceLock<DecodingKey>,
166
167	#[serde(skip)]
168	pub(crate) encode: OnceLock<EncodingKey>,
169}
170
171impl<'de> Deserialize<'de> for Key {
172	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173	where
174		D: Deserializer<'de>,
175	{
176		let mut value = serde_json::Value::deserialize(deserializer)?;
177
178		// Normally the "kty" parameter is required in a JWK: https://datatracker.ietf.org/doc/html/rfc7517#section-4.1
179		// But for backwards compatibility we need to default to "oct" because in a previous
180		// implementation the parameter was omitted, and we want to keep previously generated tokens valid
181		if let Some(obj) = value.as_object_mut()
182			&& !obj.contains_key("kty")
183		{
184			obj.insert("kty".to_string(), serde_json::Value::String("oct".to_string()));
185		}
186
187		Self::deserialize(value).map_err(serde::de::Error::custom)
188	}
189}
190
191impl Serialize for Key {
192	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193	where
194		S: Serializer,
195	{
196		Self::serialize(self, serializer)
197	}
198}
199
200impl fmt::Debug for Key {
201	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202		f.debug_struct("Key")
203			.field("algorithm", &self.algorithm)
204			.field("operations", &self.operations)
205			.field("kid", &self.kid)
206			.finish()
207	}
208}
209
210impl Key {
211	#[allow(clippy::should_implement_trait)]
212	pub fn from_str(s: &str) -> anyhow::Result<Self> {
213		Ok(serde_json::from_str(s)?)
214	}
215
216	pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
217		let contents = std::fs::read_to_string(&path)?;
218		// It's base64url encoded
219		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
220		let json = String::from_utf8(decoded)?;
221		Ok(serde_json::from_str(&json)?)
222	}
223
224	pub fn to_str(&self) -> anyhow::Result<String> {
225		Ok(serde_json::to_string(self)?)
226	}
227
228	pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
229		// Serialize to JSON first
230		let json = serde_json::to_string(self)?;
231		// Then encode as base64url
232		let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
233		std::fs::write(path, encoded)?;
234		Ok(())
235	}
236
237	pub fn to_public(&self) -> anyhow::Result<Self> {
238		if !self.operations.contains(&KeyOperation::Verify) {
239			return Err(anyhow::anyhow!("This key doesn't support the Verify operation"));
240		}
241
242		let key = match self.key {
243			KeyType::RSA { ref public, .. } => Ok(KeyType::RSA {
244				public: public.clone(),
245				private: None,
246			}),
247			KeyType::EC {
248				ref x,
249				ref y,
250				ref curve,
251				..
252			} => Ok(KeyType::EC {
253				x: x.clone(),
254				y: y.clone(),
255				curve: curve.clone(),
256				d: None,
257			}),
258			KeyType::OCT { .. } => Err(anyhow::anyhow!("OCT key cannot be converted to public key")),
259			KeyType::OKP { ref x, ref curve, .. } => Ok(KeyType::OKP {
260				x: x.clone(),
261				curve: curve.clone(),
262				d: None,
263			}),
264		};
265
266		match key {
267			Ok(key) => Ok(Self {
268				algorithm: self.algorithm,
269				operations: [KeyOperation::Verify].into(),
270				key,
271				kid: self.kid.clone(),
272				decode: Default::default(),
273				encode: Default::default(),
274			}),
275			Err(err) => Err(anyhow::anyhow!("Failed to convert key: {}", err)),
276		}
277	}
278
279	fn to_decoding_key(&self) -> anyhow::Result<&DecodingKey> {
280		if let Some(key) = self.decode.get() {
281			return Ok(key);
282		}
283
284		let decoding_key = match self.key {
285			KeyType::OCT { ref secret } => match self.algorithm {
286				Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(secret),
287				_ => bail!("Invalid algorithm for key type OCT"),
288			},
289			KeyType::EC {
290				ref curve,
291				ref x,
292				ref y,
293				..
294			} => match curve {
295				EllipticCurve::P256 => {
296					if self.algorithm != Algorithm::ES256 {
297						bail!("Invalid algorithm for P-256 curve");
298					}
299					if x.len() != 32 || y.len() != 32 {
300						bail!("Invalid coordinate length for P-256");
301					}
302
303					DecodingKey::from_ec_components(
304						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
305						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
306					)?
307				}
308				EllipticCurve::P384 => {
309					if self.algorithm != Algorithm::ES384 {
310						bail!("Invalid algorithm for P-384 curve");
311					}
312					if x.len() != 48 || y.len() != 48 {
313						bail!("Invalid coordinate length for P-384");
314					}
315
316					DecodingKey::from_ec_components(
317						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
318						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(y).as_ref(),
319					)?
320				}
321				_ => bail!("Invalid curve for EC key"),
322			},
323			KeyType::OKP { ref curve, ref x, .. } => match curve {
324				EllipticCurve::Ed25519 => {
325					if self.algorithm != Algorithm::EdDSA {
326						bail!("Invalid algorithm for Ed25519 curve");
327					}
328
329					DecodingKey::from_ed_components(
330						base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(x).as_ref(),
331					)?
332				}
333				_ => bail!("Invalid curve for OKP key"),
334			},
335			KeyType::RSA { ref public, .. } => {
336				DecodingKey::from_rsa_raw_components(public.n.as_ref(), public.e.as_ref())
337			}
338		};
339
340		Ok(self.decode.get_or_init(|| decoding_key))
341	}
342
343	fn to_encoding_key(&self) -> anyhow::Result<&EncodingKey> {
344		if let Some(key) = self.encode.get() {
345			return Ok(key);
346		}
347
348		let encoding_key = match self.key {
349			KeyType::OCT { ref secret } => match self.algorithm {
350				Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(secret),
351				_ => bail!("Invalid algorithm for key type OCT"),
352			},
353			KeyType::EC { ref curve, ref d, .. } => {
354				let d = d.as_ref().context("Missing private key")?;
355
356				match curve {
357					EllipticCurve::P256 => {
358						let secret_key = SecretKey::<p256::NistP256>::from_slice(d)?;
359						let doc = secret_key.to_pkcs8_der()?;
360						EncodingKey::from_ec_der(doc.as_bytes())
361					}
362					EllipticCurve::P384 => {
363						let secret_key = SecretKey::<p384::NistP384>::from_slice(d)?;
364						let doc = secret_key.to_pkcs8_der()?;
365						EncodingKey::from_ec_der(doc.as_bytes())
366					}
367					_ => bail!("Invalid curve for EC key"),
368				}
369			}
370			KeyType::OKP {
371				ref curve,
372				ref d,
373				ref x,
374			} => {
375				let d = d.as_ref().context("Missing private key")?;
376
377				let key_pair =
378					aws_lc_rs::signature::Ed25519KeyPair::from_seed_and_public_key(d.as_slice(), x.as_slice())?;
379
380				match curve {
381					EllipticCurve::Ed25519 => EncodingKey::from_ed_der(key_pair.to_pkcs8()?.as_ref()),
382					_ => bail!("Invalid curve for OKP key"),
383				}
384			}
385			KeyType::RSA {
386				ref public,
387				ref private,
388			} => {
389				let n = BigUint::from_bytes_be(&public.n);
390				let e = BigUint::from_bytes_be(&public.e);
391				let private = private.as_ref().context("Missing private key")?;
392				let d = BigUint::from_bytes_be(&private.d);
393				let p = BigUint::from_bytes_be(&private.p);
394				let q = BigUint::from_bytes_be(&private.q);
395
396				let rsa = rsa::RsaPrivateKey::from_components(n, e, d, vec![p, q]);
397				let pem = rsa?.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF);
398
399				EncodingKey::from_rsa_pem(pem?.as_bytes())?
400			}
401		};
402
403		Ok(self.encode.get_or_init(|| encoding_key))
404	}
405
406	pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
407		if !self.operations.contains(&KeyOperation::Verify) {
408			bail!("key does not support verification");
409		}
410
411		let decode: anyhow::Result<&DecodingKey> = self.to_decoding_key();
412
413		match decode {
414			Ok(decode) => {
415				let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
416				validation.required_spec_claims = Default::default(); // Don't require exp, but still validate it if present
417				validation.validate_exp = false; // We validate exp ourselves to handle null values
418
419				let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
420
421				if let Some(exp) = token.claims.expires
422					&& exp < std::time::SystemTime::now()
423				{
424					anyhow::bail!("token has expired");
425				}
426
427				token.claims.validate()?;
428
429				Ok(token.claims)
430			}
431			Err(e) => Err(anyhow::anyhow!("Failed to decode key: {}", e)),
432		}
433	}
434
435	pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
436		if !self.operations.contains(&KeyOperation::Sign) {
437			bail!("key does not support signing");
438		}
439
440		payload.validate()?;
441
442		let encode: anyhow::Result<&EncodingKey> = self.to_encoding_key();
443
444		match encode {
445			Ok(encode) => {
446				let mut header = Header::new(self.algorithm.into());
447				header.kid = self.kid.clone();
448				let token = jsonwebtoken::encode(&header, &payload, encode)?;
449				Ok(token)
450			}
451			Err(e) => Err(anyhow::anyhow!("Failed to encode key: {}", e)),
452		}
453	}
454
455	/// Generate a key pair for the given algorithm, returning the private and public keys.
456	pub fn generate(algorithm: Algorithm, id: Option<String>) -> anyhow::Result<Self> {
457		generate(algorithm, id)
458	}
459}
460
461/// Serialize bytes as base64url without padding
462fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
463where
464	S: Serializer,
465{
466	let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
467	serializer.serialize_str(&encoded)
468}
469
470fn serialize_base64url_optional<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
471where
472	S: Serializer,
473{
474	match bytes {
475		Some(b) => serialize_base64url(b, serializer),
476		None => serializer.serialize_none(),
477	}
478}
479
480/// Deserialize base64url string to bytes, supporting both padded and unpadded formats for backwards compatibility
481fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
482where
483	D: Deserializer<'de>,
484{
485	let s = String::deserialize(deserializer)?;
486
487	// Try to decode as unpadded base64url first (preferred format)
488	base64::engine::general_purpose::URL_SAFE_NO_PAD
489		.decode(&s)
490		.or_else(|_| {
491			// Fall back to padded base64url for backwards compatibility
492			base64::engine::general_purpose::URL_SAFE.decode(&s)
493		})
494		.map_err(serde::de::Error::custom)
495}
496
497fn deserialize_base64url_optional<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
498where
499	D: Deserializer<'de>,
500{
501	let s: Option<String> = Option::deserialize(deserializer)?;
502	match s {
503		Some(s) => {
504			let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
505				.decode(&s)
506				.or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(&s))
507				.map_err(serde::de::Error::custom)?;
508			Ok(Some(decoded))
509		}
510		None => Ok(None),
511	}
512}
513
514#[cfg(test)]
515mod tests {
516	use super::*;
517	use std::time::{Duration, SystemTime};
518
519	fn create_test_key() -> Key {
520		Key {
521			algorithm: Algorithm::HS256,
522			operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
523			key: KeyType::OCT {
524				secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
525			},
526			kid: Some("test-key-1".to_string()),
527			decode: Default::default(),
528			encode: Default::default(),
529		}
530	}
531
532	fn create_test_claims() -> Claims {
533		Claims {
534			root: "test-path".to_string(),
535			publish: vec!["test-pub".into()],
536			cluster: false,
537			subscribe: vec!["test-sub".into()],
538			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
539			issued: Some(SystemTime::now()),
540		}
541	}
542
543	#[test]
544	fn test_key_from_str_valid() {
545		let key = create_test_key();
546		let json = key.to_str().unwrap();
547		let loaded_key = Key::from_str(&json).unwrap();
548
549		assert_eq!(loaded_key.algorithm, key.algorithm);
550		assert_eq!(loaded_key.operations, key.operations);
551		match (loaded_key.key, key.key) {
552			(KeyType::OCT { secret: loaded_secret }, KeyType::OCT { secret }) => {
553				assert_eq!(loaded_secret, secret);
554			}
555			_ => panic!("Expected OCT key"),
556		}
557		assert_eq!(loaded_key.kid, key.kid);
558	}
559
560	/// Tests whether Key::from_str() works for keys without a kty value to fall back to OCT
561	#[test]
562	fn test_key_oct_backwards_compatibility() {
563		let json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68"}"#;
564		let key = Key::from_str(json);
565
566		assert!(key.is_ok());
567		let key = key.unwrap();
568
569		if let KeyType::OCT { ref secret, .. } = key.key {
570			let base64_key = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(secret);
571			assert_eq!(base64_key, "Fp8kipWUJeUFqeSqWym_tRC_tyI8z-QpqopIGrbrD68");
572		} else {
573			panic!("Expected OCT key");
574		}
575
576		let key_str = key.to_str();
577		assert!(key_str.is_ok());
578		let key_str = key_str.unwrap();
579
580		// After serializing again it must contain the kty
581		assert!(key_str.contains("\"alg\":\"HS256\""));
582		assert!(key_str.contains("\"key_ops\""));
583		assert!(key_str.contains("\"sign\""));
584		assert!(key_str.contains("\"verify\""));
585		assert!(key_str.contains("\"kty\":\"oct\""));
586	}
587
588	#[test]
589	fn test_key_from_str_invalid_json() {
590		let result = Key::from_str("invalid json");
591		assert!(result.is_err());
592	}
593
594	#[test]
595	fn test_key_to_str() {
596		let key = create_test_key();
597		let json = key.to_str().unwrap();
598		assert!(json.contains("\"alg\":\"HS256\""));
599		assert!(json.contains("\"key_ops\""));
600		assert!(json.contains("\"sign\""));
601		assert!(json.contains("\"verify\""));
602		assert!(json.contains("\"kid\":\"test-key-1\""));
603		assert!(json.contains("\"kty\":\"oct\""));
604	}
605
606	#[test]
607	fn test_key_sign_success() {
608		let key = create_test_key();
609		let claims = create_test_claims();
610		let token = key.encode(&claims).unwrap();
611
612		assert!(!token.is_empty());
613		assert_eq!(token.matches('.').count(), 2); // JWT format: header.payload.signature
614	}
615
616	#[test]
617	fn test_key_sign_no_permission() {
618		let mut key = create_test_key();
619		key.operations = [KeyOperation::Verify].into();
620		let claims = create_test_claims();
621
622		let result = key.encode(&claims);
623		assert!(result.is_err());
624		assert!(result.unwrap_err().to_string().contains("key does not support signing"));
625	}
626
627	#[test]
628	fn test_key_sign_invalid_claims() {
629		let key = create_test_key();
630		let invalid_claims = Claims {
631			root: "test-path".to_string(),
632			publish: vec![],
633			subscribe: vec![],
634			cluster: false,
635			expires: None,
636			issued: None,
637		};
638
639		let result = key.encode(&invalid_claims);
640		assert!(result.is_err());
641		assert!(
642			result
643				.unwrap_err()
644				.to_string()
645				.contains("no publish or subscribe allowed; token is useless")
646		);
647	}
648
649	#[test]
650	fn test_key_verify_success() {
651		let key = create_test_key();
652		let claims = create_test_claims();
653		let token = key.encode(&claims).unwrap();
654
655		let verified_claims = key.decode(&token).unwrap();
656		assert_eq!(verified_claims.root, claims.root);
657		assert_eq!(verified_claims.publish, claims.publish);
658		assert_eq!(verified_claims.subscribe, claims.subscribe);
659		assert_eq!(verified_claims.cluster, claims.cluster);
660	}
661
662	#[test]
663	fn test_key_verify_no_permission() {
664		let mut key = create_test_key();
665		key.operations = [KeyOperation::Sign].into();
666
667		let result = key.decode("some.jwt.token");
668		assert!(result.is_err());
669		assert!(
670			result
671				.unwrap_err()
672				.to_string()
673				.contains("key does not support verification")
674		);
675	}
676
677	#[test]
678	fn test_key_verify_invalid_token() {
679		let key = create_test_key();
680		let result = key.decode("invalid-token");
681		assert!(result.is_err());
682	}
683
684	#[test]
685	fn test_key_verify_path_mismatch() {
686		let key = create_test_key();
687		let claims = create_test_claims();
688		let token = key.encode(&claims).unwrap();
689
690		// This test was expecting a path mismatch error, but now decode succeeds
691		let result = key.decode(&token);
692		assert!(result.is_ok());
693	}
694
695	#[test]
696	fn test_key_verify_expired_token() {
697		let key = create_test_key();
698		let mut claims = create_test_claims();
699		claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); // 1 hour ago
700		let token = key.encode(&claims).unwrap();
701
702		let result = key.decode(&token);
703		assert!(result.is_err());
704	}
705
706	#[test]
707	fn test_key_verify_token_without_exp() {
708		let key = create_test_key();
709		let claims = Claims {
710			root: "test-path".to_string(),
711			publish: vec!["".to_string()],
712			subscribe: vec!["".to_string()],
713			cluster: false,
714			expires: None,
715			issued: None,
716		};
717		let token = key.encode(&claims).unwrap();
718
719		let verified_claims = key.decode(&token).unwrap();
720		assert_eq!(verified_claims.root, claims.root);
721		assert_eq!(verified_claims.publish, claims.publish);
722		assert_eq!(verified_claims.subscribe, claims.subscribe);
723		assert_eq!(verified_claims.expires, None);
724	}
725
726	#[test]
727	fn test_key_round_trip() {
728		let key = create_test_key();
729		let original_claims = Claims {
730			root: "test-path".to_string(),
731			publish: vec!["test-pub".into()],
732			subscribe: vec!["test-sub".into()],
733			cluster: true,
734			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
735			issued: Some(SystemTime::now()),
736		};
737
738		let token = key.encode(&original_claims).unwrap();
739		let verified_claims = key.decode(&token).unwrap();
740
741		assert_eq!(verified_claims.root, original_claims.root);
742		assert_eq!(verified_claims.publish, original_claims.publish);
743		assert_eq!(verified_claims.subscribe, original_claims.subscribe);
744		assert_eq!(verified_claims.cluster, original_claims.cluster);
745	}
746
747	#[test]
748	fn test_key_generate_hs256() {
749		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
750		assert!(key.is_ok());
751		let key = key.unwrap();
752
753		assert_eq!(key.algorithm, Algorithm::HS256);
754		assert_eq!(key.kid, Some("test-id".to_string()));
755		assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
756
757		match key.key {
758			KeyType::OCT { ref secret } => assert_eq!(secret.len(), 32),
759			_ => panic!("Expected OCT key"),
760		}
761	}
762
763	#[test]
764	fn test_key_generate_hs384() {
765		let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
766		assert!(key.is_ok());
767		let key = key.unwrap();
768
769		assert_eq!(key.algorithm, Algorithm::HS384);
770
771		match key.key {
772			KeyType::OCT { ref secret } => assert_eq!(secret.len(), 48),
773			_ => panic!("Expected OCT key"),
774		}
775	}
776
777	#[test]
778	fn test_key_generate_hs512() {
779		let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
780		assert!(key.is_ok());
781		let key = key.unwrap();
782
783		assert_eq!(key.algorithm, Algorithm::HS512);
784
785		match key.key {
786			KeyType::OCT { ref secret } => assert_eq!(secret.len(), 64),
787			_ => panic!("Expected OCT key"),
788		}
789	}
790
791	#[test]
792	fn test_key_generate_rs512() {
793		let key = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
794		assert!(key.is_ok());
795		let key = key.unwrap();
796
797		assert_eq!(key.algorithm, Algorithm::RS512);
798		assert!(matches!(key.key, KeyType::RSA { .. }));
799		match key.key {
800			KeyType::RSA {
801				ref public,
802				ref private,
803			} => {
804				assert!(private.is_some());
805				assert_eq!(public.n.len(), 256);
806				assert_eq!(public.e.len(), 3);
807			}
808			_ => panic!("Expected RSA key"),
809		}
810	}
811
812	#[test]
813	fn test_key_generate_es256() {
814		let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
815		assert!(key.is_ok());
816		let key = key.unwrap();
817
818		assert_eq!(key.algorithm, Algorithm::ES256);
819		assert!(matches!(key.key, KeyType::EC { .. }))
820	}
821
822	#[test]
823	fn test_key_generate_ps512() {
824		let key = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
825		assert!(key.is_ok());
826		let key = key.unwrap();
827
828		assert_eq!(key.algorithm, Algorithm::PS512);
829		assert!(matches!(key.key, KeyType::RSA { .. }));
830	}
831
832	#[test]
833	fn test_key_generate_eddsa() {
834		let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
835		assert!(key.is_ok());
836		let key = key.unwrap();
837
838		assert_eq!(key.algorithm, Algorithm::EdDSA);
839		assert!(matches!(key.key, KeyType::OKP { .. }));
840	}
841
842	#[test]
843	fn test_key_generate_without_id() {
844		let key = Key::generate(Algorithm::HS256, None);
845		assert!(key.is_ok());
846		let key = key.unwrap();
847
848		assert_eq!(key.algorithm, Algorithm::HS256);
849		assert_eq!(key.kid, None);
850		assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
851	}
852
853	#[test]
854	fn test_public_key_conversion_hmac() {
855		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string())).expect("HMAC key generation failed");
856
857		assert!(key.to_public().is_err());
858	}
859
860	#[test]
861	fn test_public_key_conversion_rsa() {
862		let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
863		assert!(key.is_ok());
864		let key = key.unwrap();
865
866		let public_key = key.to_public().unwrap();
867		assert_eq!(key.kid, public_key.kid);
868		assert_eq!(public_key.operations, [KeyOperation::Verify].into());
869		assert!(public_key.encode.get().is_none());
870		assert!(public_key.decode.get().is_none());
871		assert!(matches!(public_key.key, KeyType::RSA { .. }));
872
873		if let KeyType::RSA { public, private } = &public_key.key {
874			assert!(private.is_none());
875
876			if let KeyType::RSA { public: src_public, .. } = &key.key {
877				assert_eq!(public.e, src_public.e);
878				assert_eq!(public.n, src_public.n);
879			} else {
880				unreachable!("Expected RSA key")
881			}
882		} else {
883			unreachable!("Expected RSA key");
884		}
885	}
886
887	#[test]
888	fn test_public_key_conversion_es() {
889		let key = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
890		assert!(key.is_ok());
891		let key = key.unwrap();
892
893		let public_key = key.to_public().unwrap();
894		assert_eq!(key.kid, public_key.kid);
895		assert_eq!(public_key.operations, [KeyOperation::Verify].into());
896		assert!(public_key.encode.get().is_none());
897		assert!(public_key.decode.get().is_none());
898		assert!(matches!(public_key.key, KeyType::EC { .. }));
899
900		if let KeyType::EC { x, y, d, curve } = &public_key.key {
901			assert!(d.is_none());
902
903			if let KeyType::EC {
904				x: src_x,
905				y: src_y,
906				curve: src_curve,
907				..
908			} = &key.key
909			{
910				assert_eq!(x, src_x);
911				assert_eq!(y, src_y);
912				assert_eq!(curve, src_curve);
913			} else {
914				unreachable!("Expected EC key")
915			}
916		} else {
917			unreachable!("Expected EC key");
918		}
919	}
920
921	#[test]
922	fn test_public_key_conversion_ed() {
923		let key = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
924		assert!(key.is_ok());
925		let key = key.unwrap();
926
927		let public_key = key.to_public().unwrap();
928		assert_eq!(key.kid, public_key.kid);
929		assert_eq!(public_key.operations, [KeyOperation::Verify].into());
930		assert!(public_key.encode.get().is_none());
931		assert!(public_key.decode.get().is_none());
932		assert!(matches!(public_key.key, KeyType::OKP { .. }));
933
934		if let KeyType::OKP { x, d, curve } = &public_key.key {
935			assert!(d.is_none());
936
937			if let KeyType::OKP {
938				x: src_x,
939				curve: src_curve,
940				..
941			} = &key.key
942			{
943				assert_eq!(x, src_x);
944				assert_eq!(curve, src_curve);
945			} else {
946				unreachable!("Expected OKP key")
947			}
948		} else {
949			unreachable!("Expected OKP key");
950		}
951	}
952
953	#[test]
954	fn test_key_generate_sign_verify_cycle() {
955		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
956		assert!(key.is_ok());
957		let key = key.unwrap();
958
959		let claims = create_test_claims();
960
961		let token = key.encode(&claims).unwrap();
962		let verified_claims = key.decode(&token).unwrap();
963
964		assert_eq!(verified_claims.root, claims.root);
965		assert_eq!(verified_claims.publish, claims.publish);
966		assert_eq!(verified_claims.subscribe, claims.subscribe);
967		assert_eq!(verified_claims.cluster, claims.cluster);
968	}
969
970	#[test]
971	fn test_key_debug_no_secret() {
972		let key = create_test_key();
973		let debug_str = format!("{key:?}");
974
975		assert!(debug_str.contains("algorithm: HS256"));
976		assert!(debug_str.contains("operations"));
977		assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
978		assert!(!debug_str.contains("secret")); // Should not contain secret
979	}
980
981	#[test]
982	fn test_key_operations_enum() {
983		let sign_op = KeyOperation::Sign;
984		let verify_op = KeyOperation::Verify;
985		let decrypt_op = KeyOperation::Decrypt;
986		let encrypt_op = KeyOperation::Encrypt;
987
988		assert_eq!(sign_op, KeyOperation::Sign);
989		assert_eq!(verify_op, KeyOperation::Verify);
990		assert_eq!(decrypt_op, KeyOperation::Decrypt);
991		assert_eq!(encrypt_op, KeyOperation::Encrypt);
992
993		assert_ne!(sign_op, verify_op);
994		assert_ne!(decrypt_op, encrypt_op);
995	}
996
997	#[test]
998	fn test_key_operations_serde() {
999		let operations = [KeyOperation::Sign, KeyOperation::Verify];
1000		let json = serde_json::to_string(&operations).unwrap();
1001		assert!(json.contains("\"sign\""));
1002		assert!(json.contains("\"verify\""));
1003
1004		let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
1005		assert_eq!(deserialized, operations);
1006	}
1007
1008	#[test]
1009	fn test_key_serde() {
1010		let key = create_test_key();
1011		let json = serde_json::to_string(&key).unwrap();
1012		let deserialized: Key = serde_json::from_str(&json).unwrap();
1013
1014		assert_eq!(deserialized.algorithm, key.algorithm);
1015		assert_eq!(deserialized.operations, key.operations);
1016		assert_eq!(deserialized.kid, key.kid);
1017
1018		if let (
1019			KeyType::OCT {
1020				secret: original_secret,
1021			},
1022			KeyType::OCT {
1023				secret: deserialized_secret,
1024			},
1025		) = (&key.key, &deserialized.key)
1026		{
1027			assert_eq!(deserialized_secret, original_secret);
1028		} else {
1029			panic!("Expected both keys to be OCT variant");
1030		}
1031	}
1032
1033	#[test]
1034	fn test_key_clone() {
1035		let key = create_test_key();
1036		let cloned = key.clone();
1037
1038		assert_eq!(cloned.algorithm, key.algorithm);
1039		assert_eq!(cloned.operations, key.operations);
1040		assert_eq!(cloned.kid, key.kid);
1041
1042		if let (
1043			KeyType::OCT {
1044				secret: original_secret,
1045			},
1046			KeyType::OCT { secret: cloned_secret },
1047		) = (&key.key, &cloned.key)
1048		{
1049			assert_eq!(cloned_secret, original_secret);
1050		} else {
1051			panic!("Expected both keys to be OCT variant");
1052		}
1053	}
1054
1055	#[test]
1056	fn test_hmac_algorithms() {
1057		let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1058		let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1059		let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
1060
1061		let claims = create_test_claims();
1062
1063		// Test that each algorithm can sign and verify
1064		for key in [key_256, key_384, key_512] {
1065			assert!(key.is_ok());
1066			let key = key.unwrap();
1067
1068			let token = key.encode(&claims).unwrap();
1069			let verified_claims = key.decode(&token).unwrap();
1070			assert_eq!(verified_claims.root, claims.root);
1071		}
1072	}
1073
1074	#[test]
1075	fn test_rsa_pkcs1_asymmetric_algorithms() {
1076		let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1077		let key_rs384 = Key::generate(Algorithm::RS384, Some("test-id".to_string()));
1078		let key_rs512 = Key::generate(Algorithm::RS512, Some("test-id".to_string()));
1079
1080		for key in [key_rs256, key_rs384, key_rs512] {
1081			test_asymmetric_key(key);
1082		}
1083	}
1084
1085	#[test]
1086	fn test_rsa_pss_asymmetric_algorithms() {
1087		let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1088		let key_ps384 = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1089		let key_ps512 = Key::generate(Algorithm::PS512, Some("test-id".to_string()));
1090
1091		for key in [key_ps256, key_ps384, key_ps512] {
1092			test_asymmetric_key(key);
1093		}
1094	}
1095
1096	#[test]
1097	fn test_ec_asymmetric_algorithms() {
1098		let key_es256 = Key::generate(Algorithm::ES256, Some("test-id".to_string()));
1099		let key_es384 = Key::generate(Algorithm::ES384, Some("test-id".to_string()));
1100
1101		for key in [key_es256, key_es384] {
1102			test_asymmetric_key(key);
1103		}
1104	}
1105
1106	#[test]
1107	fn test_ed_asymmetric_algorithms() {
1108		let key_eddsa = Key::generate(Algorithm::EdDSA, Some("test-id".to_string()));
1109
1110		test_asymmetric_key(key_eddsa);
1111	}
1112
1113	fn test_asymmetric_key(key: anyhow::Result<Key>) {
1114		assert!(key.is_ok());
1115		let key = key.unwrap();
1116
1117		let claims = create_test_claims();
1118		let token = key.encode(&claims).unwrap();
1119
1120		let private_verified_claims = key.decode(&token).unwrap();
1121		assert_eq!(
1122			private_verified_claims.root, claims.root,
1123			"validation using private key"
1124		);
1125
1126		let public_verified_claims = key.to_public().unwrap().decode(&token).unwrap();
1127		assert_eq!(public_verified_claims.root, claims.root, "validation using public key");
1128	}
1129
1130	#[test]
1131	fn test_cross_algorithm_verification_fails() {
1132		let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
1133		assert!(key_256.is_ok());
1134		let key_256 = key_256.unwrap();
1135
1136		let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
1137		assert!(key_384.is_ok());
1138		let key_384 = key_384.unwrap();
1139
1140		let claims = create_test_claims();
1141		let token = key_256.encode(&claims).unwrap();
1142
1143		// Different algorithm should fail verification
1144		let result = key_384.decode(&token);
1145		assert!(result.is_err());
1146	}
1147
1148	#[test]
1149	fn test_asymmetric_cross_algorithm_verification_fails() {
1150		let key_rs256 = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1151		assert!(key_rs256.is_ok());
1152		let key_rs256 = key_rs256.unwrap();
1153
1154		let key_ps256 = Key::generate(Algorithm::PS256, Some("test-id".to_string()));
1155		assert!(key_ps256.is_ok());
1156		let key_ps256 = key_ps256.unwrap();
1157
1158		let claims = create_test_claims();
1159		let token = key_rs256.encode(&claims).unwrap();
1160
1161		// Different algorithm should fail verification
1162		let private_result = key_ps256.decode(&token);
1163		let public_result = key_ps256.to_public().unwrap().decode(&token);
1164		assert!(private_result.is_err());
1165		assert!(public_result.is_err());
1166	}
1167
1168	#[test]
1169	fn test_rsa_pkcs1_public_key_conversion() {
1170		let key = Key::generate(Algorithm::RS256, Some("test-id".to_string()));
1171		assert!(key.is_ok());
1172		let key = key.unwrap();
1173
1174		assert!(key.operations.contains(&KeyOperation::Sign));
1175		assert!(key.operations.contains(&KeyOperation::Verify));
1176
1177		let public_key = key.to_public().unwrap();
1178		assert!(!public_key.operations.contains(&KeyOperation::Sign));
1179		assert!(public_key.operations.contains(&KeyOperation::Verify));
1180
1181		match key.key {
1182			KeyType::RSA {
1183				ref public,
1184				ref private,
1185			} => {
1186				assert!(private.is_some());
1187				assert_eq!(public.n.len(), 256);
1188				assert_eq!(public.e.len(), 3);
1189
1190				match public_key.key {
1191					KeyType::RSA {
1192						public: ref public_public,
1193						private: ref public_private,
1194					} => {
1195						assert!(public_private.is_none());
1196						assert_eq!(public.n, public_public.n);
1197						assert_eq!(public.e, public_public.e);
1198					}
1199					_ => panic!("Expected public key to be an RSA key"),
1200				}
1201			}
1202			_ => panic!("Expected private key to be an RSA key"),
1203		}
1204	}
1205
1206	#[test]
1207	fn test_rsa_pss_public_key_conversion() {
1208		let key = Key::generate(Algorithm::PS384, Some("test-id".to_string()));
1209		assert!(key.is_ok());
1210		let key = key.unwrap();
1211
1212		assert!(key.operations.contains(&KeyOperation::Sign));
1213		assert!(key.operations.contains(&KeyOperation::Verify));
1214
1215		let public_key = key.to_public().unwrap();
1216		assert!(!public_key.operations.contains(&KeyOperation::Sign));
1217		assert!(public_key.operations.contains(&KeyOperation::Verify));
1218
1219		match key.key {
1220			KeyType::RSA {
1221				ref public,
1222				ref private,
1223			} => {
1224				assert!(private.is_some());
1225				assert_eq!(public.n.len(), 256);
1226				assert_eq!(public.e.len(), 3);
1227
1228				match public_key.key {
1229					KeyType::RSA {
1230						public: ref public_public,
1231						private: ref public_private,
1232					} => {
1233						assert!(public_private.is_none());
1234						assert_eq!(public.n, public_public.n);
1235						assert_eq!(public.e, public_public.e);
1236					}
1237					_ => panic!("Expected public key to be an RSA key"),
1238				}
1239			}
1240			_ => panic!("Expected private key to be an RSA key"),
1241		}
1242	}
1243
1244	#[test]
1245	fn test_base64url_serialization() {
1246		let key = create_test_key();
1247		let json = serde_json::to_string(&key).unwrap();
1248
1249		// Check that the secret is base64url encoded without padding
1250		let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1251		let k_value = parsed["k"].as_str().unwrap();
1252
1253		// Base64url should not contain padding characters
1254		assert!(!k_value.contains('='));
1255		assert!(!k_value.contains('+'));
1256		assert!(!k_value.contains('/'));
1257
1258		// Verify it decodes correctly
1259		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1260			.decode(k_value)
1261			.unwrap();
1262
1263		if let KeyType::OCT {
1264			secret: original_secret,
1265		} = &key.key
1266		{
1267			assert_eq!(decoded, *original_secret);
1268		} else {
1269			panic!("Expected both keys to be OCT variant");
1270		}
1271	}
1272
1273	#[test]
1274	fn test_backwards_compatibility_unpadded_base64url() {
1275		// Create a JSON with unpadded base64url (new format)
1276		let unpadded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
1277
1278		// Should be able to deserialize new format
1279		let key: Key = serde_json::from_str(unpadded_json).unwrap();
1280		assert_eq!(key.algorithm, Algorithm::HS256);
1281		assert_eq!(key.kid, Some("test-key-1".to_string()));
1282
1283		if let KeyType::OCT { secret } = &key.key {
1284			assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1285		} else {
1286			panic!("Expected key to be OCT variant");
1287		}
1288	}
1289
1290	#[test]
1291	fn test_backwards_compatibility_padded_base64url() {
1292		// Create a JSON with padded base64url (old format) - same secret but with padding
1293		let padded_json = r#"{"kty":"oct","alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
1294
1295		// Should be able to deserialize old format for backwards compatibility
1296		let key: Key = serde_json::from_str(padded_json).unwrap();
1297		assert_eq!(key.algorithm, Algorithm::HS256);
1298		assert_eq!(key.kid, Some("test-key-1".to_string()));
1299
1300		if let KeyType::OCT { secret } = &key.key {
1301			assert_eq!(secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
1302		} else {
1303			panic!("Expected key to be OCT variant");
1304		}
1305	}
1306
1307	#[test]
1308	fn test_file_io_base64url() {
1309		let key = create_test_key();
1310		let temp_dir = std::env::temp_dir();
1311		let temp_path = temp_dir.join("test_jwk.key");
1312
1313		// Write key to file
1314		key.to_file(&temp_path).unwrap();
1315
1316		// Read file contents
1317		let contents = std::fs::read_to_string(&temp_path).unwrap();
1318
1319		// Should be base64url encoded
1320		assert!(!contents.contains('{'));
1321		assert!(!contents.contains('}'));
1322		assert!(!contents.contains('"'));
1323
1324		// Decode and verify it's valid JSON
1325		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
1326			.decode(&contents)
1327			.unwrap();
1328		let json_str = String::from_utf8(decoded).unwrap();
1329		let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1330
1331		// Read key back from file
1332		let loaded_key = Key::from_file(&temp_path).unwrap();
1333		assert_eq!(loaded_key.algorithm, key.algorithm);
1334		assert_eq!(loaded_key.operations, key.operations);
1335		assert_eq!(loaded_key.kid, key.kid);
1336
1337		if let (
1338			KeyType::OCT {
1339				secret: original_secret,
1340			},
1341			KeyType::OCT { secret: loaded_secret },
1342		) = (&key.key, &loaded_key.key)
1343		{
1344			assert_eq!(loaded_secret, original_secret);
1345		} else {
1346			panic!("Expected both keys to be OCT variant");
1347		}
1348
1349		// Clean up
1350		std::fs::remove_file(temp_path).ok();
1351	}
1352}