use std::sync::Arc;
use alun_config::ConfigManager;
use crate::middleware::{TokenClaims, TokenType};
#[derive(Clone)]
pub struct JWT {
config: Arc<ConfigManager>,
#[cfg(feature = "cache")]
cache: Option<alun_cache::SharedCache>,
}
impl JWT {
pub fn from_config() -> Self {
let config = crate::resources::config().clone();
#[cfg(feature = "cache")]
let cache = crate::resources::try_cache().cloned();
JWT {
config,
#[cfg(feature = "cache")]
cache,
}
}
pub fn with_config(config: Arc<ConfigManager>) -> Self {
JWT {
config,
#[cfg(feature = "cache")]
cache: None,
}
}
#[cfg(feature = "cache")]
pub fn with_config_and_cache(config: Arc<ConfigManager>, cache: alun_cache::SharedCache) -> Self {
JWT { config, cache: Some(cache) }
}
pub fn jwt_secret(&self) -> &str {
&self.config.get().middleware.auth.jwt_secret
}
pub fn access_token_expire_secs(&self) -> u64 {
self.config.get().middleware.auth.access_token_expire_secs
}
pub fn refresh_token_expire_secs(&self) -> u64 {
self.config.get().middleware.auth.refresh_token_expire_secs
}
pub fn create_access_token(
&self,
user_id: &str,
username: Option<&str>,
roles: &[String],
permissions: &[String],
) -> Result<String, String> {
self.create_token(
user_id,
username,
roles,
permissions,
TokenType::Access,
self.access_token_expire_secs(),
)
}
pub fn create_refresh_token(&self, user_id: &str) -> Result<String, String> {
self.create_token(
user_id,
None,
&[],
&[],
TokenType::Refresh,
self.refresh_token_expire_secs(),
)
}
pub fn validate(&self, token: &str) -> Result<TokenClaims, String> {
use jsonwebtoken::{decode, DecodingKey, Validation};
let token_data = decode::<TokenClaims>(
token,
&DecodingKey::from_secret(self.jwt_secret().as_bytes()),
&Validation::default(),
)
.map_err(|e| format!("Token 验证失败: {}", e))?;
Ok(token_data.claims)
}
#[cfg(feature = "cache")]
pub async fn blacklist(&self, claims: &TokenClaims) {
if let (Some(ref cache), Some(ref jti)) = (&self.cache, &claims.jti) {
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as usize)
.unwrap_or(0);
let ttl = if claims.exp > now_secs {
claims.exp - now_secs
} else {
60
};
let key = format!("token:blacklist:{}", jti);
let _ = alun_cache::Cache::set_ex(cache, &key, &serde_json::json!(true), ttl as u64).await;
}
}
#[cfg(feature = "cache")]
pub async fn is_blacklisted(&self, claims: &TokenClaims) -> bool {
if let (Some(ref cache), Some(ref jti)) = (&self.cache, &claims.jti) {
let key = format!("token:blacklist:{}", jti);
match alun_cache::Cache::get::<serde_json::Value>(cache, &key).await {
Ok(Some(_)) => true,
_ => false,
}
} else {
false
}
}
#[cfg(feature = "cache")]
pub async fn logout(&self, claims: &TokenClaims) {
self.blacklist(claims).await;
}
#[cfg(feature = "cache")]
pub async fn refresh(
&self,
refresh_token_str: &str,
) -> Result<(String, String), String> {
let claims = self.validate(refresh_token_str)?;
if claims.token_type != Some(TokenType::Refresh) {
return Err("Token 类型不正确,需要 Refresh Token".into());
}
if self.is_blacklisted(&claims).await {
return Err("Refresh Token 已被撤销".into());
}
self.blacklist(&claims).await;
let access_token = self.create_access_token(
&claims.sub,
claims.username.as_deref(),
&claims.roles,
&claims.permissions,
)?;
let new_refresh_token = self.create_refresh_token(&claims.sub)?;
Ok((access_token, new_refresh_token))
}
fn create_token(
&self,
user_id: &str,
username: Option<&str>,
roles: &[String],
permissions: &[String],
token_type: TokenType,
expire_secs: u64,
) -> Result<String, String> {
use jsonwebtoken::{encode, EncodingKey, Header};
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| format!("时间戳错误: {}", e))?
.as_secs() as usize;
let claims = TokenClaims {
jti: Some(uuid::Uuid::new_v4().to_string()),
sub: user_id.to_string(),
username: username.map(|s| s.to_string()),
roles: roles.to_vec(),
permissions: permissions.to_vec(),
token_type: Some(token_type),
exp: now + expire_secs as usize,
iat: now,
};
encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(self.jwt_secret().as_bytes()),
)
.map_err(|e| format!("Token 生成失败: {}", e))
}
}
impl std::fmt::Debug for JWT {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("JWT");
let d = &mut ds;
#[cfg(feature = "cache")]
{
d.field("has_cache", &self.cache.is_some());
}
d.finish()
}
}