br_web_server/
websocket.rs

1use crate::request::Request;
2use crate::response::Response;
3use crate::{Handler, HttpError};
4use json::{object, JsonValue};
5use std::sync::mpsc::{channel, Receiver, Sender};
6use std::sync::{Arc, Mutex};
7use std::{thread};
8use std::time::Duration;
9use dashmap::DashMap;
10use log::{debug};
11
12pub static USERS: std::sync::LazyLock<DashMap<String, Websocket>> = std::sync::LazyLock::new(DashMap::new);
13pub static WS_NOTICE: std::sync::LazyLock<Mutex<Vec<NoticeMsg>>> = std::sync::LazyLock::new(|| Mutex::new(Vec::new()));
14#[derive(Debug, Clone)]
15pub struct Websocket {
16    /// 发送
17    send: Option<Sender<Message>>,
18    /// 接收
19    receive: Option<Arc<Mutex<Receiver<Message>>>>,
20    pub key: String,
21    pub user_user: String,
22    pub org_org: String,
23    version: String,
24    request: Request,
25    response: Response,
26}
27
28
29impl Websocket {
30    #[must_use]
31    pub fn http(request: Request, response: Response) -> Self {
32        Self {
33            send: None,
34            receive: None,
35            request,
36            key: String::new(),
37            user_user: "".to_string(),
38            org_org: "".to_string(),
39            version: String::new(),
40            response,
41        }
42    }
43    pub fn new(request: Request, response: Response) -> Self {
44        let (send, receive) = channel();
45        Self {
46            send: Some(send),
47            receive: Some(Arc::new(Mutex::new(receive))),
48            request,
49            key: String::new(),
50            user_user: "".to_string(),
51            org_org: "".to_string(),
52            version: String::new(),
53            response,
54        }
55    }
56    /// 发送数据
57    pub fn send(&mut self, data: JsonValue) {
58        let msg = Message {
59            mode: MessageMode::Server,
60            message_type: MessageType::Text,
61            payload: data.to_string().into_bytes(),
62            text: data.to_string(),
63            close: CloseCode::None,
64            error: ErrorCode::None,
65        };
66        match self.send.clone().unwrap().send(msg) {
67            Ok(()) => (),
68            Err(_) => self.on_error(ErrorCode::SendingDataFailed),
69        }
70    }
71    /// 关闭连接
72    pub fn close(&mut self, code: CloseCode, reason: &str) {
73        let msg = Message {
74            mode: MessageMode::Server,
75            message_type: MessageType::Close,
76            payload: reason.as_bytes().to_vec(),
77            text: reason.to_string(),
78            close: code,
79            error: ErrorCode::None,
80        };
81        match self.send.clone().unwrap().send(msg) {
82            Ok(()) => (),
83            Err(_) => self.on_error(ErrorCode::SendingDataFailed),
84        }
85    }
86    /// 发送给所有对象
87    pub fn send_all(&mut self, data: JsonValue) {
88        for mut user in USERS.iter_mut() {
89            user.send(data.clone());
90        }
91    }
92    /// 发送给指定对象
93    pub fn send_user(&mut self, user_user: &str, data: JsonValue) {
94        if USERS.get(user_user).is_some() {
95            for mut user in USERS.iter_mut() {
96                if user.user_user == user_user {
97                    user.send(data.clone());
98                    return;
99                }
100            }
101        }
102    }
103    /// 发送给指定的企业
104    pub fn send_org(&mut self, org_org: &str, data: JsonValue) {
105        if USERS.get(org_org).is_some() {
106            for mut user in USERS.iter_mut() {
107                if user.org_org == org_org {
108                    user.send(data.clone());
109                    return;
110                }
111            }
112        }
113    }
114    /// 在线人数
115    pub fn online_users(&mut self) -> usize {
116        USERS.len()
117    }
118    pub fn handle(&mut self) -> Result<(), HttpError> {
119        let receive = self.receive.clone().unwrap();
120        self.on_frame()?;
121        let mut factory = (self.response.factory)(self.clone());
122        USERS.insert(self.key.to_string(), self.clone());
123        factory.on_open()?;
124        loop {
125            let msg = self.response.request.scheme.lock().unwrap().read_ws_data()?;
126            match msg.message_type {
127                MessageType::TimeOut => {}
128                _ => match self.send.clone().unwrap().send(msg) {
129                    Ok(()) => {}
130                    Err(e) => return Err(HttpError::new(500, format!("message_type: {}",e.to_string().as_str()).as_str())),
131                }
132            }
133            let data = receive.lock().unwrap();
134            match data.try_recv() {
135                Ok(mut msg) => {
136                    match msg.mode {
137                        MessageMode::Client => {
138                            match msg.message_type {
139                                MessageType::Close => {
140                                    if USERS.get(&self.key.to_string()).is_some() {
141                                        USERS.remove(&self.key);
142                                    }
143                                    factory.on_close(msg.close.clone(), &msg.text);
144                                    self.response.request.scheme.lock().unwrap().write(&Message::send_close(CloseCode::ServerClose, "客户退出关闭"))?;
145                                    return Err(HttpError::new(500, msg.text.as_str()));
146                                }
147                                MessageType::Binary | MessageType::Text => {
148                                    factory.on_message(msg).expect("TODO: panic message");
149                                }
150                                MessageType::Continuation | MessageType::Ping | MessageType::Error | MessageType::None => {
151                                    debug!("Client有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
152                                }
153                                MessageType::Pong => {
154                                    debug!("接收到一个Pong: {:?} {:?} {:?}", msg.mode, msg.message_type, msg.payload);
155                                }
156                                MessageType::TimeOut => {}
157                            }
158                        }
159                        MessageMode::Server => {
160                            match msg.message_type {
161                                MessageType::Close => {
162                                    if USERS.get(&self.key.to_string()).is_some() {
163                                        USERS.remove(&self.key);
164                                    }
165                                    factory.on_close(msg.close.clone(), &msg.text);
166                                    self.response.request.scheme.lock().unwrap().write(&Message::send_close(msg.close.clone(), &msg.text))?;
167                                    return Err(HttpError::new(500, msg.text.as_str()));
168                                }
169                                MessageType::Binary | MessageType::Text => {
170                                    let res = msg.send_message();
171                                    self.response.request.scheme.lock().unwrap().write(&res)?;
172                                }
173                                MessageType::Continuation | MessageType::Ping | MessageType::Pong | MessageType::None | MessageType::Error => {
174                                    debug!("服务器有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
175                                }
176                                MessageType::TimeOut => {}
177                            }
178                        }
179                    }
180                }
181                Err(std::sync::mpsc::TryRecvError::Empty) => {
182                    thread::sleep(Duration::from_millis(10));
183                }
184                Err(e) => {
185                    return Err(HttpError::new(500, e.to_string().as_str()));
186                }
187            }
188        }
189    }
190}
191impl Handler for Websocket {
192    fn on_request(&mut self, _request: Request, _response: &mut Response) {}
193    fn on_frame(&mut self) -> Result<(), HttpError> {
194        self.key = self.request.header["sec-websocket-key"].as_str().unwrap_or("").to_string();
195        self.version = self.request.header["sec-websocket-version"].as_str().unwrap_or("").to_string();
196        self.response.header("Upgrade", "websocket");
197        self.response.header("Connection", "Upgrade");
198        let sec_websocket_accept = br_crypto::sha1::encrypt_base64(format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11",self.key).as_bytes());
199        self.response.header("Sec-WebSocket-Accept", sec_websocket_accept.as_str());
200        self.response.status(101).send()?;
201        Ok(())
202    }
203}
204#[derive(Debug, Clone)]
205pub struct Message {
206    pub mode: MessageMode,
207    pub message_type: MessageType,
208    pub payload: Vec<u8>, // 消息载荷,以字节向量形式表示
209    pub text: String,
210    pub close: CloseCode,
211    pub error: ErrorCode,
212}
213
214impl Message {
215    #[must_use]
216    pub fn msg_error() -> Self {
217        Message {
218            mode: MessageMode::Client,
219            message_type: MessageType::Error,
220            payload: vec![],
221            text: "长度不够".to_string(),
222            close: CloseCode::None,
223            error: ErrorCode::SendingDataFailed,
224        }
225    }
226    // 解析WebSocket消息
227    pub fn parse_message(data: &mut Vec<u8>) -> Message {
228        // 检查数据是否足够长来包含消息类型和载荷长度
229        if data.len() < 2 {
230            return Message {
231                mode: MessageMode::Client,
232                message_type: MessageType::Error,
233                payload: vec![],
234                text: "长度不够".to_string(),
235                close: CloseCode::None,
236                error: ErrorCode::SendingDataFailed,
237            };
238        }
239
240        let header = data.drain(0..2).collect::<Vec<u8>>();
241
242        // 解析帧头
243        let _fin = (header[0] & 0b1000_0000) != 0;
244        let opcode = header[0] & 0b0000_1111;
245        let masked = (header[1] & 0b1000_0000) != 0;
246        let len_flag = header[1] & 0b0111_1111;
247        let mut payload_data = Vec::new();
248        let message_tpye = MessageType::from(opcode);
249        match message_tpye {
250            MessageType::Text => {
251                let payload_length = match len_flag {
252                    0..=125 => len_flag as usize,
253                    126 => {
254                        let ext = data.drain(..2).collect::<Vec<u8>>();
255                        u16::from_be_bytes([ext[0], ext[1]]) as usize
256                    }
257                    127 => {
258                        let ext = data.drain(..8).collect::<Vec<u8>>();
259                        u64::from_be_bytes([
260                            ext[0], ext[1], ext[2], ext[3],
261                            ext[4], ext[5], ext[6], ext[7],
262                        ]) as usize
263                    }
264                    _ => return Message {
265                        mode: MessageMode::Client,
266                        message_type: MessageType::Error,
267                        payload: vec![],
268                        text: "数据格式错误".to_string(),
269                        close: CloseCode::None,
270                        error: ErrorCode::SendingDataFailed,
271                    }
272                };
273                if masked {
274                    let mask_key = data.drain(..4).collect::<Vec<u8>>();
275                    if data.len() < payload_length {
276                        return Message {
277                            mode: MessageMode::Client,
278                            message_type: message_tpye,
279                            payload: payload_data,
280                            text: "继续加载".to_string(),
281                            close: CloseCode::None,
282                            error: ErrorCode::None,
283                        };
284                    }
285                    let payload = &data[..payload_length];
286                    for i in 0..payload.len() {
287                        payload_data.push(payload[i] ^ mask_key[i % 4]);
288                    }
289                } else {
290                    if data.len() < payload_length {
291                        return Message {
292                            mode: MessageMode::Client,
293                            message_type: message_tpye,
294                            payload: payload_data,
295                            text: "继续加载".to_string(),
296                            close: CloseCode::None,
297                            error: ErrorCode::None,
298                        };
299                    }
300                    let t = data.drain(..payload_length).collect::<Vec<u8>>();
301                    payload_data.extend_from_slice(&t);
302                }
303                let text = unsafe { String::from_utf8_unchecked(payload_data.clone()) };
304                Message {
305                    mode: MessageMode::Client,
306                    message_type: message_tpye,
307                    payload: payload_data,
308                    text: text.to_string(),
309                    close: CloseCode::None,
310                    error: ErrorCode::None,
311                }
312            }
313            MessageType::Binary => Message {
314                mode: MessageMode::Client,
315                message_type: message_tpye,
316                payload: payload_data,
317                text: String::new(),
318                close: CloseCode::None,
319                error: ErrorCode::None,
320            },
321            MessageType::Continuation => Message {
322                mode: MessageMode::Client,
323                message_type: message_tpye,
324                payload: payload_data,
325                text: "继续加载".to_string(),
326                close: CloseCode::None,
327                error: ErrorCode::None,
328            },
329            MessageType::Close => Message {
330                mode: MessageMode::Client,
331                message_type: message_tpye,
332                payload: payload_data,
333                text: "客户端关闭".to_string(),
334                close: CloseCode::ClientClose,
335                error: ErrorCode::None,
336            },
337            MessageType::Ping => Message {
338                mode: MessageMode::Client,
339                message_type: message_tpye,
340                payload: payload_data,
341                text: "Ping".to_string(),
342                close: CloseCode::None,
343                error: ErrorCode::None,
344            },
345            MessageType::Pong => Message {
346                mode: MessageMode::Client,
347                message_type: message_tpye,
348                payload: payload_data,
349                text: "Pong".to_string(),
350                close: CloseCode::None,
351                error: ErrorCode::None,
352            },
353            MessageType::Error => {
354                Message {
355                    mode: MessageMode::Client,
356                    message_type: message_tpye,
357                    payload: vec![],
358                    text: String::new(),
359                    close: CloseCode::None,
360                    error: ErrorCode::Unknown,
361                }
362            }
363            MessageType::None => Message {
364                mode: MessageMode::Client,
365                message_type: message_tpye,
366                payload: vec![],
367                text: String::new(),
368                close: CloseCode::None,
369                error: ErrorCode::None,
370            },
371            MessageType::TimeOut => Message {
372                mode: MessageMode::Client,
373                message_type: message_tpye,
374                payload: vec![],
375                text: String::new(),
376                close: CloseCode::None,
377                error: ErrorCode::TimeOut,
378            }
379        }
380    }
381    pub fn send_message(&mut self) -> Vec<u8> {
382        let mut frame = Vec::new();
383
384        // 第1字节:FIN + RSV1-3 + OPCODE
385        let opcode = self.clone().message_type.to_u8();
386        let mut byte1 = opcode & 0x0F;
387        byte1 |= 0x80; // FIN = 1
388
389        frame.push(byte1);
390
391        // 第2字节:MASK = 0(服务器发送),+ payload length
392        let payload_len = self.payload.len();
393        if payload_len < 126 {
394            frame.push(payload_len as u8);
395        } else if payload_len <= 65535 {
396            frame.push(126);
397            frame.extend_from_slice(&u16::try_from(payload_len).unwrap().to_be_bytes());
398        } else {
399            frame.push(127);
400            frame.extend_from_slice(&(payload_len as u64).to_be_bytes());
401        }
402        frame.extend_from_slice(&self.payload);
403        frame
404    }
405    #[must_use]
406    pub fn send_close(code: CloseCode, reason: &str) -> Vec<u8> {
407        let mut frame = Vec::new();
408        frame.push(0x88);
409        let payload_len = code.clone().to_u16().to_be_bytes().len() + reason.len();
410        frame.push(u8::try_from(payload_len).unwrap());
411        frame.extend(&code.to_u16().to_be_bytes());
412        frame.extend(reason.as_bytes());
413        frame
414    }
415}
416#[derive(Debug, Clone)]
417pub enum MessageType {
418    /// 文本
419    Text,
420    Continuation,
421    /// 客户端关闭
422    Close,
423    Binary,
424    Ping,
425    Pong,
426    None,
427    TimeOut,
428    Error,
429}
430
431impl MessageType {
432    #[must_use]
433    pub fn from(types: u8) -> Self {
434        match types {
435            0x0 => Self::Continuation,
436            0x1 => Self::Text,
437            0x2 => Self::Binary,
438            0x8 => Self::Close,
439            0x9 => Self::Ping,
440            0xa => Self::Pong,
441            _ => Self::None,
442        }
443    }
444    #[must_use]
445    pub fn to_u8(self) -> u8 {
446        match self {
447            MessageType::Text => 0x1,
448            MessageType::Continuation | MessageType::None | MessageType::Error | MessageType::TimeOut => 0x0,
449            MessageType::Close => 0x8,
450            MessageType::Binary => 0x2,
451            MessageType::Ping => 0x9,
452            MessageType::Pong => 0xa,
453            }
454    }
455}
456#[derive(Debug, Clone)]
457pub enum CloseCode {
458    /// 客户端主动关闭
459    ClientClose,
460    /// 服务端主动关闭
461    ServerClose,
462    /// 正常关闭
463    NormalClosure,
464    GoingAway,
465    /// 协议错误
466    ProtocolError,
467    /// 其它错误
468    Other,
469    None,
470}
471impl CloseCode {
472    #[must_use]
473    pub fn from_err(_err: ErrorCode) -> CloseCode {
474        CloseCode::None
475    }
476    #[must_use]
477    pub fn str(&self) -> String {
478        match self {
479            CloseCode::ClientClose => "客户端主动关闭",
480            CloseCode::ServerClose => "服务端主动关闭",
481            CloseCode::None => "未知关闭",
482            CloseCode::NormalClosure => "正常关闭",
483            CloseCode::GoingAway => "对方离开",
484            CloseCode::ProtocolError => "协议错误",
485            CloseCode::Other => "其它错误",
486        }.to_string()
487    }
488    #[must_use]
489    pub fn to_u16(self) -> u16 {
490        match self {
491            CloseCode::NormalClosure => 1000,
492            CloseCode::GoingAway => 1001,
493            CloseCode::ProtocolError => 1002,
494            CloseCode::ClientClose => 1003,
495            CloseCode::ServerClose => 1004,
496            CloseCode::Other => 1005,
497            CloseCode::None => 1006,
498        }
499    }
500}
501#[derive(Debug, Clone, Copy)]
502pub enum ErrorCode {
503    /// 发送数据失败
504    SendingDataFailed,
505    /// Unknown request error
506    Unknown,
507    /// 线程异常
508    ThreadException,
509    /// 超时
510    TimeOut,
511    None,
512}
513#[derive(Debug, Clone)]
514pub enum MessageMode {
515    Client,
516    Server,
517}
518
519pub struct NoticeMsg {
520    /// 消息类型
521    pub types: Types,
522    /// 消息实体内容
523    pub msg: JsonValue,
524    /// 消息创建时间
525    pub timestamp: i64,
526    /// 监听通道名称
527    pub channel: String,
528    pub user: String,
529    pub org: String,
530}
531impl NoticeMsg {
532    pub fn json(&mut self) -> JsonValue {
533        object! {
534            type:"notice",
535            channel: self.channel.clone(),
536            msg: self.msg.clone(),
537            timestamp: self.timestamp,
538        }
539    }
540}
541
542pub enum Types {
543    /// 全体消息
544    All,
545    /// 指定用户消息
546    User,
547    /// 指定企业消息
548    Org,
549}