1use crate::errors::{AuthError, Result, TokenError};
3use crate::providers::{OAuthProvider, ProfileExtractor, UserProfile};
4use chrono::{DateTime, Utc};
5use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "postgres-storage")]
8use sqlx::FromRow;
9use std::collections::HashMap;
10use std::time::Duration;
11use uuid::Uuid;
12
13#[cfg_attr(feature = "postgres-storage", derive(FromRow))]
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AuthToken {
17 pub token_id: String,
19
20 pub user_id: String,
22
23 pub access_token: String,
25
26 pub token_type: Option<String>,
28
29 pub subject: Option<String>,
31
32 pub issuer: Option<String>,
34
35 pub refresh_token: Option<String>,
37
38 pub issued_at: DateTime<Utc>,
40
41 pub expires_at: DateTime<Utc>,
43
44 pub scopes: Vec<String>,
46
47 pub auth_method: String,
49
50 pub client_id: Option<String>,
52
53 pub user_profile: Option<UserProfile>,
55
56 pub permissions: Vec<String>,
58
59 pub roles: Vec<String>,
61
62 pub metadata: TokenMetadata,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, Default)]
68pub struct TokenMetadata {
69 pub issued_ip: Option<String>,
71
72 pub user_agent: Option<String>,
74
75 pub device_id: Option<String>,
77
78 pub session_id: Option<String>,
80
81 pub revoked: bool,
83
84 pub revoked_at: Option<DateTime<Utc>>,
86
87 pub revoked_reason: Option<String>,
89
90 pub last_used: Option<DateTime<Utc>>,
92
93 pub use_count: u64,
95
96 pub custom: HashMap<String, serde_json::Value>,
98}
99
100#[cfg(feature = "postgres-storage")]
101use sqlx::{Decode, Postgres, Type, postgres::PgValueRef};
102
103#[cfg(feature = "postgres-storage")]
104impl<'r> Decode<'r, Postgres> for TokenMetadata {
105 fn decode(value: PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
106 let json: serde_json::Value = <serde_json::Value as Decode<Postgres>>::decode(value)?;
107 serde_json::from_value(json).map_err(|e| Box::new(e) as sqlx::error::BoxDynError)
108 }
109}
110
111#[cfg(feature = "postgres-storage")]
112impl Type<Postgres> for TokenMetadata {
113 fn type_info() -> sqlx::postgres::PgTypeInfo {
114 <serde_json::Value as Type<Postgres>>::type_info()
115 }
116 fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
117 <serde_json::Value as Type<Postgres>>::compatible(ty)
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct TokenInfo {
124 pub user_id: String,
126
127 pub username: Option<String>,
129
130 pub email: Option<String>,
132
133 pub name: Option<String>,
135
136 pub roles: Vec<String>,
138
139 pub permissions: Vec<String>,
141
142 pub attributes: HashMap<String, serde_json::Value>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct JwtClaims {
149 pub sub: String,
151
152 pub iss: String,
154
155 pub aud: String,
157
158 pub exp: i64,
160
161 pub iat: i64,
163
164 pub nbf: i64,
166
167 pub jti: String,
169
170 pub scope: String,
172
173 pub permissions: Option<Vec<String>>,
175
176 pub roles: Option<Vec<String>>,
178
179 pub client_id: Option<String>,
181
182 #[serde(flatten)]
184 pub custom: HashMap<String, serde_json::Value>,
185}
186
187pub struct TokenManager {
189 encoding_key: EncodingKey,
191
192 decoding_key: DecodingKey,
194
195 key_material: KeyMaterial,
197
198 algorithm: Algorithm,
200
201 issuer: String,
203
204 audience: String,
206
207 default_lifetime: Duration,
209}
210
211#[derive(Clone)]
213enum KeyMaterial {
214 Hmac(Vec<u8>),
216 Rsa { private: Vec<u8>, public: Vec<u8> },
218}
219
220impl AuthToken {
221 pub fn new(
223 user_id: impl Into<String>,
224 access_token: impl Into<String>,
225 expires_in: std::time::Duration,
226 auth_method: impl Into<String>,
227 ) -> Self {
228 let now = Utc::now();
229 let expires_in_chrono =
230 chrono::Duration::from_std(expires_in).unwrap_or(chrono::Duration::hours(1));
231
232 Self {
233 token_id: Uuid::new_v4().to_string(),
234 user_id: user_id.into(),
235 access_token: access_token.into(),
236 refresh_token: None,
237 token_type: Some("Bearer".to_string()),
238 subject: None,
239 issuer: None,
240 issued_at: now,
241 expires_at: now + expires_in_chrono,
242 scopes: Vec::new(),
243 auth_method: auth_method.into(),
244 client_id: None,
245 user_profile: None,
246 permissions: Vec::new(),
247 roles: Vec::new(),
248 metadata: TokenMetadata::default(),
249 }
250 }
251
252 pub fn access_token(&self) -> &str {
254 &self.access_token
255 }
256
257 pub fn user_id(&self) -> &str {
259 &self.user_id
260 }
261
262 pub fn expires_at(&self) -> DateTime<Utc> {
264 self.expires_at
265 }
266
267 pub fn token_value(&self) -> &str {
269 &self.access_token
270 }
271
272 pub fn token_type(&self) -> Option<&str> {
274 self.token_type.as_deref()
275 }
276
277 pub fn subject(&self) -> Option<&str> {
279 self.subject.as_deref()
280 }
281
282 pub fn issuer(&self) -> Option<&str> {
284 self.issuer.as_deref()
285 }
286
287 pub fn is_expired(&self) -> bool {
289 Utc::now() > self.expires_at
290 }
291
292 pub fn is_expiring(&self, within: Duration) -> bool {
294 Utc::now() + within > self.expires_at
295 }
296
297 pub fn is_revoked(&self) -> bool {
299 self.metadata.revoked
300 }
301
302 pub fn is_valid(&self) -> bool {
304 !self.is_expired() && !self.is_revoked()
305 }
306
307 pub fn revoke(&mut self, reason: Option<String>) {
309 self.metadata.revoked = true;
310 self.metadata.revoked_at = Some(Utc::now());
311 self.metadata.revoked_reason = reason;
312 }
313
314 pub fn mark_used(&mut self) {
316 self.metadata.last_used = Some(Utc::now());
317 self.metadata.use_count += 1;
318 }
319
320 pub fn add_scope(&mut self, scope: impl Into<String>) {
322 let scope = scope.into();
323 if !self.scopes.contains(&scope) {
324 self.scopes.push(scope);
325 }
326 }
327
328 pub fn has_scope(&self, scope: &str) -> bool {
330 self.scopes.contains(&scope.to_string())
331 }
332
333 pub fn with_refresh_token(mut self, refresh_token: impl Into<String>) -> Self {
335 self.refresh_token = Some(refresh_token.into());
336 self
337 }
338
339 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
341 self.client_id = Some(client_id.into());
342 self
343 }
344
345 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
347 self.scopes = scopes;
348 self
349 }
350
351 pub fn with_metadata(mut self, metadata: TokenMetadata) -> Self {
353 self.metadata = metadata;
354 self
355 }
356
357 pub fn time_until_expiry(&self) -> Duration {
359 let now = Utc::now();
360 if self.expires_at > now {
361 (self.expires_at - now).to_std().unwrap_or(Duration::ZERO)
362 } else {
363 Duration::ZERO
364 }
365 }
366
367 pub fn add_custom_claim(&mut self, key: impl Into<String>, value: serde_json::Value) {
369 self.metadata.custom.insert(key.into(), value);
370 }
371
372 pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
374 self.metadata.custom.get(key)
375 }
376
377 pub fn has_permission(&self, permission: &str) -> bool {
379 self.permissions.contains(&permission.to_string())
380 }
381
382 pub fn add_permission(&mut self, permission: impl Into<String>) {
384 let permission = permission.into();
385 if !self.permissions.contains(&permission) {
386 self.permissions.push(permission);
387 }
388 }
389
390 pub fn add_role(&mut self, role: impl Into<String>) {
392 let role = role.into();
393 if !self.roles.contains(&role) {
394 self.roles.push(role);
395 }
396 }
397
398 pub fn has_role(&self, role: &str) -> bool {
400 self.roles.contains(&role.to_string())
401 }
402
403 pub fn with_permissions(mut self, permissions: Vec<String>) -> Self {
405 self.permissions = permissions;
406 self
407 }
408
409 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
411 self.roles = roles;
412 self
413 }
414}
415
416impl Clone for TokenManager {
417 fn clone(&self) -> Self {
418 match &self.key_material {
419 KeyMaterial::Hmac(secret) => Self {
420 encoding_key: EncodingKey::from_secret(secret),
421 decoding_key: DecodingKey::from_secret(secret),
422 key_material: self.key_material.clone(),
423 algorithm: self.algorithm,
424 issuer: self.issuer.clone(),
425 audience: self.audience.clone(),
426 default_lifetime: self.default_lifetime,
427 },
428 KeyMaterial::Rsa { private, public } => Self {
429 encoding_key: EncodingKey::from_rsa_pem(private).expect("Valid RSA private key"),
430 decoding_key: DecodingKey::from_rsa_pem(public).expect("Valid RSA public key"),
431 key_material: self.key_material.clone(),
432 algorithm: self.algorithm,
433 issuer: self.issuer.clone(),
434 audience: self.audience.clone(),
435 default_lifetime: self.default_lifetime,
436 },
437 }
438 }
439}
440
441impl TokenManager {
442 pub fn new_hmac(secret: &[u8], issuer: impl Into<String>, audience: impl Into<String>) -> Self {
444 Self {
445 encoding_key: EncodingKey::from_secret(secret),
446 decoding_key: DecodingKey::from_secret(secret),
447 key_material: KeyMaterial::Hmac(secret.to_vec()),
448 algorithm: Algorithm::HS256,
449 issuer: issuer.into(),
450 audience: audience.into(),
451 default_lifetime: Duration::from_secs(3600), }
453 }
454
455 pub fn new_rsa(
483 private_key: &[u8],
484 public_key: &[u8],
485 issuer: impl Into<String>,
486 audience: impl Into<String>,
487 ) -> Result<Self> {
488 let encoding_key = EncodingKey::from_rsa_pem(private_key)
489 .map_err(|e| AuthError::crypto(format!("Invalid RSA private key: {e}")))?;
490
491 let decoding_key = DecodingKey::from_rsa_pem(public_key)
492 .map_err(|e| AuthError::crypto(format!("Invalid RSA public key: {e}")))?;
493
494 Ok(Self {
495 encoding_key,
496 decoding_key,
497 key_material: KeyMaterial::Rsa {
498 private: private_key.to_vec(),
499 public: public_key.to_vec(),
500 },
501 algorithm: Algorithm::RS256,
502 issuer: issuer.into(),
503 audience: audience.into(),
504 default_lifetime: Duration::from_secs(3600), })
506 }
507
508 pub fn with_default_lifetime(mut self, lifetime: Duration) -> Self {
510 self.default_lifetime = lifetime;
511 self
512 }
513
514 pub fn create_jwt_token(
516 &self,
517 user_id: impl Into<String>,
518 scopes: Vec<String>,
519 lifetime: Option<Duration>,
520 ) -> Result<String> {
521 let user_id = user_id.into();
522 let lifetime = lifetime.unwrap_or(self.default_lifetime);
523 let now = Utc::now();
524 let exp = now + chrono::Duration::from_std(lifetime).unwrap_or(chrono::Duration::hours(1));
525
526 let claims = JwtClaims {
527 sub: user_id,
528 iss: self.issuer.clone(),
529 aud: self.audience.clone(),
530 exp: exp.timestamp(),
531 iat: now.timestamp(),
532 nbf: now.timestamp(),
533 jti: Uuid::new_v4().to_string(),
534 scope: scopes.join(" "),
535 permissions: None,
536 roles: None,
537 client_id: None,
538 custom: HashMap::new(),
539 };
540
541 let header = Header::new(self.algorithm);
542
543 encode(&header, &claims, &self.encoding_key)
544 .map_err(|e| TokenError::creation_failed(format!("JWT encoding failed: {e}")).into())
545 }
546
547 pub fn validate_jwt_token(&self, token: &str) -> Result<JwtClaims> {
549 let mut validation = Validation::new(self.algorithm);
550 validation.set_issuer(&[&self.issuer]);
551 validation.set_audience(&[&self.audience]);
552
553 let token_data =
554 decode::<JwtClaims>(token, &self.decoding_key, &validation).map_err(|e| {
555 match e.kind() {
556 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
557 AuthError::Token(TokenError::Expired)
558 }
559 _ => AuthError::Token(TokenError::Invalid {
560 message: "Invalid token format".to_string(),
561 }),
562 }
563 })?;
564
565 Ok(token_data.claims)
566 }
567
568 pub fn create_auth_token(
570 &self,
571 user_id: impl Into<String>,
572 scopes: Vec<String>,
573 auth_method: impl Into<String>,
574 lifetime: Option<std::time::Duration>,
575 ) -> Result<AuthToken> {
576 let user_id_str = user_id.into();
577 let lifetime = lifetime.unwrap_or(self.default_lifetime);
578
579 let jwt_token = self.create_jwt_token(&user_id_str, scopes.clone(), Some(lifetime))?;
580
581 let token =
582 AuthToken::new(user_id_str, jwt_token, lifetime, auth_method).with_scopes(scopes);
583
584 Ok(token)
585 }
586
587 pub fn validate_auth_token(&self, token: &AuthToken) -> Result<()> {
589 if token.is_expired() {
591 return Err(TokenError::Expired.into());
592 }
593
594 if token.is_revoked() {
596 return Err(TokenError::Invalid {
597 message: "Token has been revoked".to_string(),
598 }
599 .into());
600 }
601
602 if token.auth_method == "jwt" || token.access_token.contains('.') {
604 self.validate_jwt_token(&token.access_token)?;
605 }
606
607 Ok(())
608 }
609
610 pub fn refresh_token(&self, token: &AuthToken) -> Result<AuthToken> {
612 if token.is_expired() {
613 return Err(TokenError::Expired.into());
614 }
615
616 if token.is_revoked() {
617 return Err(TokenError::Invalid {
618 message: "Cannot refresh revoked token".to_string(),
619 }
620 .into());
621 }
622
623 self.create_auth_token(
625 &token.user_id,
626 token.scopes.clone(),
627 &token.auth_method,
628 Some(self.default_lifetime),
629 )
630 }
631
632 pub fn extract_token_info(&self, token: &str) -> Result<TokenInfo> {
634 let claims = self.validate_jwt_token(token)?;
635
636 Ok(TokenInfo {
637 user_id: claims.sub,
638 username: claims
639 .custom
640 .get("username")
641 .and_then(|v| v.as_str())
642 .map(|s| s.to_string()),
643 email: claims
644 .custom
645 .get("email")
646 .and_then(|v| v.as_str())
647 .map(|s| s.to_string()),
648 name: claims
649 .custom
650 .get("name")
651 .and_then(|v| v.as_str())
652 .map(|s| s.to_string()),
653 roles: claims
654 .custom
655 .get("roles")
656 .and_then(|v| v.as_array())
657 .map(|arr| {
658 arr.iter()
659 .filter_map(|v| v.as_str())
660 .map(|s| s.to_string())
661 .collect()
662 })
663 .unwrap_or_default(),
664 permissions: claims
665 .scope
666 .split_whitespace()
667 .map(|s| s.to_string())
668 .collect(),
669 attributes: claims.custom,
670 })
671 }
672}
673
674#[async_trait::async_trait]
676pub trait TokenToProfile {
677 async fn to_profile(&self, provider: &OAuthProvider) -> Result<UserProfile>;
679
680 async fn to_profile_with_extractor(
682 &self,
683 provider: &OAuthProvider,
684 extractor: &ProfileExtractor,
685 ) -> Result<UserProfile>;
686}
687
688#[async_trait::async_trait]
689impl TokenToProfile for AuthToken {
690 async fn to_profile(&self, provider: &OAuthProvider) -> Result<UserProfile> {
691 let extractor = ProfileExtractor::new();
692 extractor.extract_profile(self, provider).await
693 }
694
695 async fn to_profile_with_extractor(
696 &self,
697 provider: &OAuthProvider,
698 extractor: &ProfileExtractor,
699 ) -> Result<UserProfile> {
700 extractor.extract_profile(self, provider).await
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_auth_token_creation() {
710 let token = AuthToken::new(
711 "user123",
712 "token123",
713 Duration::from_secs(3600), "password",
715 );
716
717 assert_eq!(token.user_id(), "user123");
718 assert_eq!(token.access_token(), "token123");
719 assert!(!token.is_expired());
720 assert!(!token.is_revoked());
721 assert!(token.is_valid());
722 }
723
724 #[test]
725 fn test_token_expiry() {
726 let token = AuthToken::new("user123", "token123", Duration::from_millis(1), "password");
727
728 std::thread::sleep(std::time::Duration::from_millis(10));
730
731 assert!(token.is_expired());
732 assert!(!token.is_valid());
733 }
734
735 #[test]
736 fn test_token_revocation() {
737 let mut token = AuthToken::new(
738 "user123",
739 "token123",
740 Duration::from_secs(3600), "password",
742 );
743
744 assert!(!token.is_revoked());
745
746 token.revoke(Some("User logout".to_string()));
747
748 assert!(token.is_revoked());
749 assert!(!token.is_valid());
750 assert_eq!(
751 token.metadata.revoked_reason,
752 Some("User logout".to_string())
753 );
754 }
755
756 #[tokio::test]
757 async fn test_jwt_token_manager() {
758 let secret = b"test-secret-key";
759 let manager = TokenManager::new_hmac(secret, "test-issuer", "test-audience");
760
761 let token = manager
762 .create_jwt_token(
763 "user123",
764 vec!["read".to_string(), "write".to_string()],
765 Some(Duration::from_secs(3600)), )
767 .unwrap();
768
769 let claims = manager.validate_jwt_token(&token).unwrap();
770 assert_eq!(claims.sub, "user123");
771 assert_eq!(claims.scope, "read write");
772 }
773}
774
775