1use crate::errors::{AuthError, Result};
7use crate::server::core::common_validation;
8use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Clone)]
15pub struct JwtConfig {
16 pub algorithm: Algorithm,
18 pub signing_key: EncodingKey,
20 pub verification_key: DecodingKey,
22 pub default_expiration: u64,
24 pub issuer: String,
26 pub audiences: Vec<String>,
28}
29
30impl JwtConfig {
31 pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
33 Self {
34 algorithm: Algorithm::HS256,
35 signing_key: EncodingKey::from_secret(secret),
36 verification_key: DecodingKey::from_secret(secret),
37 default_expiration: 3600, issuer,
39 audiences: vec![],
40 }
41 }
42
43 pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
45 let signing_key = EncodingKey::from_rsa_pem(private_key)
46 .map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
47
48 let verification_key = DecodingKey::from_rsa_pem(public_key)
49 .map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
50
51 Ok(Self {
52 algorithm: Algorithm::RS256,
53 signing_key,
54 verification_key,
55 default_expiration: 3600, issuer,
57 audiences: vec![],
58 })
59 }
60
61 pub fn with_audience(mut self, audience: String) -> Self {
63 self.audiences.push(audience);
64 self
65 }
66
67 pub fn with_expiration(mut self, expiration: u64) -> Self {
69 self.default_expiration = expiration;
70 self
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CommonJwtClaims {
77 pub iss: String,
79 pub sub: String,
81 pub aud: Vec<String>,
83 pub exp: i64,
85 pub iat: i64,
87 pub nbf: Option<i64>,
89 pub jti: Option<String>,
91 #[serde(flatten)]
93 pub custom: HashMap<String, serde_json::Value>,
94}
95
96impl CommonJwtClaims {
97 pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
99 let now = SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .unwrap_or_default()
102 .as_secs() as i64;
103
104 Self {
105 iss: issuer,
106 sub: subject,
107 aud: audiences,
108 exp: expiration,
109 iat: now,
110 nbf: None,
111 jti: None,
112 custom: HashMap::new(),
113 }
114 }
115
116 pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
118 self.custom.insert(key, value);
119 self
120 }
121
122 pub fn with_jti(mut self, jti: String) -> Self {
124 self.jti = Some(jti);
125 self
126 }
127
128 pub fn with_nbf(mut self, nbf: i64) -> Self {
130 self.nbf = Some(nbf);
131 self
132 }
133}
134
135pub struct JwtManager {
217 config: JwtConfig,
218}
219
220impl JwtManager {
221 pub fn new(config: JwtConfig) -> Self {
223 Self { config }
224 }
225
226 pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
228 let header = Header {
229 alg: self.config.algorithm,
230 ..Default::default()
231 };
232
233 encode(&header, claims, &self.config.signing_key)
234 .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
235 }
236
237 pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
239 where
240 T: Serialize,
241 {
242 let header = Header {
243 alg: self.config.algorithm,
244 ..Default::default()
245 };
246
247 encode(&header, claims, &self.config.signing_key)
248 .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
249 }
250
251 pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
253 common_validation::jwt::validate_jwt_format(token)?;
255
256 let mut validation = Validation::new(self.config.algorithm);
257 validation.set_issuer(&[&self.config.issuer]);
258
259 if !self.config.audiences.is_empty() {
260 validation.set_audience(
261 &self
262 .config
263 .audiences
264 .iter()
265 .map(String::as_str)
266 .collect::<Vec<_>>(),
267 );
268 }
269
270 let token_data =
271 decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
272 .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
273
274 let claims_value = serde_json::to_value(&token_data.claims)
276 .map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
277
278 common_validation::jwt::validate_time_claims(&claims_value)?;
279
280 Ok(token_data.claims)
281 }
282
283 pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
285 where
286 T: for<'de> Deserialize<'de>,
287 {
288 common_validation::jwt::validate_jwt_format(token)?;
289
290 let mut validation = Validation::new(self.config.algorithm);
291 validation.set_issuer(&[&self.config.issuer]);
292
293 if !self.config.audiences.is_empty() {
294 validation.set_audience(
295 &self
296 .config
297 .audiences
298 .iter()
299 .map(String::as_str)
300 .collect::<Vec<_>>(),
301 );
302 }
303
304 let token_data = decode::<T>(token, &self.config.verification_key, &validation)
305 .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
306
307 Ok(token_data.claims)
308 }
309
310 pub fn create_access_token(
312 &self,
313 subject: String,
314 scope: Vec<String>,
315 client_id: Option<String>,
316 ) -> Result<String> {
317 let exp = SystemTime::now()
318 .duration_since(UNIX_EPOCH)
319 .unwrap_or_default()
320 .as_secs() as i64
321 + self.config.default_expiration as i64;
322
323 let mut claims = CommonJwtClaims::new(
324 self.config.issuer.clone(),
325 subject,
326 self.config.audiences.clone(),
327 exp,
328 );
329
330 claims
331 .custom
332 .insert("scope".to_string(), serde_json::json!(scope.join(" ")));
333
334 if let Some(client_id) = client_id {
335 claims.custom.insert(
336 "client_id".to_string(),
337 serde_json::Value::String(client_id),
338 );
339 }
340
341 claims.custom.insert(
342 "token_type".to_string(),
343 serde_json::Value::String("access_token".to_string()),
344 );
345
346 self.create_token(&claims)
347 }
348
349 pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
351 let exp = SystemTime::now()
353 .duration_since(UNIX_EPOCH)
354 .unwrap_or_default()
355 .as_secs() as i64
356 + (self.config.default_expiration * 24) as i64; let mut claims = CommonJwtClaims::new(
359 self.config.issuer.clone(),
360 subject,
361 self.config.audiences.clone(),
362 exp,
363 );
364
365 claims.custom.insert(
366 "client_id".to_string(),
367 serde_json::Value::String(client_id),
368 );
369 claims.custom.insert(
370 "token_type".to_string(),
371 serde_json::Value::String("refresh_token".to_string()),
372 );
373
374 self.create_token(&claims)
375 }
376
377 pub fn create_id_token(
379 &self,
380 subject: String,
381 nonce: Option<String>,
382 auth_time: Option<i64>,
383 user_info: HashMap<String, serde_json::Value>,
384 ) -> Result<String> {
385 let exp = SystemTime::now()
386 .duration_since(UNIX_EPOCH)
387 .unwrap_or_default()
388 .as_secs() as i64
389 + 300; let mut claims = CommonJwtClaims::new(
392 self.config.issuer.clone(),
393 subject,
394 self.config.audiences.clone(),
395 exp,
396 );
397
398 claims.custom.insert(
399 "token_type".to_string(),
400 serde_json::Value::String("id_token".to_string()),
401 );
402
403 if let Some(nonce) = nonce {
404 claims
405 .custom
406 .insert("nonce".to_string(), serde_json::Value::String(nonce));
407 }
408
409 if let Some(auth_time) = auth_time {
410 claims.custom.insert(
411 "auth_time".to_string(),
412 serde_json::Value::Number(auth_time.into()),
413 );
414 }
415
416 for (key, value) in user_info {
418 claims.custom.insert(key, value);
419 }
420
421 self.create_token(&claims)
422 }
423}
424
425pub(crate) mod utils {
427 use super::*;
428
429 #[allow(dead_code)]
439 pub(crate) fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
440 common_validation::jwt::extract_claims_unsafe(token)
441 }
442
443 #[allow(dead_code)]
450 pub(crate) fn is_token_expired(token: &str) -> Result<bool> {
451 let claims = extract_claims_unsafe(token)?;
452
453 let now = SystemTime::now()
454 .duration_since(UNIX_EPOCH)
455 .unwrap_or_default()
456 .as_secs() as i64;
457
458 if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
459 Ok(now >= exp)
460 } else {
461 Ok(false) }
463 }
464
465 #[allow(dead_code)]
471 pub(crate) fn get_token_expiration(token: &str) -> Result<Option<i64>> {
472 let claims = extract_claims_unsafe(token)?;
473 Ok(claims.get("exp").and_then(|v| v.as_i64()))
474 }
475
476 #[allow(dead_code)]
482 pub(crate) fn get_token_subject(token: &str) -> Result<Option<String>> {
483 let claims = extract_claims_unsafe(token)?;
484 Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
485 }
486
487 #[allow(dead_code)]
493 pub(crate) fn get_token_scopes(token: &str) -> Result<Vec<String>> {
494 let claims = extract_claims_unsafe(token)?;
495
496 if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
497 Ok(scope_str.split_whitespace().map(String::from).collect())
498 } else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
499 Ok(scopes_array
500 .iter()
501 .filter_map(|v| v.as_str())
502 .map(String::from)
503 .collect())
504 } else {
505 Ok(vec![])
506 }
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 fn make_manager() -> JwtManager {
515 let config = JwtConfig::with_symmetric_key(
516 b"a-test-secret-key-with-enough-bytes-for-hmac",
517 "https://test-issuer.example.com".into(),
518 );
519 JwtManager::new(config)
520 }
521
522 #[test]
525 fn test_jwt_config_symmetric() {
526 let config = JwtConfig::with_symmetric_key(b"secret", "iss".into());
527 assert_eq!(config.issuer, "iss");
528 assert_eq!(config.default_expiration, 3600);
529 }
530
531 #[test]
532 fn test_jwt_config_with_audience() {
533 let config =
534 JwtConfig::with_symmetric_key(b"secret", "iss".into()).with_audience("aud1".into());
535 assert_eq!(config.audiences, vec!["aud1"]);
536 }
537
538 #[test]
539 fn test_jwt_config_with_expiration() {
540 let config = JwtConfig::with_symmetric_key(b"secret", "iss".into()).with_expiration(7200);
541 assert_eq!(config.default_expiration, 7200);
542 }
543
544 #[test]
547 fn test_claims_new() {
548 let claims = CommonJwtClaims::new(
549 "issuer".into(),
550 "subject".into(),
551 vec!["aud".into()],
552 9999999999,
553 );
554 assert_eq!(claims.iss, "issuer");
555 assert_eq!(claims.sub, "subject");
556 assert!(claims.iat > 0);
557 }
558
559 #[test]
560 fn test_claims_with_custom_claim() {
561 let claims = CommonJwtClaims::new("iss".into(), "sub".into(), vec![], 9999999999)
562 .with_custom_claim("role".to_string(), serde_json::json!("admin"));
563 assert_eq!(claims.custom.get("role").unwrap(), "admin");
564 }
565
566 #[test]
567 fn test_claims_with_jti() {
568 let claims = CommonJwtClaims::new("iss".into(), "sub".into(), vec![], 9999999999)
569 .with_jti("test-jti-value".into());
570 assert!(claims.jti.is_some());
571 }
572
573 #[test]
576 fn test_create_and_verify_token() {
577 let mgr = make_manager();
578 let claims = CommonJwtClaims::new(
579 "https://test-issuer.example.com".into(),
580 "user_123".into(),
581 vec![],
582 (SystemTime::now()
583 .duration_since(UNIX_EPOCH)
584 .unwrap()
585 .as_secs()
586 + 3600) as i64,
587 );
588 let token = mgr.create_token(&claims).unwrap();
589 let verified = mgr.verify_token(&token).unwrap();
590 assert_eq!(verified.sub, "user_123");
591 }
592
593 #[test]
594 fn test_verify_invalid_token() {
595 let mgr = make_manager();
596 assert!(mgr.verify_token("not.a.valid.jwt").is_err());
597 }
598
599 #[test]
600 fn test_verify_wrong_key() {
601 let mgr1 = make_manager();
602 let mgr2 = JwtManager::new(JwtConfig::with_symmetric_key(
603 b"different-key-entirely-for-testing",
604 "https://test-issuer.example.com".into(),
605 ));
606 let claims = CommonJwtClaims::new(
607 "https://test-issuer.example.com".into(),
608 "user".into(),
609 vec![],
610 (SystemTime::now()
611 .duration_since(UNIX_EPOCH)
612 .unwrap()
613 .as_secs()
614 + 3600) as i64,
615 );
616 let token = mgr1.create_token(&claims).unwrap();
617 assert!(mgr2.verify_token(&token).is_err());
618 }
619
620 #[test]
623 fn test_create_access_token() {
624 let mgr = make_manager();
625 let token = mgr
626 .create_access_token(
627 "user_1".into(),
628 vec!["read".into()],
629 Some("client_1".into()),
630 )
631 .unwrap();
632 let claims = mgr.verify_token(&token).unwrap();
633 assert_eq!(claims.sub, "user_1");
634 assert!(claims.custom.contains_key("scope"));
635 }
636
637 #[test]
638 fn test_create_refresh_token() {
639 let mgr = make_manager();
640 let token = mgr
641 .create_refresh_token("user_2".into(), "client_2".into())
642 .unwrap();
643 let claims = mgr.verify_token(&token).unwrap();
644 assert_eq!(claims.sub, "user_2");
645 assert_eq!(
646 claims.custom.get("token_type").unwrap(),
647 &serde_json::json!("refresh_token")
648 );
649 }
650
651 #[test]
652 fn test_create_id_token() {
653 let mgr = make_manager();
654 let user_info = HashMap::from([
655 ("name".into(), serde_json::json!("Test User")),
656 ("email".into(), serde_json::json!("test@example.com")),
657 ]);
658 let token = mgr
659 .create_id_token("user_3".into(), Some("nonce_123".into()), None, user_info)
660 .unwrap();
661 let claims = mgr.verify_token(&token).unwrap();
662 assert_eq!(claims.sub, "user_3");
663 assert_eq!(claims.custom.get("nonce").unwrap(), "nonce_123");
664 assert_eq!(
665 claims.custom.get("token_type").unwrap(),
666 &serde_json::json!("id_token")
667 );
668 }
669
670 #[test]
673 fn test_extract_claims_unsafe_works() {
674 let mgr = make_manager();
675 let claims = CommonJwtClaims::new(
676 "https://test-issuer.example.com".into(),
677 "peek_user".into(),
678 vec![],
679 (SystemTime::now()
680 .duration_since(UNIX_EPOCH)
681 .unwrap()
682 .as_secs()
683 + 3600) as i64,
684 );
685 let token = mgr.create_token(&claims).unwrap();
686 let extracted = utils::extract_claims_unsafe(&token).unwrap();
687 assert_eq!(extracted["sub"], "peek_user");
688 }
689
690 #[test]
691 fn test_is_token_expired_not_expired() {
692 let mgr = make_manager();
693 let token = mgr
694 .create_access_token("user".into(), vec![], None)
695 .unwrap();
696 assert!(!utils::is_token_expired(&token).unwrap());
697 }
698}