auth_framework/tokens/
mod.rs

1//! Token management and validation for the authentication framework.
2use crate::errors::{AuthError, Result, TokenError};
3use crate::providers::{OAuthProvider, ProfileExtractor, UserProfile};
4use chrono::{DateTime, Utc};
5use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "postgres-storage")]
8use sqlx::FromRow;
9use std::collections::HashMap;
10use std::time::Duration;
11use uuid::Uuid;
12
13/// Represents an authentication token with all associated metadata.
14#[cfg_attr(feature = "postgres-storage", derive(FromRow))]
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AuthToken {
17    /// Unique token identifier
18    pub token_id: String,
19
20    /// User identifier this token belongs to
21    pub user_id: String,
22
23    /// Access token value
24    pub access_token: String,
25
26    /// Token type (e.g., "bearer")
27    pub token_type: Option<String>,
28
29    /// Subject claim
30    pub subject: Option<String>,
31
32    /// Token issuer
33    pub issuer: Option<String>,
34
35    /// Optional refresh token
36    pub refresh_token: Option<String>,
37
38    /// When the token was issued
39    pub issued_at: DateTime<Utc>,
40
41    /// When the token expires
42    pub expires_at: DateTime<Utc>,
43
44    /// Scopes granted to this token
45    pub scopes: Vec<String>,
46
47    /// Authentication method used to obtain this token
48    pub auth_method: String,
49
50    /// Client ID that requested this token
51    pub client_id: Option<String>,
52
53    /// User profile data (optional)
54    pub user_profile: Option<UserProfile>,
55
56    /// User's permissions
57    pub permissions: Vec<String>,
58
59    /// User's roles
60    pub roles: Vec<String>,
61
62    /// Additional token metadata
63    pub metadata: TokenMetadata,
64}
65
66/// Additional metadata that can be attached to tokens.
67#[derive(Debug, Clone, Serialize, Deserialize, Default)]
68pub struct TokenMetadata {
69    /// IP address where the token was issued
70    pub issued_ip: Option<String>,
71
72    /// User agent of the client
73    pub user_agent: Option<String>,
74
75    /// Device identifier
76    pub device_id: Option<String>,
77
78    /// Session identifier
79    pub session_id: Option<String>,
80
81    /// Whether this token has been revoked
82    pub revoked: bool,
83
84    /// When the token was revoked (if applicable)
85    pub revoked_at: Option<DateTime<Utc>>,
86
87    /// Reason for revocation
88    pub revoked_reason: Option<String>,
89
90    /// Last time this token was used
91    pub last_used: Option<DateTime<Utc>>,
92
93    /// Number of times this token has been used
94    pub use_count: u64,
95
96    /// Custom metadata
97    pub custom: HashMap<String, serde_json::Value>,
98}
99
100#[cfg(feature = "postgres-storage")]
101use sqlx::{Decode, Postgres, Type, postgres::PgValueRef};
102
103#[cfg(feature = "postgres-storage")]
104impl<'r> Decode<'r, Postgres> for TokenMetadata {
105    fn decode(value: PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
106        let json: serde_json::Value = <serde_json::Value as Decode<Postgres>>::decode(value)?;
107        serde_json::from_value(json).map_err(|e| Box::new(e) as sqlx::error::BoxDynError)
108    }
109}
110
111#[cfg(feature = "postgres-storage")]
112impl Type<Postgres> for TokenMetadata {
113    fn type_info() -> sqlx::postgres::PgTypeInfo {
114        <serde_json::Value as Type<Postgres>>::type_info()
115    }
116    fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
117        <serde_json::Value as Type<Postgres>>::compatible(ty)
118    }
119}
120
121/// Information about a user extracted from a token.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct TokenInfo {
124    /// User identifier
125    pub user_id: String,
126
127    /// Username or email
128    pub username: Option<String>,
129
130    /// User's email address
131    pub email: Option<String>,
132
133    /// User's display name
134    pub name: Option<String>,
135
136    /// User's roles
137    pub roles: Vec<String>,
138
139    /// User's permissions
140    pub permissions: Vec<String>,
141
142    /// Additional user attributes
143    pub attributes: HashMap<String, serde_json::Value>,
144}
145
146/// JWT claims structure used internally.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct JwtClaims {
149    /// Subject (user ID)
150    pub sub: String,
151
152    /// Issuer
153    pub iss: String,
154
155    /// Audience
156    pub aud: String,
157
158    /// Expiration time
159    pub exp: i64,
160
161    /// Issued at
162    pub iat: i64,
163
164    /// Not before
165    pub nbf: i64,
166
167    /// JWT ID
168    pub jti: String,
169
170    /// Scopes
171    pub scope: String,
172
173    /// User permissions
174    pub permissions: Option<Vec<String>>,
175
176    /// User roles
177    pub roles: Option<Vec<String>>,
178
179    /// Client ID
180    pub client_id: Option<String>,
181
182    /// Custom claims
183    #[serde(flatten)]
184    pub custom: HashMap<String, serde_json::Value>,
185}
186
187/// Token manager for creating, validating, and managing tokens.
188pub struct TokenManager {
189    /// JWT encoding key
190    encoding_key: EncodingKey,
191
192    /// JWT decoding key
193    decoding_key: DecodingKey,
194
195    /// Key material for recreating keys during clone
196    key_material: KeyMaterial,
197
198    /// JWT algorithm
199    algorithm: Algorithm,
200
201    /// Token issuer
202    issuer: String,
203
204    /// Token audience
205    audience: String,
206
207    /// Default token lifetime
208    default_lifetime: Duration,
209}
210
211/// Key material for cloning TokenManager
212#[derive(Clone)]
213enum KeyMaterial {
214    /// HMAC secret
215    Hmac(Vec<u8>),
216    /// RSA private and public keys
217    Rsa { private: Vec<u8>, public: Vec<u8> },
218}
219
220impl AuthToken {
221    /// Create a new authentication token.
222    pub fn new(
223        user_id: impl Into<String>,
224        access_token: impl Into<String>,
225        expires_in: std::time::Duration,
226        auth_method: impl Into<String>,
227    ) -> Self {
228        let now = Utc::now();
229        let expires_in_chrono =
230            chrono::Duration::from_std(expires_in).unwrap_or(chrono::Duration::hours(1));
231
232        Self {
233            token_id: Uuid::new_v4().to_string(),
234            user_id: user_id.into(),
235            access_token: access_token.into(),
236            refresh_token: None,
237            token_type: Some("Bearer".to_string()),
238            subject: None,
239            issuer: None,
240            issued_at: now,
241            expires_at: now + expires_in_chrono,
242            scopes: Vec::new(),
243            auth_method: auth_method.into(),
244            client_id: None,
245            user_profile: None,
246            permissions: Vec::new(),
247            roles: Vec::new(),
248            metadata: TokenMetadata::default(),
249        }
250    }
251
252    /// Get the access token string.
253    pub fn access_token(&self) -> &str {
254        &self.access_token
255    }
256
257    /// Get the user ID.
258    pub fn user_id(&self) -> &str {
259        &self.user_id
260    }
261
262    /// Get the expiration time.
263    pub fn expires_at(&self) -> DateTime<Utc> {
264        self.expires_at
265    }
266
267    /// Get the token value
268    pub fn token_value(&self) -> &str {
269        &self.access_token
270    }
271
272    /// Get the token type
273    pub fn token_type(&self) -> Option<&str> {
274        self.token_type.as_deref()
275    }
276
277    /// Get the subject claim
278    pub fn subject(&self) -> Option<&str> {
279        self.subject.as_deref()
280    }
281
282    /// Get the issuer
283    pub fn issuer(&self) -> Option<&str> {
284        self.issuer.as_deref()
285    }
286
287    /// Check if the token has expired.
288    pub fn is_expired(&self) -> bool {
289        Utc::now() > self.expires_at
290    }
291
292    /// Check if the token is expiring within the given duration.
293    pub fn is_expiring(&self, within: Duration) -> bool {
294        Utc::now() + within > self.expires_at
295    }
296
297    /// Check if the token has been revoked.
298    pub fn is_revoked(&self) -> bool {
299        self.metadata.revoked
300    }
301
302    /// Check if the token is valid (not expired and not revoked).
303    pub fn is_valid(&self) -> bool {
304        !self.is_expired() && !self.is_revoked()
305    }
306
307    /// Revoke the token.
308    pub fn revoke(&mut self, reason: Option<String>) {
309        self.metadata.revoked = true;
310        self.metadata.revoked_at = Some(Utc::now());
311        self.metadata.revoked_reason = reason;
312    }
313
314    /// Update the last used time and increment use count.
315    pub fn mark_used(&mut self) {
316        self.metadata.last_used = Some(Utc::now());
317        self.metadata.use_count += 1;
318    }
319
320    /// Add a scope to the token.
321    pub fn add_scope(&mut self, scope: impl Into<String>) {
322        let scope = scope.into();
323        if !self.scopes.contains(&scope) {
324            self.scopes.push(scope);
325        }
326    }
327
328    /// Check if the token has a specific scope.
329    pub fn has_scope(&self, scope: &str) -> bool {
330        self.scopes.contains(&scope.to_string())
331    }
332
333    /// Set the refresh token.
334    pub fn with_refresh_token(mut self, refresh_token: impl Into<String>) -> Self {
335        self.refresh_token = Some(refresh_token.into());
336        self
337    }
338
339    /// Set the client ID.
340    pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
341        self.client_id = Some(client_id.into());
342        self
343    }
344
345    /// Set the token scopes.
346    pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
347        self.scopes = scopes;
348        self
349    }
350
351    /// Add metadata to the token.
352    pub fn with_metadata(mut self, metadata: TokenMetadata) -> Self {
353        self.metadata = metadata;
354        self
355    }
356
357    /// Get time until expiration.
358    pub fn time_until_expiry(&self) -> Duration {
359        let now = Utc::now();
360        if self.expires_at > now {
361            (self.expires_at - now).to_std().unwrap_or(Duration::ZERO)
362        } else {
363            Duration::ZERO
364        }
365    }
366
367    /// Add a custom claim to the token metadata
368    pub fn add_custom_claim(&mut self, key: impl Into<String>, value: serde_json::Value) {
369        self.metadata.custom.insert(key.into(), value);
370    }
371
372    /// Get a custom claim from the token metadata
373    pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
374        self.metadata.custom.get(key)
375    }
376
377    /// Check if the token has a specific permission
378    pub fn has_permission(&self, permission: &str) -> bool {
379        self.permissions.contains(&permission.to_string())
380    }
381
382    /// Add a permission to the token
383    pub fn add_permission(&mut self, permission: impl Into<String>) {
384        let permission = permission.into();
385        if !self.permissions.contains(&permission) {
386            self.permissions.push(permission);
387        }
388    }
389
390    /// Add a role to the token
391    pub fn add_role(&mut self, role: impl Into<String>) {
392        let role = role.into();
393        if !self.roles.contains(&role) {
394            self.roles.push(role);
395        }
396    }
397
398    /// Check if the token has a specific role
399    pub fn has_role(&self, role: &str) -> bool {
400        self.roles.contains(&role.to_string())
401    }
402
403    /// Set the permissions
404    pub fn with_permissions(mut self, permissions: Vec<String>) -> Self {
405        self.permissions = permissions;
406        self
407    }
408
409    /// Set the roles
410    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
411        self.roles = roles;
412        self
413    }
414}
415
416impl Clone for TokenManager {
417    fn clone(&self) -> Self {
418        match &self.key_material {
419            KeyMaterial::Hmac(secret) => Self {
420                encoding_key: EncodingKey::from_secret(secret),
421                decoding_key: DecodingKey::from_secret(secret),
422                key_material: self.key_material.clone(),
423                algorithm: self.algorithm,
424                issuer: self.issuer.clone(),
425                audience: self.audience.clone(),
426                default_lifetime: self.default_lifetime,
427            },
428            KeyMaterial::Rsa { private, public } => Self {
429                encoding_key: EncodingKey::from_rsa_pem(private).expect("Valid RSA private key"),
430                decoding_key: DecodingKey::from_rsa_pem(public).expect("Valid RSA public key"),
431                key_material: self.key_material.clone(),
432                algorithm: self.algorithm,
433                issuer: self.issuer.clone(),
434                audience: self.audience.clone(),
435                default_lifetime: self.default_lifetime,
436            },
437        }
438    }
439}
440
441impl TokenManager {
442    /// Create a new token manager with HMAC key.
443    pub fn new_hmac(secret: &[u8], issuer: impl Into<String>, audience: impl Into<String>) -> Self {
444        Self {
445            encoding_key: EncodingKey::from_secret(secret),
446            decoding_key: DecodingKey::from_secret(secret),
447            key_material: KeyMaterial::Hmac(secret.to_vec()),
448            algorithm: Algorithm::HS256,
449            issuer: issuer.into(),
450            audience: audience.into(),
451            default_lifetime: Duration::from_secs(3600), // 1 hour
452        }
453    }
454
455    /// Create a new token manager with RSA keys.
456    ///
457    /// ## RSA Key Format Support
458    ///
459    /// This method supports RSA keys in both standard PEM formats:
460    /// - **PKCS#1**: `-----BEGIN RSA PRIVATE KEY-----` (traditional RSA format)
461    /// - **PKCS#8**: `-----BEGIN PRIVATE KEY-----` (modern standard format, recommended)
462    ///
463    /// Both formats are automatically detected and parsed. No format conversion is required.
464    ///
465    /// ## Example
466    ///
467    /// ```rust,no_run
468    /// use auth_framework::tokens::TokenManager;
469    ///
470    /// // Both PKCS#1 and PKCS#8 formats work
471    /// let private_key = include_bytes!("../../private.pem");  // Either format
472    /// let public_key = include_bytes!("../../public.pem");
473    ///
474    /// let manager = TokenManager::new_rsa(
475    ///     private_key,
476    ///     public_key,
477    ///     "my-service",
478    ///     "my-audience"
479    /// )?;
480    /// # Ok::<(), auth_framework::errors::AuthError>(())
481    /// ```
482    pub fn new_rsa(
483        private_key: &[u8],
484        public_key: &[u8],
485        issuer: impl Into<String>,
486        audience: impl Into<String>,
487    ) -> Result<Self> {
488        let encoding_key = EncodingKey::from_rsa_pem(private_key)
489            .map_err(|e| AuthError::crypto(format!("Invalid RSA private key: {e}")))?;
490
491        let decoding_key = DecodingKey::from_rsa_pem(public_key)
492            .map_err(|e| AuthError::crypto(format!("Invalid RSA public key: {e}")))?;
493
494        Ok(Self {
495            encoding_key,
496            decoding_key,
497            key_material: KeyMaterial::Rsa {
498                private: private_key.to_vec(),
499                public: public_key.to_vec(),
500            },
501            algorithm: Algorithm::RS256,
502            issuer: issuer.into(),
503            audience: audience.into(),
504            default_lifetime: Duration::from_secs(3600), // 1 hour
505        })
506    }
507
508    /// Set the default token lifetime.
509    pub fn with_default_lifetime(mut self, lifetime: Duration) -> Self {
510        self.default_lifetime = lifetime;
511        self
512    }
513
514    /// Create a new JWT token.
515    pub fn create_jwt_token(
516        &self,
517        user_id: impl Into<String>,
518        scopes: Vec<String>,
519        lifetime: Option<Duration>,
520    ) -> Result<String> {
521        let user_id = user_id.into();
522        let lifetime = lifetime.unwrap_or(self.default_lifetime);
523        let now = Utc::now();
524        let exp = now + chrono::Duration::from_std(lifetime).unwrap_or(chrono::Duration::hours(1));
525
526        let claims = JwtClaims {
527            sub: user_id,
528            iss: self.issuer.clone(),
529            aud: self.audience.clone(),
530            exp: exp.timestamp(),
531            iat: now.timestamp(),
532            nbf: now.timestamp(),
533            jti: Uuid::new_v4().to_string(),
534            scope: scopes.join(" "),
535            permissions: None,
536            roles: None,
537            client_id: None,
538            custom: HashMap::new(),
539        };
540
541        let header = Header::new(self.algorithm);
542
543        encode(&header, &claims, &self.encoding_key)
544            .map_err(|e| TokenError::creation_failed(format!("JWT encoding failed: {e}")).into())
545    }
546
547    /// Validate and decode a JWT token.
548    pub fn validate_jwt_token(&self, token: &str) -> Result<JwtClaims> {
549        let mut validation = Validation::new(self.algorithm);
550        validation.set_issuer(&[&self.issuer]);
551        validation.set_audience(&[&self.audience]);
552
553        let token_data =
554            decode::<JwtClaims>(token, &self.decoding_key, &validation).map_err(|e| {
555                match e.kind() {
556                    jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
557                        AuthError::Token(TokenError::Expired)
558                    }
559                    _ => AuthError::Token(TokenError::Invalid {
560                        message: "Invalid token format".to_string(),
561                    }),
562                }
563            })?;
564
565        Ok(token_data.claims)
566    }
567
568    /// Create a complete authentication token with JWT.
569    pub fn create_auth_token(
570        &self,
571        user_id: impl Into<String>,
572        scopes: Vec<String>,
573        auth_method: impl Into<String>,
574        lifetime: Option<std::time::Duration>,
575    ) -> Result<AuthToken> {
576        let user_id_str = user_id.into();
577        let lifetime = lifetime.unwrap_or(self.default_lifetime);
578
579        let jwt_token = self.create_jwt_token(&user_id_str, scopes.clone(), Some(lifetime))?;
580
581        let token =
582            AuthToken::new(user_id_str, jwt_token, lifetime, auth_method).with_scopes(scopes);
583
584        Ok(token)
585    }
586
587    /// Validate an authentication token.
588    pub fn validate_auth_token(&self, token: &AuthToken) -> Result<()> {
589        // Check if token is expired
590        if token.is_expired() {
591            return Err(TokenError::Expired.into());
592        }
593
594        // Check if token is revoked
595        if token.is_revoked() {
596            return Err(TokenError::Invalid {
597                message: "Token has been revoked".to_string(),
598            }
599            .into());
600        }
601
602        // Validate JWT if it's a JWT token
603        if token.auth_method == "jwt" || token.access_token.contains('.') {
604            self.validate_jwt_token(&token.access_token)?;
605        }
606
607        Ok(())
608    }
609
610    /// Refresh a token (create a new one with extended lifetime).
611    pub fn refresh_token(&self, token: &AuthToken) -> Result<AuthToken> {
612        if token.is_expired() {
613            return Err(TokenError::Expired.into());
614        }
615
616        if token.is_revoked() {
617            return Err(TokenError::Invalid {
618                message: "Cannot refresh revoked token".to_string(),
619            }
620            .into());
621        }
622
623        // Create a new token with the same properties but new expiry
624        self.create_auth_token(
625            &token.user_id,
626            token.scopes.clone(),
627            &token.auth_method,
628            Some(self.default_lifetime),
629        )
630    }
631
632    /// Extract token information from a JWT.
633    pub fn extract_token_info(&self, token: &str) -> Result<TokenInfo> {
634        let claims = self.validate_jwt_token(token)?;
635
636        Ok(TokenInfo {
637            user_id: claims.sub,
638            username: claims
639                .custom
640                .get("username")
641                .and_then(|v| v.as_str())
642                .map(|s| s.to_string()),
643            email: claims
644                .custom
645                .get("email")
646                .and_then(|v| v.as_str())
647                .map(|s| s.to_string()),
648            name: claims
649                .custom
650                .get("name")
651                .and_then(|v| v.as_str())
652                .map(|s| s.to_string()),
653            roles: claims
654                .custom
655                .get("roles")
656                .and_then(|v| v.as_array())
657                .map(|arr| {
658                    arr.iter()
659                        .filter_map(|v| v.as_str())
660                        .map(|s| s.to_string())
661                        .collect()
662                })
663                .unwrap_or_default(),
664            permissions: claims
665                .scope
666                .split_whitespace()
667                .map(|s| s.to_string())
668                .collect(),
669            attributes: claims.custom,
670        })
671    }
672}
673
674/// Trait for converting tokens to user profiles
675#[async_trait::async_trait]
676pub trait TokenToProfile {
677    /// Convert this token to a user profile using the specified provider
678    async fn to_profile(&self, provider: &OAuthProvider) -> Result<UserProfile>;
679
680    /// Convert this token to a user profile with a custom extractor
681    async fn to_profile_with_extractor(
682        &self,
683        provider: &OAuthProvider,
684        extractor: &ProfileExtractor,
685    ) -> Result<UserProfile>;
686}
687
688#[async_trait::async_trait]
689impl TokenToProfile for AuthToken {
690    async fn to_profile(&self, provider: &OAuthProvider) -> Result<UserProfile> {
691        let extractor = ProfileExtractor::new();
692        extractor.extract_profile(self, provider).await
693    }
694
695    async fn to_profile_with_extractor(
696        &self,
697        provider: &OAuthProvider,
698        extractor: &ProfileExtractor,
699    ) -> Result<UserProfile> {
700        extractor.extract_profile(self, provider).await
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    #[test]
709    fn test_auth_token_creation() {
710        let token = AuthToken::new(
711            "user123",
712            "token123",
713            Duration::from_secs(3600), // 1 hour
714            "password",
715        );
716
717        assert_eq!(token.user_id(), "user123");
718        assert_eq!(token.access_token(), "token123");
719        assert!(!token.is_expired());
720        assert!(!token.is_revoked());
721        assert!(token.is_valid());
722    }
723
724    #[test]
725    fn test_token_expiry() {
726        let token = AuthToken::new("user123", "token123", Duration::from_millis(1), "password");
727
728        // Wait a bit to ensure expiry
729        std::thread::sleep(std::time::Duration::from_millis(10));
730
731        assert!(token.is_expired());
732        assert!(!token.is_valid());
733    }
734
735    #[test]
736    fn test_token_revocation() {
737        let mut token = AuthToken::new(
738            "user123",
739            "token123",
740            Duration::from_secs(3600), // 1 hour
741            "password",
742        );
743
744        assert!(!token.is_revoked());
745
746        token.revoke(Some("User logout".to_string()));
747
748        assert!(token.is_revoked());
749        assert!(!token.is_valid());
750        assert_eq!(
751            token.metadata.revoked_reason,
752            Some("User logout".to_string())
753        );
754    }
755
756    #[tokio::test]
757    async fn test_jwt_token_manager() {
758        let secret = b"test-secret-key";
759        let manager = TokenManager::new_hmac(secret, "test-issuer", "test-audience");
760
761        let token = manager
762            .create_jwt_token(
763                "user123",
764                vec!["read".to_string(), "write".to_string()],
765                Some(Duration::from_secs(3600)), // 1 hour
766            )
767            .unwrap();
768
769        let claims = manager.validate_jwt_token(&token).unwrap();
770        assert_eq!(claims.sub, "user123");
771        assert_eq!(claims.scope, "read write");
772    }
773}
774
775// #[cfg(test)]
776// pub mod token_edge_tests;
777
778