1mod claims;
2mod error;
3mod jwks;
4
5pub use claims::Claims;
6pub use error::AuthError;
7pub use jsonwebtoken::Algorithm;
8
9use std::sync::RwLock;
10
11use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
12
13pub struct EasyAuth {
14 decoding_keys: RwLock<Vec<(Option<String>, DecodingKey)>>,
15 validation: Validation,
16}
17
18impl EasyAuth {
19 pub fn from_jwks_json(jwks_json: &str) -> Result<Self, AuthError> {
27 let keys = jwks::parse_jwks(jwks_json)?;
28
29 let mut validation = Validation::new(Algorithm::RS256);
30 validation.validate_exp = true;
31
32 Ok(Self {
33 decoding_keys: RwLock::new(keys),
34 validation,
35 })
36 }
37
38 pub fn from_pem(pem: &str) -> Result<Self, AuthError> {
46 let key = DecodingKey::from_rsa_pem(pem.as_bytes())
47 .map_err(|e| AuthError::InvalidKey(format!("Failed to parse PEM: {}", e)))?;
48
49 let mut validation = Validation::new(Algorithm::RS256);
50 validation.validate_exp = true;
51
52 Ok(Self {
53 decoding_keys: RwLock::new(vec![(None, key)]),
54 validation,
55 })
56 }
57
58 pub fn update_jwks(&self, jwks_json: &str) -> Result<(), AuthError> {
63 let keys = jwks::parse_jwks(jwks_json)?;
64 let mut guard = self.decoding_keys.write().expect("decoding_keys poisoned");
65 *guard = keys;
66 Ok(())
67 }
68
69 pub fn validate(&self, token: &str) -> Result<Claims, AuthError> {
80 self.decode_token(token)
81 }
82
83 fn decode_token(&self, token: &str) -> Result<Claims, AuthError> {
84 let header = decode_header(token)?;
85 let kid = header.kid.as_deref();
86
87 let guard = self.decoding_keys.read().expect("decoding_keys poisoned");
88
89 let decoding_key = Self::find_key(&guard, kid)?;
90 let token_data = decode::<Claims>(token, decoding_key, &self.validation)?;
91
92 Ok(token_data.claims)
93 }
94
95 fn find_key<'a>(
96 keys: &'a [(Option<String>, DecodingKey)],
97 kid: Option<&str>,
98 ) -> Result<&'a DecodingKey, AuthError> {
99 if keys.is_empty() {
100 return Err(AuthError::InvalidKey("No keys available".to_string()));
101 }
102
103 match kid {
104 Some(kid) => {
105 for (key_kid, key) in keys {
107 if key_kid.as_deref() == Some(kid) {
108 return Ok(key);
109 }
110 }
111 let all_keys_have_kids = keys.iter().all(|(k, _)| k.is_some());
115 if all_keys_have_kids {
116 Err(AuthError::KeyNotFound(kid.to_string()))
117 } else {
118 Ok(&keys[0].1)
119 }
120 }
121 None => Ok(&keys[0].1),
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
130 use jsonwebtoken::{encode, EncodingKey, Header};
131 use rand::rngs::OsRng;
132 use rsa::pkcs1::EncodeRsaPrivateKey;
133 use rsa::pkcs8::EncodePublicKey;
134 use rsa::traits::PublicKeyParts;
135 use rsa::RsaPrivateKey;
136 use serde::Serialize;
137 use std::time::{SystemTime, UNIX_EPOCH};
138
139 #[derive(Debug, Serialize)]
140 struct TestClaims {
141 sub: String,
142 domain_roles: Vec<String>,
143 exp: u64,
144 iat: u64,
145 }
146
147 struct TestKeys {
148 encoding_key: EncodingKey,
149 pem_public: String,
150 jwks_json: String,
151 }
152
153 fn generate_test_keys() -> TestKeys {
154 let mut rng = OsRng;
155 let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
156 let public_key = private_key.to_public_key();
157
158 let private_pem = private_key.to_pkcs1_pem(Default::default()).unwrap();
159 let public_pem = public_key.to_public_key_pem(Default::default()).unwrap();
160
161 let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
162
163 let n = URL_SAFE_NO_PAD.encode(private_key.n().to_bytes_be());
164 let e = URL_SAFE_NO_PAD.encode(private_key.e().to_bytes_be());
165
166 let jwks_json = format!(
167 r#"{{"keys":[{{"kty":"RSA","kid":"test-key","use":"sig","alg":"RS256","n":"{}","e":"{}"}}]}}"#,
168 n, e
169 );
170
171 TestKeys {
172 encoding_key,
173 pem_public: public_pem,
174 jwks_json,
175 }
176 }
177
178 fn create_token(keys: &TestKeys, claims: &TestClaims) -> String {
179 let mut header = Header::new(Algorithm::RS256);
180 header.kid = Some("test-key".to_string());
181 encode(&header, claims, &keys.encoding_key).unwrap()
182 }
183
184 fn now_secs() -> u64 {
185 SystemTime::now()
186 .duration_since(UNIX_EPOCH)
187 .unwrap()
188 .as_secs()
189 }
190
191 #[test]
192 fn test_allowed_domain_roles_with_matching_role() {
193 let keys = generate_test_keys();
194 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
195
196 let test_claims = TestClaims {
197 sub: "user-123".to_string(),
198 domain_roles: vec!["moon:user".to_string(), "example:admin".to_string()],
199 exp: now_secs() + 3600,
200 iat: now_secs(),
201 };
202
203 let token = create_token(&keys, &test_claims);
204 let claims = auth.validate(&token).unwrap();
205 assert!(claims.allowed_domain_roles(&["moon:user"]));
206 assert_eq!(claims.sub, "user-123");
207 }
208
209 #[test]
210 fn test_allowed_domain_roles_without_matching_role() {
211 let keys = generate_test_keys();
212 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
213
214 let test_claims = TestClaims {
215 sub: "user-123".to_string(),
216 domain_roles: vec!["example:viewer".to_string()],
217 exp: now_secs() + 3600,
218 iat: now_secs(),
219 };
220
221 let token = create_token(&keys, &test_claims);
222 let claims = auth.validate(&token).unwrap();
223 assert!(!claims.allowed_domain_roles(&["moon:admin"]));
224 }
225
226 #[test]
227 fn test_is_subject_matching() {
228 let keys = generate_test_keys();
229 let auth = EasyAuth::from_pem(&keys.pem_public).unwrap();
230
231 let test_claims = TestClaims {
232 sub: "295fafbb-7da3-4881-858f-e6ea5d2b65ae".to_string(),
233 domain_roles: vec![],
234 exp: now_secs() + 3600,
235 iat: now_secs(),
236 };
237
238 let mut header = Header::new(Algorithm::RS256);
239 header.kid = None;
240 let token = encode(&header, &test_claims, &keys.encoding_key).unwrap();
241
242 let claims = auth.validate(&token).unwrap();
243 assert!(claims.is_subject("295fafbb-7da3-4881-858f-e6ea5d2b65ae"));
244 }
245
246 #[test]
247 fn test_is_subject_not_matching() {
248 let keys = generate_test_keys();
249 let auth = EasyAuth::from_pem(&keys.pem_public).unwrap();
250
251 let test_claims = TestClaims {
252 sub: "user-123".to_string(),
253 domain_roles: vec![],
254 exp: now_secs() + 3600,
255 iat: now_secs(),
256 };
257
258 let mut header = Header::new(Algorithm::RS256);
259 header.kid = None;
260 let token = encode(&header, &test_claims, &keys.encoding_key).unwrap();
261
262 let claims = auth.validate(&token).unwrap();
263 assert!(!claims.is_subject("different-user"));
264 }
265
266 #[test]
267 fn test_validate() {
268 let keys = generate_test_keys();
269 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
270
271 let test_claims = TestClaims {
272 sub: "user-456".to_string(),
273 domain_roles: vec!["test:role".to_string()],
274 exp: now_secs() + 3600,
275 iat: now_secs(),
276 };
277
278 let token = create_token(&keys, &test_claims);
279 let claims = auth.validate(&token).unwrap();
280 assert_eq!(claims.sub, "user-456");
281 assert_eq!(claims.domain_roles, vec!["test:role".to_string()]);
282 }
283
284 #[test]
285 fn test_combined_checks() {
286 let keys = generate_test_keys();
287 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
288
289 let test_claims = TestClaims {
290 sub: "user-789".to_string(),
291 domain_roles: vec!["api:read".to_string(), "api:write".to_string()],
292 exp: now_secs() + 3600,
293 iat: now_secs(),
294 };
295
296 let token = create_token(&keys, &test_claims);
297
298 let claims = auth.validate(&token).unwrap();
300 assert!(claims.allowed_domain_roles(&["api:read"]));
301 assert!(claims.is_subject("user-789"));
302
303 assert!(claims.is_subject("user-789") || claims.allowed_domain_roles(&["admin"]));
305 }
306
307 #[test]
308 fn test_expired_token() {
309 let keys = generate_test_keys();
310 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
311
312 let test_claims = TestClaims {
313 sub: "user-123".to_string(),
314 domain_roles: vec!["moon:user".to_string()],
315 exp: now_secs() - 3600, iat: now_secs() - 7200,
317 };
318
319 let token = create_token(&keys, &test_claims);
320 let result = auth.validate(&token);
321
322 assert!(matches!(result, Err(AuthError::TokenExpired)));
323 }
324
325 #[test]
326 fn test_invalid_signature() {
327 let keys1 = generate_test_keys();
328 let keys2 = generate_test_keys();
329
330 let auth = EasyAuth::from_jwks_json(&keys1.jwks_json).unwrap();
332
333 let test_claims = TestClaims {
335 sub: "user-123".to_string(),
336 domain_roles: vec!["moon:user".to_string()],
337 exp: now_secs() + 3600,
338 iat: now_secs(),
339 };
340 let token = create_token(&keys2, &test_claims);
341
342 let result = auth.validate(&token);
343 assert!(matches!(result, Err(AuthError::InvalidSignature)));
344 }
345
346 #[test]
347 fn test_malformed_token() {
348 let keys = generate_test_keys();
349 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
350
351 let result = auth.validate("not.a.valid.token");
352 assert!(matches!(result, Err(AuthError::InvalidToken(_))));
353 }
354
355 #[test]
356 fn test_invalid_jwks() {
357 let result = EasyAuth::from_jwks_json("not valid json");
358 assert!(matches!(result, Err(AuthError::JsonError(_))));
359 }
360
361 #[test]
362 fn test_empty_jwks() {
363 let result = EasyAuth::from_jwks_json(r#"{"keys":[]}"#);
364 assert!(matches!(result, Err(AuthError::InvalidKey(_))));
365 }
366
367 fn generate_test_keys_with_kid(kid: &str) -> TestKeys {
368 let mut rng = OsRng;
369 let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
370 let public_key = private_key.to_public_key();
371
372 let private_pem = private_key.to_pkcs1_pem(Default::default()).unwrap();
373 let public_pem = public_key.to_public_key_pem(Default::default()).unwrap();
374
375 let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
376
377 let n = URL_SAFE_NO_PAD.encode(private_key.n().to_bytes_be());
378 let e = URL_SAFE_NO_PAD.encode(private_key.e().to_bytes_be());
379
380 let jwks_json = format!(
381 r#"{{"keys":[{{"kty":"RSA","kid":"{}","use":"sig","alg":"RS256","n":"{}","e":"{}"}}]}}"#,
382 kid, n, e
383 );
384
385 TestKeys {
386 encoding_key,
387 pem_public: public_pem,
388 jwks_json,
389 }
390 }
391
392 fn create_token_with_kid(keys: &TestKeys, claims: &TestClaims, kid: &str) -> String {
393 let mut header = Header::new(Algorithm::RS256);
394 header.kid = Some(kid.to_string());
395 encode(&header, claims, &keys.encoding_key).unwrap()
396 }
397
398 #[test]
399 fn test_key_not_found() {
400 let keys = generate_test_keys_with_kid("old-key");
401 let auth = EasyAuth::from_jwks_json(&keys.jwks_json).unwrap();
402
403 let test_claims = TestClaims {
404 sub: "user-123".to_string(),
405 domain_roles: vec![],
406 exp: now_secs() + 3600,
407 iat: now_secs(),
408 };
409
410 let token = create_token_with_kid(&keys, &test_claims, "rotated-new-key");
412 let result = auth.validate(&token);
413 assert!(
414 matches!(result, Err(AuthError::KeyNotFound(ref kid)) if kid == "rotated-new-key"),
415 "Expected KeyNotFound for unknown kid, got: {:?}",
416 result
417 );
418 }
419
420 #[test]
421 fn test_update_jwks() {
422 let old_keys = generate_test_keys_with_kid("old-key");
423 let new_keys = generate_test_keys_with_kid("new-key");
424 let auth = EasyAuth::from_jwks_json(&old_keys.jwks_json).unwrap();
425
426 let test_claims = TestClaims {
427 sub: "user-123".to_string(),
428 domain_roles: vec![],
429 exp: now_secs() + 3600,
430 iat: now_secs(),
431 };
432
433 let token = create_token_with_kid(&new_keys, &test_claims, "new-key");
435 assert!(matches!(
436 auth.validate(&token),
437 Err(AuthError::KeyNotFound(_))
438 ));
439
440 auth.update_jwks(&new_keys.jwks_json).unwrap();
442 let claims = auth.validate(&token).unwrap();
443 assert_eq!(claims.sub, "user-123");
444 }
445
446 #[test]
447 fn test_pem_fallback_no_key_not_found() {
448 let keys = generate_test_keys();
450 let auth = EasyAuth::from_pem(&keys.pem_public).unwrap();
451
452 let test_claims = TestClaims {
453 sub: "user-123".to_string(),
454 domain_roles: vec![],
455 exp: now_secs() + 3600,
456 iat: now_secs(),
457 };
458
459 let token = create_token_with_kid(&keys, &test_claims, "any-kid");
461 let claims = auth.validate(&token).unwrap();
462 assert_eq!(claims.sub, "user-123");
463 }
464}