Skip to main content

fr_rust/ws/
ws.rs

1use crate::prelude::*;
2use deadpool_redis::redis::AsyncCommands;
3use chrono::Utc;
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10// 1. Define the custom error enum using thiserror
11
12#[derive(thiserror::Error, Debug)]
13pub enum WsError {
14    #[error("Redis manager error: {0}")]
15    RedisManager(#[from] RedisManagerError),
16
17    #[error("Redis error: {0}")]
18    Redis(#[from] deadpool_redis::redis::RedisError),
19
20    #[error("Redis pool error: {0}")]
21    RedisPool(#[from] deadpool_redis::PoolError),
22
23    #[error("JSON error: {0}")]
24    Json(#[from] serde_json::Error),
25}
26
27// 2. Create a custom Result alias to clean up the function signatures
28pub type Result<T> = std::result::Result<T, WsError>;
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31pub struct UserMsg {
32    pub from: String, // user id
33    pub to: String,   // room_name
34    pub msg: String,  // message content
35    pub time: String, // timestamp
36}
37
38impl UserMsg {
39    pub fn new(from: String, to: String, msg: String) -> Self {
40        Self {
41            from,
42            to,
43            msg,
44            time: Utc::now().to_rfc3339(),
45        }
46    }
47}
48
49pub struct WsConfig {
50    pub server: u32,
51    pub redis: RedisManager,
52}
53
54#[derive(Clone)]
55pub struct WsManager {
56    pub server: u32,
57    pub redis: RedisManager,
58    // Local state: Maps uid -> Sender channel to the actual WebSocket stream
59    pub local_sessions: Arc<DashMap<String, mpsc::Sender<String>>>,
60}
61
62impl WsManager {
63    // 1. "new" create a new web socket service
64    pub fn new(config: WsConfig) -> Self {
65        Self {
66            server: config.server,
67            redis: config.redis,
68            local_sessions: Arc::new(DashMap::new()),
69        }
70    }
71
72    // 2. "register" save new in redis: user_id: server
73    pub async fn register(&self, uid: &str, tx: mpsc::Sender<String>) -> Result<()> {
74        let mut conn = self.redis.get_connection().await?;
75
76        // Fixed: Added explicit () return type via turbofish
77        conn.set::<_, _, ()>(format!("user:{}", uid), self.server.to_string()).await?;
78
79        self.local_sessions.insert(uid.to_string(), tx);
80        Ok(())
81    }
82
83    // 3. "join_room" add new user_id to room users.
84    pub async fn join_room(&self, room_name: &str, uid: &str) -> Result<()> {
85        let mut conn = self.redis.get_connection().await?;
86        // Fixed: Added explicit () return type via turbofish
87        conn.sadd::<_, _, ()>(format!("room:{}", room_name), uid).await?;
88        Ok(())
89    }
90    
91    // 4. "leave_room" 
92    pub async fn leave_room(&self, room_name: &str, uid: &str) -> Result<()> {
93        let mut conn = self.redis.get_connection().await?;
94        conn.srem::<_, _, ()>(format!("room:{}", room_name), uid).await?;
95        
96        Ok(())
97    }
98
99    // 5. "msg_room" loop in room_users, send msg, save msg in redis
100    pub async fn msg_room(&self, room_name: &str, msg_obj: UserMsg) -> Result<()> {
101        let mut conn = self.redis.get_connection().await?;
102        let msg_str = serde_json::to_string(&msg_obj)?;
103
104        // Fixed: Added explicit () return type via turbofish
105        conn.rpush::<_, _, ()>(format!("room_msgs:{}", room_name), &msg_str).await?;
106
107        // Get all users in the room
108        let users: Vec<String> = conn.smembers(format!("room:{}", room_name)).await?;
109
110        // Send to each user
111        for uid in users {
112            // We clone msg_str so we don't consume it
113            let _ = self.msg_user(&uid, msg_str.clone()).await; 
114        }
115        Ok(())
116    }
117
118    // 6. "msg_user" take user id, check server match -> send locally OR publish
119    pub async fn msg_user(&self, uid: &str, msg: String) -> Result<bool> {
120        let mut conn = self.redis.get_connection().await?;
121        
122        // Fetch user server data from Redis
123        let user_data: Option<String> = conn.get(format!("user:{}", uid)).await?;
124        
125        if let Some(data) = user_data {
126            let server_id = data.parse::<u32>().unwrap_or(0);
127
128            if server_id == self.server {
129                // Match! User is connected to THIS server instance. Send directly.
130                if let Some(sender) = self.local_sessions.get(uid) {
131                    let _ = sender.send(msg).await;
132                }
133            } else {
134                // Doesn't match. User is on another node. Publish to Redis.
135                // We wrap it so the receiving server knows who the target is.
136                let payload = json!({
137                    "target_uid": uid,
138                    "msg": msg
139                }).to_string();
140                
141                // Fixed: Added explicit () return type via turbofish
142                conn.publish::<_, _, ()>("fr-ws", payload).await?;
143            }
144        } else {
145            return Ok(false);
146        }
147
148        Ok(true)
149    }
150
151    // 7. "drop_user" remove user from redis and local sessions
152    pub async fn drop_user(&self, uid: &str) -> Result<()> {
153        let mut conn = self.redis.get_connection().await?;
154        // Fixed: Added explicit () return type via turbofish
155        conn.del::<_, ()>(format!("user:{}", uid)).await?;
156        self.local_sessions.remove(uid);
157        Ok(())
158    }
159
160    // 8. "drop_room" remove room and messages from redis
161    pub async fn drop_room(&self, room_name: &str) -> Result<()> {
162        let mut conn = self.redis.get_connection().await?;
163        // Fixed: Added explicit () return type via turbofish for all del calls
164        conn.del::<_, ()>(format!("room:{}", room_name)).await?;
165        conn.del::<_, ()>(format!("room_msgs:{}", room_name)).await?;
166        Ok(())
167    }
168
169    // 9. "broadcast" loop in all users & send them all msg
170    pub async fn broadcast(&self, msg: String) -> Result<()> {
171        let mut conn = self.redis.get_connection().await?;
172        
173        // 1. Publish to the global broadcast channel so ALL servers get it
174        // Fixed: Added explicit () return type via turbofish
175        conn.publish::<_, _, ()>("fr-ws-broadcast", &msg).await?;
176        
177        // 2. Send to all users connected to THIS local server immediately
178        for entry in self.local_sessions.iter() {
179            let _ = entry.value().send(msg.clone()).await;
180        }
181        Ok(())
182    }
183
184    // 10. "get_room_msgs" get all msgs that exist in room_name
185    pub async fn get_room_msgs(&self, room_name: &str) -> Result<Vec<UserMsg>> {
186        let mut conn = self.redis.get_connection().await?;
187        let msgs_str: Vec<String> = conn.lrange(format!("room_msgs:{}", room_name), 0, -1).await?;
188        
189        let mut msgs = Vec::new();
190        for m in msgs_str {
191            if let Ok(parsed) = serde_json::from_str(&m) {
192                msgs.push(parsed);
193            }
194        }
195        Ok(msgs)
196    }
197}