br_web_server/
websocket.rs

1use crate::request::Request;
2use crate::response::Response;
3use crate::{Handler, HttpError};
4use json::{object, JsonValue};
5use std::sync::mpsc::{channel, Receiver, Sender};
6use std::sync::{Arc, Mutex};
7use std::{thread};
8use std::time::Duration;
9use dashmap::DashMap;
10use log::{debug};
11
12pub static USERS: std::sync::LazyLock<DashMap<String, Websocket>> = std::sync::LazyLock::new(DashMap::new);
13pub static WS_NOTICE: std::sync::LazyLock<Mutex<Vec<NoticeMsg>>> = std::sync::LazyLock::new(|| Mutex::new(Vec::new()));
14#[derive(Debug, Clone)]
15pub struct Websocket {
16    /// 发送
17    send: Option<Sender<Message>>,
18    /// 接收
19    receive: Option<Arc<Mutex<Receiver<Message>>>>,
20    pub key: String,
21    pub user_user: String,
22    pub org_org: String,
23    version: String,
24    request: Request,
25    response: Response,
26}
27
28
29impl Websocket {
30    #[must_use]
31    pub fn http(request: Request, response: Response) -> Self {
32        Self {
33            send: None,
34            receive: None,
35            request,
36            key: String::new(),
37            user_user: "".to_string(),
38            org_org: "".to_string(),
39            version: String::new(),
40            response,
41        }
42    }
43    pub fn new(request: Request, response: Response) -> Self {
44        let (send, receive) = channel();
45        Self {
46            send: Some(send),
47            receive: Some(Arc::new(Mutex::new(receive))),
48            request,
49            key: String::new(),
50            user_user: "".to_string(),
51            org_org: "".to_string(),
52            version: String::new(),
53            response,
54        }
55    }
56    /// 发送数据
57    pub fn send(&mut self, data: JsonValue) {
58        let msg = Message {
59            mode: MessageMode::Server,
60            message_type: MessageType::Text,
61            payload: data.to_string().into_bytes(),
62            text: data.to_string(),
63            close: CloseCode::None,
64            error: ErrorCode::None,
65        };
66        match self.send.clone().unwrap().send(msg) {
67            Ok(()) => (),
68            Err(_) => self.on_error(ErrorCode::SendingDataFailed),
69        }
70    }
71    /// 关闭连接
72    pub fn close(&mut self, code: CloseCode, reason: &str) {
73        let msg = Message {
74            mode: MessageMode::Server,
75            message_type: MessageType::Close,
76            payload: reason.as_bytes().to_vec(),
77            text: reason.to_string(),
78            close: code,
79            error: ErrorCode::None,
80        };
81        match self.send.clone().unwrap().send(msg) {
82            Ok(()) => (),
83            Err(_) => self.on_error(ErrorCode::SendingDataFailed),
84        }
85    }
86    /// 发送给所有对象
87    pub fn send_all(&mut self, data: JsonValue) {
88        for mut user in USERS.iter_mut() {
89            user.send(data.clone());
90        }
91    }
92    /// 发送给指定对象
93    pub fn send_user(&mut self, user_user: &str, data: JsonValue) {
94        if USERS.get(user_user).is_some() {
95            for mut user in USERS.iter_mut() {
96                if user.user_user == user_user {
97                    user.send(data.clone());
98                    return;
99                }
100            }
101        }
102    }
103    /// 发送给指定的企业
104    pub fn send_org(&mut self, org_org: &str, data: JsonValue) {
105        if USERS.get(org_org).is_some() {
106            for mut user in USERS.iter_mut() {
107                if user.org_org == org_org {
108                    user.send(data.clone());
109                    return;
110                }
111            }
112        }
113    }
114    /// 在线人数
115    pub fn online_users(&mut self) -> usize {
116        USERS.len()
117    }
118    pub fn handle(&mut self) -> Result<(), HttpError> {
119        let receive = self.receive.clone().unwrap();
120        self.on_frame()?;
121        let mut factory = (self.response.factory)(self.clone());
122        USERS.insert(self.key.to_string(), self.clone());
123        factory.on_open()?;
124        loop {
125            let msg = self.response.request.scheme.lock().unwrap().read_ws_msg()?;
126            match msg.message_type {
127                MessageType::TimeOut => {}
128                _ => match self.send.clone().unwrap().clone().send(msg) {
129                    Ok(()) => {}
130                    Err(e) => return Err(HttpError::new(500, e.to_string().as_str())),
131                }
132            }
133            let data = receive.lock().unwrap();
134            match data.try_recv() {
135                Ok(mut msg) => {
136                    match msg.mode {
137                        MessageMode::Client => {
138                            match msg.message_type {
139                                MessageType::Close => {
140                                    if USERS.get(&self.key.to_string()).is_some() {
141                                        USERS.remove(&self.key);
142                                    }
143                                    factory.on_close(msg.close.clone(), &msg.text);
144                                    self.response.request.scheme.lock().unwrap().write(&Message::send_close(CloseCode::ServerClose, "客户退出关闭"))?;
145                                    return Err(HttpError::new(500, msg.text.as_str()));
146                                }
147                                MessageType::Binary | MessageType::Text => {
148                                    factory.on_message(msg).expect("TODO: panic message");
149                                }
150                                MessageType::Continuation | MessageType::Ping | MessageType::Error | MessageType::None => {
151                                    debug!("Client有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
152                                }
153                                MessageType::Pong => {
154                                    debug!("接收到一个Pong: {:?} {:?} {:?}", msg.mode, msg.message_type, msg.payload);
155                                }
156                                MessageType::TimeOut => {}
157                            }
158                        }
159                        MessageMode::Server => {
160                            match msg.message_type {
161                                MessageType::Close => {
162                                    if USERS.get(&self.key.to_string()).is_some() {
163                                        USERS.remove(&self.key);
164                                    }
165                                    factory.on_close(msg.close.clone(), &msg.text);
166                                    self.response.request.scheme.lock().unwrap().write(&Message::send_close(msg.close.clone(), &msg.text))?;
167                                    return Err(HttpError::new(500, msg.text.as_str()));
168                                }
169                                MessageType::Binary | MessageType::Text => {
170                                    let res = msg.send_message();
171                                    self.response.request.scheme.lock().unwrap().write(&res)?;
172                                }
173                                MessageType::Continuation | MessageType::Ping | MessageType::Pong | MessageType::None | MessageType::Error => {
174                                    debug!("服务器有数据: {:?} {:?} {}", msg.mode, msg.message_type, msg.text.clone());
175                                }
176                                MessageType::TimeOut => {}
177                            }
178                        }
179                    }
180                }
181                Err(std::sync::mpsc::TryRecvError::Empty) => {
182                    thread::sleep(Duration::from_millis(10));
183                }
184                Err(e) => {
185                    return Err(HttpError::new(500, e.to_string().as_str()));
186                }
187            }
188        }
189    }
190}
191impl Handler for Websocket {
192    fn on_request(&mut self, _request: Request, _response: &mut Response) {
193        todo!()
194    }
195
196    fn on_options(&mut self, _response: &mut Response) {
197        todo!()
198    }
199
200    fn on_response(&mut self, _request: Request, _response: &mut Response) {
201        todo!()
202    }
203
204    fn on_frame(&mut self) -> Result<(), HttpError> {
205        self.key = self.request.header["sec-websocket-key"].as_str().unwrap_or("").to_string();
206        self.version = self.request.header["sec-websocket-version"].as_str().unwrap_or("").to_string();
207        self.response.header("Upgrade", "websocket");
208        self.response.header("Connection", "Upgrade");
209        let sec_websocket_accept = br_crypto::sha1::encrypt_base64(format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11",self.key).as_bytes());
210        self.response.header("Sec-WebSocket-Accept", sec_websocket_accept.as_str());
211        self.response.status(101).send()?;
212        Ok(())
213    }
214}
215#[derive(Debug, Clone)]
216pub struct Message {
217    pub mode: MessageMode,
218    pub message_type: MessageType,
219    pub payload: Vec<u8>, // 消息载荷,以字节向量形式表示
220    pub text: String,
221    pub close: CloseCode,
222    pub error: ErrorCode,
223}
224
225impl Message {
226    #[must_use]
227    pub fn msg_error() -> Self {
228        Message {
229            mode: MessageMode::Client,
230            message_type: MessageType::Error,
231            payload: vec![],
232            text: "长度不够".to_string(),
233            close: CloseCode::None,
234            error: ErrorCode::SendingDataFailed,
235        }
236    }
237    // 解析WebSocket消息
238    pub fn parse_message(data: &mut Vec<u8>) -> Message {
239        // 检查数据是否足够长来包含消息类型和载荷长度
240        if data.len() < 2 {
241            return Message {
242                mode: MessageMode::Client,
243                message_type: MessageType::Error,
244                payload: vec![],
245                text: "长度不够".to_string(),
246                close: CloseCode::None,
247                error: ErrorCode::SendingDataFailed,
248            };
249        }
250
251        let header = data.drain(0..2).collect::<Vec<u8>>();
252
253        // 解析帧头
254        let _fin = (header[0] & 0b1000_0000) != 0;
255        let opcode = header[0] & 0b0000_1111;
256        let masked = (header[1] & 0b1000_0000) != 0;
257        let len_flag = header[1] & 0b0111_1111;
258        let mut payload_data = Vec::new();
259        let message_tpye = MessageType::from(opcode);
260        match message_tpye {
261            MessageType::Text => {
262                let payload_length = match len_flag {
263                    0..=125 => len_flag as usize,
264                    126 => {
265                        let ext = data.drain(..2).collect::<Vec<u8>>();
266                        u16::from_be_bytes([ext[0], ext[1]]) as usize
267                    }
268                    127 => {
269                        let ext = data.drain(..8).collect::<Vec<u8>>();
270                        u64::from_be_bytes([
271                            ext[0], ext[1], ext[2], ext[3],
272                            ext[4], ext[5], ext[6], ext[7],
273                        ]) as usize
274                    }
275                    _ => return Message {
276                        mode: MessageMode::Client,
277                        message_type: MessageType::Error,
278                        payload: vec![],
279                        text: "数据格式错误".to_string(),
280                        close: CloseCode::None,
281                        error: ErrorCode::SendingDataFailed,
282                    }
283                };
284                if masked {
285                    let mask_key = data.drain(..4).collect::<Vec<u8>>();
286                    if data.len() < payload_length {
287                        return Message {
288                            mode: MessageMode::Client,
289                            message_type: message_tpye,
290                            payload: payload_data,
291                            text: "继续加载".to_string(),
292                            close: CloseCode::None,
293                            error: ErrorCode::None,
294                        };
295                    }
296                    let payload = &data[..payload_length];
297                    for i in 0..payload.len() {
298                        payload_data.push(payload[i] ^ mask_key[i % 4]);
299                    }
300                } else {
301                    if data.len() < payload_length {
302                        return Message {
303                            mode: MessageMode::Client,
304                            message_type: message_tpye,
305                            payload: payload_data,
306                            text: "继续加载".to_string(),
307                            close: CloseCode::None,
308                            error: ErrorCode::None,
309                        };
310                    }
311                    let t = data.drain(..payload_length).collect::<Vec<u8>>();
312                    payload_data.extend_from_slice(&t);
313                }
314                let text = unsafe { String::from_utf8_unchecked(payload_data.clone()) };
315                Message {
316                    mode: MessageMode::Client,
317                    message_type: message_tpye,
318                    payload: payload_data,
319                    text: text.to_string(),
320                    close: CloseCode::None,
321                    error: ErrorCode::None,
322                }
323            }
324            MessageType::Binary => Message {
325                mode: MessageMode::Client,
326                message_type: message_tpye,
327                payload: payload_data,
328                text: String::new(),
329                close: CloseCode::None,
330                error: ErrorCode::None,
331            },
332            MessageType::Continuation => Message {
333                mode: MessageMode::Client,
334                message_type: message_tpye,
335                payload: payload_data,
336                text: "继续加载".to_string(),
337                close: CloseCode::None,
338                error: ErrorCode::None,
339            },
340            MessageType::Close => Message {
341                mode: MessageMode::Client,
342                message_type: message_tpye,
343                payload: payload_data,
344                text: "客户端关闭".to_string(),
345                close: CloseCode::ClientClose,
346                error: ErrorCode::None,
347            },
348            MessageType::Ping => Message {
349                mode: MessageMode::Client,
350                message_type: message_tpye,
351                payload: payload_data,
352                text: "Ping".to_string(),
353                close: CloseCode::None,
354                error: ErrorCode::None,
355            },
356            MessageType::Pong => Message {
357                mode: MessageMode::Client,
358                message_type: message_tpye,
359                payload: payload_data,
360                text: "Pong".to_string(),
361                close: CloseCode::None,
362                error: ErrorCode::None,
363            },
364            MessageType::Error => {
365                Message {
366                    mode: MessageMode::Client,
367                    message_type: message_tpye,
368                    payload: vec![],
369                    text: String::new(),
370                    close: CloseCode::None,
371                    error: ErrorCode::Unknown,
372                }
373            }
374            MessageType::None => Message {
375                mode: MessageMode::Client,
376                message_type: message_tpye,
377                payload: vec![],
378                text: String::new(),
379                close: CloseCode::None,
380                error: ErrorCode::None,
381            },
382            MessageType::TimeOut => Message {
383                mode: MessageMode::Client,
384                message_type: message_tpye,
385                payload: vec![],
386                text: String::new(),
387                close: CloseCode::None,
388                error: ErrorCode::TimeOut,
389            }
390        }
391    }
392    pub fn send_message(&mut self) -> Vec<u8> {
393        let mut frame = Vec::new();
394
395        // 第1字节:FIN + RSV1-3 + OPCODE
396        let opcode = self.clone().message_type.to_u8();
397        let mut byte1 = opcode & 0x0F;
398        byte1 |= 0x80; // FIN = 1
399
400        frame.push(byte1);
401
402        // 第2字节:MASK = 0(服务器发送),+ payload length
403        let payload_len = self.payload.len();
404        if payload_len < 126 {
405            frame.push(payload_len as u8);
406        } else if payload_len <= 65535 {
407            frame.push(126);
408            frame.extend_from_slice(&u16::try_from(payload_len).unwrap().to_be_bytes());
409        } else {
410            frame.push(127);
411            frame.extend_from_slice(&(payload_len as u64).to_be_bytes());
412        }
413        frame.extend_from_slice(&self.payload);
414        frame
415    }
416    #[must_use]
417    pub fn send_close(code: CloseCode, reason: &str) -> Vec<u8> {
418        let mut frame = Vec::new();
419        frame.push(0x88);
420        let payload_len = code.clone().to_u16().to_be_bytes().len() + reason.len();
421        frame.push(u8::try_from(payload_len).unwrap());
422        frame.extend(&code.to_u16().to_be_bytes());
423        frame.extend(reason.as_bytes());
424        frame
425    }
426}
427#[derive(Debug, Clone)]
428pub enum MessageType {
429    /// 文本
430    Text,
431    Continuation,
432    /// 客户端关闭
433    Close,
434    Binary,
435    Ping,
436    Pong,
437    None,
438    TimeOut,
439    Error,
440}
441
442impl MessageType {
443    #[must_use]
444    pub fn from(types: u8) -> Self {
445        match types {
446            0x0 => Self::Continuation,
447            0x1 => Self::Text,
448            0x2 => Self::Binary,
449            0x8 => Self::Close,
450            0x9 => Self::Ping,
451            0xa => Self::Pong,
452            _ => Self::None,
453        }
454    }
455    #[must_use]
456    pub fn to_u8(self) -> u8 {
457        match self {
458            MessageType::Text => 0x1,
459            MessageType::Continuation | MessageType::None | MessageType::Error | MessageType::TimeOut => 0x0,
460            MessageType::Close => 0x8,
461            MessageType::Binary => 0x2,
462            MessageType::Ping => 0x9,
463            MessageType::Pong => 0xa,
464            }
465    }
466}
467#[derive(Debug, Clone)]
468pub enum CloseCode {
469    /// 客户端主动关闭
470    ClientClose,
471    /// 服务端主动关闭
472    ServerClose,
473    /// 正常关闭
474    NormalClosure,
475    GoingAway,
476    /// 协议错误
477    ProtocolError,
478    /// 其它错误
479    Other,
480    None,
481}
482impl CloseCode {
483    #[must_use]
484    pub fn from_err(_err: ErrorCode) -> CloseCode {
485        CloseCode::None
486    }
487    #[must_use]
488    pub fn str(&self) -> String {
489        match self {
490            CloseCode::ClientClose => "客户端主动关闭",
491            CloseCode::ServerClose => "服务端主动关闭",
492            CloseCode::None => "未知关闭",
493            CloseCode::NormalClosure => "正常关闭",
494            CloseCode::GoingAway => "对方离开",
495            CloseCode::ProtocolError => "协议错误",
496            CloseCode::Other => "其它错误",
497        }.to_string()
498    }
499    #[must_use]
500    pub fn to_u16(self) -> u16 {
501        match self {
502            CloseCode::NormalClosure => 1000,
503            CloseCode::GoingAway => 1001,
504            CloseCode::ProtocolError => 1002,
505            CloseCode::ClientClose => 1003,
506            CloseCode::ServerClose => 1004,
507            CloseCode::Other => 1005,
508            CloseCode::None => 1006,
509        }
510    }
511}
512#[derive(Debug, Clone, Copy)]
513pub enum ErrorCode {
514    /// 发送数据失败
515    SendingDataFailed,
516    /// Unknown request error
517    Unknown,
518    /// 线程异常
519    ThreadException,
520    /// 超时
521    TimeOut,
522    None,
523}
524#[derive(Debug, Clone)]
525pub enum MessageMode {
526    Client,
527    Server,
528}
529
530pub struct NoticeMsg {
531    /// 消息类型
532    pub types: Types,
533    /// 消息实体内容
534    pub msg: JsonValue,
535    /// 消息创建时间
536    pub timestamp: i64,
537    /// 监听通道名称
538    pub channel: String,
539    pub user: String,
540    pub org: String,
541}
542impl NoticeMsg {
543    pub fn json(&mut self) -> JsonValue {
544        object! {
545            type:"notice",
546            channel: self.channel.clone(),
547            msg: self.msg.clone(),
548            timestamp: self.timestamp,
549        }
550    }
551}
552
553pub enum Types {
554    /// 全体消息
555    All,
556    /// 指定用户消息
557    User,
558    /// 指定企业消息
559    Org,
560}