use crate::errors::{AuthError, Result, TokenError};
use crate::providers::{OAuthProvider, ProfileExtractor, UserProfile};
use chrono::{DateTime, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
#[cfg(feature = "postgres-storage")]
use sqlx::FromRow;
use std::collections::HashMap;
use std::time::Duration;
use uuid::Uuid;
#[cfg_attr(feature = "postgres-storage", derive(FromRow))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthToken {
pub token_id: String,
pub user_id: String,
pub access_token: String,
pub token_type: Option<String>,
pub subject: Option<String>,
pub issuer: Option<String>,
pub refresh_token: Option<String>,
pub issued_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub scopes: Vec<String>,
pub auth_method: String,
pub client_id: Option<String>,
pub user_profile: Option<UserProfile>,
pub permissions: Vec<String>,
pub roles: Vec<String>,
pub metadata: TokenMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TokenMetadata {
pub issued_ip: Option<String>,
pub user_agent: Option<String>,
pub device_id: Option<String>,
pub session_id: Option<String>,
pub revoked: bool,
pub revoked_at: Option<DateTime<Utc>>,
pub revoked_reason: Option<String>,
pub last_used: Option<DateTime<Utc>>,
pub use_count: u64,
pub custom: HashMap<String, serde_json::Value>,
}
#[cfg(feature = "postgres-storage")]
use sqlx::{Decode, Postgres, Type, postgres::PgValueRef};
#[cfg(feature = "postgres-storage")]
impl<'r> Decode<'r, Postgres> for TokenMetadata {
fn decode(value: PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
let json: serde_json::Value = <serde_json::Value as Decode<Postgres>>::decode(value)?;
serde_json::from_value(json).map_err(|e| Box::new(e) as sqlx::error::BoxDynError)
}
}
#[cfg(feature = "postgres-storage")]
impl Type<Postgres> for TokenMetadata {
fn type_info() -> sqlx::postgres::PgTypeInfo {
<serde_json::Value as Type<Postgres>>::type_info()
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
<serde_json::Value as Type<Postgres>>::compatible(ty)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenInfo {
pub user_id: String,
pub username: Option<String>,
pub email: Option<String>,
pub name: Option<String>,
pub roles: Vec<String>,
pub permissions: Vec<String>,
pub attributes: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: String,
pub iss: String,
pub aud: String,
pub exp: i64,
pub iat: i64,
pub nbf: i64,
pub jti: String,
pub scope: String,
pub permissions: Option<Vec<String>>,
pub roles: Option<Vec<String>>,
pub client_id: Option<String>,
#[serde(flatten)]
pub custom: HashMap<String, serde_json::Value>,
}
pub struct TokenManager {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
key_material: KeyMaterial,
algorithm: Algorithm,
issuer: String,
audience: String,
default_lifetime: Duration,
}
#[derive(Clone)]
enum KeyMaterial {
Hmac(Vec<u8>),
Rsa { private: Vec<u8>, public: Vec<u8> },
}
impl AuthToken {
pub fn new(
user_id: impl Into<String>,
access_token: impl Into<String>,
expires_in: std::time::Duration,
auth_method: impl Into<String>,
) -> Self {
let now = Utc::now();
let expires_in_chrono =
chrono::Duration::from_std(expires_in).unwrap_or(chrono::Duration::hours(1));
Self {
token_id: Uuid::new_v4().to_string(),
user_id: user_id.into(),
access_token: access_token.into(),
refresh_token: None,
token_type: Some("Bearer".to_string()),
subject: None,
issuer: None,
issued_at: now,
expires_at: now + expires_in_chrono,
scopes: Vec::new(),
auth_method: auth_method.into(),
client_id: None,
user_profile: None,
permissions: Vec::new(),
roles: Vec::new(),
metadata: TokenMetadata::default(),
}
}
pub fn access_token(&self) -> &str {
&self.access_token
}
pub fn user_id(&self) -> &str {
&self.user_id
}
pub fn expires_at(&self) -> DateTime<Utc> {
self.expires_at
}
pub fn token_value(&self) -> &str {
&self.access_token
}
pub fn token_type(&self) -> Option<&str> {
self.token_type.as_deref()
}
pub fn subject(&self) -> Option<&str> {
self.subject.as_deref()
}
pub fn issuer(&self) -> Option<&str> {
self.issuer.as_deref()
}
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn is_expiring(&self, within: Duration) -> bool {
Utc::now() + within > self.expires_at
}
pub fn is_revoked(&self) -> bool {
self.metadata.revoked
}
pub fn is_valid(&self) -> bool {
!self.is_expired() && !self.is_revoked()
}
pub fn revoke(&mut self, reason: Option<String>) {
self.metadata.revoked = true;
self.metadata.revoked_at = Some(Utc::now());
self.metadata.revoked_reason = reason;
}
pub fn mark_used(&mut self) {
self.metadata.last_used = Some(Utc::now());
self.metadata.use_count += 1;
}
pub fn add_scope(&mut self, scope: impl Into<String>) {
let scope = scope.into();
if !self.scopes.contains(&scope) {
self.scopes.push(scope);
}
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.contains(&scope.to_string())
}
pub fn with_refresh_token(mut self, refresh_token: impl Into<String>) -> Self {
self.refresh_token = Some(refresh_token.into());
self
}
pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(client_id.into());
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn with_metadata(mut self, metadata: TokenMetadata) -> Self {
self.metadata = metadata;
self
}
pub fn time_until_expiry(&self) -> Duration {
let now = Utc::now();
if self.expires_at > now {
(self.expires_at - now).to_std().unwrap_or(Duration::ZERO)
} else {
Duration::ZERO
}
}
pub fn add_custom_claim(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.metadata.custom.insert(key.into(), value);
}
pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.custom.get(key)
}
pub fn has_permission(&self, permission: &str) -> bool {
self.permissions.contains(&permission.to_string())
}
pub fn add_permission(&mut self, permission: impl Into<String>) {
let permission = permission.into();
if !self.permissions.contains(&permission) {
self.permissions.push(permission);
}
}
pub fn add_role(&mut self, role: impl Into<String>) {
let role = role.into();
if !self.roles.contains(&role) {
self.roles.push(role);
}
}
pub fn has_role(&self, role: &str) -> bool {
self.roles.contains(&role.to_string())
}
pub fn with_permissions(mut self, permissions: Vec<String>) -> Self {
self.permissions = permissions;
self
}
pub fn with_roles(mut self, roles: Vec<String>) -> Self {
self.roles = roles;
self
}
}
impl Clone for TokenManager {
fn clone(&self) -> Self {
match &self.key_material {
KeyMaterial::Hmac(secret) => Self {
encoding_key: EncodingKey::from_secret(secret),
decoding_key: DecodingKey::from_secret(secret),
key_material: self.key_material.clone(),
algorithm: self.algorithm,
issuer: self.issuer.clone(),
audience: self.audience.clone(),
default_lifetime: self.default_lifetime,
},
KeyMaterial::Rsa { private, public } => Self {
encoding_key: EncodingKey::from_rsa_pem(private).expect("Valid RSA private key"),
decoding_key: DecodingKey::from_rsa_pem(public).expect("Valid RSA public key"),
key_material: self.key_material.clone(),
algorithm: self.algorithm,
issuer: self.issuer.clone(),
audience: self.audience.clone(),
default_lifetime: self.default_lifetime,
},
}
}
}
impl TokenManager {
pub fn new_hmac(secret: &[u8], issuer: impl Into<String>, audience: impl Into<String>) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret),
decoding_key: DecodingKey::from_secret(secret),
key_material: KeyMaterial::Hmac(secret.to_vec()),
algorithm: Algorithm::HS256,
issuer: issuer.into(),
audience: audience.into(),
default_lifetime: Duration::from_secs(3600), }
}
pub fn new_rsa(
private_key: &[u8],
public_key: &[u8],
issuer: impl Into<String>,
audience: impl Into<String>,
) -> Result<Self> {
let encoding_key = EncodingKey::from_rsa_pem(private_key)
.map_err(|e| AuthError::crypto(format!("Invalid RSA private key: {e}")))?;
let decoding_key = DecodingKey::from_rsa_pem(public_key)
.map_err(|e| AuthError::crypto(format!("Invalid RSA public key: {e}")))?;
Ok(Self {
encoding_key,
decoding_key,
key_material: KeyMaterial::Rsa {
private: private_key.to_vec(),
public: public_key.to_vec(),
},
algorithm: Algorithm::RS256,
issuer: issuer.into(),
audience: audience.into(),
default_lifetime: Duration::from_secs(3600), })
}
pub fn with_default_lifetime(mut self, lifetime: Duration) -> Self {
self.default_lifetime = lifetime;
self
}
pub fn create_jwt_token(
&self,
user_id: impl Into<String>,
scopes: Vec<String>,
lifetime: Option<Duration>,
) -> Result<String> {
let user_id = user_id.into();
let lifetime = lifetime.unwrap_or(self.default_lifetime);
let now = Utc::now();
let exp = now + chrono::Duration::from_std(lifetime).unwrap_or(chrono::Duration::hours(1));
let claims = JwtClaims {
sub: user_id,
iss: self.issuer.clone(),
aud: self.audience.clone(),
exp: exp.timestamp(),
iat: now.timestamp(),
nbf: now.timestamp(),
jti: Uuid::new_v4().to_string(),
scope: scopes.join(" "),
permissions: None,
roles: None,
client_id: None,
custom: HashMap::new(),
};
let header = Header::new(self.algorithm);
encode(&header, &claims, &self.encoding_key)
.map_err(|e| TokenError::creation_failed(format!("JWT encoding failed: {e}")).into())
}
pub fn validate_jwt_token(&self, token: &str) -> Result<JwtClaims> {
let mut validation = Validation::new(self.algorithm);
validation.set_issuer(&[&self.issuer]);
validation.set_audience(&[&self.audience]);
let token_data =
decode::<JwtClaims>(token, &self.decoding_key, &validation).map_err(|e| {
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
AuthError::Token(TokenError::Expired)
}
_ => AuthError::Token(TokenError::Invalid {
message: "Invalid token format".to_string(),
}),
}
})?;
Ok(token_data.claims)
}
pub fn create_auth_token(
&self,
user_id: impl Into<String>,
scopes: Vec<String>,
auth_method: impl Into<String>,
lifetime: Option<std::time::Duration>,
) -> Result<AuthToken> {
let user_id_str = user_id.into();
let lifetime = lifetime.unwrap_or(self.default_lifetime);
let jwt_token = self.create_jwt_token(&user_id_str, scopes.clone(), Some(lifetime))?;
let token =
AuthToken::new(user_id_str, jwt_token, lifetime, auth_method).with_scopes(scopes);
Ok(token)
}
pub fn validate_auth_token(&self, token: &AuthToken) -> Result<()> {
if token.is_expired() {
return Err(TokenError::Expired.into());
}
if token.is_revoked() {
return Err(TokenError::Invalid {
message: "Token has been revoked".to_string(),
}
.into());
}
if token.auth_method == "jwt" || token.access_token.contains('.') {
self.validate_jwt_token(&token.access_token)?;
}
Ok(())
}
pub fn refresh_token(&self, token: &AuthToken) -> Result<AuthToken> {
if token.is_expired() {
return Err(TokenError::Expired.into());
}
if token.is_revoked() {
return Err(TokenError::Invalid {
message: "Cannot refresh revoked token".to_string(),
}
.into());
}
self.create_auth_token(
&token.user_id,
token.scopes.clone(),
&token.auth_method,
Some(self.default_lifetime),
)
}
pub fn extract_token_info(&self, token: &str) -> Result<TokenInfo> {
let claims = self.validate_jwt_token(token)?;
Ok(TokenInfo {
user_id: claims.sub,
username: claims
.custom
.get("username")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
email: claims
.custom
.get("email")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
name: claims
.custom
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
roles: claims
.custom
.get("roles")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect()
})
.unwrap_or_default(),
permissions: claims
.scope
.split_whitespace()
.map(|s| s.to_string())
.collect(),
attributes: claims.custom,
})
}
}
#[async_trait::async_trait]
pub trait TokenToProfile {
async fn to_profile(&self, provider: &OAuthProvider) -> Result<UserProfile>;
async fn to_profile_with_extractor(
&self,
provider: &OAuthProvider,
extractor: &ProfileExtractor,
) -> Result<UserProfile>;
}
#[async_trait::async_trait]
impl TokenToProfile for AuthToken {
async fn to_profile(&self, provider: &OAuthProvider) -> Result<UserProfile> {
let extractor = ProfileExtractor::new();
extractor.extract_profile(self, provider).await
}
async fn to_profile_with_extractor(
&self,
provider: &OAuthProvider,
extractor: &ProfileExtractor,
) -> Result<UserProfile> {
extractor.extract_profile(self, provider).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_token_creation() {
let token = AuthToken::new(
"user123",
"token123",
Duration::from_secs(3600), "password",
);
assert_eq!(token.user_id(), "user123");
assert_eq!(token.access_token(), "token123");
assert!(!token.is_expired());
assert!(!token.is_revoked());
assert!(token.is_valid());
}
#[test]
fn test_token_expiry() {
let token = AuthToken::new("user123", "token123", Duration::from_millis(1), "password");
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(token.is_expired());
assert!(!token.is_valid());
}
#[test]
fn test_token_revocation() {
let mut token = AuthToken::new(
"user123",
"token123",
Duration::from_secs(3600), "password",
);
assert!(!token.is_revoked());
token.revoke(Some("User logout".to_string()));
assert!(token.is_revoked());
assert!(!token.is_valid());
assert!(token.metadata.revoked);
}
}