use std::sync::Arc;
use std::collections::HashMap;
use chrono::{DateTime, Duration, Utc};
use tokio::sync::RwLock;
use sa_token_adapter::storage::SaStorage;
use crate::config::SaTokenConfig;
use crate::error::{SaTokenError, SaTokenResult};
use crate::token::{TokenInfo, TokenValue, TokenGenerator};
use crate::session::SaSession;
use crate::event::{SaTokenEventBus, SaTokenEvent};
use crate::online::OnlineManager;
use crate::distributed::DistributedSessionManager;
#[derive(Clone)]
pub struct SaTokenManager {
pub(crate) storage: Arc<dyn SaStorage>,
pub config: SaTokenConfig,
pub(crate) user_permissions: Arc<RwLock<HashMap<String, Vec<String>>>>,
pub(crate) user_roles: Arc<RwLock<HashMap<String, Vec<String>>>>,
pub(crate) event_bus: SaTokenEventBus,
online_manager: Option<Arc<OnlineManager>>,
distributed_manager: Option<Arc<DistributedSessionManager>>,
}
impl SaTokenManager {
pub fn new(storage: Arc<dyn SaStorage>, config: SaTokenConfig) -> Self {
Self {
storage,
config,
user_permissions: Arc::new(RwLock::new(HashMap::new())),
user_roles: Arc::new(RwLock::new(HashMap::new())),
event_bus: SaTokenEventBus::new(),
online_manager: None,
distributed_manager: None,
}
}
pub fn with_online_manager(mut self, manager: Arc<OnlineManager>) -> Self {
self.online_manager = Some(manager);
self
}
pub fn with_distributed_manager(mut self, manager: Arc<DistributedSessionManager>) -> Self {
self.distributed_manager = Some(manager);
self
}
pub fn online_manager(&self) -> Option<&Arc<OnlineManager>> {
self.online_manager.as_ref()
}
pub fn distributed_manager(&self) -> Option<&Arc<DistributedSessionManager>> {
self.distributed_manager.as_ref()
}
pub fn event_bus(&self) -> &SaTokenEventBus {
&self.event_bus
}
pub async fn login(&self, login_id: impl Into<String>) -> SaTokenResult<TokenValue> {
self.login_with_options(login_id, None, None, None, None, None).await
}
pub async fn login_with_options(
&self,
login_id: impl Into<String>,
login_type: Option<String>,
device: Option<String>,
extra_data: Option<serde_json::Value>,
nonce: Option<String>,
expire_time: Option<DateTime<Utc>>,
) -> SaTokenResult<TokenValue> {
let login_id = login_id.into();
let token = match &extra_data {
Some(extra) => TokenGenerator::generate_with_login_id_and_extra(&self.config, &login_id, extra),
None => TokenGenerator::generate_with_login_id(&self.config, &login_id),
};
let mut token_info = TokenInfo::new(token.clone(), login_id.clone());
token_info.login_type = login_type.unwrap_or_else(|| "default".to_string());
if let Some(device_str) = device {
token_info.device = Some(device_str);
}
if let Some(extra) = extra_data {
token_info.extra_data = Some(extra);
}
if let Some(nonce_str) = nonce {
token_info.nonce = Some(nonce_str);
}
if let Some(custom_expire_time) = expire_time {
token_info.expire_time = Some(custom_expire_time);
}
self.login_with_token_info(token_info).await
}
pub async fn login_with_token_info(&self, mut token_info: TokenInfo) -> SaTokenResult<TokenValue> {
let login_id = token_info.login_id.clone();
let token = if token_info.token.as_str().is_empty() {
TokenGenerator::generate_with_login_id(&self.config, &login_id)
} else {
token_info.token.clone()
};
token_info.token = token.clone();
token_info.update_active_time();
let now = Utc::now();
if token_info.expire_time.is_none()
&& let Some(timeout) = self.config.timeout_duration() {
token_info.expire_time = Some(now + Duration::from_std(timeout).unwrap());
}
if token_info.login_type.is_empty() {
token_info.login_type = "default".to_string();
}
let key = format!("sa:token:{}", token.as_str());
let value = serde_json::to_string(&token_info)
.map_err(SaTokenError::SerializationError)?;
self.storage.set(&key, &value, self.config.timeout_duration()).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
let login_token_key = if !token_info.login_type.is_empty() && token_info.login_type != "default" {
format!("sa:login:token:{}:{}", login_id, token_info.login_type)
} else {
format!("sa:login:token:{}", login_id)
};
self.storage.set(&login_token_key, token.as_str(), self.config.timeout_duration()).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
if !self.config.is_concurrent {
self.logout_by_login_id(&login_id).await?;
}
let event = SaTokenEvent::login(login_id.clone(), token.as_str())
.with_login_type(&token_info.login_type);
self.event_bus.publish(event).await;
Ok(token)
}
pub async fn logout(&self, token: &TokenValue) -> SaTokenResult<()> {
tracing::debug!("Manager: 开始 logout,token: {}", token);
let key = format!("sa:token:{}", token.as_str());
tracing::debug!("Manager: 查询 token 信息,key: {}", key);
let token_info_str = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
let token_info = if let Some(value) = token_info_str {
tracing::debug!("Manager: 找到 token 信息: {}", value);
serde_json::from_str::<TokenInfo>(&value).ok()
} else {
tracing::debug!("Manager: 未找到 token 信息");
None
};
tracing::debug!("Manager: 删除 token,key: {}", key);
self.storage.delete(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
tracing::debug!("Manager: token 已从存储中删除");
if let Some(info) = token_info.clone() {
tracing::debug!("Manager: 触发登出事件,login_id: {}, login_type: {}", info.login_id, info.login_type);
let event = SaTokenEvent::logout(&info.login_id, token.as_str())
.with_login_type(&info.login_type);
self.event_bus.publish(event).await;
if let Some(online_mgr) = &self.online_manager {
tracing::debug!("Manager: 标记用户下线,login_id: {}", info.login_id);
online_mgr.mark_offline(&info.login_id, token.as_str()).await;
}
}
tracing::debug!("Manager: logout 完成,token: {}", token);
Ok(())
}
pub async fn logout_by_login_id(&self, login_id: &str) -> SaTokenResult<()> {
let token_prefix = "sa:token:";
if let Ok(keys) = self.storage.keys(&format!("{}*", token_prefix)).await {
for key in keys {
if let Ok(Some(token_info_str)) = self.storage.get(&key).await {
if let Ok(token_info) = serde_json::from_str::<TokenInfo>(&token_info_str) {
if token_info.login_id == login_id {
let token_str = key[token_prefix.len()..].to_string();
let token = TokenValue::new(token_str);
let _ = self.logout(&token).await;
}
}
}
}
}
Ok(())
}
pub async fn get_token_info(&self, token: &TokenValue) -> SaTokenResult<TokenInfo> {
let key = format!("sa:token:{}", token.as_str());
let value = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
.ok_or(SaTokenError::TokenNotFound)?;
let token_info: TokenInfo = serde_json::from_str(&value)
.map_err(SaTokenError::SerializationError)?;
if token_info.is_expired() {
self.logout(token).await?;
return Err(SaTokenError::TokenExpired);
}
if self.config.auto_renew {
let renew_timeout = if self.config.active_timeout > 0 {
self.config.active_timeout
} else {
self.config.timeout
};
let _ = self.renew_timeout_internal(token, renew_timeout, &token_info).await;
}
Ok(token_info)
}
pub async fn is_valid(&self, token: &TokenValue) -> bool {
self.get_token_info(token).await.is_ok()
}
pub async fn get_session(&self, login_id: &str) -> SaTokenResult<SaSession> {
let key = format!("sa:session:{}", login_id);
let value = self.storage.get(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
if let Some(value) = value {
let session: SaSession = serde_json::from_str(&value)
.map_err(SaTokenError::SerializationError)?;
Ok(session)
} else {
Ok(SaSession::new(login_id))
}
}
pub async fn save_session(&self, session: &SaSession) -> SaTokenResult<()> {
let key = format!("sa:session:{}", session.id);
let value = serde_json::to_string(session)
.map_err(SaTokenError::SerializationError)?;
self.storage.set(&key, &value, None).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn delete_session(&self, login_id: &str) -> SaTokenResult<()> {
let key = format!("sa:session:{}", login_id);
self.storage.delete(&key).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn renew_timeout(
&self,
token: &TokenValue,
timeout_seconds: i64,
) -> SaTokenResult<()> {
let token_info = self.get_token_info(token).await?;
self.renew_timeout_internal(token, timeout_seconds, &token_info).await
}
async fn renew_timeout_internal(
&self,
token: &TokenValue,
timeout_seconds: i64,
token_info: &TokenInfo,
) -> SaTokenResult<()> {
let mut new_token_info = token_info.clone();
use chrono::{Utc, Duration};
let new_expire_time = Utc::now() + Duration::seconds(timeout_seconds);
new_token_info.expire_time = Some(new_expire_time);
let key = format!("sa:token:{}", token.as_str());
let value = serde_json::to_string(&new_token_info)
.map_err(SaTokenError::SerializationError)?;
let timeout = std::time::Duration::from_secs(timeout_seconds as u64);
self.storage.set(&key, &value, Some(timeout)).await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
Ok(())
}
pub async fn kick_out(&self, login_id: &str) -> SaTokenResult<()> {
let token_result = self.storage.get(&format!("sa:login:token:{}", login_id)).await;
if let Some(online_mgr) = &self.online_manager {
let _ = online_mgr.kick_out_notify(login_id, "Account kicked out".to_string()).await;
}
self.logout_by_login_id(login_id).await?;
self.delete_session(login_id).await?;
if let Ok(Some(token_str)) = token_result {
let event = SaTokenEvent::kick_out(login_id, token_str);
self.event_bus.publish(event).await;
}
Ok(())
}
}