makepad_http/
websocket.rs1use std::convert::TryInto;
2use crate::digest::{Sha1, base64_encode};
3
4#[derive(Debug, PartialEq)]
5enum State {
6 Opcode,
7 Len1,
8 Len2,
9 Len8,
10 Data,
11 Mask
12}
13
14impl State {
15 fn head_expected(&self) -> usize {
16 match self {
17 State::Opcode => 1,
18 State::Len1 => 1,
19 State::Len2 => 2,
20 State::Len8 => 8,
21 State::Data => 0,
22 State::Mask => 4
23 }
24 }
25}
26
27pub struct WebSocket {
28 head: [u8; 8],
29 head_expected: usize,
30 head_written: usize,
31 data: Vec<u8>,
32 data_len: usize,
33 input_read: usize,
34 mask_counter: usize,
35 is_ping: bool,
36 is_pong: bool,
37 is_partial: bool,
38 is_text: bool,
39 is_masked: bool,
40 state: State
41}
42
43pub enum WebSocketMessage<'a> {
44 Ping(&'a [u8]),
45 Pong(&'a [u8]),
46 Text(&'a str),
47 Binary(&'a [u8]),
48 Close
49}
50
51#[derive(Debug)]
52pub enum WebSocketError<'a> {
53 OpcodeNotSupported(u8),
54 TextNotUTF8(&'a [u8]),
55}
56
57pub const PING_MESSAGE:[u8;2] = [128 | 9,0];
58pub const PONG_MESSAGE:[u8;2] = [128 | 10,0];
59
60pub struct BinaryMessageHeader{
61 len: usize,
62 data:[u8;10]
63}
64
65impl BinaryMessageHeader{
66 pub fn from_len(len:usize)->Self{
67 let mut data = [0u8;10];
68
69 data[0] = 128 | 2; if len < 126{
72 data[1] = len as u8;
73 BinaryMessageHeader{len:2, data}
74 }
75 else if len < 65536{
76 data[1] = 126;
77 let bytes = &(len as u16).to_be_bytes();
78 for (i, &byte) in bytes.iter().enumerate() {
79 data[i + 2] = byte;
80 }
81 return BinaryMessageHeader{len:4, data}
82 }
83 else{
84 data[1] = 127;
85 let bytes = &(len as u64).to_be_bytes();
86 for (i, &byte) in bytes.iter().enumerate() {
87 data[i + 2] = byte;
88 }
89 return BinaryMessageHeader{len:10, data}
90 }
91 }
92
93 pub fn as_slice(&self)->&[u8]{
94 &self.data[0..self.len]
95 }
96}
97
98impl WebSocket {
99 pub fn new() -> Self {
100 Self {
101 head: [0u8; 8],
102 head_expected: 1,
103 head_written: 0,
104 data: Vec::new(),
105 data_len: 0,
106 input_read: 0,
107 mask_counter: 0,
108 is_ping: false,
109 is_pong: false,
110 is_masked: false,
111 is_partial: false,
112 is_text: false,
113 state: State::Opcode
114 }
115 }
116
117 pub fn create_upgrade_response(key: &str) -> String {
118 let to_hash = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", key);
119 let mut sha1 = Sha1::new();
120 sha1.update(to_hash.as_bytes());
121 let out_bytes = sha1.finalise();
122 let base64 = base64_encode(&out_bytes);
123 let response_ack = format!(
124 "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {}\r\n\r\n",
125 base64
126 );
127 response_ack
128 }
129
130 fn parse_head(&mut self, input: &[u8]) -> bool {
131 while self.head_expected > 0
132 && self.input_read < input.len()
133 && self.head_written < self.head.len()
134 {
135 self.head[self.head_written] = input[self.input_read];
136 self.input_read += 1;
137 self.head_written += 1;
138 self.head_expected -= 1;
139 }
140 self.head_expected != 0
141 }
142
143 fn to_state(&mut self, state: State) {
144 match state {
145 State::Data => {
146 self.mask_counter = 0;
147 self.data.clear();
148 }
149 State::Opcode => {
150 self.is_ping = false;
151 self.is_pong = false;
152 self.is_partial = false;
153 self.is_text = false;
154 self.is_masked = false;
155 },
156 _ => ()
157 }
158 self.head_written = 0;
159 self.head_expected = state.head_expected();
160 self.state = state;
161 }
162
163 pub fn parse<F>(&mut self, input: &[u8], mut result: F) where F: FnMut(Result<WebSocketMessage, WebSocketError>){
164 self.input_read = 0;
165 loop {
167 match self.state {
168 State::Opcode => {
169 if self.parse_head(input) {
170 break;
171 }
172 let opcode = self.head[0] & 15;
173 if opcode <= 2 {
174 self.is_partial = (self.head[0] & 128) != 0;
175 self.is_text = opcode == 1;
176 self.to_state(State::Len1);
177 }
178 else if opcode == 8 {
179 result(Ok(WebSocketMessage::Close));
180 break;
181 }
182 else if opcode == 9 {
183 self.is_ping = true;
184 self.to_state(State::Len1);
185 }
186 else if opcode == 10 {
187 self.is_pong = true;
188 self.to_state(State::Len1);
189 }
190 else {
191 result(Err(WebSocketError::OpcodeNotSupported(opcode)));
192 break;
193 }
194 },
195 State::Len1 => {
196 if self.parse_head(input) {
197 break;
198 }
199 self.is_masked = (self.head[0] & 128) > 0;
200 let len_type = self.head[0] & 127;
201 if len_type < 126 {
202 self.data_len = len_type as usize;
203 if !self.is_masked {
204 self.to_state(State::Data);
205 }
206 else {
207 self.to_state(State::Mask);
208 }
209 }
210 else if len_type == 126 {
211 self.to_state(State::Len2);
212 }
213 else if len_type == 127 {
214 self.to_state(State::Len8);
215 }
216 },
217 State::Len2 => {
218 if self.parse_head(input) {
219 break;
220 }
221 self.data_len = u16::from_be_bytes(
222 self.head[0..2].try_into().unwrap()
223 ) as usize;
224 if self.is_masked {
225 self.to_state(State::Mask);
226 }
227 else {
228 self.to_state(State::Data);
229 }
230 },
231 State::Len8 => {
232 if self.parse_head(input) {
233 break;
234 }
235 self.data_len = u64::from_be_bytes(
236 self.head[0..8].try_into().unwrap()
237 ) as usize;
238 if self.is_masked {
239 self.to_state(State::Mask);
240 }
241 else {
242 self.to_state(State::Data);
243 }
244 },
245 State::Mask => {
246 if self.parse_head(input) {
247 break;
248 }
249 self.to_state(State::Data);
250 },
251 State::Data => {
252 if self.is_masked {
253 while self.data.len() < self.data_len && self.input_read < input.len() {
254 self.data.push(input[self.input_read] ^ self.head[self.mask_counter]);
255 self.mask_counter = (self.mask_counter + 1) & 3;
256 self.input_read += 1;
257 }
258 }
259 else {
260 while self.data.len() < self.data_len && self.input_read < input.len() {
261 self.data.push(input[self.input_read]);
262 self.input_read += 1;
263 }
264 }
265 if self.data.len() < self.data_len { break;
267 }
268 else {
269 if self.is_ping {
270 result(Ok(WebSocketMessage::Ping(&self.data)));
271 }
272 else if self.is_pong {
273 result(Ok(WebSocketMessage::Pong(&self.data)));
274 }
275 else if self.is_text{
276 if let Ok(text) = std::str::from_utf8(&self.data){
277 result(Ok(WebSocketMessage::Text(text)));
278 }
279 else{
280 result(Err(WebSocketError::TextNotUTF8(&self.data)))
281 }
282 }
283 else{
284 result(Ok(WebSocketMessage::Binary(&self.data)));
285 }
286
287 self.to_state(State::Opcode);
288 }
289 },
290 }
291 }
292 }
293
294}
295
296impl Default for WebSocket {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301