makepad_http/
websocket.rs

1
2use crate::digest::{Sha1, base64_encode};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5#[derive(Debug, PartialEq)]
6enum State {
7    Opcode,
8    Len1,
9    Len2,
10    Len8,
11    Data,
12    Mask
13}
14
15impl State {
16    fn head_expected(&self) -> usize {
17        match self {
18            State::Opcode => 1,
19            State::Len1 => 1,
20            State::Len2 => 2,
21            State::Len8 => 8,
22            State::Data => 0,
23            State::Mask => 4
24        }
25    }
26}
27
28pub struct ServerWebSocket {
29    head: [u8; 8],
30    head_expected: usize,
31    head_written: usize,
32    data: Vec<u8>,
33    data_len: usize,
34    input_read: usize,
35    mask_counter: usize,
36    is_ping: bool,
37    is_pong: bool,
38    is_partial: bool,
39    is_text: bool,
40    is_masked: bool,
41    state: State
42}
43
44pub enum ServerWebSocketMessage<'a> {
45    Ping(&'a [u8]),
46    Pong(&'a [u8]),
47    Text(&'a str),
48    Binary(&'a [u8]),
49    Close
50}
51
52#[derive(Debug)]
53pub enum ServerWebSocketError<'a> {
54    OpcodeNotSupported(u8),
55    TextNotUTF8(&'a [u8]),
56}
57
58pub const SERVER_WEB_SOCKET_PING_MESSAGE:[u8;2] = [128 | 9,0];
59pub const SERVER_WEB_SOCKET_PONG_MESSAGE:[u8;2] = [128 | 10,0];
60
61pub enum ServerWebSocketMessageFormat {
62    Binary,
63    Text
64}
65
66pub struct ServerWebSocketMessageHeader {
67    pub format: ServerWebSocketMessageFormat,
68    len: usize,
69    masked: bool,
70    data: [u8;14]
71}
72
73impl ServerWebSocketMessageHeader {
74    pub fn from_len(len: usize, format: ServerWebSocketMessageFormat, masked: bool)->Self{
75        let mut data = [0u8;14];
76        
77        match format {
78            ServerWebSocketMessageFormat::Binary => data[0] = 128 | 2,
79            ServerWebSocketMessageFormat::Text => data[0] = 128 | 1,
80        }
81
82        if masked {
83            data[1] = 128;
84        } else {
85            data[1] = 0;
86        }
87
88        let header_len;
89        if len < 126{
90            data[1] |= len as u8;
91            header_len = 2;
92        }
93        else if len < 65536{
94            data[1] |= 126;
95            let bytes = &(len as u16).to_be_bytes();
96            for (i, &byte) in bytes.iter().enumerate() {
97                data[i + 2] = byte;
98            }
99            header_len = 4;
100        }
101        else{
102            data[1] |= 127;
103            let bytes = &(len as u64).to_be_bytes();
104            for (i, &byte) in bytes.iter().enumerate() {
105                data[i + 2] = byte;
106            }
107            header_len = 10;
108        }
109
110        if masked {
111            for i in header_len..header_len + 4 {
112                data[i] = Self::random_byte();
113            }
114            return ServerWebSocketMessageHeader{len: header_len + 4, data, format, masked}
115        } else {
116            return ServerWebSocketMessageHeader{len: header_len, data, format, masked}
117        }
118    }
119    
120    pub fn as_slice(&self)->&[u8]{
121        &self.data[0..self.len]
122    }
123
124    pub fn mask(&mut self)->Option<&[u8]> {
125        if self.masked {
126            match self.len {
127                6 => Some(&self.data[2..6]),
128                10 => Some(&self.data[6..10]),
129                14 => Some(&self.data[10..14]),
130                _ => None
131            }
132        } else {
133            None
134        }
135    }
136
137    // TODO Improve this using a proper random number generator
138    fn random_byte() -> u8 {
139        let num = SystemTime::now().duration_since(UNIX_EPOCH).expect("duration_since failed").subsec_nanos();
140        num as u8
141    }
142}
143
144impl ServerWebSocket {
145    pub fn new() -> Self {
146        Self {
147            head: [0u8; 8],
148            head_expected: 1,
149            head_written: 0,
150            data: Vec::new(),
151            data_len: 0,
152            input_read: 0,
153            mask_counter: 0,
154            is_ping: false,
155            is_pong: false,
156            is_masked: false,
157            is_partial: false,
158            is_text: false,
159            state: State::Opcode
160        }
161    }
162    
163    pub fn message_to_frame(msg:ServerWebSocketMessage) ->Vec<u8>
164    {
165        match &msg{
166            ServerWebSocketMessage::Text(data)=>{
167                let header = ServerWebSocketMessageHeader::from_len(data.len(), ServerWebSocketMessageFormat::Text, false);
168                ServerWebSocket::build_message(header, &data.to_string().into_bytes())
169            }
170            ServerWebSocketMessage::Binary(data)=>{
171                let header = ServerWebSocketMessageHeader::from_len(data.len(), ServerWebSocketMessageFormat::Binary, false);
172                ServerWebSocket::build_message(header, &data)
173            }
174            _=>panic!()
175        }  
176    }
177
178    pub fn create_upgrade_response(key: &str) -> String {
179        let to_hash = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", key);
180        let mut sha1 = Sha1::new();
181        sha1.update(to_hash.as_bytes());
182        let out_bytes = sha1.finalise();
183        let base64 = base64_encode(&out_bytes);
184        let response_ack = format!(
185            "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {}\r\n\r\n",
186            base64
187        );
188        response_ack
189    }
190
191    pub fn build_message(mut header: ServerWebSocketMessageHeader, data: &[u8])->Vec<u8>{
192        let mut frame = header.as_slice().to_vec();
193        if let Some(mask) = header.mask(){
194            for (i, &byte) in data.iter().enumerate() {
195                frame.push(byte ^ mask[i % 4]);
196            }
197        } else {
198            frame.extend_from_slice(data);
199        }
200        frame
201    }
202    
203    fn parse_head(&mut self, input: &[u8]) -> bool {
204        while self.head_expected > 0
205            && self.input_read < input.len()
206            && self.head_written < self.head.len()
207        {
208            self.head[self.head_written] = input[self.input_read];
209            self.input_read += 1;
210            self.head_written += 1;
211            self.head_expected -= 1;
212        }
213        self.head_expected != 0
214    }
215    
216    fn to_state(&mut self, state: State) {
217        match state {
218            State::Data => {
219                self.mask_counter = 0;
220                self.data.clear();
221            }
222            State::Opcode => {
223                self.is_ping = false;
224                self.is_pong = false;
225                self.is_partial = false;
226                self.is_text = false;
227                self.is_masked = false;
228            },
229            _ => ()
230        }
231        self.head_written = 0;
232        self.head_expected = state.head_expected();
233        self.state = state;
234    }
235    
236    pub fn parse<F>(&mut self, input: &[u8], mut result: F) where F: FnMut(Result<ServerWebSocketMessage, ServerWebSocketError>){
237        self.input_read = 0;
238        // parse a header
239        loop {
240            match self.state {
241                State::Opcode => {
242                    if self.parse_head(input) {
243                        break;
244                    }
245                    let opcode = self.head[0] & 15;
246                    
247                    if opcode <= 2 {
248                        self.is_partial = (self.head[0] & 128) != 0;
249                        self.is_text = opcode == 1;
250                        self.to_state(State::Len1);
251                    }
252                    else if opcode == 8 {
253                        result(Ok(ServerWebSocketMessage::Close));
254                        break;
255                    }
256                    else if opcode == 9 {
257                        self.is_ping = true;
258                        self.to_state(State::Len1);
259                    }
260                    else if opcode == 10 {
261                        self.is_pong = true;
262                        self.to_state(State::Len1);
263                    }
264                    else {
265                        result(Err(ServerWebSocketError::OpcodeNotSupported(opcode)));
266                        break;
267                    }
268                },
269                State::Len1 => {
270                    if self.parse_head(input) {
271                        break;
272                    }
273                    self.is_masked = (self.head[0] & 128) > 0;
274                    let len_type = self.head[0] & 127;
275                    if len_type < 126 {
276                        self.data_len = len_type as usize;
277                        if !self.is_masked {
278                            self.to_state(State::Data);
279                        }
280                        else {
281                            self.to_state(State::Mask);
282                        }
283                    }
284                    else if len_type == 126 {
285                        self.to_state(State::Len2);
286                    }
287                    else if len_type == 127 {
288                        self.to_state(State::Len8);
289                    }
290                },
291                State::Len2 => {
292                    if self.parse_head(input) {
293                        break;
294                    }
295                    self.data_len = u16::from_be_bytes(
296                        self.head[0..2].try_into().unwrap()
297                    ) as usize;
298                    if self.is_masked {
299                        self.to_state(State::Mask);
300                    }
301                    else {
302                        self.to_state(State::Data);
303                    }
304                },
305                State::Len8 => {
306                    if self.parse_head(input) {
307                        break;
308                    }
309                    self.data_len = u64::from_be_bytes(
310                        self.head[0..8].try_into().unwrap()
311                    ) as usize;
312                    if self.is_masked {
313                        self.to_state(State::Mask);
314                    }
315                    else {
316                        self.to_state(State::Data);
317                    }
318                },
319                State::Mask => {
320                    if self.parse_head(input) {
321                        break;
322                    }
323                    self.to_state(State::Data);
324                },
325                State::Data => {
326                    if self.is_masked {
327                        while self.data.len() < self.data_len && self.input_read < input.len() {
328                            self.data.push(input[self.input_read] ^ self.head[self.mask_counter]);
329                            self.mask_counter = (self.mask_counter + 1) & 3;
330                            self.input_read += 1;
331                        }
332                    }
333                    else {
334                        while self.data.len() < self.data_len && self.input_read < input.len() {
335                            self.data.push(input[self.input_read]);
336                            self.input_read += 1;
337                        }
338                    }
339                    if self.data.len() < self.data_len { // not enough data yet
340                        break;
341                    }
342                    else {
343                        if self.is_ping {
344                            result(Ok(ServerWebSocketMessage::Ping(&self.data)));
345                        }
346                        else if self.is_pong {
347                            result(Ok(ServerWebSocketMessage::Pong(&self.data)));
348                        }
349                        else if self.is_text{
350                            if let Ok(text) = std::str::from_utf8(&self.data){
351                                result(Ok(ServerWebSocketMessage::Text(text)));
352                            }
353                            else{
354                                result(Err(ServerWebSocketError::TextNotUTF8(&self.data)))
355                            }
356                        }
357                        else{
358                            result(Ok(ServerWebSocketMessage::Binary(&self.data)));
359                        }
360                        
361                        self.to_state(State::Opcode);
362                    }
363                },
364            }
365        }
366    }
367    
368}
369
370impl Default for ServerWebSocket {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375