br_web_server/
websocket.rs

1use crate::request::Request;
2use crate::response::Response;
3use crate::{Handler};
4use json::{object, JsonValue};
5use std::sync::mpsc::{channel, Receiver, Sender};
6use std::sync::{Arc, Mutex};
7use std::{io, thread};
8use std::time::Duration;
9use dashmap::DashMap;
10use log::{debug};
11
12pub static USERS: std::sync::LazyLock<DashMap<String, Websocket>> = std::sync::LazyLock::new(DashMap::new);
13pub static WS_NOTICE: std::sync::LazyLock<Mutex<Vec<NoticeMsg>>> = std::sync::LazyLock::new(|| Mutex::new(Vec::new()));
14#[derive(Debug, Clone)]
15pub struct Websocket {
16    /// 发送
17    send: Option<Sender<Message>>,
18    /// 接收
19    receive: Option<Arc<Mutex<Receiver<Message>>>>,
20    pub key: String,
21    pub user_user: String,
22    pub org_org: String,
23    version: String,
24    request: Request,
25    response: Response,
26}
27
28
29impl Websocket {
30    pub fn http(request: Request, response: Response) -> Self {
31        Self {
32            send: None,
33            receive: None,
34            request,
35            key: String::new(),
36            user_user: "".to_string(),
37            org_org: "".to_string(),
38            version: String::new(),
39            response,
40        }
41    }
42    pub fn new(request: Request, response: Response) -> Self {
43        let (send, receive) = channel();
44        Self {
45            send: Some(send),
46            receive: Some(Arc::new(Mutex::new(receive))),
47            request,
48            key: String::new(),
49            user_user: "".to_string(),
50            org_org: "".to_string(),
51            version: String::new(),
52            response,
53        }
54    }
55    /// 发送数据
56    pub fn send(&mut self, data: JsonValue) {
57        let msg = Message {
58            mode: MessageMode::Server,
59            message_type: MessageType::Text,
60            payload: data.to_string().into_bytes(),
61            text: data.to_string(),
62            close: CloseCode::None,
63            error: ErrorCode::None,
64        };
65        match self.send.clone().unwrap().send(msg) {
66            Ok(()) => (),
67            Err(_) => self.on_error(ErrorCode::SendingDataFailed),
68        }
69    }
70    /// 关闭连接
71    pub fn close(&mut self, code: CloseCode, reason: &str) {
72        let msg = Message {
73            mode: MessageMode::Server,
74            message_type: MessageType::Close,
75            payload: reason.as_bytes().to_vec(),
76            text: reason.to_string(),
77            close: code,
78            error: ErrorCode::None,
79        };
80        match self.send.clone().unwrap().send(msg) {
81            Ok(()) => (),
82            Err(_) => self.on_error(ErrorCode::SendingDataFailed),
83        }
84    }
85    /// 发送给所有对象
86    pub fn send_all(&mut self, data: JsonValue) {
87        for mut user in USERS.iter_mut() {
88            user.send(data.clone());
89        }
90    }
91    /// 发送给指定对象
92    pub fn send_user(&mut self, user_user: &str, data: JsonValue) {
93        if USERS.get(user_user).is_some() {
94            for mut user in USERS.iter_mut() {
95                if user.user_user == user_user {
96                    user.send(data.clone());
97                    return;
98                }
99            }
100        }
101    }
102    /// 发送给指定的企业
103    pub fn send_org(&mut self, org_org: &str, data: JsonValue) {
104        if USERS.get(org_org).is_some() {
105            for mut user in USERS.iter_mut() {
106                if user.org_org == org_org {
107                    user.send(data.clone());
108                    return;
109                }
110            }
111        }
112    }
113    /// 在线人数
114    pub fn online_users(&mut self) -> usize {
115        USERS.len()
116    }
117    pub fn handle(&mut self, factory: fn(out: Websocket) -> Box<dyn Handler>) -> io::Result<()> {
118        let receive = self.receive.clone().unwrap();
119        self.on_frame()?;
120        let mut factory = factory(self.clone());
121        USERS.insert(self.key.to_string(), self.clone());
122        factory.on_open()?;
123        loop {
124            let msg = self.response.scheme.lock().unwrap().read_ws_msg()?;
125            match msg.message_type {
126                MessageType::TimeOut => {}
127                _ => match self.send.clone().unwrap().clone().send(msg) {
128                    Ok(_) => {}
129                    Err(e) => return Err(io::Error::new(io::ErrorKind::BrokenPipe, e))
130                }
131            }
132            let data = receive.lock().unwrap();
133            match data.try_recv() {
134                Ok(mut msg) => {
135                    match msg.mode {
136                        MessageMode::Client => {
137                            match msg.message_type {
138                                MessageType::Close => {
139                                    if USERS.get(&self.key.to_string()).is_some() {
140                                        USERS.remove(&self.key);
141                                    }
142                                    factory.on_close(msg.close.clone(), &msg.text);
143                                    self.response.scheme.lock().unwrap().write(&Message::send_close(CloseCode::ServerClose, "客户退出关闭"))?;
144                                    return Err(io::Error::new(io::ErrorKind::BrokenPipe, msg.text));
145                                }
146                                MessageType::Binary | MessageType::Text => {
147                                    factory.on_message(msg).expect("TODO: panic message");
148                                }
149                                MessageType::Continuation | MessageType::Ping | MessageType::Error => {
150                                    debug!("Client有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
151                                    continue;
152                                }
153                                MessageType::None => {
154                                    debug!("Client有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
155                                    continue;
156                                }
157                                MessageType::Pong => {
158                                    debug!("接收到一个Pong: {:?} {:?} {:?}", msg.mode, msg.message_type, msg.payload);
159                                    continue;
160                                }
161                                MessageType::TimeOut => continue
162                            }
163                        }
164                        MessageMode::Server => {
165                            match msg.message_type {
166                                MessageType::Close => {
167                                    if USERS.get(&self.key.to_string()).is_some() {
168                                        USERS.remove(&self.key);
169                                    }
170                                    factory.on_close(msg.close.clone(), &msg.text);
171                                    self.response.scheme.lock().unwrap().write(&Message::send_close(msg.close.clone(), &msg.text))?;
172                                    return Err(io::Error::new(io::ErrorKind::BrokenPipe, msg.text));
173                                }
174                                MessageType::Binary | MessageType::Text => {
175                                    let res = msg.send_message();
176                                    self.response.scheme.lock().unwrap().write(&res)?;
177                                }
178                                MessageType::Continuation | MessageType::Ping | MessageType::Pong | MessageType::None | MessageType::Error => {
179                                    debug!("服务器有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
180                                }
181                                MessageType::TimeOut => continue
182                            }
183                        }
184                    }
185                }
186                Err(std::sync::mpsc::TryRecvError::Empty) => {
187                    thread::sleep(Duration::from_millis(10));
188                    continue;
189                }
190                Err(e) => {
191                    return Err(io::Error::new(io::ErrorKind::BrokenPipe, e));
192                }
193            }
194        }
195    }
196}
197impl Handler for Websocket {
198    fn on_frame(&mut self) -> io::Result<()> {
199        self.key = self.request.header["sec-websocket-key"].as_str().unwrap_or("").to_string();
200        self.version = self.request.header["sec-websocket-version"].as_str().unwrap_or("").to_string();
201        self.response.status(101).websocket(self.key.as_str()).send()?;
202        Ok(())
203    }
204}
205#[derive(Debug, Clone)]
206pub struct Message {
207    pub mode: MessageMode,
208    pub message_type: MessageType,
209    pub payload: Vec<u8>, // 消息载荷,以字节向量形式表示
210    pub text: String,
211    pub close: CloseCode,
212    pub error: ErrorCode,
213}
214
215impl Message {
216    pub fn msg_error() -> Self {
217        Message {
218            mode: MessageMode::Client,
219            message_type: MessageType::Error,
220            payload: vec![],
221            text: "长度不够".to_string(),
222            close: CloseCode::None,
223            error: ErrorCode::SendingDataFailed,
224        }
225    }
226    // 解析WebSocket消息
227    pub fn parse_message(data: &mut Vec<u8>) -> Message {
228        // 检查数据是否足够长来包含消息类型和载荷长度
229        if data.len() < 2 {
230            return Message {
231                mode: MessageMode::Client,
232                message_type: MessageType::Error,
233                payload: vec![],
234                text: "长度不够".to_string(),
235                close: CloseCode::None,
236                error: ErrorCode::SendingDataFailed,
237            };
238        }
239
240        let header = data.drain(0..2).collect::<Vec<u8>>();
241
242        // 解析帧头
243        let _fin = (header[0] & 0b1000_0000) != 0;
244        let opcode = header[0] & 0b0000_1111;
245        let masked = (header[1] & 0b1000_0000) != 0;
246        let len_flag = header[1] & 0b0111_1111;
247        let mut payload_data = Vec::new();
248        let message_tpye = MessageType::from(opcode);
249        match message_tpye {
250            MessageType::Text => {
251                let payload_length = match len_flag {
252                    0..=125 => len_flag as usize,
253                    126 => {
254                        let ext = data.drain(..2).collect::<Vec<u8>>();
255                        u16::from_be_bytes([ext[0], ext[1]]) as usize
256                    }
257                    127 => {
258                        let ext = data.drain(..8).collect::<Vec<u8>>();
259                        u64::from_be_bytes([
260                            ext[0], ext[1], ext[2], ext[3],
261                            ext[4], ext[5], ext[6], ext[7],
262                        ]) as usize
263                    }
264                    _ => return Message {
265                        mode: MessageMode::Client,
266                        message_type: MessageType::Error,
267                        payload: vec![],
268                        text: "数据格式错误".to_string(),
269                        close: CloseCode::None,
270                        error: ErrorCode::SendingDataFailed,
271                    }
272                };
273                if masked {
274                    let mask_key = data.drain(..4).collect::<Vec<u8>>();
275                    if data.len() < payload_length {
276                        return Message {
277                            mode: MessageMode::Client,
278                            message_type: message_tpye,
279                            payload: payload_data,
280                            text: "继续加载".to_string(),
281                            close: CloseCode::None,
282                            error: ErrorCode::None,
283                        };
284                    }
285                    let payload = &data[..payload_length];
286                    for i in 0..payload.len() {
287                        payload_data.push(payload[i] ^ mask_key[i % 4]);
288                    }
289                } else {
290                    if data.len() < payload_length {
291                        return Message {
292                            mode: MessageMode::Client,
293                            message_type: message_tpye,
294                            payload: payload_data,
295                            text: "继续加载".to_string(),
296                            close: CloseCode::None,
297                            error: ErrorCode::None,
298                        };
299                    }
300                    let t = data.drain(..payload_length).collect::<Vec<u8>>();
301                    payload_data.extend_from_slice(&t);
302                }
303                let text = unsafe { String::from_utf8_unchecked(payload_data.clone()) };
304                Message {
305                    mode: MessageMode::Client,
306                    message_type: message_tpye,
307                    payload: payload_data,
308                    text: text.to_string(),
309                    close: CloseCode::None,
310                    error: ErrorCode::None,
311                }
312            }
313            MessageType::Binary => Message {
314                mode: MessageMode::Client,
315                message_type: message_tpye,
316                payload: payload_data,
317                text: String::new(),
318                close: CloseCode::None,
319                error: ErrorCode::None,
320            },
321            MessageType::Continuation => Message {
322                mode: MessageMode::Client,
323                message_type: message_tpye,
324                payload: payload_data,
325                text: "继续加载".to_string(),
326                close: CloseCode::None,
327                error: ErrorCode::None,
328            },
329            MessageType::Close => Message {
330                mode: MessageMode::Client,
331                message_type: message_tpye,
332                payload: payload_data,
333                text: "客户端关闭".to_string(),
334                close: CloseCode::ClientClose,
335                error: ErrorCode::None,
336            },
337            MessageType::Ping => Message {
338                mode: MessageMode::Client,
339                message_type: message_tpye,
340                payload: payload_data,
341                text: "Ping".to_string(),
342                close: CloseCode::None,
343                error: ErrorCode::None,
344            },
345            MessageType::Pong => Message {
346                mode: MessageMode::Client,
347                message_type: message_tpye,
348                payload: payload_data,
349                text: "Pong".to_string(),
350                close: CloseCode::None,
351                error: ErrorCode::None,
352            },
353            MessageType::Error => {
354                Message {
355                    mode: MessageMode::Client,
356                    message_type: message_tpye,
357                    payload: vec![],
358                    text: String::new(),
359                    close: CloseCode::None,
360                    error: ErrorCode::Unknown,
361                }
362            }
363            MessageType::None => Message {
364                mode: MessageMode::Client,
365                message_type: message_tpye,
366                payload: vec![],
367                text: String::new(),
368                close: CloseCode::None,
369                error: ErrorCode::None,
370            },
371            MessageType::TimeOut => Message {
372                mode: MessageMode::Client,
373                message_type: message_tpye,
374                payload: vec![],
375                text: String::new(),
376                close: CloseCode::None,
377                error: ErrorCode::TimeOut,
378            }
379        }
380    }
381    pub fn send_message(&mut self) -> Vec<u8> {
382        let mut frame = Vec::new();
383
384        // 第1字节:FIN + RSV1-3 + OPCODE
385        let opcode = self.clone().message_type.to_u8();
386        let mut byte1 = opcode & 0x0F;
387        byte1 |= 0x80; // FIN = 1
388
389        frame.push(byte1);
390
391        // 第2字节:MASK = 0(服务器发送),+ payload length
392        let payload_len = self.payload.len();
393        if payload_len < 126 {
394            frame.push(payload_len as u8);
395        } else if payload_len <= 65535 {
396            frame.push(126);
397            frame.extend_from_slice(&u16::try_from(payload_len).unwrap().to_be_bytes());
398        } else {
399            frame.push(127);
400            frame.extend_from_slice(&(payload_len as u64).to_be_bytes());
401        }
402        frame.extend_from_slice(&self.payload);
403        frame
404    }
405    pub fn send_close(code: CloseCode, reason: &str) -> Vec<u8> {
406        let mut frame = Vec::new();
407        frame.push(0x88);
408        let payload_len = code.clone().to_u16().to_be_bytes().len() + reason.len();
409        frame.push(u8::try_from(payload_len).unwrap());
410        frame.extend(&code.to_u16().to_be_bytes());
411        frame.extend(reason.as_bytes());
412        frame
413    }
414}
415#[derive(Debug, Clone)]
416pub enum MessageType {
417    /// 文本
418    Text,
419    Continuation,
420    /// 客户端关闭
421    Close,
422    Binary,
423    Ping,
424    Pong,
425    None,
426    TimeOut,
427    Error,
428}
429
430impl MessageType {
431    #[must_use]
432    pub fn from(types: u8) -> Self {
433        match types {
434            0x0 => Self::Continuation,
435            0x1 => Self::Text,
436            0x2 => Self::Binary,
437            0x8 => Self::Close,
438            0x9 => Self::Ping,
439            0xa => Self::Pong,
440            _ => Self::None,
441        }
442    }
443    #[must_use]
444    pub fn to_u8(self) -> u8 {
445        match self {
446            MessageType::Text => 0x1,
447            MessageType::Continuation | MessageType::None => 0x0,
448            MessageType::Close => 0x8,
449            MessageType::Binary => 0x2,
450            MessageType::Ping => 0x9,
451            MessageType::Pong => 0xa,
452            MessageType::Error | MessageType::TimeOut => 0x0,
453        }
454    }
455}
456#[derive(Debug, Clone)]
457pub enum CloseCode {
458    /// 客户端主动关闭
459    ClientClose,
460    /// 服务端主动关闭
461    ServerClose,
462    /// 正常关闭
463    NormalClosure,
464    GoingAway,
465    /// 协议错误
466    ProtocolError,
467    /// 其它错误
468    Other,
469    None,
470}
471impl CloseCode {
472    #[must_use]
473    pub fn from_err(_err: ErrorCode) -> CloseCode {
474        CloseCode::None
475    }
476    #[must_use]
477    pub fn str(&self) -> String {
478        match self {
479            CloseCode::ClientClose => "客户端主动关闭",
480            CloseCode::ServerClose => "服务端主动关闭",
481            CloseCode::None => "未知关闭",
482            CloseCode::NormalClosure => "正常关闭",
483            CloseCode::GoingAway => "对方离开",
484            CloseCode::ProtocolError => "协议错误",
485            CloseCode::Other => "其它错误",
486        }.to_string()
487    }
488    #[must_use]
489    pub fn to_u16(self) -> u16 {
490        match self {
491            CloseCode::NormalClosure => 1000,
492            CloseCode::GoingAway => 1001,
493            CloseCode::ProtocolError => 1002,
494            CloseCode::ClientClose => 1003,
495            CloseCode::ServerClose => 1004,
496            CloseCode::Other => 1005,
497            CloseCode::None => 1006,
498        }
499    }
500}
501#[derive(Debug, Clone, Copy)]
502pub enum ErrorCode {
503    /// 发送数据失败
504    SendingDataFailed,
505    /// Unknown request error
506    Unknown,
507    /// 线程异常
508    ThreadException,
509    /// 超时
510    TimeOut,
511    None,
512}
513#[derive(Debug, Clone)]
514pub enum MessageMode {
515    Client,
516    Server,
517}
518
519pub struct NoticeMsg {
520    /// 消息类型
521    pub types: Types,
522    /// 消息实体内容
523    pub msg: JsonValue,
524    /// 消息创建时间
525    pub timestamp: i64,
526    /// 监听通道名称
527    pub channel: String,
528    pub user: String,
529    pub org: String,
530}
531impl NoticeMsg {
532    pub fn json(&mut self) -> JsonValue {
533        object! {
534            type:"notice",
535            channel: self.channel.clone(),
536            msg: self.msg.clone(),
537            timestamp: self.timestamp,
538        }
539    }
540}
541
542pub enum Types {
543    /// 全体消息
544    All,
545    /// 指定用户消息
546    User,
547    /// 指定企业消息
548    Org,
549}