1
2use crate::digest::{Sha1, base64_encode};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5#[derive(Debug, PartialEq)]
6enum State {
7 Opcode,
8 Len1,
9 Len2,
10 Len8,
11 Data,
12 Mask
13}
14
15impl State {
16 fn head_expected(&self) -> usize {
17 match self {
18 State::Opcode => 1,
19 State::Len1 => 1,
20 State::Len2 => 2,
21 State::Len8 => 8,
22 State::Data => 0,
23 State::Mask => 4
24 }
25 }
26}
27
28pub struct ServerWebSocket {
29 head: [u8; 8],
30 head_expected: usize,
31 head_written: usize,
32 data: Vec<u8>,
33 data_len: usize,
34 input_read: usize,
35 mask_counter: usize,
36 is_ping: bool,
37 is_pong: bool,
38 is_partial: bool,
39 is_text: bool,
40 is_masked: bool,
41 state: State
42}
43
44pub enum ServerWebSocketMessage<'a> {
45 Ping(&'a [u8]),
46 Pong(&'a [u8]),
47 Text(&'a str),
48 Binary(&'a [u8]),
49 Close
50}
51
52#[derive(Debug)]
53pub enum ServerWebSocketError<'a> {
54 OpcodeNotSupported(u8),
55 TextNotUTF8(&'a [u8]),
56}
57
58pub const SERVER_WEB_SOCKET_PING_MESSAGE:[u8;2] = [128 | 9,0];
59pub const SERVER_WEB_SOCKET_PONG_MESSAGE:[u8;2] = [128 | 10,0];
60
61pub enum ServerWebSocketMessageFormat {
62 Binary,
63 Text
64}
65
66pub struct ServerWebSocketMessageHeader {
67 pub format: ServerWebSocketMessageFormat,
68 len: usize,
69 masked: bool,
70 data: [u8;14]
71}
72
73impl ServerWebSocketMessageHeader {
74 pub fn from_len(len: usize, format: ServerWebSocketMessageFormat, masked: bool)->Self{
75 let mut data = [0u8;14];
76
77 match format {
78 ServerWebSocketMessageFormat::Binary => data[0] = 128 | 2,
79 ServerWebSocketMessageFormat::Text => data[0] = 128 | 1,
80 }
81
82 if masked {
83 data[1] = 128;
84 } else {
85 data[1] = 0;
86 }
87
88 let header_len;
89 if len < 126{
90 data[1] |= len as u8;
91 header_len = 2;
92 }
93 else if len < 65536{
94 data[1] |= 126;
95 let bytes = &(len as u16).to_be_bytes();
96 for (i, &byte) in bytes.iter().enumerate() {
97 data[i + 2] = byte;
98 }
99 header_len = 4;
100 }
101 else{
102 data[1] |= 127;
103 let bytes = &(len as u64).to_be_bytes();
104 for (i, &byte) in bytes.iter().enumerate() {
105 data[i + 2] = byte;
106 }
107 header_len = 10;
108 }
109
110 if masked {
111 for i in header_len..header_len + 4 {
112 data[i] = Self::random_byte();
113 }
114 return ServerWebSocketMessageHeader{len: header_len + 4, data, format, masked}
115 } else {
116 return ServerWebSocketMessageHeader{len: header_len, data, format, masked}
117 }
118 }
119
120 pub fn as_slice(&self)->&[u8]{
121 &self.data[0..self.len]
122 }
123
124 pub fn mask(&mut self)->Option<&[u8]> {
125 if self.masked {
126 match self.len {
127 6 => Some(&self.data[2..6]),
128 10 => Some(&self.data[6..10]),
129 14 => Some(&self.data[10..14]),
130 _ => None
131 }
132 } else {
133 None
134 }
135 }
136
137 fn random_byte() -> u8 {
139 let num = SystemTime::now().duration_since(UNIX_EPOCH).expect("duration_since failed").subsec_nanos();
140 num as u8
141 }
142}
143
144impl ServerWebSocket {
145 pub fn new() -> Self {
146 Self {
147 head: [0u8; 8],
148 head_expected: 1,
149 head_written: 0,
150 data: Vec::new(),
151 data_len: 0,
152 input_read: 0,
153 mask_counter: 0,
154 is_ping: false,
155 is_pong: false,
156 is_masked: false,
157 is_partial: false,
158 is_text: false,
159 state: State::Opcode
160 }
161 }
162
163 pub fn message_to_frame(msg:ServerWebSocketMessage) ->Vec<u8>
164 {
165 match &msg{
166 ServerWebSocketMessage::Text(data)=>{
167 let header = ServerWebSocketMessageHeader::from_len(data.len(), ServerWebSocketMessageFormat::Text, false);
168 ServerWebSocket::build_message(header, &data.to_string().into_bytes())
169 }
170 ServerWebSocketMessage::Binary(data)=>{
171 let header = ServerWebSocketMessageHeader::from_len(data.len(), ServerWebSocketMessageFormat::Binary, false);
172 ServerWebSocket::build_message(header, &data)
173 }
174 _=>panic!()
175 }
176 }
177
178 pub fn create_upgrade_response(key: &str) -> String {
179 let to_hash = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", key);
180 let mut sha1 = Sha1::new();
181 sha1.update(to_hash.as_bytes());
182 let out_bytes = sha1.finalise();
183 let base64 = base64_encode(&out_bytes);
184 let response_ack = format!(
185 "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {}\r\n\r\n",
186 base64
187 );
188 response_ack
189 }
190
191 pub fn build_message(mut header: ServerWebSocketMessageHeader, data: &[u8])->Vec<u8>{
192 let mut frame = header.as_slice().to_vec();
193 if let Some(mask) = header.mask(){
194 for (i, &byte) in data.iter().enumerate() {
195 frame.push(byte ^ mask[i % 4]);
196 }
197 } else {
198 frame.extend_from_slice(data);
199 }
200 frame
201 }
202
203 fn parse_head(&mut self, input: &[u8]) -> bool {
204 while self.head_expected > 0
205 && self.input_read < input.len()
206 && self.head_written < self.head.len()
207 {
208 self.head[self.head_written] = input[self.input_read];
209 self.input_read += 1;
210 self.head_written += 1;
211 self.head_expected -= 1;
212 }
213 self.head_expected != 0
214 }
215
216 fn to_state(&mut self, state: State) {
217 match state {
218 State::Data => {
219 self.mask_counter = 0;
220 self.data.clear();
221 }
222 State::Opcode => {
223 self.is_ping = false;
224 self.is_pong = false;
225 self.is_partial = false;
226 self.is_text = false;
227 self.is_masked = false;
228 },
229 _ => ()
230 }
231 self.head_written = 0;
232 self.head_expected = state.head_expected();
233 self.state = state;
234 }
235
236 pub fn parse<F>(&mut self, input: &[u8], mut result: F) where F: FnMut(Result<ServerWebSocketMessage, ServerWebSocketError>){
237 self.input_read = 0;
238 loop {
240 match self.state {
241 State::Opcode => {
242 if self.parse_head(input) {
243 break;
244 }
245 let opcode = self.head[0] & 15;
246
247 if opcode <= 2 {
248 self.is_partial = (self.head[0] & 128) != 0;
249 self.is_text = opcode == 1;
250 self.to_state(State::Len1);
251 }
252 else if opcode == 8 {
253 result(Ok(ServerWebSocketMessage::Close));
254 break;
255 }
256 else if opcode == 9 {
257 self.is_ping = true;
258 self.to_state(State::Len1);
259 }
260 else if opcode == 10 {
261 self.is_pong = true;
262 self.to_state(State::Len1);
263 }
264 else {
265 result(Err(ServerWebSocketError::OpcodeNotSupported(opcode)));
266 break;
267 }
268 },
269 State::Len1 => {
270 if self.parse_head(input) {
271 break;
272 }
273 self.is_masked = (self.head[0] & 128) > 0;
274 let len_type = self.head[0] & 127;
275 if len_type < 126 {
276 self.data_len = len_type as usize;
277 if !self.is_masked {
278 self.to_state(State::Data);
279 }
280 else {
281 self.to_state(State::Mask);
282 }
283 }
284 else if len_type == 126 {
285 self.to_state(State::Len2);
286 }
287 else if len_type == 127 {
288 self.to_state(State::Len8);
289 }
290 },
291 State::Len2 => {
292 if self.parse_head(input) {
293 break;
294 }
295 self.data_len = u16::from_be_bytes(
296 self.head[0..2].try_into().unwrap()
297 ) as usize;
298 if self.is_masked {
299 self.to_state(State::Mask);
300 }
301 else {
302 self.to_state(State::Data);
303 }
304 },
305 State::Len8 => {
306 if self.parse_head(input) {
307 break;
308 }
309 self.data_len = u64::from_be_bytes(
310 self.head[0..8].try_into().unwrap()
311 ) as usize;
312 if self.is_masked {
313 self.to_state(State::Mask);
314 }
315 else {
316 self.to_state(State::Data);
317 }
318 },
319 State::Mask => {
320 if self.parse_head(input) {
321 break;
322 }
323 self.to_state(State::Data);
324 },
325 State::Data => {
326 if self.is_masked {
327 while self.data.len() < self.data_len && self.input_read < input.len() {
328 self.data.push(input[self.input_read] ^ self.head[self.mask_counter]);
329 self.mask_counter = (self.mask_counter + 1) & 3;
330 self.input_read += 1;
331 }
332 }
333 else {
334 while self.data.len() < self.data_len && self.input_read < input.len() {
335 self.data.push(input[self.input_read]);
336 self.input_read += 1;
337 }
338 }
339 if self.data.len() < self.data_len { break;
341 }
342 else {
343 if self.is_ping {
344 result(Ok(ServerWebSocketMessage::Ping(&self.data)));
345 }
346 else if self.is_pong {
347 result(Ok(ServerWebSocketMessage::Pong(&self.data)));
348 }
349 else if self.is_text{
350 if let Ok(text) = std::str::from_utf8(&self.data){
351 result(Ok(ServerWebSocketMessage::Text(text)));
352 }
353 else{
354 result(Err(ServerWebSocketError::TextNotUTF8(&self.data)))
355 }
356 }
357 else{
358 result(Ok(ServerWebSocketMessage::Binary(&self.data)));
359 }
360
361 self.to_state(State::Opcode);
362 }
363 },
364 }
365 }
366 }
367
368}
369
370impl Default for ServerWebSocket {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375