makepad_http/
websocket.rs

1use std::convert::TryInto;
2use crate::digest::{Sha1, base64_encode};
3
4#[derive(Debug, PartialEq)]
5enum State {
6    Opcode,
7    Len1,
8    Len2,
9    Len8,
10    Data,
11    Mask
12}
13
14impl State {
15    fn head_expected(&self) -> usize {
16        match self {
17            State::Opcode => 1,
18            State::Len1 => 1,
19            State::Len2 => 2,
20            State::Len8 => 8,
21            State::Data => 0,
22            State::Mask => 4
23        }
24    }
25}
26
27pub struct WebSocket {
28    head: [u8; 8],
29    head_expected: usize,
30    head_written: usize,
31    data: Vec<u8>,
32    data_len: usize,
33    input_read: usize,
34    mask_counter: usize,
35    is_ping: bool,
36    is_pong: bool,
37    is_partial: bool,
38    is_text: bool,
39    is_masked: bool,
40    state: State
41}
42
43pub enum WebSocketMessage<'a> {
44    Ping(&'a [u8]),
45    Pong(&'a [u8]),
46    Text(&'a str),
47    Binary(&'a [u8]),
48    Close
49}
50
51#[derive(Debug)]
52pub enum WebSocketError<'a> {
53    OpcodeNotSupported(u8),
54    TextNotUTF8(&'a [u8]),
55}
56
57pub const PING_MESSAGE:[u8;2] = [128 | 9,0];
58pub const PONG_MESSAGE:[u8;2] = [128 | 10,0];
59
60pub struct BinaryMessageHeader{
61    len: usize,
62    data:[u8;10]
63}
64
65impl BinaryMessageHeader{
66    pub fn from_len(len:usize)->Self{
67        let mut data = [0u8;10];
68        
69        data[0] = 128 | 2; // binary single message
70        
71        if len < 126{
72            data[1] = len as u8;
73            BinaryMessageHeader{len:2, data}
74        }
75        else if len < 65536{
76            data[1] = 126; 
77            let bytes = &(len as u16).to_be_bytes();
78            for (i, &byte) in bytes.iter().enumerate() {
79                data[i + 2] = byte;
80            }
81            return BinaryMessageHeader{len:4, data}
82        }
83        else{
84            data[1] = 127;
85            let bytes = &(len as u64).to_be_bytes();
86            for (i, &byte) in bytes.iter().enumerate() {
87                data[i + 2] = byte;
88            }
89            return BinaryMessageHeader{len:10, data}
90        }
91    }
92    
93    pub fn as_slice(&self)->&[u8]{
94        &self.data[0..self.len]
95    }
96}
97
98impl WebSocket {
99    pub fn new() -> Self {
100        Self {
101            head: [0u8; 8],
102            head_expected: 1,
103            head_written: 0,
104            data: Vec::new(),
105            data_len: 0,
106            input_read: 0,
107            mask_counter: 0,
108            is_ping: false,
109            is_pong: false,
110            is_masked: false,
111            is_partial: false,
112            is_text: false,
113            state: State::Opcode
114        }
115    }
116    
117    pub fn create_upgrade_response(key: &str) -> String {
118        let to_hash = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", key);
119        let mut sha1 = Sha1::new();
120        sha1.update(to_hash.as_bytes());
121        let out_bytes = sha1.finalise();
122        let base64 = base64_encode(&out_bytes);
123        let response_ack = format!(
124            "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {}\r\n\r\n",
125            base64
126        );
127        response_ack
128    }
129    
130    fn parse_head(&mut self, input: &[u8]) -> bool {
131        while self.head_expected > 0
132            && self.input_read < input.len()
133            && self.head_written < self.head.len()
134        {
135            self.head[self.head_written] = input[self.input_read];
136            self.input_read += 1;
137            self.head_written += 1;
138            self.head_expected -= 1;
139        }
140        self.head_expected != 0
141    }
142    
143    fn to_state(&mut self, state: State) {
144        match state {
145            State::Data => {
146                self.mask_counter = 0;
147                self.data.clear();
148            }
149            State::Opcode => {
150                self.is_ping = false;
151                self.is_pong = false;
152                self.is_partial = false;
153                self.is_text = false;
154                self.is_masked = false;
155            },
156            _ => ()
157        }
158        self.head_written = 0;
159        self.head_expected = state.head_expected();
160        self.state = state;
161    }
162    
163    pub fn parse<F>(&mut self, input: &[u8], mut result: F) where F: FnMut(Result<WebSocketMessage, WebSocketError>){
164        self.input_read = 0;
165        // parse a header
166        loop {
167            match self.state {
168                State::Opcode => {
169                    if self.parse_head(input) {
170                        break;
171                    }
172                    let opcode = self.head[0] & 15;
173                    if opcode <= 2 {
174                        self.is_partial = (self.head[0] & 128) != 0;
175                        self.is_text = opcode == 1;
176                        self.to_state(State::Len1);
177                    }
178                    else if opcode == 8 {
179                        result(Ok(WebSocketMessage::Close));
180                        break;
181                    }
182                    else if opcode == 9 {
183                        self.is_ping = true;
184                        self.to_state(State::Len1);
185                    }
186                    else if opcode == 10 {
187                        self.is_pong = true;
188                        self.to_state(State::Len1);
189                    }
190                    else {
191                        result(Err(WebSocketError::OpcodeNotSupported(opcode)));
192                        break;
193                    }
194                },
195                State::Len1 => {
196                    if self.parse_head(input) {
197                        break;
198                    }
199                    self.is_masked = (self.head[0] & 128) > 0;
200                    let len_type = self.head[0] & 127;
201                    if len_type < 126 {
202                        self.data_len = len_type as usize;
203                        if !self.is_masked {
204                            self.to_state(State::Data);
205                        }
206                        else {
207                            self.to_state(State::Mask);
208                        }
209                    }
210                    else if len_type == 126 {
211                        self.to_state(State::Len2);
212                    }
213                    else if len_type == 127 {
214                        self.to_state(State::Len8);
215                    }
216                },
217                State::Len2 => {
218                    if self.parse_head(input) {
219                        break;
220                    }
221                    self.data_len = u16::from_be_bytes(
222                        self.head[0..2].try_into().unwrap()
223                    ) as usize;
224                    if self.is_masked {
225                        self.to_state(State::Mask);
226                    }
227                    else {
228                        self.to_state(State::Data);
229                    }
230                },
231                State::Len8 => {
232                    if self.parse_head(input) {
233                        break;
234                    }
235                    self.data_len = u64::from_be_bytes(
236                        self.head[0..8].try_into().unwrap()
237                    ) as usize;
238                    if self.is_masked {
239                        self.to_state(State::Mask);
240                    }
241                    else {
242                        self.to_state(State::Data);
243                    }
244                },
245                State::Mask => {
246                    if self.parse_head(input) {
247                        break;
248                    }
249                    self.to_state(State::Data);
250                },
251                State::Data => {
252                    if self.is_masked {
253                        while self.data.len() < self.data_len && self.input_read < input.len() {
254                            self.data.push(input[self.input_read] ^ self.head[self.mask_counter]);
255                            self.mask_counter = (self.mask_counter + 1) & 3;
256                            self.input_read += 1;
257                        }
258                    }
259                    else {
260                        while self.data.len() < self.data_len && self.input_read < input.len() {
261                            self.data.push(input[self.input_read]);
262                            self.input_read += 1;
263                        }
264                    }
265                    if self.data.len() < self.data_len { // not enough data yet
266                        break;
267                    }
268                    else {
269                        if self.is_ping {
270                            result(Ok(WebSocketMessage::Ping(&self.data)));
271                        }
272                        else if self.is_pong {
273                            result(Ok(WebSocketMessage::Pong(&self.data)));
274                        }
275                        else if self.is_text{
276                            if let Ok(text) = std::str::from_utf8(&self.data){
277                                result(Ok(WebSocketMessage::Text(text)));
278                            }
279                            else{
280                                result(Err(WebSocketError::TextNotUTF8(&self.data)))
281                            }
282                        }
283                        else{
284                            result(Ok(WebSocketMessage::Binary(&self.data)));
285                        }
286                        
287                        self.to_state(State::Opcode);
288                    }
289                },
290            }
291        }
292    }
293    
294}
295
296impl Default for WebSocket {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301