gel_jwt/
lib.rs

1#[cfg(feature = "python_extension")]
2pub mod python;
3
4use std::fmt::Debug;
5
6mod bare_key;
7mod key;
8mod registry;
9mod sig;
10
11#[cfg(feature = "gel")]
12mod gel;
13#[cfg(feature = "gel")]
14pub use gel::{GelPrivateKeyRegistry, GelPublicKeyRegistry};
15
16pub use bare_key::{BareKey, BarePrivateKey, BarePublicKey};
17pub use key::{Key, KeyType, PrivateKey, PublicKey};
18pub use registry::KeyRegistry;
19pub use sig::{Any, SigningContext, ValidationContext, ValidationType};
20
21#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, Eq, PartialEq)]
22pub enum ValidationError {
23    /// The token format or signature was invalid
24    #[display("Invalid token")]
25    Invalid(
26        #[from]
27        #[error(not(source))]
28        OpaqueValidationFailureReason,
29    ),
30    /// The key is invalid
31    KeyError(#[from] KeyError),
32}
33
34impl ValidationError {
35    /// Display an error not intended for the end-user as it may leak information about the keys
36    /// and/or tokens.
37    pub fn error_string_not_for_user(&self) -> String {
38        match self {
39            ValidationError::Invalid(OpaqueValidationFailureReason::Failure(s)) => {
40                format!("Invalid token: {s}")
41            }
42            ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue(
43                claim,
44                value,
45            )) => format!("Invalid claim value for {claim}: {value:?}"),
46            ValidationError::Invalid(OpaqueValidationFailureReason::InvalidHeader(
47                header,
48                value,
49                expected,
50            )) => format!("Invalid header {header}: {value:?}, expected {expected:?}"),
51            ValidationError::Invalid(OpaqueValidationFailureReason::NoAppropriateKey) => {
52                "No appropriate key found".to_string()
53            }
54            ValidationError::Invalid(OpaqueValidationFailureReason::InvalidSignature) => {
55                "Invalid signature".to_string()
56            }
57            ValidationError::KeyError(error) => format!("Key error: {error}"),
58        }
59    }
60}
61
62/// A reason for validation failure that is opaque to debugging or printing to avoid
63/// leaking information about the token failure.
64#[derive(Eq, PartialEq)]
65pub enum OpaqueValidationFailureReason {
66    NoAppropriateKey,
67    InvalidSignature,
68    InvalidClaimValue(String, Option<String>),
69    InvalidHeader(String, String, Option<String>),
70    Failure(String),
71}
72
73impl std::fmt::Debug for OpaqueValidationFailureReason {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        write!(f, "...")
76    }
77}
78
79#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug)]
80pub enum SignatureError {
81    /// The token format or signature was invalid
82    #[display("Signature operation failed: {_0}")]
83    SignatureError(#[error(not(source))] String),
84    /// No appropriate key was found
85    #[display("No appropriate signing key found")]
86    NoAppropriateKey,
87    /// The key is invalid
88    KeyError(#[from] KeyError),
89}
90
91#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, Eq, PartialEq)]
92pub enum KeyError {
93    #[display("Invalid PEM format")]
94    InvalidPem,
95    #[display("Invalid JSON format")]
96    InvalidJson,
97    #[display("Unsupported key type: {_0}")]
98    UnsupportedKeyType(#[error(not(source))] String),
99    #[display("Invalid EC key parameters")]
100    InvalidEcParameters,
101    #[display("Failed to decode key")]
102    DecodeError,
103    #[display("Failed to encode key")]
104    EncodeError,
105    #[display("Failed to validate key pair: {_0:?}")]
106    KeyValidationError(#[from] KeyValidationError),
107}
108
109#[derive(derive_more::Error, derive_more::Display, Debug, Eq, PartialEq)]
110pub struct KeyValidationError(#[error(not(source))] String);
111
112#[cfg(test)]
113mod tests {
114    use std::{collections::HashMap, time::Duration};
115
116    use super::*;
117
118    #[test]
119    fn test_key_registry_add_remove() {
120        let mut registry = KeyRegistry::default();
121        registry.add_key(PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap());
122        registry.add_key(PrivateKey::generate(Some("2".to_owned()), KeyType::HS256).unwrap());
123        registry.add_key(PrivateKey::generate(Some("3".to_owned()), KeyType::HS256).unwrap());
124        assert_eq!(registry.len(), 3);
125        assert!(!registry.is_empty());
126        assert!(registry.remove_kid("1"));
127        assert_eq!(registry.len(), 2);
128        assert!(!registry.remove_kid("1"));
129        assert_eq!(registry.len(), 2);
130        assert!(registry.remove_kid("2"));
131        assert_eq!(registry.len(), 1);
132        assert!(!registry.remove_kid("2"));
133        assert_eq!(registry.len(), 1);
134        assert!(registry.remove_kid("3"));
135        assert_eq!(registry.len(), 0);
136        assert!(!registry.remove_kid("3"));
137    }
138
139    #[test]
140    fn test_key_registry_re_add() {
141        let mut registry = KeyRegistry::default();
142        let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
143
144        registry.add_key(key.clone_key());
145        assert_eq!(registry.len(), 1);
146        registry.add_key(key.clone_key());
147        assert_eq!(registry.len(), 1);
148        registry.remove_kid("1");
149        assert_eq!(registry.len(), 0);
150        registry.add_key(key);
151        assert_eq!(registry.len(), 1);
152    }
153
154    #[test]
155    fn test_key_registry_add_dupe_kid() {
156        let mut registry = KeyRegistry::default();
157        let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
158        registry.add_key(key.clone_key());
159        assert_eq!(registry.len(), 1);
160        registry.add_key(key.clone_key());
161        assert_eq!(registry.len(), 1);
162
163        let key2 = PrivateKey::generate(Some("1".to_owned()), KeyType::RS256).unwrap();
164        registry.add_key(key2.clone_key());
165        assert_eq!(registry.len(), 1);
166        registry.add_key(key2.clone_key());
167        assert_eq!(registry.len(), 1);
168    }
169
170    #[test]
171    fn test_sign() {
172        let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
173        let claims = HashMap::from([("hello".to_owned(), "world".into())]);
174        let signing_ctx = SigningContext {
175            expiry: Some(Duration::from_secs(600)),
176            issuer: Some("issuer".to_owned()),
177            audience: Some("audience".to_owned()),
178            ..Default::default()
179        };
180        let mut validation_ctx = ValidationContext::default();
181        validation_ctx.require_claim("aud");
182        validation_ctx.require_claim_with_allow_list("iss", &["issuer"]);
183
184        let token = key.sign(claims.clone(), &signing_ctx).unwrap();
185        println!("token: {token}");
186        let decoded = key.validate(&token, &validation_ctx).unwrap();
187        assert_eq!(decoded, claims);
188    }
189
190    #[test]
191    fn test_sign_no_expiry() {
192        let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
193        let claims = HashMap::from([("hello".to_owned(), "world".into())]);
194        let signing_ctx = SigningContext {
195            issuer: Some("issuer".to_owned()),
196            audience: Some("audience".to_owned()),
197            ..Default::default()
198        };
199        let token = key.sign(claims.clone(), &signing_ctx).unwrap();
200        let mut validation_ctx = ValidationContext::default();
201        validation_ctx.require_claim("aud");
202        validation_ctx.require_claim_with_allow_list("iss", &["issuer"]);
203        let decoded = key
204            .validate(&token, &validation_ctx)
205            .map_err(|e| e.error_string_not_for_user())
206            .unwrap();
207        assert_eq!(decoded, claims);
208    }
209
210    #[test]
211    fn load_from_empty() {
212        let mut registry = KeyRegistry::<PrivateKey>::default();
213        let added = registry.add_from_any("").unwrap();
214        assert_eq!(added, 0);
215        registry.add_from_pem("").unwrap();
216        assert_eq!(added, 0);
217        registry.add_from_jwkset("{\"keys\":[]}").unwrap();
218        assert_eq!(added, 0);
219    }
220
221    #[test]
222    fn test_google_jwkset() {
223        let mut registry = KeyRegistry::<Key>::default();
224        let added = registry
225            .add_from_jwkset(include_str!("testcases/jwkset-goog.json"))
226            .unwrap();
227        assert_eq!(added, 2);
228    }
229
230    #[test]
231    fn test_microsoft_jwkset() {
232        let mut registry = KeyRegistry::<Key>::default();
233        let added = registry
234            .add_from_jwkset(include_str!("testcases/jwkset-msft.json"))
235            .unwrap();
236        assert_eq!(added, 9);
237    }
238
239    #[test]
240    fn test_slack_jwkset() {
241        let mut registry = KeyRegistry::<Key>::default();
242        let added = registry
243            .add_from_jwkset(include_str!("testcases/jwkset-slck.json"))
244            .unwrap();
245        assert_eq!(added, 1);
246    }
247
248    #[test]
249    fn test_apple_jwkset() {
250        let mut registry = KeyRegistry::<Key>::default();
251        let added = registry
252            .add_from_jwkset(include_str!("testcases/jwkset-aapl.json"))
253            .unwrap();
254        assert_eq!(added, 3);
255    }
256
257    #[test]
258    fn load_keys_from_jwkset() {
259        let mut registry = KeyRegistry::<PrivateKey>::default();
260        let added = registry
261            .add_from_jwkset(include_str!("testcases/jwkset-pub.json"))
262            .unwrap();
263        assert_eq!(added, 0);
264        let mut registry = KeyRegistry::<PrivateKey>::default();
265        let added = registry
266            .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
267            .unwrap();
268        assert_eq!(added, 3);
269    }
270
271    #[test]
272    fn load_pub_keys_from_jwkset() {
273        let mut registry = KeyRegistry::<PublicKey>::default();
274        let added = registry
275            .add_from_jwkset(include_str!("testcases/jwkset-pub.json"))
276            .unwrap();
277        assert_eq!(added, 2);
278        let mut registry = KeyRegistry::<PublicKey>::default();
279        let added = registry
280            .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
281            .unwrap();
282        assert_eq!(added, 3);
283    }
284
285    #[test]
286    fn validate_tokens_from_jwkset() {
287        let mut registry = KeyRegistry::<PrivateKey>::default();
288        registry
289            .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
290            .unwrap();
291        let keys = registry.into_keys().collect::<Vec<_>>();
292
293        let mut registry = KeyRegistry::<PrivateKey>::default();
294        registry
295            .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
296            .unwrap();
297
298        let claims = HashMap::from([("test".to_owned(), "value".into())]);
299        let signing_ctx = SigningContext {
300            issuer: Some("test-issuer".to_owned()),
301            audience: Some("test-audience".to_owned()),
302            ..Default::default()
303        };
304        let mut validation_ctx = ValidationContext::default();
305        validation_ctx.require_claim_with_allow_list("iss", &["test-issuer"]);
306        validation_ctx.require_claim_with_allow_list("aud", &["test-audience"]);
307
308        // Generate and validate a token with each key
309        for key in &keys {
310            let token = key.sign(claims.clone(), &signing_ctx).unwrap();
311            let decoded = registry.validate(&token, &validation_ctx).unwrap();
312            assert_eq!(decoded, claims);
313        }
314
315        // Generate and validate a token with each key against the public keys
316        let mut registry = KeyRegistry::<PublicKey>::default();
317        registry
318            .add_from_jwkset(include_str!("testcases/jwkset-prv.json"))
319            .unwrap();
320        for key in &keys {
321            let token = key.sign(claims.clone(), &signing_ctx).unwrap();
322            let decoded = registry.validate(&token, &validation_ctx).unwrap();
323            assert_eq!(decoded, claims);
324        }
325    }
326
327    #[test]
328    fn test_validate_tokens_from_jwkset_named() {
329        let mut key1 = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
330        let mut key2 = PrivateKey::generate(Some("2".to_owned()), KeyType::HS256).unwrap();
331
332        let claims = HashMap::from([("test".to_owned(), "value".into())]);
333        let signing_ctx = SigningContext {
334            issuer: Some("test-issuer".to_owned()),
335            audience: Some("test-audience".to_owned()),
336            ..Default::default()
337        };
338        let validation_ctx = ValidationContext::default();
339        let token = key1.sign(claims, &signing_ctx).unwrap();
340
341        // Swap the keys so the signature is no longer valid with the specified kid
342        key1.set_kid(Some("2".to_owned()));
343        key2.set_kid(Some("1".to_owned()));
344
345        let mut registry = KeyRegistry::<PrivateKey>::default();
346        registry.add_key(key1);
347        registry.add_key(key2);
348
349        let decoded = registry.validate(&token, &validation_ctx).unwrap_err();
350        assert_eq!(
351            decoded,
352            ValidationError::Invalid(OpaqueValidationFailureReason::InvalidSignature),
353            "{}",
354            decoded.error_string_not_for_user()
355        );
356    }
357
358    #[test]
359    fn test_validate_tokens_from_jwkset_named_allow_deny() {
360        let key = PrivateKey::generate(Some("1".to_owned()), KeyType::HS256).unwrap();
361        let mut registry = KeyRegistry::<PrivateKey>::default();
362        registry.add_key(key);
363
364        let claims = HashMap::from([("jti".to_owned(), "1234".into())]);
365        let signing_ctx = SigningContext::default();
366        let mut validation_ctx = ValidationContext::default();
367        let token = registry.sign(claims.clone(), &signing_ctx).unwrap();
368
369        // With no claim validation, the token should be valid
370        let res = registry.validate(&token, &validation_ctx);
371        assert!(
372            res.is_ok(),
373            "{}",
374            res.unwrap_err().error_string_not_for_user()
375        );
376
377        validation_ctx.require_claim_with_allow_list("jti", &["1234"]);
378        let decoded = registry.validate(&token, &validation_ctx).unwrap();
379        assert_eq!(decoded, Default::default());
380
381        let claims = HashMap::from([("jti".to_owned(), "bad".into())]);
382        let token = registry.sign(claims, &signing_ctx).unwrap();
383        let decoded = registry.validate(&token, &validation_ctx).unwrap_err();
384        assert_eq!(
385            decoded,
386            ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue(
387                "jti".to_string(),
388                Some("bad".to_string())
389            ))
390        );
391
392        validation_ctx.require_claim_with_deny_list("jti", &["bad"]);
393        let decoded = registry.validate(&token, &validation_ctx).unwrap_err();
394        assert_eq!(
395            decoded,
396            ValidationError::Invalid(OpaqueValidationFailureReason::InvalidClaimValue(
397                "jti".to_string(),
398                Some("bad".to_string())
399            ))
400        );
401    }
402
403    #[test]
404    fn test_any_json() {
405        let map: HashMap<String, Any> = HashMap::from([
406            ("hello".to_owned(), "world".into()),
407            ("empty".to_owned(), Any::None),
408            ("bool".to_owned(), Any::Bool(true)),
409            ("number".to_owned(), Any::Number(123)),
410            (
411                "array".to_owned(),
412                Any::Array(vec![Any::String("1".into()), Any::String("2".into())]),
413            ),
414        ]);
415        let json = serde_json::to_string(&map).unwrap();
416        assert!(json.contains("\"hello\":\"world\""));
417        assert!(json.contains("\"empty\":null"));
418        assert!(json.contains("\"bool\":true"));
419        assert!(json.contains("\"number\":123"));
420        assert!(json.contains("\"array\":[\"1\",\"2\"]"));
421    }
422}