actix_web_socket_io/
session.rs

1use actix::prelude::*;
2
3use actix_web_actors::ws::{self, Message};
4use crossbeam::channel::{self, Receiver, Sender};
5use serde::{Deserialize, Serialize};
6use std::{collections::HashMap, sync::Arc, time::Duration};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use crate::{
11    socketio::{EngineIOPacketType, EventData, MessageType, OpenPacket, SocketIOPacketType},
12    SocketConfig,
13};
14
15/// 会话,每创建一个连接,生成一个会话
16pub struct Session {
17    pub id: Uuid,
18    session_store: Arc<RwLock<SessionStore>>,
19    sender: Sender<MessageType>,
20    receiver: Receiver<MessageType>,
21    pub heartbeat: bool,
22    socket_config: Arc<SocketConfig>,
23}
24
25impl Session {
26    pub fn new(socket_config: Arc<SocketConfig>, session_store: Arc<RwLock<SessionStore>>) -> Self {
27        let (sender, receiver) = channel::unbounded::<MessageType>();
28        Self {
29            id: Uuid::new_v4(),
30            session_store,
31            sender,
32            receiver,
33            heartbeat: true,
34            socket_config,
35        }
36    }
37
38    /// 注册消息处理逻辑
39    pub fn get_receiver(&self) -> Receiver<MessageType> {
40        self.receiver.clone()
41    }
42}
43
44impl Actor for Session {
45    type Context = ws::WebsocketContext<Self>;
46
47    /// 会话创建后
48    fn started(&mut self, ctx: &mut Self::Context) {
49        actix_web::rt::spawn({
50            let session_store = self.session_store.clone();
51            let id = self.id;
52            let address = ctx.address();
53            async move {
54                session_store.write().await.sessions.insert(id, address);
55            }
56        });
57
58        // 回应 engine.io
59        let ping_interval = self.socket_config.ping_interval;
60        let ping_timeout = self.socket_config.ping_timeout;
61        ctx.address().do_send(OpenPacket {
62            sid: self.id.to_string(),
63            upgrades: vec![],
64            ping_interval,
65            ping_timeout,
66            max_payload: self.socket_config.max_payload,
67        });
68
69        // 心跳
70        ctx.run_interval(
71            Duration::from_millis(ping_interval.into()),
72            move |session, ctx| {
73                // 发送 Ping
74                ctx.text((EngineIOPacketType::Ping as u8).to_string());
75                session.heartbeat = false;
76
77                ctx.run_later(
78                    Duration::from_millis(ping_timeout.into()),
79                    |session, ctx| {
80                        // 没有收到心跳回应,断开连接
81                        if !session.heartbeat {
82                            ctx.close(None);
83                        }
84                    },
85                );
86            },
87        );
88    }
89
90    /// 会话将要断开时
91    fn stopping(&mut self, _ctx: &mut Self::Context) -> Running {
92        let _ = self.sender.send(MessageType::Event(EventData(
93            "disconnect".to_string(),
94            serde_json::Value::Null,
95        )));
96
97        actix_web::rt::spawn({
98            let session_store = self.session_store.clone();
99            let id = self.id;
100            async move {
101                session_store.write().await.sessions.remove(&id);
102            }
103        });
104        Running::Stop
105    }
106}
107
108impl<T: Serialize> Handler<Arc<Emiter<T>>> for Session {
109    type Result = Result<(), &'static str>;
110    fn handle(&mut self, msg: Arc<Emiter<T>>, ctx: &mut Self::Context) -> Self::Result {
111        let Ok(json_str) = serde_json::to_string(&msg.data) else {
112            return Err("json 序列化失败");
113        };
114        ctx.text(format!(
115            "{}{}[\"{}\",{}]",
116            EngineIOPacketType::Message as u8,
117            SocketIOPacketType::Event as u8,
118            msg.event_name,
119            json_str
120        ));
121
122        Ok(())
123    }
124}
125
126/// 建立连接回应给客户端处理
127impl Handler<ConnectPacket> for Session {
128    type Result = Result<(), &'static str>;
129    fn handle(&mut self, msg: ConnectPacket, ctx: &mut Self::Context) -> Self::Result {
130        let Ok(json_str) = serde_json::to_string(&msg.data) else {
131            return Err("json 序列化失败");
132        };
133        ctx.text(format!("{}{}", msg.r#type as u8, json_str));
134
135        Ok(())
136    }
137}
138
139impl Handler<OpenPacket> for Session {
140    type Result = Result<(), &'static str>;
141    fn handle(&mut self, msg: OpenPacket, ctx: &mut Self::Context) -> Self::Result {
142        let Ok(json_str) = serde_json::to_string(&msg) else {
143            return Err("json 序列化失败");
144        };
145
146        ctx.text(format!("{}{}", EngineIOPacketType::Open as u8, json_str));
147
148        Ok(())
149    }
150}
151
152impl<T: Serialize> Handler<AuthSuccess<T>> for Session {
153    type Result = Result<(), &'static str>;
154    fn handle(&mut self, msg: AuthSuccess<T>, ctx: &mut Self::Context) -> Self::Result {
155        let Ok(json_str) = serde_json::to_string(&msg) else {
156            return Err("json 序列化失败");
157        };
158
159        ctx.text(format!(
160            "{}{}{}",
161            EngineIOPacketType::Message as u8,
162            SocketIOPacketType::Connect as u8,
163            json_str
164        ));
165
166        Ok(())
167    }
168}
169
170impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for Session {
171    /// 收到消息后的处理
172    fn handle(&mut self, item: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
173        // 提取消息
174        let msg = match item {
175            Err(_) => {
176                ctx.stop();
177                return;
178            }
179            Ok(msg) => msg,
180        };
181        match msg {
182            // 收到文本消息
183            Message::Text(byte_string) => {
184                let raw = byte_string.to_string();
185                let mut eg_type = None;
186                let mut sc_type = None;
187                let data_str = raw.get(2..);
188
189                eg_type = raw
190                    .get(0..1)
191                    .and_then(|f| f.parse::<u8>().ok())
192                    .and_then(|f| EngineIOPacketType::try_from(f).ok());
193
194                sc_type = raw
195                    .get(1..2)
196                    .and_then(|f| f.parse::<u8>().ok())
197                    .and_then(|f| SocketIOPacketType::try_from(f).ok());
198
199                if let Some(eg_type) = eg_type {
200                    match eg_type {
201                        EngineIOPacketType::Open => (),
202                        EngineIOPacketType::Close => (),
203                        EngineIOPacketType::Ping => (),
204                        EngineIOPacketType::Pong => {
205                            // 客户端心跳上报
206                            self.heartbeat = true;
207                        }
208                        EngineIOPacketType::Message => {
209                            if let Some(sc_type) = sc_type {
210                                if let Some(data_str) = data_str {
211                                    let sended = self.sender.send(match sc_type {
212                                        SocketIOPacketType::Connect => MessageType::None,
213                                        SocketIOPacketType::Disconnect => MessageType::None,
214                                        SocketIOPacketType::Event => {
215                                            serde_json::from_str::<EventData>(data_str)
216                                                .map_or(MessageType::None, |event| {
217                                                    MessageType::Event(event)
218                                                })
219                                        }
220                                        SocketIOPacketType::Ack => MessageType::None,
221                                        SocketIOPacketType::ConnectError => MessageType::None,
222                                        SocketIOPacketType::BinaryEvent => MessageType::None,
223                                        SocketIOPacketType::BinaryAck => MessageType::None,
224                                    });
225
226                                    if sended.is_err() {
227                                        log::error!("socket-io 发送数据失败{sended:?}");
228                                    }
229                                }
230                            }
231                        }
232                        EngineIOPacketType::Upgrade => (),
233                        EngineIOPacketType::Noop => (),
234                    }
235                }
236            }
237            // 收到二进制消息
238            Message::Binary(_bytes) => {
239                // data_binary = bytes;
240            }
241            _ => {}
242        }
243    }
244}
245
246/// 建立连接 header 头
247#[derive(Serialize, Deserialize, Clone)]
248struct Header {
249    sid: Option<String>,
250    token: Option<String>,
251}
252
253/// 建立连接结构体
254#[derive(Message)]
255#[rtype(result = "Result<(), &'static str>")]
256pub struct ConnectPacket {
257    r#type: SocketIOPacketType,
258    data: Header,
259}
260
261/// 鉴权响应数据
262#[derive(Message, Serialize)]
263#[rtype(result = "Result<(), &'static str>")]
264pub struct AuthSuccess<T: Serialize> {
265    pub data: T,
266}
267
268/// 发送客户端
269#[derive(Message)]
270#[rtype(result = "Result<(), &'static str>")]
271pub struct Emiter<T: Serialize> {
272    pub event_name: String,
273    pub data: T,
274}
275
276/// 存储所有客户端会话的 store
277pub struct SessionStore {
278    // 存储的客户端会话
279    pub sessions: HashMap<Uuid, Addr<Session>>,
280}
281impl SessionStore {
282    pub fn new() -> Self {
283        Self {
284            sessions: HashMap::new(),
285        }
286    }
287}