moq_token/
key.rs

1use std::{collections::HashSet, fmt, path::Path as StdPath, sync::OnceLock};
2
3use base64::Engine;
4use jsonwebtoken::{DecodingKey, EncodingKey, Header};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6
7use crate::{Algorithm, Claims};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
10#[serde(rename_all = "camelCase")]
11pub enum KeyOperation {
12	Sign,
13	Verify,
14	Decrypt,
15	Encrypt,
16}
17
18/// Similar to JWK but not quite the same because it's annoying to implement.
19#[derive(Clone, Serialize, Deserialize)]
20pub struct Key {
21	/// The algorithm used by the key.
22	#[serde(rename = "alg")]
23	pub algorithm: Algorithm,
24
25	/// The operations that the key can perform.
26	#[serde(rename = "key_ops")]
27	pub operations: HashSet<KeyOperation>,
28
29	/// The secret key as base64url (unpadded).
30	#[serde(
31		rename = "k",
32		serialize_with = "serialize_base64url",
33		deserialize_with = "deserialize_base64url"
34	)]
35	pub secret: Vec<u8>,
36
37	/// The key ID, useful for rotating keys.
38	#[serde(skip_serializing_if = "Option::is_none")]
39	pub kid: Option<String>,
40
41	// Cached for performance reasons, unfortunately.
42	#[serde(skip)]
43	pub(crate) decode: OnceLock<DecodingKey>,
44
45	#[serde(skip)]
46	pub(crate) encode: OnceLock<EncodingKey>,
47}
48
49impl fmt::Debug for Key {
50	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51		f.debug_struct("Key")
52			.field("algorithm", &self.algorithm)
53			.field("operations", &self.operations)
54			.field("kid", &self.kid)
55			.finish()
56	}
57}
58
59impl Key {
60	#[allow(clippy::should_implement_trait)]
61	pub fn from_str(s: &str) -> anyhow::Result<Self> {
62		Ok(serde_json::from_str(s)?)
63	}
64
65	pub fn from_file<P: AsRef<StdPath>>(path: P) -> anyhow::Result<Self> {
66		let contents = std::fs::read_to_string(&path)?;
67		// It's base64url encoded
68		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(contents.trim())?;
69		let json = String::from_utf8(decoded)?;
70		Ok(serde_json::from_str(&json)?)
71	}
72
73	pub fn to_str(&self) -> anyhow::Result<String> {
74		Ok(serde_json::to_string(self)?)
75	}
76
77	pub fn to_file<P: AsRef<StdPath>>(&self, path: P) -> anyhow::Result<()> {
78		// Serialize to JSON first
79		let json = serde_json::to_string(self)?;
80		// Then encode as base64url
81		let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes());
82		std::fs::write(path, encoded)?;
83		Ok(())
84	}
85
86	pub fn decode(&self, token: &str) -> anyhow::Result<Claims> {
87		if !self.operations.contains(&KeyOperation::Verify) {
88			anyhow::bail!("key does not support verification");
89		}
90
91		let decode = self.decode.get_or_init(|| match self.algorithm {
92			Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => DecodingKey::from_secret(&self.secret),
93			/*
94			Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => DecodingKey::from_rsa_der(&self.der),
95			Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => DecodingKey::from_rsa_der(&self.der),
96			Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_der(&self.der),
97			Algorithm::EdDSA => DecodingKey::from_ed_der(&self.der),
98			*/
99		});
100
101		let mut validation = jsonwebtoken::Validation::new(self.algorithm.into());
102		validation.required_spec_claims = Default::default(); // Don't require exp, but still validate it if present
103
104		let token = jsonwebtoken::decode::<Claims>(token, decode, &validation)?;
105		token.claims.validate()?;
106
107		Ok(token.claims)
108	}
109
110	pub fn encode(&self, payload: &Claims) -> anyhow::Result<String> {
111		if !self.operations.contains(&KeyOperation::Sign) {
112			anyhow::bail!("key does not support signing");
113		}
114
115		payload.validate()?;
116
117		let encode = self.encode.get_or_init(|| match self.algorithm {
118			Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => EncodingKey::from_secret(&self.secret),
119			/*
120			Algorithm::PS256 | Algorithm::PS384 | Algorithm::PS512 => EncodingKey::from_rsa_der(&self.der),
121			Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => EncodingKey::from_rsa_der(&self.der),
122			Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_der(&self.der),
123			Algorithm::EdDSA => EncodingKey::from_ed_der(&self.der),
124			*/
125		});
126
127		let mut header = Header::new(self.algorithm.into());
128		header.kid = self.kid.clone();
129		let token = jsonwebtoken::encode(&header, &payload, encode)?;
130		Ok(token)
131	}
132
133	/// Generate a key pair for the given algorithm, returning the private and public keys.
134	pub fn generate(algorithm: Algorithm, id: Option<String>) -> Self {
135		let private_key = match algorithm {
136			Algorithm::HS256 => generate_hmac_key::<32>(),
137			Algorithm::HS384 => generate_hmac_key::<48>(),
138			Algorithm::HS512 => generate_hmac_key::<64>(),
139			/*
140			Algorithm::RS256 => generate_rsa_key(rsa::KeySize::Rsa2048),
141			Algorithm::RS384 => generate_rsa_key(rsa::KeySize::Rsa2048),
142			Algorithm::RS512 => generate_rsa_key(rsa::KeySize::Rsa2048),
143			Algorithm::ES256 => generate_ec_key(&signature::ECDSA_P256_SHA256_FIXED_SIGNING),
144			Algorithm::ES384 => generate_ec_key(&signature::ECDSA_P384_SHA384_FIXED_SIGNING),
145			Algorithm::PS256 => generate_rsa_key(rsa::KeySize::Rsa2048),
146			Algorithm::PS384 => generate_rsa_key(rsa::KeySize::Rsa2048),
147			Algorithm::PS512 => generate_rsa_key(rsa::KeySize::Rsa2048),
148			Algorithm::EdDSA => generate_ed25519_key(),
149			*/
150		};
151
152		Key {
153			kid: id.clone(),
154			operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
155			algorithm,
156			secret: private_key,
157			decode: Default::default(),
158			encode: Default::default(),
159		}
160
161		/*
162		let public_key = Key {
163			kid: id,
164			operations: [KeyOperation::Verify].into(),
165			algorithm,
166			der: public_key,
167			decode: Default::default(),
168			encode: Default::default(),
169		};
170
171		(private_key, public_key)
172		*/
173	}
174}
175
176/// Serialize bytes as base64url without padding
177fn serialize_base64url<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
178where
179	S: Serializer,
180{
181	let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
182	serializer.serialize_str(&encoded)
183}
184
185/// Deserialize base64url string to bytes, supporting both padded and unpadded formats for backwards compatibility
186fn deserialize_base64url<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
187where
188	D: Deserializer<'de>,
189{
190	let s = String::deserialize(deserializer)?;
191
192	// Try to decode as unpadded base64url first (preferred format)
193	base64::engine::general_purpose::URL_SAFE_NO_PAD
194		.decode(&s)
195		.or_else(|_| {
196			// Fall back to padded base64url for backwards compatibility
197			base64::engine::general_purpose::URL_SAFE.decode(&s)
198		})
199		.map_err(serde::de::Error::custom)
200}
201
202fn generate_hmac_key<const SIZE: usize>() -> Vec<u8> {
203	let mut key = [0u8; SIZE];
204	aws_lc_rs::rand::fill(&mut key).unwrap();
205	key.to_vec()
206}
207
208#[cfg(test)]
209mod tests {
210	use super::*;
211	use std::time::{Duration, SystemTime};
212
213	fn create_test_key() -> Key {
214		Key {
215			algorithm: Algorithm::HS256,
216			operations: [KeyOperation::Sign, KeyOperation::Verify].into(),
217			secret: b"test-secret-that-is-long-enough-for-hmac-sha256".to_vec(),
218			kid: Some("test-key-1".to_string()),
219			decode: Default::default(),
220			encode: Default::default(),
221		}
222	}
223
224	fn create_test_claims() -> Claims {
225		Claims {
226			root: "test-path".to_string(),
227			publish: vec!["test-pub".into()],
228			cluster: false,
229			subscribe: vec!["test-sub".into()],
230			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
231			issued: Some(SystemTime::now()),
232		}
233	}
234
235	#[test]
236	fn test_key_from_str_valid() {
237		let key = create_test_key();
238		let json = key.to_str().unwrap();
239		let loaded_key = Key::from_str(&json).unwrap();
240
241		assert_eq!(loaded_key.algorithm, key.algorithm);
242		assert_eq!(loaded_key.operations, key.operations);
243		assert_eq!(loaded_key.secret, key.secret);
244		assert_eq!(loaded_key.kid, key.kid);
245	}
246
247	#[test]
248	fn test_key_from_str_invalid_json() {
249		let result = Key::from_str("invalid json");
250		assert!(result.is_err());
251	}
252
253	#[test]
254	fn test_key_to_str() {
255		let key = create_test_key();
256		let json = key.to_str().unwrap();
257		assert!(json.contains("\"alg\":\"HS256\""));
258		assert!(json.contains("\"key_ops\""));
259		assert!(json.contains("\"sign\""));
260		assert!(json.contains("\"verify\""));
261		assert!(json.contains("\"kid\":\"test-key-1\""));
262	}
263
264	#[test]
265	fn test_key_sign_success() {
266		let key = create_test_key();
267		let claims = create_test_claims();
268		let token = key.encode(&claims).unwrap();
269
270		assert!(!token.is_empty());
271		assert_eq!(token.matches('.').count(), 2); // JWT format: header.payload.signature
272	}
273
274	#[test]
275	fn test_key_sign_no_permission() {
276		let mut key = create_test_key();
277		key.operations = [KeyOperation::Verify].into();
278		let claims = create_test_claims();
279
280		let result = key.encode(&claims);
281		assert!(result.is_err());
282		assert!(result.unwrap_err().to_string().contains("key does not support signing"));
283	}
284
285	#[test]
286	fn test_key_sign_invalid_claims() {
287		let key = create_test_key();
288		let invalid_claims = Claims {
289			root: "test-path".to_string(),
290			publish: vec![],
291			subscribe: vec![],
292			cluster: false,
293			expires: None,
294			issued: None,
295		};
296
297		let result = key.encode(&invalid_claims);
298		assert!(result.is_err());
299		assert!(result
300			.unwrap_err()
301			.to_string()
302			.contains("no publish or subscribe allowed; token is useless"));
303	}
304
305	#[test]
306	fn test_key_verify_success() {
307		let key = create_test_key();
308		let claims = create_test_claims();
309		let token = key.encode(&claims).unwrap();
310
311		let verified_claims = key.decode(&token).unwrap();
312		assert_eq!(verified_claims.root, claims.root);
313		assert_eq!(verified_claims.publish, claims.publish);
314		assert_eq!(verified_claims.subscribe, claims.subscribe);
315		assert_eq!(verified_claims.cluster, claims.cluster);
316	}
317
318	#[test]
319	fn test_key_verify_no_permission() {
320		let mut key = create_test_key();
321		key.operations = [KeyOperation::Sign].into();
322
323		let result = key.decode("some.jwt.token");
324		assert!(result.is_err());
325		assert!(result
326			.unwrap_err()
327			.to_string()
328			.contains("key does not support verification"));
329	}
330
331	#[test]
332	fn test_key_verify_invalid_token() {
333		let key = create_test_key();
334		let result = key.decode("invalid-token");
335		assert!(result.is_err());
336	}
337
338	#[test]
339	fn test_key_verify_path_mismatch() {
340		let key = create_test_key();
341		let claims = create_test_claims();
342		let token = key.encode(&claims).unwrap();
343
344		// This test was expecting a path mismatch error, but now decode succeeds
345		let result = key.decode(&token);
346		assert!(result.is_ok());
347	}
348
349	#[test]
350	fn test_key_verify_expired_token() {
351		let key = create_test_key();
352		let mut claims = create_test_claims();
353		claims.expires = Some(SystemTime::now() - Duration::from_secs(3600)); // 1 hour ago
354		let token = key.encode(&claims).unwrap();
355
356		let result = key.decode(&token);
357		assert!(result.is_err());
358	}
359
360	#[test]
361	fn test_key_verify_token_without_exp() {
362		let key = create_test_key();
363		let claims = Claims {
364			root: "test-path".to_string(),
365			publish: vec!["".to_string()],
366			subscribe: vec!["".to_string()],
367			cluster: false,
368			expires: None,
369			issued: None,
370		};
371		let token = key.encode(&claims).unwrap();
372
373		let verified_claims = key.decode(&token).unwrap();
374		assert_eq!(verified_claims.root, claims.root);
375		assert_eq!(verified_claims.publish, claims.publish);
376		assert_eq!(verified_claims.subscribe, claims.subscribe);
377		assert_eq!(verified_claims.expires, None);
378	}
379
380	#[test]
381	fn test_key_round_trip() {
382		let key = create_test_key();
383		let original_claims = Claims {
384			root: "test-path".to_string(),
385			publish: vec!["test-pub".into()],
386			subscribe: vec!["test-sub".into()],
387			cluster: true,
388			expires: Some(SystemTime::now() + Duration::from_secs(3600)),
389			issued: Some(SystemTime::now()),
390		};
391
392		let token = key.encode(&original_claims).unwrap();
393		let verified_claims = key.decode(&token).unwrap();
394
395		assert_eq!(verified_claims.root, original_claims.root);
396		assert_eq!(verified_claims.publish, original_claims.publish);
397		assert_eq!(verified_claims.subscribe, original_claims.subscribe);
398		assert_eq!(verified_claims.cluster, original_claims.cluster);
399	}
400
401	#[test]
402	fn test_key_generate_hs256() {
403		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
404		assert_eq!(key.algorithm, Algorithm::HS256);
405		assert_eq!(key.kid, Some("test-id".to_string()));
406		assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
407		assert_eq!(key.secret.len(), 32);
408	}
409
410	#[test]
411	fn test_key_generate_hs384() {
412		let key = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
413		assert_eq!(key.algorithm, Algorithm::HS384);
414		assert_eq!(key.secret.len(), 48);
415	}
416
417	#[test]
418	fn test_key_generate_hs512() {
419		let key = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
420		assert_eq!(key.algorithm, Algorithm::HS512);
421		assert_eq!(key.secret.len(), 64);
422	}
423
424	#[test]
425	fn test_key_generate_without_id() {
426		let key = Key::generate(Algorithm::HS256, None);
427		assert_eq!(key.algorithm, Algorithm::HS256);
428		assert_eq!(key.kid, None);
429		assert_eq!(key.operations, [KeyOperation::Sign, KeyOperation::Verify].into());
430	}
431
432	#[test]
433	fn test_key_generate_sign_verify_cycle() {
434		let key = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
435		let claims = create_test_claims();
436
437		let token = key.encode(&claims).unwrap();
438		let verified_claims = key.decode(&token).unwrap();
439
440		assert_eq!(verified_claims.root, claims.root);
441		assert_eq!(verified_claims.publish, claims.publish);
442		assert_eq!(verified_claims.subscribe, claims.subscribe);
443		assert_eq!(verified_claims.cluster, claims.cluster);
444	}
445
446	#[test]
447	fn test_key_debug_no_secret() {
448		let key = create_test_key();
449		let debug_str = format!("{key:?}");
450
451		assert!(debug_str.contains("algorithm: HS256"));
452		assert!(debug_str.contains("operations"));
453		assert!(debug_str.contains("kid: Some(\"test-key-1\")"));
454		assert!(!debug_str.contains("secret")); // Should not contain secret
455	}
456
457	#[test]
458	fn test_key_operations_enum() {
459		let sign_op = KeyOperation::Sign;
460		let verify_op = KeyOperation::Verify;
461		let decrypt_op = KeyOperation::Decrypt;
462		let encrypt_op = KeyOperation::Encrypt;
463
464		assert_eq!(sign_op, KeyOperation::Sign);
465		assert_eq!(verify_op, KeyOperation::Verify);
466		assert_eq!(decrypt_op, KeyOperation::Decrypt);
467		assert_eq!(encrypt_op, KeyOperation::Encrypt);
468
469		assert_ne!(sign_op, verify_op);
470		assert_ne!(decrypt_op, encrypt_op);
471	}
472
473	#[test]
474	fn test_key_operations_serde() {
475		let operations = [KeyOperation::Sign, KeyOperation::Verify];
476		let json = serde_json::to_string(&operations).unwrap();
477		assert!(json.contains("\"sign\""));
478		assert!(json.contains("\"verify\""));
479
480		let deserialized: Vec<KeyOperation> = serde_json::from_str(&json).unwrap();
481		assert_eq!(deserialized, operations);
482	}
483
484	#[test]
485	fn test_key_serde() {
486		let key = create_test_key();
487		let json = serde_json::to_string(&key).unwrap();
488		let deserialized: Key = serde_json::from_str(&json).unwrap();
489
490		assert_eq!(deserialized.algorithm, key.algorithm);
491		assert_eq!(deserialized.operations, key.operations);
492		assert_eq!(deserialized.secret, key.secret);
493		assert_eq!(deserialized.kid, key.kid);
494	}
495
496	#[test]
497	fn test_key_clone() {
498		let key = create_test_key();
499		let cloned = key.clone();
500
501		assert_eq!(cloned.algorithm, key.algorithm);
502		assert_eq!(cloned.operations, key.operations);
503		assert_eq!(cloned.secret, key.secret);
504		assert_eq!(cloned.kid, key.kid);
505	}
506
507	#[test]
508	fn test_different_algorithms() {
509		let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
510		let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
511		let key_512 = Key::generate(Algorithm::HS512, Some("test-id".to_string()));
512
513		let claims = create_test_claims();
514
515		// Test that each algorithm can sign and verify
516		for key in [key_256, key_384, key_512] {
517			let token = key.encode(&claims).unwrap();
518			let verified_claims = key.decode(&token).unwrap();
519			assert_eq!(verified_claims.root, claims.root);
520		}
521	}
522
523	#[test]
524	fn test_cross_algorithm_verification_fails() {
525		let key_256 = Key::generate(Algorithm::HS256, Some("test-id".to_string()));
526		let key_384 = Key::generate(Algorithm::HS384, Some("test-id".to_string()));
527
528		let claims = create_test_claims();
529		let token = key_256.encode(&claims).unwrap();
530
531		// Different algorithm should fail verification
532		let result = key_384.decode(&token);
533		assert!(result.is_err());
534	}
535
536	#[test]
537	fn test_base64url_serialization() {
538		let key = create_test_key();
539		let json = serde_json::to_string(&key).unwrap();
540
541		// Check that the secret is base64url encoded without padding
542		let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
543		let k_value = parsed["k"].as_str().unwrap();
544
545		// Base64url should not contain padding characters
546		assert!(!k_value.contains('='));
547		assert!(!k_value.contains('+'));
548		assert!(!k_value.contains('/'));
549
550		// Verify it decodes correctly
551		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
552			.decode(k_value)
553			.unwrap();
554		assert_eq!(decoded, key.secret);
555	}
556
557	#[test]
558	fn test_backwards_compatibility_unpadded_base64url() {
559		// Create a JSON with unpadded base64url (new format)
560		let unpadded_json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY","kid":"test-key-1"}"#;
561
562		// Should be able to deserialize new format
563		let key: Key = serde_json::from_str(unpadded_json).unwrap();
564		assert_eq!(key.secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
565		assert_eq!(key.algorithm, Algorithm::HS256);
566		assert_eq!(key.kid, Some("test-key-1".to_string()));
567	}
568
569	#[test]
570	fn test_backwards_compatibility_padded_base64url() {
571		// Create a JSON with padded base64url (old format) - same secret but with padding
572		let padded_json = r#"{"alg":"HS256","key_ops":["sign","verify"],"k":"dGVzdC1zZWNyZXQtdGhhdC1pcy1sb25nLWVub3VnaC1mb3ItaG1hYy1zaGEyNTY=","kid":"test-key-1"}"#;
573
574		// Should be able to deserialize old format for backwards compatibility
575		let key: Key = serde_json::from_str(padded_json).unwrap();
576		assert_eq!(key.secret, b"test-secret-that-is-long-enough-for-hmac-sha256");
577		assert_eq!(key.algorithm, Algorithm::HS256);
578		assert_eq!(key.kid, Some("test-key-1".to_string()));
579	}
580
581	#[test]
582	fn test_file_io_base64url() {
583		let key = create_test_key();
584		let temp_dir = std::env::temp_dir();
585		let temp_path = temp_dir.join("test_jwk.key");
586
587		// Write key to file
588		key.to_file(&temp_path).unwrap();
589
590		// Read file contents
591		let contents = std::fs::read_to_string(&temp_path).unwrap();
592
593		// Should be base64url encoded
594		assert!(!contents.contains('{'));
595		assert!(!contents.contains('}'));
596		assert!(!contents.contains('"'));
597
598		// Decode and verify it's valid JSON
599		let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
600			.decode(&contents)
601			.unwrap();
602		let json_str = String::from_utf8(decoded).unwrap();
603		let _: serde_json::Value = serde_json::from_str(&json_str).unwrap();
604
605		// Read key back from file
606		let loaded_key = Key::from_file(&temp_path).unwrap();
607		assert_eq!(loaded_key.algorithm, key.algorithm);
608		assert_eq!(loaded_key.operations, key.operations);
609		assert_eq!(loaded_key.secret, key.secret);
610		assert_eq!(loaded_key.kid, key.kid);
611
612		// Clean up
613		std::fs::remove_file(temp_path).ok();
614	}
615}