use crate::error::SaTokenError;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone)]
pub struct OnlineUser {
pub login_id: String,
pub token: String,
pub device: String,
pub connect_time: DateTime<Utc>,
pub last_activity: DateTime<Utc>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct PushMessage {
pub message_id: String,
pub content: String,
pub message_type: MessageType,
pub timestamp: DateTime<Utc>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum MessageType {
Text,
Binary,
KickOut,
Notification,
Custom(String),
}
#[async_trait]
pub trait MessagePusher: Send + Sync {
async fn push(&self, login_id: &str, message: PushMessage) -> Result<(), SaTokenError>;
}
pub struct OnlineManager {
online_users: Arc<RwLock<HashMap<String, Vec<OnlineUser>>>>,
pushers: Arc<RwLock<Vec<Arc<dyn MessagePusher>>>>,
}
impl OnlineManager {
pub fn new() -> Self {
Self {
online_users: Arc::new(RwLock::new(HashMap::new())),
pushers: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn register_pusher(&self, pusher: Arc<dyn MessagePusher>) {
let mut pushers = self.pushers.write().await;
pushers.push(pusher);
}
pub async fn mark_online(&self, user: OnlineUser) {
let mut users = self.online_users.write().await;
users.entry(user.login_id.clone())
.or_insert_with(Vec::new)
.push(user);
}
pub async fn mark_offline(&self, login_id: &str, token: &str) {
let mut users = self.online_users.write().await;
if let Some(user_sessions) = users.get_mut(login_id) {
user_sessions.retain(|u| u.token != token);
if user_sessions.is_empty() {
users.remove(login_id);
}
}
}
pub async fn mark_offline_all(&self, login_id: &str) {
let mut users = self.online_users.write().await;
users.remove(login_id);
}
pub async fn is_online(&self, login_id: &str) -> bool {
let users = self.online_users.read().await;
users.contains_key(login_id)
}
pub async fn get_online_count(&self) -> usize {
let users = self.online_users.read().await;
users.len()
}
pub async fn get_online_users(&self) -> Vec<String> {
let users = self.online_users.read().await;
users.keys().cloned().collect()
}
pub async fn get_user_sessions(&self, login_id: &str) -> Vec<OnlineUser> {
let users = self.online_users.read().await;
users.get(login_id).cloned().unwrap_or_default()
}
pub async fn update_activity(&self, login_id: &str, token: &str) {
let mut users = self.online_users.write().await;
if let Some(user_sessions) = users.get_mut(login_id) {
for user in user_sessions.iter_mut() {
if user.token == token {
user.last_activity = Utc::now();
break;
}
}
}
}
pub async fn push_to_user(&self, login_id: &str, content: String) -> Result<(), SaTokenError> {
let message = PushMessage {
message_id: uuid::Uuid::new_v4().to_string(),
content,
message_type: MessageType::Text,
timestamp: Utc::now(),
metadata: HashMap::new(),
};
let pushers = self.pushers.read().await;
for pusher in pushers.iter() {
pusher.push(login_id, message.clone()).await?;
}
Ok(())
}
pub async fn push_to_users(&self, login_ids: Vec<String>, content: String) -> Result<(), SaTokenError> {
for login_id in login_ids {
self.push_to_user(&login_id, content.clone()).await?;
}
Ok(())
}
pub async fn broadcast(&self, content: String) -> Result<(), SaTokenError> {
let login_ids = self.get_online_users().await;
self.push_to_users(login_ids, content).await
}
pub async fn push_message_to_user(&self, login_id: &str, message: PushMessage) -> Result<(), SaTokenError> {
let pushers = self.pushers.read().await;
for pusher in pushers.iter() {
pusher.push(login_id, message.clone()).await?;
}
Ok(())
}
pub async fn kick_out_notify(&self, login_id: &str, reason: String) -> Result<(), SaTokenError> {
let message = PushMessage {
message_id: uuid::Uuid::new_v4().to_string(),
content: reason,
message_type: MessageType::KickOut,
timestamp: Utc::now(),
metadata: HashMap::new(),
};
self.push_message_to_user(login_id, message).await?;
self.mark_offline_all(login_id).await;
Ok(())
}
}
impl Default for OnlineManager {
fn default() -> Self {
Self::new()
}
}
pub struct InMemoryPusher {
messages: Arc<RwLock<HashMap<String, Vec<PushMessage>>>>,
}
impl InMemoryPusher {
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn get_messages(&self, login_id: &str) -> Vec<PushMessage> {
let messages = self.messages.read().await;
messages.get(login_id).cloned().unwrap_or_default()
}
pub async fn clear_messages(&self, login_id: &str) {
let mut messages = self.messages.write().await;
messages.remove(login_id);
}
}
impl Default for InMemoryPusher {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MessagePusher for InMemoryPusher {
async fn push(&self, login_id: &str, message: PushMessage) -> Result<(), SaTokenError> {
let mut messages = self.messages.write().await;
messages.entry(login_id.to_string())
.or_insert_with(Vec::new)
.push(message);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_online_manager() {
let manager = OnlineManager::new();
let user = OnlineUser {
login_id: "user1".to_string(),
token: "token1".to_string(),
device: "web".to_string(),
connect_time: Utc::now(),
last_activity: Utc::now(),
metadata: HashMap::new(),
};
manager.mark_online(user).await;
assert!(manager.is_online("user1").await);
assert_eq!(manager.get_online_count().await, 1);
}
#[tokio::test]
async fn test_mark_offline() {
let manager = OnlineManager::new();
let user = OnlineUser {
login_id: "user2".to_string(),
token: "token2".to_string(),
device: "mobile".to_string(),
connect_time: Utc::now(),
last_activity: Utc::now(),
metadata: HashMap::new(),
};
manager.mark_online(user).await;
assert!(manager.is_online("user2").await);
manager.mark_offline("user2", "token2").await;
assert!(!manager.is_online("user2").await);
}
#[tokio::test]
async fn test_push_message() {
let manager = OnlineManager::new();
let pusher = Arc::new(InMemoryPusher::new());
manager.register_pusher(pusher.clone()).await;
let user = OnlineUser {
login_id: "user3".to_string(),
token: "token3".to_string(),
device: "web".to_string(),
connect_time: Utc::now(),
last_activity: Utc::now(),
metadata: HashMap::new(),
};
manager.mark_online(user).await;
manager.push_to_user("user3", "Hello".to_string()).await.unwrap();
let messages = pusher.get_messages("user3").await;
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].content, "Hello");
}
#[tokio::test]
async fn test_broadcast() {
let manager = OnlineManager::new();
let pusher = Arc::new(InMemoryPusher::new());
manager.register_pusher(pusher.clone()).await;
for i in 1..=3 {
let user = OnlineUser {
login_id: format!("user{}", i),
token: format!("token{}", i),
device: "web".to_string(),
connect_time: Utc::now(),
last_activity: Utc::now(),
metadata: HashMap::new(),
};
manager.mark_online(user).await;
}
manager.broadcast("Broadcast message".to_string()).await.unwrap();
for i in 1..=3 {
let messages = pusher.get_messages(&format!("user{}", i)).await;
assert_eq!(messages.len(), 1);
}
}
#[tokio::test]
async fn test_kick_out_notify() {
let manager = OnlineManager::new();
let pusher = Arc::new(InMemoryPusher::new());
manager.register_pusher(pusher.clone()).await;
let user = OnlineUser {
login_id: "user4".to_string(),
token: "token4".to_string(),
device: "web".to_string(),
connect_time: Utc::now(),
last_activity: Utc::now(),
metadata: HashMap::new(),
};
manager.mark_online(user).await;
assert!(manager.is_online("user4").await);
manager.kick_out_notify("user4", "Kicked out".to_string()).await.unwrap();
assert!(!manager.is_online("user4").await);
let messages = pusher.get_messages("user4").await;
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].message_type, MessageType::KickOut);
}
}