1#[cfg(feature = "python_extension")]
2pub mod python;
3
4use std::fmt::Debug;
5
6mod bare_key;
7mod key;
8mod registry;
9mod sig;
10
11#[cfg(feature = "gel")]
12mod gel;
13#[cfg(feature = "gel")]
14pub use gel::{GelPrivateKeyRegistry, GelPublicKeyRegistry};
15
16pub use bare_key::{BareKey, BarePrivateKey, BarePublicKey};
17pub use key::{Key, KeyType, PrivateKey, PublicKey};
18pub use registry::KeyRegistry;
19pub use sig::{Any, SigningContext, ValidationContext, ValidationType};
20
21#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, Eq, PartialEq)]
22pub enum ValidationError {
23 #[display("Invalid token")]
25 Invalid(
26 #[from]
27 #[error(not(source))]
28 OpaqueValidationFailureReason,
29 ),
30 KeyError(#[from] KeyError),
32}
33
34impl ValidationError {
35 pub fn error_string_not_for_user(&self) -> String {
38 match self {
39 ValidationError::Invalid(OpaqueValidationFailureReason::Failure(s)) => {
40 format!("Invalid token: {s}")
41 }
42 ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue(
43 claim,
44 value,
45 )) => format!("Invalid claim value for {claim}: {value:?}"),
46 ValidationError::Invalid(OpaqueValidationFailureReason::InvalidHeader(
47 header,
48 value,
49 expected,
50 )) => format!("Invalid header {header}: {value:?}, expected {expected:?}"),
51 ValidationError::Invalid(OpaqueValidationFailureReason::NoAppropriateKey) => {
52 "No appropriate key found".to_string()
53 }
54 ValidationError::Invalid(OpaqueValidationFailureReason::InvalidSignature) => {
55 "Invalid signature".to_string()
56 }
57 ValidationError::KeyError(error) => format!("Key error: {error}"),
58 }
59 }
60}
61
62#[derive(Eq, PartialEq)]
65pub enum OpaqueValidationFailureReason {
66 NoAppropriateKey,
67 InvalidSignature,
68 InvalidClaimValue(String, Option<String>),
69 InvalidHeader(String, String, Option<String>),
70 Failure(String),
71}
72
73impl std::fmt::Debug for OpaqueValidationFailureReason {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "...")
76 }
77}
78
79#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug)]
80pub enum SignatureError {
81 #[display("Signature operation failed: {_0}")]
83 SignatureError(#[error(not(source))] String),
84 #[display("No appropriate signing key found")]
86 NoAppropriateKey,
87 KeyError(#[from] KeyError),
89}
90
91#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, Eq, PartialEq)]
92pub enum KeyError {
93 #[display("Invalid PEM format")]
94 InvalidPem,
95 #[display("Invalid JSON format")]
96 InvalidJson,
97 #[display("Unsupported key type: {_0}")]
98 UnsupportedKeyType(#[error(not(source))] String),
99 #[display("Invalid EC key parameters")]
100 InvalidEcParameters,
101 #[display("Failed to decode key")]
102 DecodeError,
103 #[display("Failed to encode key")]
104 EncodeError,
105 #[display("Failed to validate key pair: {_0:?}")]
106 KeyValidationError(#[from] KeyValidationError),
107}
108
109#[derive(derive_more::Error, derive_more::Display, Debug, Eq, PartialEq)]
110pub struct KeyValidationError(#[error(not(source))] String);
111
112#[cfg(test)]
113mod tests {
114 use std::{collections::HashMap, time::Duration};
115
116 use super::*;
117
118 #[test]
119 fn test_key_registry_add_remove() {
120 let mut registry = KeyRegistry::default();
121 registry.add_key(PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap());
122 registry.add_key(PrivateKey::generate(Some("2".to_owned()), KeyType::HS256).unwrap());
123 registry.add_key(PrivateKey::generate(Some("3".to_owned()), KeyType::HS256).unwrap());
124 assert_eq!(registry.len(), 3);
125 assert!(!registry.is_empty());
126 assert!(registry.remove_kid("1"));
127 assert_eq!(registry.len(), 2);
128 assert!(!registry.remove_kid("1"));
129 assert_eq!(registry.len(), 2);
130 assert!(registry.remove_kid("2"));
131 assert_eq!(registry.len(), 1);
132 assert!(!registry.remove_kid("2"));
133 assert_eq!(registry.len(), 1);
134 assert!(registry.remove_kid("3"));
135 assert_eq!(registry.len(), 0);
136 assert!(!registry.remove_kid("3"));
137 }
138
139 #[test]
140 fn test_key_registry_re_add() {
141 let mut registry = KeyRegistry::default();
142 let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
143
144 registry.add_key(key.clone_key());
145 assert_eq!(registry.len(), 1);
146 registry.add_key(key.clone_key());
147 assert_eq!(registry.len(), 1);
148 registry.remove_kid("1");
149 assert_eq!(registry.len(), 0);
150 registry.add_key(key);
151 assert_eq!(registry.len(), 1);
152 }
153
154 #[test]
155 fn test_key_registry_add_dupe_kid() {
156 let mut registry = KeyRegistry::default();
157 let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
158 registry.add_key(key.clone_key());
159 assert_eq!(registry.len(), 1);
160 registry.add_key(key.clone_key());
161 assert_eq!(registry.len(), 1);
162
163 let key2 = PrivateKey::generate(Some("1".to_owned()), KeyType::RS256).unwrap();
164 registry.add_key(key2.clone_key());
165 assert_eq!(registry.len(), 1);
166 registry.add_key(key2.clone_key());
167 assert_eq!(registry.len(), 1);
168 }
169
170 #[test]
171 fn test_sign() {
172 let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
173 let claims = HashMap::from([("hello".to_owned(), "world".into())]);
174 let signing_ctx = SigningContext {
175 expiry: Some(Duration::from_secs(600)),
176 issuer: Some("issuer".to_owned()),
177 audience: Some("audience".to_owned()),
178 ..Default::default()
179 };
180 let mut validation_ctx = ValidationContext::default();
181 validation_ctx.require_claim("aud");
182 validation_ctx.require_claim_with_allow_list("iss", &["issuer"]);
183
184 let token = key.sign(claims.clone(), &signing_ctx).unwrap();
185 println!("token: {token}");
186 let decoded = key.validate(&token, &validation_ctx).unwrap();
187 assert_eq!(decoded, claims);
188 }
189
190 #[test]
191 fn test_sign_no_expiry() {
192 let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
193 let claims = HashMap::from([("hello".to_owned(), "world".into())]);
194 let signing_ctx = SigningContext {
195 issuer: Some("issuer".to_owned()),
196 audience: Some("audience".to_owned()),
197 ..Default::default()
198 };
199 let token = key.sign(claims.clone(), &signing_ctx).unwrap();
200 let mut validation_ctx = ValidationContext::default();
201 validation_ctx.require_claim("aud");
202 validation_ctx.require_claim_with_allow_list("iss", &["issuer"]);
203 let decoded = key
204 .validate(&token, &validation_ctx)
205 .map_err(|e| e.error_string_not_for_user())
206 .unwrap();
207 assert_eq!(decoded, claims);
208 }
209
210 #[test]
211 fn load_from_empty() {
212 let mut registry = KeyRegistry::<PrivateKey>::default();
213 let added = registry.add_from_any("").unwrap();
214 assert_eq!(added, 0);
215 registry.add_from_pem("").unwrap();
216 assert_eq!(added, 0);
217 registry.add_from_jwkset("{\"keys\":[]}").unwrap();
218 assert_eq!(added, 0);
219 }
220
221 #[test]
222 fn test_google_jwkset() {
223 let mut registry = KeyRegistry::<Key>::default();
224 let added = registry
225 .add_from_jwkset(include_str!("testcases/jwkset-goog.json"))
226 .unwrap();
227 assert_eq!(added, 2);
228 }
229
230 #[test]
231 fn test_microsoft_jwkset() {
232 let mut registry = KeyRegistry::<Key>::default();
233 let added = registry
234 .add_from_jwkset(include_str!("testcases/jwkset-msft.json"))
235 .unwrap();
236 assert_eq!(added, 9);
237 }
238
239 #[test]
240 fn test_slack_jwkset() {
241 let mut registry = KeyRegistry::<Key>::default();
242 let added = registry
243 .add_from_jwkset(include_str!("testcases/jwkset-slck.json"))
244 .unwrap();
245 assert_eq!(added, 1);
246 }
247
248 #[test]
249 fn test_apple_jwkset() {
250 let mut registry = KeyRegistry::<Key>::default();
251 let added = registry
252 .add_from_jwkset(include_str!("testcases/jwkset-aapl.json"))
253 .unwrap();
254 assert_eq!(added, 3);
255 }
256
257 #[test]
258 fn load_keys_from_jwkset() {
259 let mut registry = KeyRegistry::<PrivateKey>::default();
260 let added = registry
261 .add_from_jwkset(include_str!("testcases/jwkset-pub.json"))
262 .unwrap();
263 assert_eq!(added, 0);
264 let mut registry = KeyRegistry::<PrivateKey>::default();
265 let added = registry
266 .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
267 .unwrap();
268 assert_eq!(added, 3);
269 }
270
271 #[test]
272 fn load_pub_keys_from_jwkset() {
273 let mut registry = KeyRegistry::<PublicKey>::default();
274 let added = registry
275 .add_from_jwkset(include_str!("testcases/jwkset-pub.json"))
276 .unwrap();
277 assert_eq!(added, 2);
278 let mut registry = KeyRegistry::<PublicKey>::default();
279 let added = registry
280 .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
281 .unwrap();
282 assert_eq!(added, 3);
283 }
284
285 #[test]
286 fn validate_tokens_from_jwkset() {
287 let mut registry = KeyRegistry::<PrivateKey>::default();
288 registry
289 .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
290 .unwrap();
291 let keys = registry.into_keys().collect::<Vec<_>>();
292
293 let mut registry = KeyRegistry::<PrivateKey>::default();
294 registry
295 .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
296 .unwrap();
297
298 let claims = HashMap::from([("test".to_owned(), "value".into())]);
299 let signing_ctx = SigningContext {
300 issuer: Some("test-issuer".to_owned()),
301 audience: Some("test-audience".to_owned()),
302 ..Default::default()
303 };
304 let mut validation_ctx = ValidationContext::default();
305 validation_ctx.require_claim_with_allow_list("iss", &["test-issuer"]);
306 validation_ctx.require_claim_with_allow_list("aud", &["test-audience"]);
307
308 for key in &keys {
310 let token = key.sign(claims.clone(), &signing_ctx).unwrap();
311 let decoded = registry.validate(&token, &validation_ctx).unwrap();
312 assert_eq!(decoded, claims);
313 }
314
315 let mut registry = KeyRegistry::<PublicKey>::default();
317 registry
318 .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
319 .unwrap();
320 for key in &keys {
321 let token = key.sign(claims.clone(), &signing_ctx).unwrap();
322 let decoded = registry.validate(&token, &validation_ctx).unwrap();
323 assert_eq!(decoded, claims);
324 }
325 }
326
327 #[test]
328 fn test_validate_tokens_from_jwkset_named() {
329 let mut key1 = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
330 let mut key2 = PrivateKey::generate(Some("2".to_owned()), KeyType::HS256).unwrap();
331
332 let claims = HashMap::from([("test".to_owned(), "value".into())]);
333 let signing_ctx = SigningContext {
334 issuer: Some("test-issuer".to_owned()),
335 audience: Some("test-audience".to_owned()),
336 ..Default::default()
337 };
338 let validation_ctx = ValidationContext::default();
339 let token = key1.sign(claims, &signing_ctx).unwrap();
340
341 key1.set_kid(Some("2".to_owned()));
343 key2.set_kid(Some("1".to_owned()));
344
345 let mut registry = KeyRegistry::<PrivateKey>::default();
346 registry.add_key(key1);
347 registry.add_key(key2);
348
349 let decoded = registry.validate(&token, &validation_ctx).unwrap_err();
350 assert_eq!(
351 decoded,
352 ValidationError::Invalid(OpaqueValidationFailureReason::InvalidSignature),
353 "{}",
354 decoded.error_string_not_for_user()
355 );
356 }
357
358 #[test]
359 fn test_validate_tokens_from_jwkset_named_allow_deny() {
360 let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
361 let mut registry = KeyRegistry::<PrivateKey>::default();
362 registry.add_key(key);
363
364 let claims = HashMap::from([("jti".to_owned(), "1234".into())]);
365 let signing_ctx = SigningContext::default();
366 let mut validation_ctx = ValidationContext::default();
367 let token = registry.sign(claims.clone(), &signing_ctx).unwrap();
368
369 let res = registry.validate(&token, &validation_ctx);
371 assert!(
372 res.is_ok(),
373 "{}",
374 res.unwrap_err().error_string_not_for_user()
375 );
376
377 validation_ctx.require_claim_with_allow_list("jti", &["1234"]);
378 let decoded = registry.validate(&token, &validation_ctx).unwrap();
379 assert_eq!(decoded, Default::default());
380
381 let claims = HashMap::from([("jti".to_owned(), "bad".into())]);
382 let token = registry.sign(claims, &signing_ctx).unwrap();
383 let decoded = registry.validate(&token, &validation_ctx).unwrap_err();
384 assert_eq!(
385 decoded,
386 ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue(
387 "jti".to_string(),
388 Some("bad".to_string())
389 ))
390 );
391
392 validation_ctx.require_claim_with_deny_list("jti", &["bad"]);
393 let decoded = registry.validate(&token, &validation_ctx).unwrap_err();
394 assert_eq!(
395 decoded,
396 ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue(
397 "jti".to_string(),
398 Some("bad".to_string())
399 ))
400 );
401 }
402
403 #[test]
404 fn test_any_json() {
405 let map: HashMap<String, Any> = HashMap::from([
406 ("hello".to_owned(), "world".into()),
407 ("empty".to_owned(), Any::None),
408 ("bool".to_owned(), Any::Bool(true)),
409 ("number".to_owned(), Any::Number(123)),
410 (
411 "array".to_owned(),
412 Any::Array(vec![Any::String("1".into()), Any::String("2".into())]),
413 ),
414 ]);
415 let json = serde_json::to_string(&map).unwrap();
416 assert!(json.contains("\"hello\":\"world\""));
417 assert!(json.contains("\"empty\":null"));
418 assert!(json.contains("\"bool\":true"));
419 assert!(json.contains("\"number\":123"));
420 assert!(json.contains("\"array\":[\"1\",\"2\"]"));
421 }
422}