use hex;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
use crate::core::error::{AeroSyncError, Result};
#[derive(Debug, Clone)]
pub struct TokenInfo {
pub token: String,
pub created_at: u64,
pub expires_at: u64,
pub revoked: bool,
}
impl TokenInfo {
pub fn is_expired(&self) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
now > self.expires_at
}
pub fn is_valid(&self) -> bool {
!self.is_expired() && !self.revoked
}
}
#[derive(Debug, thiserror::Error)]
pub enum TokenError {
#[error("Token 已过期")]
Expired,
#[error("Token 无效")]
Invalid,
#[error("Token 已被撤销")]
Revoked,
#[error("Token 不存在")]
NotFound,
}
#[derive(Clone)]
pub struct TokenManager {
secret_key: String,
lifetime_hours: u64,
tokens: Arc<RwLock<HashMap<String, TokenInfo>>>,
}
impl TokenManager {
pub fn new(secret_key: String) -> Result<Self> {
if secret_key.is_empty() {
return Err(AeroSyncError::Config(
"Secret key cannot be empty".to_string(),
));
}
Ok(Self {
secret_key,
lifetime_hours: 24, tokens: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn with_lifetime(mut self, hours: u64) -> Self {
self.lifetime_hours = hours;
self
}
pub fn generate_token(&self) -> Result<String> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let expires_at = now + (self.lifetime_hours * 3600);
let uuid = Uuid::new_v4().to_string();
let signature = self.sign_token(&uuid, now, expires_at);
let token = format!("{}.{}.{}.{}", uuid, now, expires_at, signature);
let token_info = TokenInfo {
token: token.clone(),
created_at: now,
expires_at,
revoked: false,
};
self.tokens.write().unwrap().insert(uuid, token_info);
Ok(token)
}
pub fn verify_token(&self, token: &str) -> Result<bool> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 4 {
return Ok(false);
}
let uuid = parts[0];
let created_at: u64 = parts[1]
.parse()
.map_err(|_| AeroSyncError::Auth("Invalid token format".to_string()))?;
let expires_at: u64 = parts[2]
.parse()
.map_err(|_| AeroSyncError::Auth("Invalid token format".to_string()))?;
let signature = parts[3];
let expected_signature = self.sign_token(uuid, created_at, expires_at);
if signature != expected_signature {
return Ok(false);
}
let tokens = self.tokens.read().unwrap();
if let Some(token_info) = tokens.get(uuid) {
return Ok(token_info.is_valid());
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Ok(now <= expires_at)
}
pub fn revoke_token(&self, token: &str) -> Result<()> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 4 {
return Err(AeroSyncError::Auth("Invalid token format".to_string()));
}
let uuid = parts[0];
let mut tokens = self.tokens.write().unwrap();
if let Some(token_info) = tokens.get_mut(uuid) {
token_info.revoked = true;
Ok(())
} else {
Err(AeroSyncError::Auth("Token not found".to_string()))
}
}
pub fn cleanup_expired_tokens(&self) {
let mut tokens = self.tokens.write().unwrap();
tokens.retain(|_, info| info.is_valid());
}
pub fn active_token_count(&self) -> usize {
let tokens = self.tokens.read().unwrap();
tokens.values().filter(|info| info.is_valid()).count()
}
fn sign_token(&self, uuid: &str, created_at: u64, expires_at: u64) -> String {
let data = format!("{}.{}.{}.{}", uuid, created_at, expires_at, self.secret_key);
let mut hasher = Sha256::new();
hasher.update(data.as_bytes());
let result = hasher.finalize();
hex::encode(result)[..16].to_string() }
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration;
#[test]
fn test_token_generation() {
let manager = TokenManager::new("test-secret".to_string()).unwrap();
let token = manager.generate_token().unwrap();
assert_eq!(token.split('.').count(), 4);
}
#[test]
fn test_token_verification() {
let manager = TokenManager::new("test-secret".to_string()).unwrap();
let token = manager.generate_token().unwrap();
assert!(manager.verify_token(&token).unwrap());
}
#[test]
fn test_invalid_token() {
let manager = TokenManager::new("test-secret".to_string()).unwrap();
assert!(!manager.verify_token("invalid-token").unwrap());
let token = manager.generate_token().unwrap();
let parts: Vec<&str> = token.split('.').collect();
let tampered = format!("{}.{}.{}.0000000000000000", parts[0], parts[1], parts[2]);
assert!(!manager.verify_token(&tampered).unwrap());
}
#[test]
fn test_token_revocation() {
let manager = TokenManager::new("test-secret".to_string()).unwrap();
let token = manager.generate_token().unwrap();
assert!(manager.verify_token(&token).unwrap());
manager.revoke_token(&token).unwrap();
assert!(!manager.verify_token(&token).unwrap());
}
#[test]
fn test_token_expiration() {
let manager = TokenManager::new("test-secret".to_string())
.unwrap()
.with_lifetime(0);
let token = manager.generate_token().unwrap();
sleep(Duration::from_secs(1));
assert!(!manager.verify_token(&token).unwrap());
}
#[test]
fn test_cleanup_expired_tokens() {
let manager = TokenManager::new("test-secret".to_string())
.unwrap()
.with_lifetime(0);
for _ in 0..5 {
manager.generate_token().unwrap();
}
assert_eq!(manager.active_token_count(), 5);
sleep(Duration::from_secs(1));
manager.cleanup_expired_tokens();
assert_eq!(manager.active_token_count(), 0);
}
}