Skip to main content

actix_web_socket_io/
session.rs

1use actix::prelude::*;
2
3use actix_web_actors::ws::{self, Message};
4use serde::{Deserialize, Serialize};
5use std::{collections::HashMap, sync::Arc, time::Duration};
6use tokio::sync::{
7    broadcast::{self, Receiver, Sender},
8    RwLock,
9};
10use uuid::Uuid;
11
12use crate::{
13    socketio::{
14        ConnectSuccess, EngineIOPacketType, EventData, MessageType, OpenPacket, SocketIOPacketType,
15    },
16    SocketConfig,
17};
18
19/// 会话,每创建一个连接,生成一个会话
20pub struct Session {
21    pub id: Uuid,
22    session_store: Arc<RwLock<SessionStore>>,
23    sender: Sender<MessageType>,
24    pub heartbeat: bool,
25    socket_config: Arc<SocketConfig>,
26}
27
28impl Session {
29    pub fn new(socket_config: Arc<SocketConfig>, session_store: Arc<RwLock<SessionStore>>) -> Self {
30        let (sender, _) = broadcast::channel::<MessageType>(1024);
31        Self {
32            id: Uuid::new_v4(),
33            session_store,
34            sender,
35            heartbeat: true,
36            socket_config,
37        }
38    }
39
40    /// 注册消息处理逻辑
41    pub fn get_receiver(&self) -> Receiver<MessageType> {
42        self.sender.subscribe()
43    }
44}
45
46impl Actor for Session {
47    type Context = ws::WebsocketContext<Self>;
48
49    /// 会话创建后
50    fn started(&mut self, ctx: &mut Self::Context) {
51        actix_web::rt::spawn({
52            let session_store = self.session_store.clone();
53            let id = self.id;
54            let address = ctx.address();
55            async move {
56                session_store.write().await.sessions.insert(id, address);
57            }
58        });
59
60        // 回应 engine.io
61        let ping_interval = self.socket_config.ping_interval;
62        let ping_timeout = self.socket_config.ping_timeout;
63        ctx.address().do_send(OpenPacket {
64            sid: self.id.to_string(),
65            upgrades: vec![],
66            ping_interval,
67            ping_timeout,
68            max_payload: self.socket_config.max_payload,
69        });
70
71        // 心跳
72        ctx.run_interval(
73            Duration::from_millis(ping_interval.into()),
74            move |session, ctx| {
75                // 发送 Ping
76                ctx.text((EngineIOPacketType::Ping as u8).to_string());
77                session.heartbeat = false;
78
79                ctx.run_later(
80                    Duration::from_millis(ping_timeout.into()),
81                    |session, ctx| {
82                        // 没有收到心跳回应,断开连接
83                        if !session.heartbeat {
84                            ctx.close(None);
85                        }
86                    },
87                );
88            },
89        );
90    }
91
92    /// 会话将要断开时
93    fn stopping(&mut self, _ctx: &mut Self::Context) -> Running {
94        let _ = self.sender.send(MessageType::Event(EventData(
95            "disconnect".to_string(),
96            serde_json::Value::Null,
97        )));
98
99        actix_web::rt::spawn({
100            let session_store = self.session_store.clone();
101            let id = self.id;
102            async move {
103                session_store.write().await.sessions.remove(&id);
104            }
105        });
106        Running::Stop
107    }
108}
109
110impl<T: Serialize> Handler<ConnectSuccess<T>> for Session {
111    type Result = Result<(), &'static str>;
112    fn handle(&mut self, msg: ConnectSuccess<T>, ctx: &mut Self::Context) -> Self::Result {
113        let Ok(json_str) = serde_json::to_string(&msg.data) else {
114            return Err("json 序列化失败");
115        };
116        ctx.text(format!(
117            "{}{}{}",
118            EngineIOPacketType::Message as u8,
119            SocketIOPacketType::Connect as u8,
120            json_str
121        ));
122
123        Ok(())
124    }
125}
126
127impl<T: Serialize> Handler<Arc<Emiter<T>>> for Session {
128    type Result = Result<(), &'static str>;
129    fn handle(&mut self, msg: Arc<Emiter<T>>, ctx: &mut Self::Context) -> Self::Result {
130        let Ok(json_str) = serde_json::to_string(&msg.data) else {
131            return Err("json 序列化失败");
132        };
133        ctx.text(format!(
134            "{}{}[\"{}\",{}]",
135            EngineIOPacketType::Message as u8,
136            SocketIOPacketType::Event as u8,
137            msg.event_name,
138            json_str
139        ));
140
141        Ok(())
142    }
143}
144
145/// 建立连接回应给客户端处理
146impl Handler<ConnectPacket> for Session {
147    type Result = Result<(), &'static str>;
148    fn handle(&mut self, msg: ConnectPacket, ctx: &mut Self::Context) -> Self::Result {
149        let Ok(json_str) = serde_json::to_string(&msg.data) else {
150            return Err("json 序列化失败");
151        };
152        ctx.text(format!("{}{}", msg.r#type as u8, json_str));
153
154        Ok(())
155    }
156}
157
158impl Handler<OpenPacket> for Session {
159    type Result = Result<(), &'static str>;
160    fn handle(&mut self, msg: OpenPacket, ctx: &mut Self::Context) -> Self::Result {
161        let Ok(json_str) = serde_json::to_string(&msg) else {
162            return Err("json 序列化失败");
163        };
164
165        ctx.text(format!("{}{}", EngineIOPacketType::Open as u8, json_str));
166
167        Ok(())
168    }
169}
170
171impl<T: Serialize> Handler<AuthSuccess<T>> for Session {
172    type Result = Result<(), &'static str>;
173    fn handle(&mut self, msg: AuthSuccess<T>, ctx: &mut Self::Context) -> Self::Result {
174        let Ok(json_str) = serde_json::to_string(&msg) else {
175            return Err("json 序列化失败");
176        };
177
178        ctx.text(format!(
179            "{}{}{}",
180            EngineIOPacketType::Message as u8,
181            SocketIOPacketType::Connect as u8,
182            json_str
183        ));
184
185        Ok(())
186    }
187}
188
189impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for Session {
190    /// 收到消息后的处理
191    fn handle(&mut self, item: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
192        // 提取消息
193        let msg = match item {
194            Err(_) => {
195                ctx.stop();
196                return;
197            }
198            Ok(msg) => msg,
199        };
200        match msg {
201            // 收到文本消息
202            Message::Text(byte_string) => {
203                let raw = byte_string.to_string();
204                let data_str = raw.get(2..);
205
206                let eg_type = raw
207                    .get(0..1)
208                    .and_then(|f| f.parse::<u8>().ok())
209                    .and_then(|f| EngineIOPacketType::try_from(f).ok());
210
211                let sc_type = raw
212                    .get(1..2)
213                    .and_then(|f| f.parse::<u8>().ok())
214                    .and_then(|f| SocketIOPacketType::try_from(f).ok());
215
216                if let Some(eg_type) = eg_type {
217                    match eg_type {
218                        EngineIOPacketType::Open => (),
219                        EngineIOPacketType::Close => (),
220                        EngineIOPacketType::Ping => (),
221                        EngineIOPacketType::Pong => {
222                            // 客户端心跳上报
223                            self.heartbeat = true;
224                        }
225                        EngineIOPacketType::Message => {
226                            if let Some(sc_type) = sc_type {
227                                if let Some(data_str) = data_str {
228                                    let sended = self.sender.send(match sc_type {
229                                        SocketIOPacketType::Connect => MessageType::Connect,
230                                        SocketIOPacketType::Disconnect => MessageType::None,
231                                        SocketIOPacketType::Event => {
232                                            serde_json::from_str::<EventData>(data_str)
233                                                .map_or(MessageType::None, |event| {
234                                                    MessageType::Event(event)
235                                                })
236                                        }
237                                        SocketIOPacketType::Ack => MessageType::None,
238                                        SocketIOPacketType::ConnectError => MessageType::None,
239                                        SocketIOPacketType::BinaryEvent => MessageType::None,
240                                        SocketIOPacketType::BinaryAck => MessageType::None,
241                                    });
242
243                                    if sended.is_err() {
244                                        log::error!("socket-io 发送数据失败{sended:?}");
245                                    }
246                                }
247                            }
248                        }
249                        EngineIOPacketType::Upgrade => (),
250                        EngineIOPacketType::Noop => (),
251                    }
252                }
253            }
254            // 收到二进制消息
255            Message::Binary(_bytes) => {
256                // data_binary = bytes;
257            }
258            _ => {}
259        }
260    }
261}
262
263/// 建立连接 header 头
264#[derive(Serialize, Deserialize, Clone)]
265struct Header {
266    sid: Option<String>,
267    token: Option<String>,
268}
269
270/// 建立连接结构体
271#[derive(Message)]
272#[rtype(result = "Result<(), &'static str>")]
273pub struct ConnectPacket {
274    r#type: SocketIOPacketType,
275    data: Header,
276}
277
278/// 鉴权响应数据
279#[derive(Message, Serialize)]
280#[rtype(result = "Result<(), &'static str>")]
281pub struct AuthSuccess<T: Serialize> {
282    pub data: T,
283}
284
285/// 发送客户端
286#[derive(Message)]
287#[rtype(result = "Result<(), &'static str>")]
288pub struct Emiter<T: Serialize> {
289    pub event_name: String,
290    pub data: T,
291}
292
293/// 存储所有客户端会话的 store
294pub struct SessionStore {
295    // 存储的客户端会话
296    pub sessions: HashMap<Uuid, Addr<Session>>,
297}
298impl SessionStore {
299    pub fn new() -> Self {
300        Self {
301            sessions: HashMap::new(),
302        }
303    }
304}