1use crate::prelude::*;
2use deadpool_redis::redis::AsyncCommands;
3use chrono::Utc;
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10#[derive(thiserror::Error, Debug)]
13pub enum WsError {
14 #[error("Redis manager error: {0}")]
15 RedisManager(#[from] RedisManagerError),
16
17 #[error("Redis error: {0}")]
18 Redis(#[from] deadpool_redis::redis::RedisError),
19
20 #[error("Redis pool error: {0}")]
21 RedisPool(#[from] deadpool_redis::PoolError),
22
23 #[error("JSON error: {0}")]
24 Json(#[from] serde_json::Error),
25}
26
27pub type Result<T> = std::result::Result<T, WsError>;
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31pub struct UserMsg {
32 pub from: String, pub to: String, pub msg: String, pub time: String, }
37
38impl UserMsg {
39 pub fn new(from: String, to: String, msg: String) -> Self {
40 Self {
41 from,
42 to,
43 msg,
44 time: Utc::now().to_rfc3339(),
45 }
46 }
47}
48
49pub struct WsConfig {
50 pub server: u32,
51 pub redis: RedisManager,
52}
53
54#[derive(Clone)]
55pub struct WsManager {
56 pub server: u32,
57 pub redis: RedisManager,
58 pub local_sessions: Arc<DashMap<String, mpsc::Sender<String>>>,
60}
61
62impl WsManager {
63 pub fn new(config: WsConfig) -> Self {
65 Self {
66 server: config.server,
67 redis: config.redis,
68 local_sessions: Arc::new(DashMap::new()),
69 }
70 }
71
72 pub async fn register(&self, uid: &str, tx: mpsc::Sender<String>) -> Result<()> {
74 let mut conn = self.redis.get_connection().await?;
75
76 conn.set::<_, _, ()>(format!("user:{}", uid), self.server.to_string()).await?;
78
79 self.local_sessions.insert(uid.to_string(), tx);
80 Ok(())
81 }
82
83 pub async fn join_room(&self, room_name: &str, uid: &str) -> Result<()> {
85 let mut conn = self.redis.get_connection().await?;
86 conn.sadd::<_, _, ()>(format!("room:{}", room_name), uid).await?;
88 Ok(())
89 }
90
91 pub async fn leave_room(&self, room_name: &str, uid: &str) -> Result<()> {
93 let mut conn = self.redis.get_connection().await?;
94 conn.srem::<_, _, ()>(format!("room:{}", room_name), uid).await?;
95
96 Ok(())
97 }
98
99 pub async fn msg_room(&self, room_name: &str, msg_obj: UserMsg) -> Result<()> {
101 let mut conn = self.redis.get_connection().await?;
102 let msg_str = serde_json::to_string(&msg_obj)?;
103
104 conn.rpush::<_, _, ()>(format!("room_msgs:{}", room_name), &msg_str).await?;
106
107 let users: Vec<String> = conn.smembers(format!("room:{}", room_name)).await?;
109
110 for uid in users {
112 let _ = self.msg_user(&uid, msg_str.clone()).await;
114 }
115 Ok(())
116 }
117
118 pub async fn msg_user(&self, uid: &str, msg: String) -> Result<bool> {
120 let mut conn = self.redis.get_connection().await?;
121
122 let user_data: Option<String> = conn.get(format!("user:{}", uid)).await?;
124
125 if let Some(data) = user_data {
126 let server_id = data.parse::<u32>().unwrap_or(0);
127
128 if server_id == self.server {
129 if let Some(sender) = self.local_sessions.get(uid) {
131 let _ = sender.send(msg).await;
132 }
133 } else {
134 let payload = json!({
137 "target_uid": uid,
138 "msg": msg
139 }).to_string();
140
141 conn.publish::<_, _, ()>("fr-ws", payload).await?;
143 }
144 } else {
145 return Ok(false);
146 }
147
148 Ok(true)
149 }
150
151 pub async fn drop_user(&self, uid: &str) -> Result<()> {
153 let mut conn = self.redis.get_connection().await?;
154 conn.del::<_, ()>(format!("user:{}", uid)).await?;
156 self.local_sessions.remove(uid);
157 Ok(())
158 }
159
160 pub async fn drop_room(&self, room_name: &str) -> Result<()> {
162 let mut conn = self.redis.get_connection().await?;
163 conn.del::<_, ()>(format!("room:{}", room_name)).await?;
165 conn.del::<_, ()>(format!("room_msgs:{}", room_name)).await?;
166 Ok(())
167 }
168
169 pub async fn broadcast(&self, msg: String) -> Result<()> {
171 let mut conn = self.redis.get_connection().await?;
172
173 conn.publish::<_, _, ()>("fr-ws-broadcast", &msg).await?;
176
177 for entry in self.local_sessions.iter() {
179 let _ = entry.value().send(msg.clone()).await;
180 }
181 Ok(())
182 }
183
184 pub async fn get_room_msgs(&self, room_name: &str) -> Result<Vec<UserMsg>> {
186 let mut conn = self.redis.get_connection().await?;
187 let msgs_str: Vec<String> = conn.lrange(format!("room_msgs:{}", room_name), 0, -1).await?;
188
189 let mut msgs = Vec::new();
190 for m in msgs_str {
191 if let Ok(parsed) = serde_json::from_str(&m) {
192 msgs.push(parsed);
193 }
194 }
195 Ok(msgs)
196 }
197}