Skip to main content

easy_auth_sdk/
lib.rs

1mod claims;
2mod error;
3mod jwks;
4
5pub use claims::Claims;
6pub use error::AuthError;
7pub use jsonwebtoken::Algorithm;
8
9use std::sync::RwLock;
10
11use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
12
13pub struct EasyAuth {
14    decoding_keys: RwLock<Vec<(Option<String>, DecodingKey)>>,
15    validation: Validation,
16}
17
18impl EasyAuth {
19    /// Create an EasyAuth instance from a JWKS JSON string.
20    ///
21    /// # Example
22    /// ```ignore
23    /// let jwks_json = r#"{"keys":[...]}"#;
24    /// let auth = EasyAuth::from_jwks_json(jwks_json)?;
25    /// ```
26    pub fn from_jwks_json(jwks_json: &str) -> Result<Self, AuthError> {
27        let keys = jwks::parse_jwks(jwks_json)?;
28
29        let mut validation = Validation::new(Algorithm::RS256);
30        validation.validate_exp = true;
31
32        Ok(Self {
33            decoding_keys: RwLock::new(keys),
34            validation,
35        })
36    }
37
38    /// Create an EasyAuth instance from a PEM-encoded public key.
39    ///
40    /// # Example
41    /// ```ignore
42    /// let pem = "-----BEGIN PUBLIC KEY-----\n...";
43    /// let auth = EasyAuth::from_pem(pem)?;
44    /// ```
45    pub fn from_pem(pem: &str) -> Result<Self, AuthError> {
46        let key = DecodingKey::from_rsa_pem(pem.as_bytes())
47            .map_err(|e| AuthError::InvalidKey(format!("Failed to parse PEM: {}", e)))?;
48
49        let mut validation = Validation::new(Algorithm::RS256);
50        validation.validate_exp = true;
51
52        Ok(Self {
53            decoding_keys: RwLock::new(vec![(None, key)]),
54            validation,
55        })
56    }
57
58    /// Hot-swap the JWKS keys without reconstructing the `EasyAuth` instance.
59    ///
60    /// Call this when a `KeyNotFound` error indicates the signing keys have
61    /// been rotated and the current key set is stale.
62    pub fn update_jwks(&self, jwks_json: &str) -> Result<(), AuthError> {
63        let keys = jwks::parse_jwks(jwks_json)?;
64        let mut guard = self.decoding_keys.write().expect("decoding_keys poisoned");
65        *guard = keys;
66        Ok(())
67    }
68
69    /// Validate the JWT token and return the claims.
70    ///
71    /// Verifies the token signature and expiration, then returns the decoded claims.
72    ///
73    /// # Example
74    /// ```ignore
75    /// let claims = auth.validate(&token)?;
76    /// println!("User: {}", claims.sub);
77    /// println!("Roles: {:?}", claims.domain_roles);
78    /// ```
79    pub fn validate(&self, token: &str) -> Result<Claims, AuthError> {
80        self.decode_token(token)
81    }
82
83    fn decode_token(&self, token: &str) -> Result<Claims, AuthError> {
84        let header = decode_header(token)?;
85        let kid = header.kid.as_deref();
86
87        let guard = self.decoding_keys.read().expect("decoding_keys poisoned");
88
89        let decoding_key = Self::find_key(&guard, kid)?;
90        let token_data = decode::<Claims>(token, decoding_key, &self.validation)?;
91
92        Ok(token_data.claims)
93    }
94
95    fn find_key<'a>(
96        keys: &'a [(Option<String>, DecodingKey)],
97        kid: Option<&str>,
98    ) -> Result<&'a DecodingKey, AuthError> {
99        if keys.is_empty() {
100            return Err(AuthError::InvalidKey("No keys available".to_string()));
101        }
102
103        match kid {
104            Some(kid) => {
105                // Look for exact kid match
106                for (key_kid, key) in keys {
107                    if key_kid.as_deref() == Some(kid) {
108                        return Ok(key);
109                    }
110                }
111                // No match — if all stored keys have kids, this is a key-not-found
112                // (the signing key rotated and we don't have the new one).
113                // If stored keys have no kids (e.g. PEM), fall back to first key.
114                let all_keys_have_kids = keys.iter().all(|(k, _)| k.is_some());
115                if all_keys_have_kids {
116                    Err(AuthError::KeyNotFound(kid.to_string()))
117                } else {
118                    Ok(&keys[0].1)
119                }
120            }
121            None => Ok(&keys[0].1),
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
130    use jsonwebtoken::{encode, EncodingKey, Header};
131    use rand::rngs::OsRng;
132    use rsa::pkcs1::EncodeRsaPrivateKey;
133    use rsa::pkcs8::EncodePublicKey;
134    use rsa::traits::PublicKeyParts;
135    use rsa::RsaPrivateKey;
136    use serde::Serialize;
137    use std::time::{SystemTime, UNIX_EPOCH};
138
139    #[derive(Debug, Serialize)]
140    struct TestClaims {
141        sub: String,
142        domain_roles: Vec<String>,
143        exp: u64,
144        iat: u64,
145    }
146
147    struct TestKeys {
148        encoding_key: EncodingKey,
149        pem_public: String,
150        jwks_json: String,
151    }
152
153    fn generate_test_keys() -> TestKeys {
154        let mut rng = OsRng;
155        let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
156        let public_key = private_key.to_public_key();
157
158        let private_pem = private_key.to_pkcs1_pem(Default::default()).unwrap();
159        let public_pem = public_key.to_public_key_pem(Default::default()).unwrap();
160
161        let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
162
163        let n = URL_SAFE_NO_PAD.encode(private_key.n().to_bytes_be());
164        let e = URL_SAFE_NO_PAD.encode(private_key.e().to_bytes_be());
165
166        let jwks_json = format!(
167            r#"{{"keys":[{{"kty":"RSA","kid":"test-key","use":"sig","alg":"RS256","n":"{}","e":"{}"}}]}}"#,
168            n, e
169        );
170
171        TestKeys {
172            encoding_key,
173            pem_public: public_pem,
174            jwks_json,
175        }
176    }
177
178    fn create_token(keys: &TestKeys, claims: &TestClaims) -> String {
179        let mut header = Header::new(Algorithm::RS256);
180        header.kid = Some("test-key".to_string());
181        encode(&header, claims, &keys.encoding_key).unwrap()
182    }
183
184    fn now_secs() -> u64 {
185        SystemTime::now()
186            .duration_since(UNIX_EPOCH)
187            .unwrap()
188            .as_secs()
189    }
190
191    #[test]
192    fn test_allowed_domain_roles_with_matching_role() {
193        let keys = generate_test_keys();
194        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
195
196        let test_claims = TestClaims {
197            sub: "user-123".to_string(),
198            domain_roles: vec!["moon:user".to_string(), "example:admin".to_string()],
199            exp: now_secs() + 3600,
200            iat: now_secs(),
201        };
202
203        let token = create_token(&keys, &test_claims);
204        let claims = auth.validate(&token).unwrap();
205        assert!(claims.allowed_domain_roles(&["moon:user"]));
206        assert_eq!(claims.sub, "user-123");
207    }
208
209    #[test]
210    fn test_allowed_domain_roles_without_matching_role() {
211        let keys = generate_test_keys();
212        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
213
214        let test_claims = TestClaims {
215            sub: "user-123".to_string(),
216            domain_roles: vec!["example:viewer".to_string()],
217            exp: now_secs() + 3600,
218            iat: now_secs(),
219        };
220
221        let token = create_token(&keys, &test_claims);
222        let claims = auth.validate(&token).unwrap();
223        assert!(!claims.allowed_domain_roles(&["moon:admin"]));
224    }
225
226    #[test]
227    fn test_is_subject_matching() {
228        let keys = generate_test_keys();
229        let auth = EasyAuth::from_pem(&keys.pem_public).unwrap();
230
231        let test_claims = TestClaims {
232            sub: "295fafbb-7da3-4881-858f-e6ea5d2b65ae".to_string(),
233            domain_roles: vec![],
234            exp: now_secs() + 3600,
235            iat: now_secs(),
236        };
237
238        let mut header = Header::new(Algorithm::RS256);
239        header.kid = None;
240        let token = encode(&header, &test_claims, &keys.encoding_key).unwrap();
241
242        let claims = auth.validate(&token).unwrap();
243        assert!(claims.is_subject("295fafbb-7da3-4881-858f-e6ea5d2b65ae"));
244    }
245
246    #[test]
247    fn test_is_subject_not_matching() {
248        let keys = generate_test_keys();
249        let auth = EasyAuth::from_pem(&keys.pem_public).unwrap();
250
251        let test_claims = TestClaims {
252            sub: "user-123".to_string(),
253            domain_roles: vec![],
254            exp: now_secs() + 3600,
255            iat: now_secs(),
256        };
257
258        let mut header = Header::new(Algorithm::RS256);
259        header.kid = None;
260        let token = encode(&header, &test_claims, &keys.encoding_key).unwrap();
261
262        let claims = auth.validate(&token).unwrap();
263        assert!(!claims.is_subject("different-user"));
264    }
265
266    #[test]
267    fn test_validate() {
268        let keys = generate_test_keys();
269        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
270
271        let test_claims = TestClaims {
272            sub: "user-456".to_string(),
273            domain_roles: vec!["test:role".to_string()],
274            exp: now_secs() + 3600,
275            iat: now_secs(),
276        };
277
278        let token = create_token(&keys, &test_claims);
279        let claims = auth.validate(&token).unwrap();
280        assert_eq!(claims.sub, "user-456");
281        assert_eq!(claims.domain_roles, vec!["test:role".to_string()]);
282    }
283
284    #[test]
285    fn test_combined_checks() {
286        let keys = generate_test_keys();
287        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
288
289        let test_claims = TestClaims {
290            sub: "user-789".to_string(),
291            domain_roles: vec!["api:read".to_string(), "api:write".to_string()],
292            exp: now_secs() + 3600,
293            iat: now_secs(),
294        };
295
296        let token = create_token(&keys, &test_claims);
297
298        // Validate once, check multiple times
299        let claims = auth.validate(&token).unwrap();
300        assert!(claims.allowed_domain_roles(&["api:read"]));
301        assert!(claims.is_subject("user-789"));
302
303        // OR logic: allow if subject matches OR has admin role
304        assert!(claims.is_subject("user-789") || claims.allowed_domain_roles(&["admin"]));
305    }
306
307    #[test]
308    fn test_expired_token() {
309        let keys = generate_test_keys();
310        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
311
312        let test_claims = TestClaims {
313            sub: "user-123".to_string(),
314            domain_roles: vec!["moon:user".to_string()],
315            exp: now_secs() - 3600, // Expired 1 hour ago
316            iat: now_secs() - 7200,
317        };
318
319        let token = create_token(&keys, &test_claims);
320        let result = auth.validate(&token);
321
322        assert!(matches!(result, Err(AuthError::TokenExpired)));
323    }
324
325    #[test]
326    fn test_invalid_signature() {
327        let keys1 = generate_test_keys();
328        let keys2 = generate_test_keys();
329
330        // Create auth with keys1
331        let auth = EasyAuth::from_jwks_json(&keys1.jwks_json).unwrap();
332
333        // Create token with keys2 (different key)
334        let test_claims = TestClaims {
335            sub: "user-123".to_string(),
336            domain_roles: vec!["moon:user".to_string()],
337            exp: now_secs() + 3600,
338            iat: now_secs(),
339        };
340        let token = create_token(&keys2, &test_claims);
341
342        let result = auth.validate(&token);
343        assert!(matches!(result, Err(AuthError::InvalidSignature)));
344    }
345
346    #[test]
347    fn test_malformed_token() {
348        let keys = generate_test_keys();
349        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
350
351        let result = auth.validate("not.a.valid.token");
352        assert!(matches!(result, Err(AuthError::InvalidToken(_))));
353    }
354
355    #[test]
356    fn test_invalid_jwks() {
357        let result = EasyAuth::from_jwks_json("not valid json");
358        assert!(matches!(result, Err(AuthError::JsonError(_))));
359    }
360
361    #[test]
362    fn test_empty_jwks() {
363        let result = EasyAuth::from_jwks_json(r#"{"keys":[]}"#);
364        assert!(matches!(result, Err(AuthError::InvalidKey(_))));
365    }
366
367    fn generate_test_keys_with_kid(kid: &str) -> TestKeys {
368        let mut rng = OsRng;
369        let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
370        let public_key = private_key.to_public_key();
371
372        let private_pem = private_key.to_pkcs1_pem(Default::default()).unwrap();
373        let public_pem = public_key.to_public_key_pem(Default::default()).unwrap();
374
375        let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
376
377        let n = URL_SAFE_NO_PAD.encode(private_key.n().to_bytes_be());
378        let e = URL_SAFE_NO_PAD.encode(private_key.e().to_bytes_be());
379
380        let jwks_json = format!(
381            r#"{{"keys":[{{"kty":"RSA","kid":"{}","use":"sig","alg":"RS256","n":"{}","e":"{}"}}]}}"#,
382            kid, n, e
383        );
384
385        TestKeys {
386            encoding_key,
387            pem_public: public_pem,
388            jwks_json,
389        }
390    }
391
392    fn create_token_with_kid(keys: &TestKeys, claims: &TestClaims, kid: &str) -> String {
393        let mut header = Header::new(Algorithm::RS256);
394        header.kid = Some(kid.to_string());
395        encode(&header, claims, &keys.encoding_key).unwrap()
396    }
397
398    #[test]
399    fn test_key_not_found() {
400        let keys = generate_test_keys_with_kid("old-key");
401        let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
402
403        let test_claims = TestClaims {
404            sub: "user-123".to_string(),
405            domain_roles: vec![],
406            exp: now_secs() + 3600,
407            iat: now_secs(),
408        };
409
410        // Sign with a kid that doesn't exist in the JWKS
411        let token = create_token_with_kid(&keys, &test_claims, "rotated-new-key");
412        let result = auth.validate(&token);
413        assert!(
414            matches!(result, Err(AuthError::KeyNotFound(ref kid)) if kid == "rotated-new-key"),
415            "Expected KeyNotFound for unknown kid, got: {:?}",
416            result
417        );
418    }
419
420    #[test]
421    fn test_update_jwks() {
422        let old_keys = generate_test_keys_with_kid("old-key");
423        let new_keys = generate_test_keys_with_kid("new-key");
424        let auth = EasyAuth::from_jwks_json(&old_keys.jwks_json).unwrap();
425
426        let test_claims = TestClaims {
427            sub: "user-123".to_string(),
428            domain_roles: vec![],
429            exp: now_secs() + 3600,
430            iat: now_secs(),
431        };
432
433        // Token signed with new key initially fails
434        let token = create_token_with_kid(&new_keys, &test_claims, "new-key");
435        assert!(matches!(
436            auth.validate(&token),
437            Err(AuthError::KeyNotFound(_))
438        ));
439
440        // After updating JWKS with new keys, validation succeeds
441        auth.update_jwks(&new_keys.jwks_json).unwrap();
442        let claims = auth.validate(&token).unwrap();
443        assert_eq!(claims.sub, "user-123");
444    }
445
446    #[test]
447    fn test_pem_fallback_no_key_not_found() {
448        // PEM keys have no kid — should fall back to first key, not KeyNotFound
449        let keys = generate_test_keys();
450        let auth = EasyAuth::from_pem(&keys.pem_public).unwrap();
451
452        let test_claims = TestClaims {
453            sub: "user-123".to_string(),
454            domain_roles: vec![],
455            exp: now_secs() + 3600,
456            iat: now_secs(),
457        };
458
459        // Token with an unknown kid still works because PEM keys have no kids
460        let token = create_token_with_kid(&keys, &test_claims, "any-kid");
461        let claims = auth.validate(&token).unwrap();
462        assert_eq!(claims.sub, "user-123");
463    }
464}