use crate::error::SaTokenError;
use async_trait::async_trait;
use sa_token_adapter::storage::SaStorage;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedSession {
pub session_id: String,
pub login_id: String,
pub token: String,
pub service_id: String,
pub create_time: DateTime<Utc>,
pub last_access: DateTime<Utc>,
pub attributes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceCredential {
pub service_id: String,
pub service_name: String,
pub secret_key: String,
pub created_at: DateTime<Utc>,
pub permissions: Vec<String>,
}
#[async_trait]
pub trait DistributedSessionStorage: Send + Sync {
async fn save_session(&self, session: DistributedSession, ttl: Option<Duration>) -> Result<(), SaTokenError>;
async fn get_session(&self, session_id: &str) -> Result<Option<DistributedSession>, SaTokenError>;
async fn delete_session(&self, session_id: &str) -> Result<(), SaTokenError>;
async fn get_sessions_by_login_id(&self, login_id: &str) -> Result<Vec<DistributedSession>, SaTokenError>;
async fn save_credential(&self, credential: ServiceCredential) -> Result<(), SaTokenError>;
async fn get_credential(&self, service_id: &str) -> Result<Option<ServiceCredential>, SaTokenError>;
}
pub struct DistributedSessionManager {
storage: Arc<dyn DistributedSessionStorage>,
service_id: String,
session_timeout: Duration,
}
impl DistributedSessionManager {
pub fn new(
storage: Arc<dyn DistributedSessionStorage>,
service_id: String,
session_timeout: Duration,
) -> Self {
Self {
storage,
service_id,
session_timeout,
}
}
pub async fn register_service(&self, credential: ServiceCredential) -> Result<(), SaTokenError> {
self.storage.save_credential(credential).await
}
pub async fn verify_service(&self, service_id: &str, secret: &str) -> Result<ServiceCredential, SaTokenError> {
if let Some(cred) = self.storage.get_credential(service_id).await?
&& cred.secret_key == secret
{
return Ok(cred);
}
Err(SaTokenError::PermissionDenied)
}
pub async fn create_session(
&self,
login_id: String,
token: String,
) -> Result<DistributedSession, SaTokenError> {
let session = DistributedSession {
session_id: uuid::Uuid::new_v4().to_string(),
login_id,
token,
service_id: self.service_id.clone(),
create_time: Utc::now(),
last_access: Utc::now(),
attributes: HashMap::new(),
};
self.storage.save_session(session.clone(), Some(self.session_timeout)).await?;
Ok(session)
}
pub async fn get_session(&self, session_id: &str) -> Result<DistributedSession, SaTokenError> {
self.storage.get_session(session_id).await?
.ok_or(SaTokenError::SessionNotFound)
}
pub async fn update_session(&self, session: DistributedSession) -> Result<(), SaTokenError> {
self.storage.save_session(session, Some(self.session_timeout)).await
}
pub async fn delete_session(&self, session_id: &str) -> Result<(), SaTokenError> {
self.storage.delete_session(session_id).await
}
pub async fn refresh_session(&self, session_id: &str) -> Result<(), SaTokenError> {
let mut session = self.get_session(session_id).await?;
session.last_access = Utc::now();
self.update_session(session).await
}
pub async fn set_attribute(
&self,
session_id: &str,
key: String,
value: String,
) -> Result<(), SaTokenError> {
let mut session = self.get_session(session_id).await?;
session.attributes.insert(key, value);
session.last_access = Utc::now();
self.update_session(session).await
}
pub async fn get_attribute(
&self,
session_id: &str,
key: &str,
) -> Result<Option<String>, SaTokenError> {
let session = self.get_session(session_id).await?;
Ok(session.attributes.get(key).cloned())
}
pub async fn remove_attribute(
&self,
session_id: &str,
key: &str,
) -> Result<(), SaTokenError> {
let mut session = self.get_session(session_id).await?;
session.attributes.remove(key);
session.last_access = Utc::now();
self.update_session(session).await
}
pub async fn get_sessions_by_login_id(&self, login_id: &str) -> Result<Vec<DistributedSession>, SaTokenError> {
self.storage.get_sessions_by_login_id(login_id).await
}
pub async fn delete_all_sessions(&self, login_id: &str) -> Result<(), SaTokenError> {
let sessions = self.get_sessions_by_login_id(login_id).await?;
for session in sessions {
self.delete_session(&session.session_id).await?;
}
Ok(())
}
}
pub struct InMemoryDistributedStorage {
sessions: Arc<RwLock<HashMap<String, DistributedSession>>>,
login_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
credentials: Arc<RwLock<HashMap<String, ServiceCredential>>>,
}
impl InMemoryDistributedStorage {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
login_index: Arc::new(RwLock::new(HashMap::new())),
credentials: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryDistributedStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl DistributedSessionStorage for InMemoryDistributedStorage {
async fn save_session(&self, session: DistributedSession, _ttl: Option<Duration>) -> Result<(), SaTokenError> {
let session_id = session.session_id.clone();
let login_id = session.login_id.clone();
let mut sessions = self.sessions.write().await;
sessions.insert(session_id.clone(), session);
let mut index = self.login_index.write().await;
let session_list = index.entry(login_id).or_insert_with(Vec::new);
if !session_list.contains(&session_id) {
session_list.push(session_id);
}
Ok(())
}
async fn get_session(&self, session_id: &str) -> Result<Option<DistributedSession>, SaTokenError> {
let sessions = self.sessions.read().await;
Ok(sessions.get(session_id).cloned())
}
async fn delete_session(&self, session_id: &str) -> Result<(), SaTokenError> {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.remove(session_id) {
let mut index = self.login_index.write().await;
if let Some(session_ids) = index.get_mut(&session.login_id) {
session_ids.retain(|id| id != session_id);
if session_ids.is_empty() {
index.remove(&session.login_id);
}
}
}
Ok(())
}
async fn get_sessions_by_login_id(&self, login_id: &str) -> Result<Vec<DistributedSession>, SaTokenError> {
let index = self.login_index.read().await;
let session_ids = index.get(login_id).cloned().unwrap_or_default();
let sessions = self.sessions.read().await;
let mut result = Vec::new();
for session_id in session_ids {
if let Some(session) = sessions.get(&session_id) {
result.push(session.clone());
}
}
Ok(result)
}
async fn save_credential(&self, credential: ServiceCredential) -> Result<(), SaTokenError> {
let mut creds = self.credentials.write().await;
creds.insert(credential.service_id.clone(), credential);
Ok(())
}
async fn get_credential(&self, service_id: &str) -> Result<Option<ServiceCredential>, SaTokenError> {
let creds = self.credentials.read().await;
Ok(creds.get(service_id).cloned())
}
}
pub struct SaStorageDistributedStorage {
storage: Arc<dyn SaStorage>,
key_prefix: String,
}
impl SaStorageDistributedStorage {
pub fn new(storage: Arc<dyn SaStorage>, key_prefix: impl Into<String>) -> Self {
Self {
storage,
key_prefix: key_prefix.into(),
}
}
fn session_key(&self, session_id: &str) -> String {
format!("{}dsession:{}", self.key_prefix, session_id)
}
fn index_key(&self, login_id: &str) -> String {
format!("{}dsession:index:{}", self.key_prefix, login_id)
}
fn credential_key(&self, service_id: &str) -> String {
format!("{}dservice:{}", self.key_prefix, service_id)
}
async fn load_index(&self, index_key: &str) -> Result<Vec<String>, SaTokenError> {
match self
.storage
.get(index_key)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
{
Some(value) => serde_json::from_str(&value).map_err(SaTokenError::SerializationError),
None => Ok(Vec::new()),
}
}
async fn save_index(&self, index_key: &str, ids: &[String]) -> Result<(), SaTokenError> {
let value = serde_json::to_string(ids).map_err(SaTokenError::SerializationError)?;
self.storage
.set(index_key, &value, None)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))
}
}
#[async_trait]
impl DistributedSessionStorage for SaStorageDistributedStorage {
async fn save_session(&self, session: DistributedSession, ttl: Option<Duration>) -> Result<(), SaTokenError> {
let session_key = self.session_key(&session.session_id);
let index_key = self.index_key(&session.login_id);
let session_id = session.session_id.clone();
let value = serde_json::to_string(&session).map_err(SaTokenError::SerializationError)?;
self.storage
.set(&session_key, &value, ttl)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
let mut ids = self.load_index(&index_key).await?;
if !ids.contains(&session_id) {
ids.push(session_id);
self.save_index(&index_key, &ids).await?;
}
Ok(())
}
async fn get_session(&self, session_id: &str) -> Result<Option<DistributedSession>, SaTokenError> {
match self
.storage
.get(&self.session_key(session_id))
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
{
Some(value) => Ok(Some(serde_json::from_str(&value).map_err(SaTokenError::SerializationError)?)),
None => Ok(None),
}
}
async fn delete_session(&self, session_id: &str) -> Result<(), SaTokenError> {
if let Some(session) = self.get_session(session_id).await? {
self.storage
.delete(&self.session_key(session_id))
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
let index_key = self.index_key(&session.login_id);
let mut ids = self.load_index(&index_key).await?;
let before = ids.len();
ids.retain(|id| id != session_id);
if ids.is_empty() {
self.storage
.delete(&index_key)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?;
} else if ids.len() != before {
self.save_index(&index_key, &ids).await?;
}
}
Ok(())
}
async fn get_sessions_by_login_id(&self, login_id: &str) -> Result<Vec<DistributedSession>, SaTokenError> {
let index_key = self.index_key(login_id);
let ids = self.load_index(&index_key).await?;
let original_len = ids.len();
let mut result = Vec::new();
let mut alive_ids = Vec::new();
for id in ids {
if let Some(session) = self.get_session(&id).await? {
result.push(session);
alive_ids.push(id);
}
}
if alive_ids.is_empty() {
let _ = self.storage.delete(&index_key).await;
} else if alive_ids.len() != original_len {
let _ = self.save_index(&index_key, &alive_ids).await;
}
Ok(result)
}
async fn save_credential(&self, credential: ServiceCredential) -> Result<(), SaTokenError> {
let key = self.credential_key(&credential.service_id);
let value = serde_json::to_string(&credential).map_err(SaTokenError::SerializationError)?;
self.storage
.set(&key, &value, None)
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))
}
async fn get_credential(&self, service_id: &str) -> Result<Option<ServiceCredential>, SaTokenError> {
match self
.storage
.get(&self.credential_key(service_id))
.await
.map_err(|e| SaTokenError::StorageError(e.to_string()))?
{
Some(value) => Ok(Some(serde_json::from_str(&value).map_err(SaTokenError::SerializationError)?)),
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_distributed_session_manager() {
let storage = Arc::new(InMemoryDistributedStorage::new());
let manager = DistributedSessionManager::new(
storage,
"service1".to_string(),
Duration::from_secs(3600),
);
let session = manager.create_session(
"user1".to_string(),
"token1".to_string(),
).await.unwrap();
let retrieved = manager.get_session(&session.session_id).await.unwrap();
assert_eq!(retrieved.login_id, "user1");
}
#[tokio::test]
async fn test_session_attributes() {
let storage = Arc::new(InMemoryDistributedStorage::new());
let manager = DistributedSessionManager::new(
storage,
"service1".to_string(),
Duration::from_secs(3600),
);
let session = manager.create_session(
"user2".to_string(),
"token2".to_string(),
).await.unwrap();
manager.set_attribute(
&session.session_id,
"key1".to_string(),
"value1".to_string(),
).await.unwrap();
let value = manager.get_attribute(&session.session_id, "key1").await.unwrap();
assert_eq!(value, Some("value1".to_string()));
}
#[tokio::test]
async fn test_service_verification() {
let storage = Arc::new(InMemoryDistributedStorage::new());
let manager = DistributedSessionManager::new(
storage,
"service1".to_string(),
Duration::from_secs(3600),
);
let credential = ServiceCredential {
service_id: "service2".to_string(),
service_name: "Service 2".to_string(),
secret_key: "secret123".to_string(),
created_at: Utc::now(),
permissions: vec!["read".to_string(), "write".to_string()],
};
manager.register_service(credential.clone()).await.unwrap();
let verified = manager.verify_service("service2", "secret123").await.unwrap();
assert_eq!(verified.service_id, "service2");
let result = manager.verify_service("service2", "wrong_secret").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_delete_all_sessions() {
let storage = Arc::new(InMemoryDistributedStorage::new());
let manager = DistributedSessionManager::new(
storage,
"service1".to_string(),
Duration::from_secs(3600),
);
manager.create_session("user3".to_string(), "token1".to_string()).await.unwrap();
manager.create_session("user3".to_string(), "token2".to_string()).await.unwrap();
let sessions = manager.get_sessions_by_login_id("user3").await.unwrap();
assert_eq!(sessions.len(), 2);
manager.delete_all_sessions("user3").await.unwrap();
let sessions = manager.get_sessions_by_login_id("user3").await.unwrap();
assert_eq!(sessions.len(), 0);
}
}