use crate::error::SaTokenError;
use crate::manager::SaTokenManager;
use crate::token::TokenValue;
use crate::event::SaTokenEvent;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct WsAuthInfo {
pub login_id: String,
pub token: String,
pub session_id: String,
pub connect_time: chrono::DateTime<chrono::Utc>,
pub metadata: HashMap<String, String>,
}
#[async_trait]
pub trait WsTokenExtractor: Send + Sync {
async fn extract_token(&self, headers: &HashMap<String, String>, query: &HashMap<String, String>) -> Option<String>;
}
pub struct DefaultWsTokenExtractor;
#[async_trait]
impl WsTokenExtractor for DefaultWsTokenExtractor {
async fn extract_token(&self, headers: &HashMap<String, String>, query: &HashMap<String, String>) -> Option<String> {
if let Some(token) = headers.get("Authorization") {
return Some(token.trim_start_matches("Bearer ").to_string());
}
if let Some(token) = headers.get("Sec-WebSocket-Protocol") {
return Some(token.to_string());
}
if let Some(token) = query.get("token") {
return Some(token.to_string());
}
None
}
}
pub struct WsAuthManager {
manager: Arc<SaTokenManager>,
extractor: Arc<dyn WsTokenExtractor>,
}
impl WsAuthManager {
pub fn new(manager: Arc<SaTokenManager>) -> Self {
Self {
manager,
extractor: Arc::new(DefaultWsTokenExtractor),
}
}
pub fn with_extractor(manager: Arc<SaTokenManager>, extractor: Arc<dyn WsTokenExtractor>) -> Self {
Self {
manager,
extractor,
}
}
pub async fn authenticate(
&self,
headers: &HashMap<String, String>,
query: &HashMap<String, String>,
) -> Result<WsAuthInfo, SaTokenError> {
let token_str = self.extractor.extract_token(headers, query).await
.ok_or(SaTokenError::NotLogin)?;
let token = TokenValue::new(token_str.clone());
let token_info = self.manager.get_token_info(&token).await?;
if let Some(expire_time) = token_info.expire_time
&& chrono::Utc::now() > expire_time {
return Err(SaTokenError::TokenExpired);
}
let login_id = token_info.login_id.clone();
let session_id = format!("ws:{}:{}", login_id, uuid::Uuid::new_v4());
let auth_info = WsAuthInfo {
login_id: login_id.clone(),
token: token_str.clone(),
session_id,
connect_time: chrono::Utc::now(),
metadata: HashMap::new(),
};
let event = SaTokenEvent::login(login_id, &token_str)
.with_login_type("websocket");
self.manager.event_bus().publish(event).await;
Ok(auth_info)
}
pub async fn verify_token(&self, token: &str) -> Result<String, SaTokenError> {
let token_value = TokenValue::new(token.to_string());
let token_info = self.manager.get_token_info(&token_value).await?;
if let Some(expire_time) = token_info.expire_time
&& chrono::Utc::now() > expire_time {
return Err(SaTokenError::TokenExpired);
}
Ok(token_info.login_id)
}
pub async fn refresh_ws_session(&self, auth_info: &WsAuthInfo) -> Result<(), SaTokenError> {
self.verify_token(&auth_info.token).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SaTokenConfig;
use sa_token_storage_memory::MemoryStorage;
#[tokio::test]
async fn test_ws_auth_manager() {
let config = SaTokenConfig::default();
let storage = Arc::new(MemoryStorage::new());
let manager = Arc::new(SaTokenManager::new(storage, config));
let ws_manager = WsAuthManager::new(manager.clone());
let token = manager.login("user123").await.unwrap();
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), format!("Bearer {}", token.as_str()));
let auth_info = ws_manager.authenticate(&headers, &HashMap::new()).await.unwrap();
assert_eq!(auth_info.login_id, "user123");
}
#[tokio::test]
async fn test_token_extraction_from_query() {
let config = SaTokenConfig::default();
let storage = Arc::new(MemoryStorage::new());
let manager = Arc::new(SaTokenManager::new(storage, config));
let ws_manager = WsAuthManager::new(manager.clone());
let token = manager.login("user456").await.unwrap();
let mut query = HashMap::new();
query.insert("token".to_string(), token.as_str().to_string());
let auth_info = ws_manager.authenticate(&HashMap::new(), &query).await.unwrap();
assert_eq!(auth_info.login_id, "user456");
}
#[tokio::test]
async fn test_verify_token() {
let config = SaTokenConfig::default();
let storage = Arc::new(MemoryStorage::new());
let manager = Arc::new(SaTokenManager::new(storage, config));
let ws_manager = WsAuthManager::new(manager.clone());
let token = manager.login("user789").await.unwrap();
let login_id = ws_manager.verify_token(token.as_str()).await.unwrap();
assert_eq!(login_id, "user789");
}
}