use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use crate::error::{SaTokenError, SaTokenResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Default)]
pub enum JwtAlgorithm {
#[default]
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
ES256,
ES384,
}
impl From<JwtAlgorithm> for Algorithm {
fn from(alg: JwtAlgorithm) -> Self {
match alg {
JwtAlgorithm::HS256 => Algorithm::HS256,
JwtAlgorithm::HS384 => Algorithm::HS384,
JwtAlgorithm::HS512 => Algorithm::HS512,
JwtAlgorithm::RS256 => Algorithm::RS256,
JwtAlgorithm::RS384 => Algorithm::RS384,
JwtAlgorithm::RS512 => Algorithm::RS512,
JwtAlgorithm::ES256 => Algorithm::ES256,
JwtAlgorithm::ES384 => Algorithm::ES384,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
#[serde(rename = "sub")]
pub login_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub login_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub device: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, Value>,
}
impl JwtClaims {
pub fn new(login_id: impl Into<String>) -> Self {
let now = Utc::now().timestamp();
Self {
login_id: login_id.into(),
iss: None,
aud: None,
exp: None,
nbf: None,
iat: Some(now),
jti: None,
login_type: Some("default".to_string()),
device: None,
extra: HashMap::new(),
}
}
pub fn set_expiration(&mut self, seconds: i64) -> &mut Self {
let exp_time = Utc::now() + Duration::seconds(seconds);
self.exp = Some(exp_time.timestamp());
self
}
pub fn set_expiration_at(&mut self, datetime: DateTime<Utc>) -> &mut Self {
self.exp = Some(datetime.timestamp());
self
}
pub fn set_issuer(&mut self, issuer: impl Into<String>) -> &mut Self {
self.iss = Some(issuer.into());
self
}
pub fn set_audience(&mut self, audience: impl Into<String>) -> &mut Self {
self.aud = Some(audience.into());
self
}
pub fn set_jti(&mut self, jti: impl Into<String>) -> &mut Self {
self.jti = Some(jti.into());
self
}
pub fn set_login_type(&mut self, login_type: impl Into<String>) -> &mut Self {
self.login_type = Some(login_type.into());
self
}
pub fn set_device(&mut self, device: impl Into<String>) -> &mut Self {
self.device = Some(device.into());
self
}
pub fn add_claim(&mut self, key: impl Into<String>, value: Value) -> &mut Self {
self.extra.insert(key.into(), value);
self
}
pub fn get_claim(&self, key: &str) -> Option<&Value> {
self.extra.get(key)
}
pub fn set_claims(&mut self, claims: HashMap<String, Value>) -> &mut Self {
self.extra = claims;
self
}
pub fn get_claims(&self) -> &HashMap<String, Value> {
&self.extra
}
pub fn is_expired(&self) -> bool {
if let Some(exp) = self.exp {
let now = Utc::now().timestamp();
now >= exp
} else {
false
}
}
pub fn remaining_time(&self) -> Option<i64> {
self.exp.map(|exp| {
let now = Utc::now().timestamp();
(exp - now).max(0)
})
}
}
#[derive(Clone)]
pub struct JwtManager {
secret: String,
algorithm: JwtAlgorithm,
issuer: Option<String>,
audience: Option<String>,
}
impl JwtManager {
pub fn new(secret: impl Into<String>) -> Self {
Self {
secret: secret.into(),
algorithm: JwtAlgorithm::HS256,
issuer: None,
audience: None,
}
}
pub fn with_algorithm(secret: impl Into<String>, algorithm: JwtAlgorithm) -> Self {
Self {
secret: secret.into(),
algorithm,
issuer: None,
audience: None,
}
}
pub fn set_issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn set_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn generate(&self, claims: &JwtClaims) -> SaTokenResult<String> {
let mut final_claims = claims.clone();
if self.issuer.is_some() && final_claims.iss.is_none() {
final_claims.iss = self.issuer.clone();
}
if self.audience.is_some() && final_claims.aud.is_none() {
final_claims.aud = self.audience.clone();
}
let header = Header::new(self.algorithm.into());
let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
encode(&header, &final_claims, &encoding_key).map_err(|e| {
SaTokenError::InvalidToken(format!("Failed to generate JWT: {}", e))
})
}
pub fn validate(&self, token: &str) -> SaTokenResult<JwtClaims> {
let mut validation = Validation::new(self.algorithm.into());
validation.validate_exp = true;
validation.leeway = 0;
if let Some(ref iss) = self.issuer {
validation.set_issuer(&[iss]);
}
if let Some(ref aud) = self.audience {
validation.set_audience(&[aud]);
}
let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
let token_data = decode::<JwtClaims>(token, &decoding_key, &validation).map_err(|e| {
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
SaTokenError::TokenExpired
}
_ => SaTokenError::InvalidToken(format!("JWT validation failed: {}", e)),
}
})?;
Ok(token_data.claims)
}
pub fn decode_without_validation(&self, token: &str) -> SaTokenResult<JwtClaims> {
let token_data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(token)
.map_err(|e| SaTokenError::InvalidToken(format!("Failed to decode JWT: {}", e)))?;
Ok(token_data.claims)
}
pub fn refresh(&self, token: &str, extend_seconds: i64) -> SaTokenResult<String> {
let mut claims = self.validate(token)?;
claims.set_expiration(extend_seconds);
claims.iat = Some(Utc::now().timestamp());
self.generate(&claims)
}
pub fn extract_login_id(&self, token: &str) -> SaTokenResult<String> {
let claims = self.decode_without_validation(token)?;
Ok(claims.login_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_claims_creation() {
let mut claims = JwtClaims::new("user_123");
claims.set_expiration(3600);
claims.set_issuer("sa-token");
claims.add_claim("role", serde_json::json!("admin"));
assert_eq!(claims.login_id, "user_123");
assert!(claims.exp.is_some());
assert_eq!(claims.iss, Some("sa-token".to_string()));
assert_eq!(
claims.get_claim("role"),
Some(&serde_json::json!("admin"))
);
}
#[test]
fn test_jwt_generate_and_validate() {
let jwt_manager = JwtManager::new("test-secret-key");
let mut claims = JwtClaims::new("user_123");
claims.set_expiration(3600);
let token = jwt_manager.generate(&claims).unwrap();
assert!(!token.is_empty());
let decoded = jwt_manager.validate(&token).unwrap();
assert_eq!(decoded.login_id, "user_123");
assert!(!decoded.is_expired());
}
#[test]
fn test_jwt_expired() {
let jwt_manager = JwtManager::new("test-secret-key");
let mut claims = JwtClaims::new("user_123");
let exp_time = Utc::now() - Duration::seconds(10);
claims.set_expiration_at(exp_time);
let token = jwt_manager.generate(&claims).unwrap();
let result = jwt_manager.validate(&token);
assert!(result.is_err());
match result {
Err(SaTokenError::TokenExpired) => {}, _ => panic!("Expected TokenExpired error"),
}
}
#[test]
fn test_jwt_refresh() {
let jwt_manager = JwtManager::new("test-secret-key");
let mut claims = JwtClaims::new("user_123");
claims.set_expiration(3600);
let original_token = jwt_manager.generate(&claims).unwrap();
let new_token = jwt_manager.refresh(&original_token, 7200).unwrap();
assert_ne!(original_token, new_token);
let decoded = jwt_manager.validate(&new_token).unwrap();
assert_eq!(decoded.login_id, "user_123");
}
#[test]
fn test_jwt_custom_claims() {
let jwt_manager = JwtManager::new("test-secret-key");
let mut claims = JwtClaims::new("user_123");
claims.set_expiration(3600);
claims.add_claim("role", serde_json::json!("admin"));
claims.add_claim("permissions", serde_json::json!(["read", "write"]));
let token = jwt_manager.generate(&claims).unwrap();
let decoded = jwt_manager.validate(&token).unwrap();
assert_eq!(decoded.get_claim("role"), Some(&serde_json::json!("admin")));
assert_eq!(
decoded.get_claim("permissions"),
Some(&serde_json::json!(["read", "write"]))
);
}
#[test]
fn test_extract_login_id() {
let jwt_manager = JwtManager::new("test-secret-key");
let mut claims = JwtClaims::new("user_123");
claims.set_expiration(3600);
let token = jwt_manager.generate(&claims).unwrap();
let login_id = jwt_manager.extract_login_id(&token).unwrap();
assert_eq!(login_id, "user_123");
}
}